-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of random variables with PyTorch backend #1075
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1075 +/- ##
=======================================
Coverage 82.10% 82.11%
=======================================
Files 185 186 +1
Lines 48089 48184 +95
Branches 8659 8673 +14
=======================================
+ Hits 39485 39564 +79
- Misses 6439 6452 +13
- Partials 2165 2168 +3
|
static_shape = rv.type.shape | ||
batch_ndim = op.batch_ndim(node) | ||
|
||
# Try to pass static size directly to JAX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: pytorch
# XXX replace | ||
state_ = rng["pytorch_state"] | ||
gen = torch.Generator().set_state(state_) | ||
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't mind this approach! Torch has a lot of wrapping and abstraction on top of it's random generation, so if we just keep a little bit of state around it feels a bit simpler.
thunk_inputs = [] | ||
for n in self.fgraph.inputs: | ||
sinput = storage_map[n] | ||
if isinstance(sinput[0], RandomState | Generator): | ||
new_value = pytorch_typify( | ||
sinput[0], dtype=getattr(sinput[0], "dtype", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
static_shape = rv.type.shape | ||
batch_ndim = op.batch_ndim(node) | ||
|
||
# Try to pass static size directly to JAX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This static size is a JAX limitation that shouldn't exist in PyTorch
state_ = rng["pytorch_state"] | ||
gen = torch.Generator().set_state(state_) | ||
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) | ||
return (rng, sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should return a new state, otherwise the draws will be the same the next time it's evaluated
# XXX replace | ||
state_ = rng["pytorch_state"] | ||
gen = torch.Generator().set_state(state_) | ||
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it jut broadcast?, why copy?
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) | |
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) |
85d6080
to
1c8dc80
Compare
@pytorch_typify.register(Generator) | ||
def pytorch_typify_Generator(rng, **kwargs): | ||
# XXX: Check if there is a better way. | ||
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwiw, this depends on the device iiuc: https://open.spotify.com/episode/13oJCmQ2JWbk7t6sLRWlDz?si=1f43e67353284cc7
def pytorch_typify(data, dtype=None, **kwargs): | ||
if dtype is None: | ||
return data | ||
else: | ||
return torch.tensor(data, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We change this approach. You need to dispatch on the RNG type and decide what to do with it. The base-cass is to raise
# XXX: Check if there is a better way. | ||
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp) | ||
state = rng.__getstate__() | ||
seed = torch.from_numpy(rng.integers([2**32])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to copy the rng before calling rng.integers
we don't want to modify the original one
def sample_fn(rng, size, *parameters): | ||
return pytorch_sample_fn(op, node=node)(rng, shape, out_dtype, *parameters) | ||
|
||
return sample_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call pytorch_sample_fn
outside of sample_fn
.
def pytorch_sample_fn_bernoulli(op, node): | ||
def sample_fn(rng, size, dtype, p): | ||
gen = rng["pytorch_gen"] | ||
sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Size may be None
sample = torch.binomial( | ||
torch.broadcast_to(n.to(p.dtype), size), | ||
torch.broadcast_to(p, size), | ||
generator=gen, | ||
) | ||
return (gen, sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size may be none, in which case you should do: n, p = torch.broacast_arrays(n, p)
or whatever it's called
def sample_fn(rng, size, dtype, n, p): | ||
gen = rng["pytorch_gen"] | ||
sample = torch.binomial( | ||
torch.broadcast_to(n.to(p.dtype), size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you converting n to the type of p?
@@ -84,9 +86,16 @@ def fn(*inputs, inner_fn=inner_fn): | |||
return fn | |||
|
|||
def create_thunk_inputs(self, storage_map): | |||
from pytensor.link.pytorch.dispatch import pytorch_typify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to copy the logic with SharedVariables in JAX to emmit a warning and use different variables. You can refactor the logic so it's not duplicated
tests/link/pytorch/test_random.py
Outdated
4, | ||
), | ||
10, | ||
0.5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you take some of these trailing commas, pre-commit won't force it to be multi-line, which is very unreadable here
], | ||
) | ||
def test_binomial(n, p, size): | ||
rng = shared(np.random.default_rng(123)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need tests that confirm the original rng was not affected
rng = shared(np.random.default_rng(123)) | ||
g = pt.random.binomial(n, p, size=size, rng=rng) | ||
g_fn = function([], g, mode=pytorch_mode) | ||
samples = g_fn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should call twice. In this case, because you did not set updates you should get the same draws back. See https://pytensor.readthedocs.io/en/latest/tutorial/prng.html for details
You should also test with updates separately
- Copied generator before sampling from it
Description
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1075.org.readthedocs.build/en/1075/