Skip to content

Commit

Permalink
Proposal to infer Torch's generator state from the Numpy one
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw committed Dec 8, 2024
1 parent ff973be commit 1c8dc80
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 27 deletions.
47 changes: 23 additions & 24 deletions pytensor/link/pytorch/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,27 @@
from numpy.random import Generator

import pytensor.tensor.random.basic as ptr
from pytensor.graph import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
from pytensor.tensor.type_other import NoneTypeT


@pytorch_typify.register(Generator)
def pytorch_typify_Generator(rng, **kwargs):
# XXX: Check if there is a better way.
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
state = rng.__getstate__()
state["pytorch_state"] = torch.manual_seed(123).get_state() # XXX: replace
seed = torch.from_numpy(rng.integers([2**32]))
state["pytorch_gen"] = torch.manual_seed(seed)
return state


@pytorch_funcify.register(ptr.RandomVariable)
def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
rv = node.outputs[1]
out_dtype = rv.type.dtype
static_shape = rv.type.shape
batch_ndim = op.batch_ndim(node)

# Try to pass static size directly to JAX
static_size = static_shape[:batch_ndim]
if None in static_size:
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param = op.size_param(node)
if isinstance(size_param, Constant) and not isinstance(
size_param.type, NoneTypeT
):
static_size = tuple(size_param.data)
shape = rv.type.shape

def sample_fn(rng, size, *parameters):
return pytorch_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)
return pytorch_sample_fn(op, node=node)(rng, shape, out_dtype, *parameters)

return sample_fn

Expand All @@ -53,10 +40,22 @@ def pytorch_sample_fn(op, node):
@pytorch_sample_fn.register(ptr.BernoulliRV)
def pytorch_sample_fn_bernoulli(op, node):
def sample_fn(rng, size, dtype, p):
# XXX replace
state_ = rng["pytorch_state"]
gen = torch.Generator().set_state(state_)
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen)
return (rng, sample)
gen = rng["pytorch_gen"]

Check warning on line 43 in pytensor/link/pytorch/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/random.py#L43

Added line #L43 was not covered by tests
sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen)
return (gen, sample)

return sample_fn


@pytorch_sample_fn.register(ptr.BinomialRV)
def pytorch_sample_fn_binomial(op, node):
def sample_fn(rng, size, dtype, n, p):
gen = rng["pytorch_gen"]

Check warning on line 53 in pytensor/link/pytorch/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/random.py#L53

Added line #L53 was not covered by tests
sample = torch.binomial(
torch.broadcast_to(n.to(p.dtype), size),
torch.broadcast_to(p, size),
generator=gen,
)
return (gen, sample)

return sample_fn
51 changes: 48 additions & 3 deletions tests/link/pytorch/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,56 @@
torch = pytest.importorskip("torch")


@pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size):
@pytest.mark.parametrize(
"size,p",
[
((1000,), 0.5),
(
(
1000,
4,
),
0.5,
),
((10, 2), np.array([0.5, 0.3])),
((1000, 10, 2), np.array([0.5, 0.3])),
],
)
def test_random_bernoulli(size, p):
rng = shared(np.random.default_rng(123))

g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
g = pt.random.bernoulli(p, size=size, rng=rng)
g_fn = function([], g, mode=pytorch_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)


@pytest.mark.parametrize(
"size,n,p",
[
((1000,), 10, 0.5),
(
(
1000,
4,
),
10,
0.5,
),
(
(
1000,
2,
),
np.array([10, 40]),
np.array([0.5, 0.3]),
),
],
)
def test_binomial(n, p, size):
rng = shared(np.random.default_rng(123))
g = pt.random.binomial(n, p, size=size, rng=rng)
g_fn = function([], g, mode=pytorch_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)

0 comments on commit 1c8dc80

Please sign in to comment.