-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add recipes.forward_filter_backward_rsample() #549
Conversation
All of the relevant downstream tests in Numpyro pass under these changes (on my local machine, at least), but the changes to the behavior of Since we don't pin Pyro |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good per Zoom reviews, just a few nits.
# Eagerly convert reduced_vars to appropriate things. | ||
if reduced_vars is None: | ||
# Empty reduced_vars means "reduce over everything". | ||
reduced_vars = frozenset(Variable(k, v) for k, v in self.inputs.items()) | ||
else: | ||
reduced_vars = _convert_reduced_vars(reduced_vars, self.inputs) | ||
assert isinstance(reduced_vars, frozenset), reduced_vars | ||
|
||
# Attempt to convert ReductionOp to AssociativeOp. | ||
if isinstance(op, ops.ReductionOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want this standardization logic to run under all interpretations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I see .reduce()
as lightweight syntax around deeper reinterpretable syntax including Reduce
. Note it is nonsensical to create a lazy Reduce
of a ReductionOp
, the way we have defined Reduce
.
test/test_approximations.py
Outdated
@@ -87,7 +108,9 @@ def test_tensor_linear(approximate): | |||
q1 = m1.approximate(ops.logaddexp, guide, "x") | |||
q2 = m2.approximate(ops.logaddexp, guide, "x") | |||
actual = q1 + s * q2 | |||
assert_close(actual, expected) | |||
|
|||
if approximate not in (monte_carlo, monte_carlo_10): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this will report the test as passing under monte_carlo
, would it be better to keep the xfail status?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just removed these test as I don't understand what they do.
test/test_approximations.py
Outdated
@@ -110,7 +134,9 @@ def test_gaussian_linear(approximate): | |||
q1 = m1.approximate(ops.logaddexp, guide, "x") | |||
q2 = m2.approximate(ops.logaddexp, guide, "x") | |||
actual = q1 + s * q2 | |||
assert_close(actual, expected) | |||
|
|||
if approximate not in (monte_carlo, monte_carlo_10): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: should this xfail rather than pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right, I was going to strengthen these tests once we fixed adjoint, will do...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob): | ||
""" | ||
This can be seen as performing naive tensor variable elimination by | ||
breaking all plates and creating a single flat joint distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this actually turns out to be true, at some point we should move this logic into funsor.sum_product
and test it against existing implementations since it's somewhat complex for testing logic and it is used to check correctness of what will presumably end up as one of our most important code paths if/when AutoGaussian
is finished.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM, it seems reasonable to eventually factor this out as a naive_sum_product
or sth. Before we refactor, we should fix the intractable tests below, as that will involve more lambdas etc.
I think we can safely make the following sequence of changes:
|
That sounds reasonable, provided the fixes necessary in |
Addresses pyro-ppl/pyro#2929
This implements a multi-sample
forward_filter_backward_rsample()
for use in Pyro'sAutoGaussian
guide.Changes
MonteCarlo
. These changes preserve semantics of single-sampleMonteCarlo()
but change semantics ofMonteCarlo(particles=Bint[n])
from mean-reducing overparticles
to introducing a new batch dimension overparticles
.Funsor.sample()
to avoid scaling by numel(sampled_inputs). Correspondingly.unscaled_sample()
is renamed to._sample()
..reduce(ops.mean, ...)
. This breaks from.reduce(op, ...)
supporting only associative ops, but this does seem like the cleanest syntax to supportReductionOp
s over discrete input variables, which will be an important pattern now that the1/numel
scaling os no longer performed by.sample()
.funsor.recipes
module with high-level algorithms intended for use in both Pyro and NumPyro. The idea so to maximize test sharing of these recipes by testing all backends in the funsor repo.batch_vars
arg toAdjointTape
to support batched backward sample (this might be simplified by Constant Funsor #548)forward_sample()
function fromadjoint()
.Tested