Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recipes.forward_filter_backward_rsample() #549

Merged
merged 21 commits into from
Sep 22, 2021
Merged

Add recipes.forward_filter_backward_rsample() #549

merged 21 commits into from
Sep 22, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Sep 17, 2021

Addresses pyro-ppl/pyro#2929

This implements a multi-sample forward_filter_backward_rsample() for use in Pyro's AutoGaussian guide.

Changes

  • Modifying the semantics of MonteCarlo. These changes preserve semantics of single-sample MonteCarlo() but change semantics of MonteCarlo(particles=Bint[n]) from mean-reducing over particles to introducing a new batch dimension over particles.
  • Changes semantics of Funsor.sample() to avoid scaling by numel(sampled_inputs). Correspondingly .unscaled_sample() is renamed to ._sample().
  • Supporting .reduce(ops.mean, ...). This breaks from .reduce(op, ...) supporting only associative ops, but this does seem like the cleanest syntax to support ReductionOps over discrete input variables, which will be an important pattern now that the 1/numel scaling os no longer performed by .sample().
  • Adding a new funsor.recipes module with high-level algorithms intended for use in both Pyro and NumPyro. The idea so to maximize test sharing of these recipes by testing all backends in the funsor repo.
  • Adding a batch_vars arg to AdjointTape to support batched backward sample (this might be simplified by Constant Funsor #548)
  • Factoring out a new forward_sample() function from adjoint().

Tested

@fritzo fritzo added the WIP label Sep 17, 2021
@fritzo fritzo mentioned this pull request Sep 21, 2021
@fritzo fritzo changed the title Change semantics of multi-sample MonteCarlo to support AutoGaussian Add recipes.forward_filter_backward_rsample() Sep 22, 2021
@fritzo fritzo requested a review from eb8680 September 22, 2021 16:00
@fritzo fritzo marked this pull request as ready for review September 22, 2021 16:00
@fritzo fritzo added the enhancement New feature or request label Sep 22, 2021
@eb8680
Copy link
Member

eb8680 commented Sep 22, 2021

All of the relevant downstream tests in Numpyro pass under these changes (on my local machine, at least), but the changes to the behavior of .sample() seem to have broken a bunch of downstream tests in pyro.contrib.funsor.

Since we don't pin Pyro dev to Funsor master, one way to proceed would be for me to merge this and then put up a Pyro pull request with fixes that we can merge on the next Funsor or Pyro release. Does that sound reasonable?

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good per Zoom reviews, just a few nits.

funsor/terms.py Outdated Show resolved Hide resolved
# Eagerly convert reduced_vars to appropriate things.
if reduced_vars is None:
# Empty reduced_vars means "reduce over everything".
reduced_vars = frozenset(Variable(k, v) for k, v in self.inputs.items())
else:
reduced_vars = _convert_reduced_vars(reduced_vars, self.inputs)
assert isinstance(reduced_vars, frozenset), reduced_vars

# Attempt to convert ReductionOp to AssociativeOp.
if isinstance(op, ops.ReductionOp):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this standardization logic to run under all interpretations?

Copy link
Member Author

@fritzo fritzo Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I see .reduce() as lightweight syntax around deeper reinterpretable syntax including Reduce. Note it is nonsensical to create a lazy Reduce of a ReductionOp, the way we have defined Reduce.

@@ -87,7 +108,9 @@ def test_tensor_linear(approximate):
q1 = m1.approximate(ops.logaddexp, guide, "x")
q2 = m2.approximate(ops.logaddexp, guide, "x")
actual = q1 + s * q2
assert_close(actual, expected)

if approximate not in (monte_carlo, monte_carlo_10):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this will report the test as passing under monte_carlo, would it be better to keep the xfail status?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just removed these test as I don't understand what they do.

@@ -110,7 +134,9 @@ def test_gaussian_linear(approximate):
q1 = m1.approximate(ops.logaddexp, guide, "x")
q2 = m2.approximate(ops.logaddexp, guide, "x")
actual = q1 + s * q2
assert_close(actual, expected)

if approximate not in (monte_carlo, monte_carlo_10):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: should this xfail rather than pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, I was going to strengthen these tests once we fixed adjoint, will do...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob):
"""
This can be seen as performing naive tensor variable elimination by
breaking all plates and creating a single flat joint distribution.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this actually turns out to be true, at some point we should move this logic into funsor.sum_product and test it against existing implementations since it's somewhat complex for testing logic and it is used to check correctness of what will presumably end up as one of our most important code paths if/when AutoGaussian is finished.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, it seems reasonable to eventually factor this out as a naive_sum_product or sth. Before we refactor, we should fix the intractable tests below, as that will involve more lambdas etc.

@fritzo
Copy link
Member Author

fritzo commented Sep 22, 2021

one way to proceed would be for me to merge this and then put up a Pyro pull request with fixes

I think we can safely make the following sequence of changes:

  1. merge this PR
  2. fix Pyro dev and pin to a particular Funsor commit
  3. ...add AutoGaussian features via many commits to Funsor and Pyro...
  4. release Funsor
  5. pin Pyro to the Funsor release (required for Pyro releases)
  6. release Pyro

@eb8680
Copy link
Member

eb8680 commented Sep 22, 2021

That sounds reasonable, provided the fixes necessary in pyro.contrib.funsor aren't too onerous - I don't want to block AutoGaussian on that if it can be avoided

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting response enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants