Skip to content

Commit

Permalink
Started implementation of random variables with PyTorch backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw committed Nov 10, 2024
1 parent a570dbf commit 85d6080
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 2 deletions.
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.subtensor
import pytensor.link.pytorch.dispatch.blockwise
import pytensor.link.pytorch.dispatch.random
# isort: on
7 changes: 5 additions & 2 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@


@singledispatch
def pytorch_typify(data, **kwargs):
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
def pytorch_typify(data, dtype=None, **kwargs):
if dtype is None:
return data
else:
return torch.tensor(data, dtype=dtype)

Check warning on line 31 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L31

Added line #L31 was not covered by tests


@pytorch_typify.register(np.ndarray)
Expand Down
62 changes: 62 additions & 0 deletions pytensor/link/pytorch/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from functools import singledispatch

import torch
from numpy.random import Generator

import pytensor.tensor.random.basic as ptr
from pytensor.graph import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
from pytensor.tensor.type_other import NoneTypeT


@pytorch_typify.register(Generator)
def pytorch_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
state["pytorch_state"] = torch.manual_seed(123).get_state() # XXX: replace
return state


@pytorch_funcify.register(ptr.RandomVariable)
def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
rv = node.outputs[1]
out_dtype = rv.type.dtype
static_shape = rv.type.shape
batch_ndim = op.batch_ndim(node)

# Try to pass static size directly to JAX
static_size = static_shape[:batch_ndim]
if None in static_size:
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param = op.size_param(node)

Check warning on line 31 in pytensor/link/pytorch/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/random.py#L31

Added line #L31 was not covered by tests
if isinstance(size_param, Constant) and not isinstance(
size_param.type, NoneTypeT
):
static_size = tuple(size_param.data)

Check warning on line 35 in pytensor/link/pytorch/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/random.py#L35

Added line #L35 was not covered by tests

def sample_fn(rng, size, *parameters):
return pytorch_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)

return sample_fn


@singledispatch
def pytorch_sample_fn(op, node):
name = op.name
raise NotImplementedError(

Check warning on line 48 in pytensor/link/pytorch/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/random.py#L47-L48

Added lines #L47 - L48 were not covered by tests
f"No PyTorch implementation for the given distribution: {name}"
)


@pytorch_sample_fn.register(ptr.BernoulliRV)
def pytorch_sample_fn_bernoulli(op, node):
def sample_fn(rng, size, dtype, p):
# XXX replace
state_ = rng["pytorch_state"]

Check warning on line 57 in pytensor/link/pytorch/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/random.py#L57

Added line #L57 was not covered by tests
gen = torch.Generator().set_state(state_)
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen)
return (rng, sample)

return sample_fn
9 changes: 9 additions & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from numpy.random import Generator, RandomState

from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker

Expand Down Expand Up @@ -28,9 +30,16 @@ def jit_compile(self, fn):
return torch.compile(fn)

def create_thunk_inputs(self, storage_map):
from pytensor.link.pytorch.dispatch import pytorch_typify

thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState | Generator):
new_value = pytorch_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
sinput[0] = new_value
thunk_inputs.append(sinput)

return thunk_inputs
20 changes: 20 additions & 0 deletions tests/link/pytorch/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import pytest

import pytensor.tensor as pt
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import shared
from tests.link.pytorch.test_basic import pytorch_mode


torch = pytest.importorskip("torch")


@pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size):
rng = shared(np.random.default_rng(123))

g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
g_fn = function([], g, mode=pytorch_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

0 comments on commit 85d6080

Please sign in to comment.