From 1c8dc803a177217142168e6b3bcd7350bdd39c34 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Sun, 8 Dec 2024 13:16:44 +0100 Subject: [PATCH] Proposal to infer Torch's generator state from the Numpy one --- pytensor/link/pytorch/dispatch/random.py | 47 +++++++++++----------- tests/link/pytorch/test_random.py | 51 ++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 27 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/random.py b/pytensor/link/pytorch/dispatch/random.py index 4e03fc4338..57eb8275cf 100644 --- a/pytensor/link/pytorch/dispatch/random.py +++ b/pytensor/link/pytorch/dispatch/random.py @@ -4,15 +4,16 @@ from numpy.random import Generator import pytensor.tensor.random.basic as ptr -from pytensor.graph import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify -from pytensor.tensor.type_other import NoneTypeT @pytorch_typify.register(Generator) def pytorch_typify_Generator(rng, **kwargs): + # XXX: Check if there is a better way. + # Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp) state = rng.__getstate__() - state["pytorch_state"] = torch.manual_seed(123).get_state() # XXX: replace + seed = torch.from_numpy(rng.integers([2**32])) + state["pytorch_gen"] = torch.manual_seed(seed) return state @@ -20,24 +21,10 @@ def pytorch_typify_Generator(rng, **kwargs): def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): rv = node.outputs[1] out_dtype = rv.type.dtype - static_shape = rv.type.shape - batch_ndim = op.batch_ndim(node) - - # Try to pass static size directly to JAX - static_size = static_shape[:batch_ndim] - if None in static_size: - # Sometimes size can be constant folded during rewrites, - # without the RandomVariable node being updated with new static types - size_param = op.size_param(node) - if isinstance(size_param, Constant) and not isinstance( - size_param.type, NoneTypeT - ): - static_size = tuple(size_param.data) + shape = rv.type.shape def sample_fn(rng, size, *parameters): - return pytorch_sample_fn(op, node=node)( - rng, static_size, out_dtype, *parameters - ) + return pytorch_sample_fn(op, node=node)(rng, shape, out_dtype, *parameters) return sample_fn @@ -53,10 +40,22 @@ def pytorch_sample_fn(op, node): @pytorch_sample_fn.register(ptr.BernoulliRV) def pytorch_sample_fn_bernoulli(op, node): def sample_fn(rng, size, dtype, p): - # XXX replace - state_ = rng["pytorch_state"] - gen = torch.Generator().set_state(state_) - sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) - return (rng, sample) + gen = rng["pytorch_gen"] + sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen) + return (gen, sample) + + return sample_fn + + +@pytorch_sample_fn.register(ptr.BinomialRV) +def pytorch_sample_fn_binomial(op, node): + def sample_fn(rng, size, dtype, n, p): + gen = rng["pytorch_gen"] + sample = torch.binomial( + torch.broadcast_to(n.to(p.dtype), size), + torch.broadcast_to(p, size), + generator=gen, + ) + return (gen, sample) return sample_fn diff --git a/tests/link/pytorch/test_random.py b/tests/link/pytorch/test_random.py index ad2c8897a3..70b0f72102 100644 --- a/tests/link/pytorch/test_random.py +++ b/tests/link/pytorch/test_random.py @@ -10,11 +10,56 @@ torch = pytest.importorskip("torch") -@pytest.mark.parametrize("size", [(), (4,)]) -def test_random_bernoulli(size): +@pytest.mark.parametrize( + "size,p", + [ + ((1000,), 0.5), + ( + ( + 1000, + 4, + ), + 0.5, + ), + ((10, 2), np.array([0.5, 0.3])), + ((1000, 10, 2), np.array([0.5, 0.3])), + ], +) +def test_random_bernoulli(size, p): rng = shared(np.random.default_rng(123)) - g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng) + g = pt.random.bernoulli(p, size=size, rng=rng) g_fn = function([], g, mode=pytorch_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) + + +@pytest.mark.parametrize( + "size,n,p", + [ + ((1000,), 10, 0.5), + ( + ( + 1000, + 4, + ), + 10, + 0.5, + ), + ( + ( + 1000, + 2, + ), + np.array([10, 40]), + np.array([0.5, 0.3]), + ), + ], +) +def test_binomial(n, p, size): + rng = shared(np.random.default_rng(123)) + g = pt.random.binomial(n, p, size=size, rng=rng) + g_fn = function([], g, mode=pytorch_mode) + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) + np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)