diff --git a/docs/source/index.rst b/docs/source/index.rst index 614cd34f..6d1f63b5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,6 +28,7 @@ Funsor is a tensor-like library for functions and distributions :maxdepth: 2 :caption: Interfaces: + recipes pyro distributions minipyro diff --git a/docs/source/recipes.rst b/docs/source/recipes.rst new file mode 100644 index 00000000..05c5e361 --- /dev/null +++ b/docs/source/recipes.rst @@ -0,0 +1,4 @@ +.. automodule:: funsor.recipes + :members: + :show-inheritance: + :member-order: bysource diff --git a/funsor/__init__.py b/funsor/__init__.py index 6bb08650..4f7a848c 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -42,6 +42,7 @@ joint, montecarlo, ops, + recipes, sum_product, terms, testing, @@ -96,6 +97,7 @@ "pretty", "quote", "reals", + "recipes", "reinterpret", "set_backend", # 'minipyro', # TODO: enable when minipyro is backend-agnostic diff --git a/funsor/adjoint.py b/funsor/adjoint.py index bdae045a..a573e448 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -67,7 +67,9 @@ def __enter__(self): self._old_interpretation = interpreter.get_interpretation() return super().__enter__() - def adjoint(self, sum_op, bin_op, root, targets=None): + def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset()): + # TODO Replace this with root + Constant(...) after #548 merges. + root_vars = root.input_vars | batch_vars zero = to_funsor(ops.UNITS[sum_op]) one = to_funsor(ops.UNITS[bin_op]) @@ -115,7 +117,9 @@ def adjoint(self, sum_op, bin_op, root, targets=None): in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs: - agg_vars = adjv.input_vars - v.input_vars - root.input_vars + # Marginalize out message variables that don't appear in recipients. + agg_vars = adjv.input_vars - v.input_vars - root_vars + assert "particle" not in {var.name for var in agg_vars} # DEBUG FIXME old_value = adjoint_values[v] adjoint_values[v] = sum_op(old_value, adjv.reduce(sum_op, agg_vars)) @@ -129,11 +133,17 @@ def adjoint(self, sum_op, bin_op, root, targets=None): return {target: result[target] for target in targets} -def adjoint(sum_op, bin_op, expr): +def forward_backward(sum_op, bin_op, expr, *, batch_vars=frozenset()): with AdjointTape() as tape: # TODO fix traversal order in AdjointTape instead of using stack_reinterpret - root = stack_reinterpret(expr) - return tape.adjoint(sum_op, bin_op, root) + forward = stack_reinterpret(expr) + backward = tape.adjoint(sum_op, bin_op, forward, batch_vars=batch_vars) + return forward, backward + + +def adjoint(sum_op, bin_op, expr): + forward, backward = forward_backward(sum_op, bin_op, expr) + return backward # logaddexp/add diff --git a/funsor/approximations.py b/funsor/approximations.py index bc2d8134..9ae1caed 100644 --- a/funsor/approximations.py +++ b/funsor/approximations.py @@ -66,7 +66,7 @@ def laplace_approximate_logaddexp(op, model, guide, approx_vars): ################################################################################ # Computations. # TODO Consider either making these Funsor methods or making .sample() and -# .unscaled_sample() singledispatch functions. +# ._sample() singledispatch functions. @singledispatch diff --git a/funsor/cnf.py b/funsor/cnf.py index 39e82e1e..de47d0df 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -102,38 +102,31 @@ def __str__(self): ) return super().__str__() - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self if self.red_op in (ops.null, ops.logaddexp): - if self.bin_op in (ops.null, ops.logaddexp): - if rng_key is not None and get_backend() == "jax": - import jax + if rng_key is not None and get_backend() == "jax": + import jax - rng_keys = jax.random.split(rng_key, len(self.terms)) - else: - rng_keys = [None] * len(self.terms) + rng_keys = jax.random.split(rng_key, len(self.terms)) + else: + rng_keys = [None] * len(self.terms) + if self.bin_op in (ops.null, ops.logaddexp): # Design choice: we sample over logaddexp reductions, but leave logaddexp # binary choices symbolic. terms = [ - term.unscaled_sample( - sampled_vars.intersection(term.inputs), sample_inputs + term._sample( + sampled_vars.intersection(term.inputs), sample_inputs, rng_key ) for term, rng_key in zip(self.terms, rng_keys) ] return Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms) if self.bin_op is ops.add: - if rng_key is not None and get_backend() == "jax": - import jax - - rng_keys = jax.random.split(rng_key) - else: - rng_keys = [None] * 2 - # Sample variables greedily in order of the terms in which they appear. for term in self.terms: greedy_vars = sampled_vars.intersection(term.inputs) @@ -146,9 +139,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): ).append(term) if len(greedy_terms) == 1: term = greedy_terms[0] - terms.append( - term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]) - ) + terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0])) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms ) @@ -161,9 +152,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): term = discrete + gaussian.log_normalizer terms.append(gaussian) terms.append(-gaussian.log_normalizer) - terms.append( - term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]) - ) + terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0])) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms ) @@ -173,10 +162,12 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): for term in greedy_terms ): sampled_terms = [ - term.unscaled_sample( - greedy_vars.intersection(term.value.inputs), sample_inputs + term._sample( + greedy_vars.intersection(term.value.inputs), + sample_inputs, + rng_key, ) - for term in greedy_terms + for term, rng_key in zip(greedy_terms, rng_keys) if isinstance(term, funsor.distribution.Distribution) and not greedy_vars.isdisjoint(term.value.inputs) ] @@ -192,7 +183,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): ", ".join(str(type(t)) for t in greedy_terms) ) ) - return result.unscaled_sample( + return result._sample( sampled_vars - greedy_vars, sample_inputs, rng_keys[1] ) diff --git a/funsor/delta.py b/funsor/delta.py index ddc3d296..aa236ba4 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -200,7 +200,7 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key): return self diff --git a/funsor/distribution.py b/funsor/distribution.py index 6a1e1939..a61538e5 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -206,7 +206,7 @@ def eager_log_prob(cls, *params): inputs.update(x.inputs) return log_prob.align(tuple(inputs)) - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 4cfb7b25..1fa6f583 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -688,7 +688,7 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index e0669640..06d66961 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -1,11 +1,15 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import functools from collections import OrderedDict +from funsor.cnf import Contraction +from funsor.delta import Delta from funsor.integrate import Integrate from funsor.interpretations import StatefulInterpretation -from funsor.terms import Approximate, Funsor +from funsor.tensor import Tensor +from funsor.terms import Approximate, Funsor, Number from funsor.util import get_backend from . import ops @@ -36,9 +40,7 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options) if sample is log_measure: return None # cannot progress - reduced_vars |= frozenset( - v for v in sample.input_vars if v.name in state.sample_inputs - ) + return Integrate(sample, integrand, reduced_vars) @@ -53,13 +55,43 @@ def monte_carlo_approximate(state, op, model, guide, approx_vars): sample = guide.sample(approx_vars, state.sample_inputs, **sample_options) if sample is guide: return model # cannot progress - reduced_vars = frozenset( - v for v in sample.input_vars if v.name in state.sample_inputs + result = sample + model - guide + + return result + + +@functools.singledispatch +def extract_samples(discrete_density): + """ + Extract sample values out of a funsor Delta, possibly scaled by Tensors. + This is useful for extracting sample tensors from a Monte Carlo + computation. + """ + raise ValueError( + f"Could not extract support from {type(discrete_density).__name__}" ) - result = (sample + model - guide).reduce(op, reduced_vars) + + +@extract_samples.register(Delta) +def _extract_samples_delta(discrete_density): + return {name: point for name, (point, log_density) in discrete_density.terms} + + +@extract_samples.register(Contraction) +def _extract_samples_contraction(discrete_density): + assert not discrete_density.reduced_vars + result = {} + for term in discrete_density.terms: + result.update(extract_samples(term)) return result +@extract_samples.register(Number) +@extract_samples.register(Tensor) +def _extract_samples_scale(discrete_density): + return {} + + __all__ = [ "MonteCarlo", ] diff --git a/funsor/recipes.py b/funsor/recipes.py new file mode 100644 index 00000000..1e52b2ee --- /dev/null +++ b/funsor/recipes.py @@ -0,0 +1,84 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Recipes using Funsor +-------------------- +This module provides a number of high-level algorithms using Funsor. + +""" + +from typing import Dict, FrozenSet + +import funsor # Let's use fully qualified names in this file. + + +def forward_filter_backward_rsample( + factors: Dict[str, funsor.Funsor], + eliminate: FrozenSet[str], + plates: FrozenSet[str], + sample_inputs: Dict[str, funsor.domains.Domain] = {}, + rng_key=None, +): + """ + A forward-filter backward-batched-reparametrized-sample algorithm for use + in variational inference. The motivating use case is performing Gaussian + tensor variable elimination over structured variational posteriors. + + :param dict factors: A dictionary mapping sample site name to a Funsor + factor created at that sample site. + :param frozenset: A set of names of latent variables to marginalize and + plates to aggregate. + :param plates: A set of names of plates to aggregate. + :param dict sample_inputs: An optional dict of enclosing sample indices + over which samples will be drawn in batch. + :param rng_key: A random number key for the JAX backend. + :returns: A pair ``samples:Dict[str, Tensor], log_prob: Tensor`` of samples + and log density evaluated at each of those samples. If ``sample_inputs`` + is nonempty, both outputs will be batched. + :rtype: tuple + """ + assert isinstance(factors, dict) + assert all(isinstance(k, str) for k in factors) + assert all(isinstance(v, funsor.Funsor) for v in factors.values()) + assert isinstance(eliminate, frozenset) + assert all(isinstance(v, str) for v in eliminate) + assert isinstance(plates, frozenset) + assert all(isinstance(v, str) for v in plates) + assert isinstance(sample_inputs, dict) + assert all(isinstance(k, str) for k in sample_inputs) + assert all(isinstance(v, funsor.domains.Domain) for v in sample_inputs.values()) + + # Perform tensor variable elimination. + with funsor.interpretations.reflect: + log_Z = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + list(factors.values()), + eliminate, + plates, + ) + log_Z = funsor.optimizer.apply_optimizer(log_Z) + batch_vars = frozenset(funsor.Variable(k, v) for k, v in sample_inputs.items()) + with funsor.montecarlo.MonteCarlo(**sample_inputs, rng_key=rng_key): + log_Z, marginals = funsor.adjoint.forward_backward( + funsor.ops.logaddexp, funsor.ops.add, log_Z, batch_vars=batch_vars + ) + + # Extract sample tensors. + samples = {} + for name, factor in factors.items(): + if name in eliminate: + samples.update(funsor.montecarlo.extract_samples(marginals[factor])) + assert frozenset(samples) == eliminate - plates + + # Compute log density at each sample. + log_prob = -log_Z + for f in factors.values(): + term = f(**samples) + plates = eliminate.intersection(term.inputs) + term = term.reduce(funsor.ops.add, plates) + log_prob += term + assert set(log_prob.inputs) == set(sample_inputs) + + return samples, log_prob diff --git a/funsor/tensor.py b/funsor/tensor.py index 1c5853d3..204c0aaf 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -329,7 +329,7 @@ def eager_reduce(self, op, reduced_vars): return Tensor(data, inputs, dtype) return super(Tensor, self).eager_reduce(op, reduced_vars) - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key): assert self.output == Real sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: diff --git a/funsor/terms.py b/funsor/terms.py index 1041a56f..354921ba 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -3,7 +3,6 @@ import functools import itertools -import math import numbers import typing import warnings @@ -376,12 +375,13 @@ def reduce(self, op, reduced_vars=None): Reduce along all or a subset of inputs. :param op: A reduction operation. - :type op: ~funsor.ops.AssociativeOp + :type op: ~funsor.ops.AssociativeOp or ~funsor.ops.ReductionOp :param reduced_vars: An optional input name or set of names to reduce. If unspecified, all inputs will be reduced. :type reduced_vars: str, Variable, or set or frozenset thereof. """ - assert isinstance(op, AssociativeOp) + assert isinstance(op, (AssociativeOp, ops.ReductionOp)) + # Eagerly convert reduced_vars to appropriate things. if reduced_vars is None: # Empty reduced_vars means "reduce over everything". @@ -389,6 +389,23 @@ def reduce(self, op, reduced_vars=None): else: reduced_vars = _convert_reduced_vars(reduced_vars, self.inputs) assert isinstance(reduced_vars, frozenset), reduced_vars + + # Attempt to convert ReductionOp to AssociativeOp. + if isinstance(op, ops.ReductionOp): + if isinstance(op, ops.MeanOp): + reduced_vars &= self.input_vars + if not reduced_vars: + return self + scale = 1 / reduce(ops.mul, [v.output.size for v in reduced_vars], 1) + return self.reduce(ops.add, reduced_vars) * scale + if isinstance(op, ops.VarOp): + diff = self - self.reduce(ops.mean, reduced_vars) + return (diff * diff).reduce(ops.mean, reduced_vars) + if isinstance(op, ops.StdOp): + return self.reduce(ops.var, reduced_vars).sqrt() + raise NotImplementedError(f"Unsupported reduction op: {op}") + assert isinstance(op, AssociativeOp) + if not reduced_vars: return self return Reduce(op, self, reduced_vars) @@ -436,15 +453,15 @@ def sample(self, sampled_vars, sample_inputs=None, rng_key=None): exact = (x.exp() * integrand).reduce(ops.add) approx = (y.exp() * integrand).reduce(ops.add) - If ``sample_inputs`` is provided, this creates a batch of samples - scaled samples. + If ``sample_inputs`` is provided, this creates a batch of samples. :param sampled_vars: A set of input variables to sample. :type sampled_vars: str, Variable, or set or frozenset thereof. :param OrderedDict sample_inputs: An optional mapping from variable name to :class:`~funsor.domains.Domain` over which samples will be batched. - :param rng_key: a PRNG state to be used by JAX backend to generate random samples + :param rng_key: a PRNG state to be used by JAX backend to generate + random samples :type rng_key: None or JAX's random.PRNGKey """ assert self.output == Real @@ -457,21 +474,14 @@ def sample(self, sampled_vars, sample_inputs=None, rng_key=None): if sampled_vars.isdisjoint(self.inputs): return self - result = instrument.debug_logged(self.unscaled_sample)( + result = instrument.debug_logged(self._sample)( sampled_vars, sample_inputs, rng_key ) - if sample_inputs is not None: - log_scale = 0 - for var, domain in sample_inputs.items(): - if var in result.inputs and var not in self.inputs: - log_scale -= math.log(domain.dtype) - if log_scale != 0: - result += log_scale return result - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key): """ - Internal method to draw an unscaled sample. + Internal method to draw samples. This should be overridden by subclasses. """ assert self.output == Real @@ -926,7 +936,7 @@ def _alpha_convert(self, alpha_subs): subs = tuple((str(alpha_subs.get(k, k)), v) for k, v in subs) return arg, subs - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key=None): if any(k in sample_inputs for k, v in self.subs.items()): raise NotImplementedError("TODO alpha-convert") subs_sampled_vars = set() @@ -940,7 +950,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): if name in v.inputs: subs_sampled_vars.add(k) subs_sampled_vars = frozenset(subs_sampled_vars) - arg = self.arg.unscaled_sample(subs_sampled_vars, sample_inputs, rng_key) + arg = self.arg._sample(subs_sampled_vars, sample_inputs, rng_key) return Subs(arg, tuple(self.subs.items())) @@ -1796,13 +1806,13 @@ def _alpha_convert(self, alpha_subs): diag_var = str(alpha_subs.get(diag_var, diag_var)) return fn, reals_var, bint_var, diag_var - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def _sample(self, sampled_vars, sample_inputs, rng_key=None): if self.bint_var in sampled_vars or self.bint_var in sample_inputs: raise NotImplementedError("TODO alpha-convert") sampled_vars = frozenset( self.diag_var if v == self.reals_var else v for v in sampled_vars ) - fn = self.fn.unscaled_sample(sampled_vars, sample_inputs, rng_key) + fn = self.fn._sample(sampled_vars, sample_inputs, rng_key) return Independent(fn, self.reals_var, self.bint_var, self.diag_var) def eager_subs(self, subs): diff --git a/funsor/testing.py b/funsor/testing.py index 3dc45ff0..c1c39e2d 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -192,6 +192,11 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): assert diff < (atol + abs(expected)) * rtol, msg elif atol is not None: assert diff < atol, msg + elif isinstance(actual, dict): + assert isinstance(expected, dict) + assert set(actual) == set(expected) + for k, actual_v in actual.items(): + assert_close(actual_v, expected[k], atol=atol, rtol=rtol) else: raise ValueError("cannot compare objects of type {}".format(type(actual))) diff --git a/test/test_approximations.py b/test/test_approximations.py index b61c0189..2bfc88c1 100644 --- a/test/test_approximations.py +++ b/test/test_approximations.py @@ -22,7 +22,7 @@ from funsor.interpreter import reinterpret from funsor.montecarlo import MonteCarlo from funsor.tensor import Tensor -from funsor.terms import Approximate +from funsor.terms import Approximate, Variable from funsor.testing import ( assert_close, make_einsum_example, @@ -34,9 +34,22 @@ from funsor.util import get_backend monte_carlo = MonteCarlo(rng_key=np.array([0, 0], dtype=np.uint32)) +monte_carlo_10 = MonteCarlo( + rng_key=np.array([0, 0], dtype=np.uint32), + particle=Bint[10], +) +particles_10 = frozenset([Variable("particle", Bint[10])]) -@pytest.mark.parametrize("approximate", [eager, argmax_approximate, monte_carlo]) +@pytest.mark.parametrize( + "approximate", + [ + eager, + argmax_approximate, + monte_carlo, + monte_carlo_10, + ], +) def test_tensor_smoke(approximate): with normalize: model = random_tensor(OrderedDict(i=Bint[2], j=Bint[3])) @@ -45,7 +58,10 @@ def test_tensor_smoke(approximate): with approximate, xfail_if_not_implemented(): q = reinterpret(p) assert q.output == p.output - assert q.input_vars.issubset(p.input_vars) + if approximate == monte_carlo_10: + assert q.input_vars.issubset(p.input_vars | particles_10) + else: + assert q.input_vars.issubset(p.input_vars) @pytest.mark.parametrize( @@ -56,6 +72,7 @@ def test_tensor_smoke(approximate): laplace_approximate, xfail_param(mean_approximate, reason="alpha conversion bug"), monte_carlo, + monte_carlo_10, ], ) def test_gaussian_smoke(approximate): @@ -66,7 +83,10 @@ def test_gaussian_smoke(approximate): with approximate, xfail_if_not_implemented(): q = reinterpret(p) assert q.output == p.output - assert q.input_vars.issubset(p.input_vars) + if approximate == monte_carlo_10: + assert q.input_vars.issubset(p.input_vars | particles_10) + else: + assert q.input_vars.issubset(p.input_vars) @pytest.mark.parametrize( @@ -74,8 +94,9 @@ def test_gaussian_smoke(approximate): [ eager, argmax_approximate, - xfail_param(monte_carlo, reason="only true in expectation"), + monte_carlo, ], + ids=str, ) def test_tensor_linear(approximate): m1 = random_tensor(OrderedDict(i=Bint[2], x=Bint[4])) @@ -87,7 +108,11 @@ def test_tensor_linear(approximate): q1 = m1.approximate(ops.logaddexp, guide, "x") q2 = m2.approximate(ops.logaddexp, guide, "x") actual = q1 + s * q2 - assert_close(actual, expected) + + if approximate == monte_carlo: + assert actual.inputs == expected.inputs + else: + assert_close(actual, expected) @pytest.mark.parametrize( @@ -97,8 +122,9 @@ def test_tensor_linear(approximate): argmax_approximate, laplace_approximate, mean_approximate, - xfail_param(monte_carlo, reason="only true in expectation"), + monte_carlo, ], + ids=str, ) def test_gaussian_linear(approximate): m1 = random_gaussian(OrderedDict(i=Bint[2], x=Real)) @@ -110,7 +136,11 @@ def test_gaussian_linear(approximate): q1 = m1.approximate(ops.logaddexp, guide, "x") q2 = m2.approximate(ops.logaddexp, guide, "x") actual = q1 + s * q2 - assert_close(actual, expected) + + if approximate == monte_carlo: + assert actual.inputs == expected.inputs + else: + assert_close(actual, expected) def test_backward_argmax_simple_reduce(): diff --git a/test/test_distribution.py b/test/test_distribution.py index 5fa75b46..59eeeaaf 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -771,24 +771,24 @@ def _get_stat_diff( sample_value, Variable("value", funsor_dist.inputs["value"]), frozenset(["value"]), - ).reduce(ops.add, frozenset(sample_inputs)) + ).reduce(ops.mean, frozenset(sample_inputs)) expected_stat = funsor_dist.mean() elif statistic == "variance": actual_mean = Integrate( sample_value, Variable("value", funsor_dist.inputs["value"]), frozenset(["value"]), - ).reduce(ops.add, frozenset(sample_inputs)) + ).reduce(ops.mean, frozenset(sample_inputs)) actual_stat = Integrate( sample_value, (Variable("value", funsor_dist.inputs["value"]) - actual_mean) ** 2, frozenset(["value"]), - ).reduce(ops.add, frozenset(sample_inputs)) + ).reduce(ops.mean, frozenset(sample_inputs)) expected_stat = funsor_dist.variance() elif statistic == "entropy": actual_stat = -Integrate( sample_value, funsor_dist, frozenset(["value"]) - ).reduce(ops.add, frozenset(sample_inputs)) + ).reduce(ops.mean, frozenset(sample_inputs)) expected_stat = funsor_dist.entropy() else: raise ValueError("invalid test statistic") @@ -1187,7 +1187,7 @@ def _assert_conjugate_density_ok( expected = Integrate( latent_samples, conditional(value=obs).exp(), frozenset(["prior"]) ) - expected = expected.reduce(ops.add, frozenset(sample_inputs)) + expected = expected.reduce(ops.mean, frozenset(sample_inputs)) actual = ( (latent + conditional).reduce(ops.logaddexp, set(["prior"]))(value=obs).exp() ) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 0a05a465..b81b346c 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -596,7 +596,8 @@ def test_integrate_variable(int_inputs, real_inputs): sampled_log_measure = log_measure.sample( reduced_vars, OrderedDict(particle=Bint[100000]), rng_key=rng_key ) - approx = Integrate(sampled_log_measure, integrand, reduced_vars | {"particle"}) + approx = Integrate(sampled_log_measure, integrand, reduced_vars) + approx = approx.reduce(ops.mean, "particle") assert isinstance(approx, Tensor) exact = Integrate(log_measure, integrand, reduced_vars) @@ -604,6 +605,7 @@ def test_integrate_variable(int_inputs, real_inputs): assert_close(approx, exact, atol=0.1, rtol=0.1) +@pytest.mark.xfail(get_backend() == "jax", reason="numerically unstable in jax backend") @pytest.mark.parametrize( "int_inputs", [ @@ -638,7 +640,8 @@ def test_integrate_gaussian(int_inputs, real_inputs): sampled_log_measure = log_measure.sample( reduced_vars, OrderedDict(particle=Bint[100000]), rng_key=rng_key ) - approx = Integrate(sampled_log_measure, integrand, reduced_vars | {"particle"}) + approx = Integrate(sampled_log_measure, integrand, reduced_vars) + approx = approx.reduce(ops.mean, "particle") assert isinstance(approx, Tensor) exact = Integrate(log_measure, integrand, reduced_vars) @@ -646,15 +649,16 @@ def test_integrate_gaussian(int_inputs, real_inputs): assert_close(approx, exact, atol=0.1, rtol=0.1) -@pytest.mark.xfail( - get_backend() == "torch", reason="numerically unstable in torch backend" -) def test_mc_plate_gaussian(): log_measure = Gaussian( numeric_array([0.0]), numeric_array([[1.0]]), (("loc", Real),) ) + numeric_array(-0.9189) + + plate_size = 10 integrand = Gaussian( - randn((100, 1)) + 3.0, ones((100, 1, 1)), (("data", Bint[100]), ("loc", Real)) + randn((plate_size, 1)) + 3.0, + ones((plate_size, 1, 1)), + (("data", Bint[plate_size]), ("loc", Real)), ) rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) diff --git a/test/test_integrate.py b/test/test_integrate.py index 627efbe2..ac111a6d 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -24,9 +24,10 @@ eager, moment_matching, MonteCarlo(rng_key=np.array([0, 0], dtype=np.uint32)), + MonteCarlo(rng_key=np.array([0, 0], dtype=np.uint32), particle=Bint[10]), ], ) -def test_integrate(interp): +def test_integrate_smoke(interp): log_measure = random_tensor(OrderedDict([("i", Bint[2]), ("j", Bint[3])])) integrand = random_tensor(OrderedDict([("j", Bint[3]), ("k", Bint[4])])) with interp: diff --git a/test/test_joint.py b/test/test_joint.py index f61692dd..88f3267d 100644 --- a/test/test_joint.py +++ b/test/test_joint.py @@ -347,16 +347,16 @@ def test_reduce_moment_matching_moments(): with moment_matching: approx = gaussian.reduce(ops.logaddexp, "j") with MonteCarlo(s=Bint[100000]): - actual = Integrate(approx, Number(1.0), "x") - expected = Integrate(gaussian, Number(1.0), {"j", "x"}) + actual = Integrate(approx, Number(1.0), "x").reduce(ops.mean, "s") + expected = Integrate(gaussian, Number(1.0), {"j", "x"}).reduce(ops.mean, "s") assert_close(actual, expected, atol=1e-3, rtol=1e-3) - actual = Integrate(approx, x, "x") - expected = Integrate(gaussian, x, {"j", "x"}) + actual = Integrate(approx, x, "x").reduce(ops.mean, "s") + expected = Integrate(gaussian, x, {"j", "x"}).reduce(ops.mean, "s") assert_close(actual, expected, atol=1e-2, rtol=1e-2) - actual = Integrate(approx, x * x, "x") - expected = Integrate(gaussian, x * x, {"j", "x"}) + actual = Integrate(approx, x * x, "x").reduce(ops.mean, "s") + expected = Integrate(gaussian, x * x, {"j", "x"}).reduce(ops.mean, "s") assert_close(actual, expected, atol=1e-2, rtol=1e-2) diff --git a/test/test_recipes.py b/test/test_recipes.py new file mode 100644 index 00000000..8eb249a7 --- /dev/null +++ b/test/test_recipes.py @@ -0,0 +1,331 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict +from typing import Dict, Tuple + +import numpy as np +import pytest + +import funsor.ops as ops +from funsor.domains import Bint, Real, Reals +from funsor.montecarlo import extract_samples +from funsor.recipes import forward_filter_backward_rsample +from funsor.terms import Lambda, Variable +from funsor.testing import assert_close, random_gaussian +from funsor.util import get_backend + + +def get_moments(samples): + reduced_vars = frozenset(["particle"]) + moments = OrderedDict() + + # Compute first moments. + diffs = OrderedDict() + for name, value in samples.items(): + mean = value.reduce(ops.mean, reduced_vars) + moments[name] = mean + diffs[name] = value - mean + + # Compute centered second moments. + for i, (name1, diff1) in enumerate(diffs.items()): + diff1_ = diff1.reshape((diff1.output.num_elements, 1)) + for name2, diff2 in list(diffs.items())[:i]: + diff_2 = diff2.reshape((1, diff2.output.num_elements)) + diff12 = diff1_ * diff_2 + moments[name1, name2] = diff12.reduce(ops.mean, reduced_vars) + + return moments + + +def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob): + """ + This can be seen as performing naive tensor variable elimination by + breaking all plates and creating a single flat joint distribution. + """ + assert "particle" not in plates + flat_vars: Dict[str, Variable] = {} + plate_vars: Dict[str, Variable] = {} + broken_plates: Dict[str, Tuple[Variable]] = {} + for name, factor in factors.items(): + for k, d in factor.inputs.items(): + if k in plates: + plate_vars[k] = Variable(k, d) + if name in factor.inputs: # i.e. if is latent + broken_plates[name] = tuple( + plate_vars[p] for p in sorted(plates.intersection(factor.inputs)) + ) + # I guess we could use Lambda here? + broken_shape = tuple(p.output.size for p in broken_plates[name]) + domain = Reals[broken_shape + factor.inputs[name].shape] + flat_vars[name] = Variable("flat_" + name, domain)[broken_plates[name]] + + flat_factors = [] + for factor in factors.values(): + f = factor(**flat_vars) + f = f.reduce(ops.add, plates.intersection(f.inputs)) + flat_factors.append(f) + + # Check log prob. + flat_joint = sum(flat_factors) + log_Z = flat_joint.reduce(ops.logaddexp) + flat_samples = {} + for k, v in actual_samples.items(): + for p in reversed(broken_plates[k]): + v = Lambda(p, v) + flat_samples["flat_" + k] = v + expected_log_prob = flat_joint(**flat_samples) - log_Z + assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=None) + + # Check sample moments. + sample_inputs = OrderedDict(particle=actual_log_prob.inputs["particle"]) + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + flat_deltas = flat_joint.sample( + {"flat_" + k for k in flat_vars}, sample_inputs, rng_key + ) + flat_samples = extract_samples(flat_deltas) + expected_samples = { + k: flat_samples["flat_" + k][broken_plates[k]] for k in flat_vars + } + expected_moments = get_moments(expected_samples) + actual_moments = get_moments(actual_samples) + assert_close(actual_moments, expected_moments, atol=0.02, rtol=None) + + +def test_ffbr_1(): + """ + def model(data): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a, 1), obs=data) + """ + num_samples = int(1e5) + + factors = { + "a": random_gaussian(OrderedDict({"a": Real})), + "b": random_gaussian(OrderedDict({"a": Real})), + } + eliminate = frozenset(["a"]) + plates = frozenset() + sample_inputs = OrderedDict(particle=Bint[num_samples]) + + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a"} + assert actual_samples["a"].output == Real + assert set(actual_samples["a"].inputs) == {"particle"} + + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + +def test_ffbr_2(): + """ + def model(data): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + pyro.sample("c", dist.Normal(a, b.exp()), obs=data) + """ + num_samples = int(1e5) + + factors = { + "a": random_gaussian(OrderedDict({"a": Real})), + "b": random_gaussian(OrderedDict({"b": Real})), + "c": random_gaussian(OrderedDict({"a": Real, "b": Real})), + } + eliminate = frozenset(["a", "b"]) + plates = frozenset() + sample_inputs = {"particle": Bint[num_samples]} + + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a", "b"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert set(actual_samples["a"].inputs) == {"particle"} + assert set(actual_samples["b"].inputs) == {"particle"} + + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + +def test_ffbr_3(): + """ + def model(data): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("i", 2): + b = pyro.sample("b", dist.Normal(0, 1)) + pyro.sample("c", dist.Normal(a, b.exp()), obs=data) + """ + num_samples = int(1e5) + + factors = { + "a": random_gaussian(OrderedDict({"a": Real})), + "b": random_gaussian(OrderedDict({"i": Bint[2], "b": Real})), + "c": random_gaussian(OrderedDict({"i": Bint[2], "a": Real, "b": Real})), + } + eliminate = frozenset(["a", "b", "i"]) + plates = frozenset(["i"]) + sample_inputs = {"particle": Bint[num_samples]} + + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a", "b"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert set(actual_samples["a"].inputs) == {"particle"} + assert set(actual_samples["b"].inputs) == {"particle", "i"} + + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + +def test_ffbr_4(): + """ + def model(data): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + with pyro.plate("i", 2): + c = pyro.sample("c", dist.Normal(a, 1)) + d = pyro.sample("d", dist.Normal(b, 1)) + with pyro.plate("j", 3): + pyro.sample("e", dist.Normal(c, d.exp()), obs=data) + """ + num_samples = int(1e5) + + factors = { + "a": random_gaussian(OrderedDict({"a": Real})), + "b": random_gaussian(OrderedDict({"b": Real})), + "c": random_gaussian(OrderedDict({"i": Bint[2], "a": Real, "c": Real})), + "d": random_gaussian(OrderedDict({"i": Bint[2], "b": Real, "d": Real})), + "e": random_gaussian( + OrderedDict({"i": Bint[2], "j": Bint[3], "c": Real, "d": Real}) + ), + } + eliminate = frozenset(["a", "b", "c", "d", "i", "j"]) + plates = frozenset(["i", "j"]) + sample_inputs = {"particle": Bint[num_samples]} + + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a", "b", "c", "d"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert actual_samples["c"].output == Real + assert actual_samples["d"].output == Real + assert set(actual_samples["a"].inputs) == {"particle"} + assert set(actual_samples["b"].inputs) == {"particle"} + assert set(actual_samples["c"].inputs) == {"particle", "i"} + assert set(actual_samples["d"].inputs) == {"particle", "i"} + + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + +def test_ffbr_5(): + """ + def model(data): + a = pyro.sample("a", dist.MultivariateNormal(zeros(2), eye(2))) + b = pyro.sample("b", dist.MultivariateNormal(a, eye(2))) + c = pyro.sample("c", dist.MultivariateNormal(b, eye(2))) + d = pyro.sample("d", dist.MultivariateNormal(c, eye(2))) + pyro.sample("e", dist.MultivariateNormal(d, eye(2)), obs=data) + """ + num_samples = int(1e5) + + factors = { + "a": random_gaussian(OrderedDict({"a": Reals[2]})), + "b": random_gaussian(OrderedDict({"b": Reals[2], "a": Reals[2]})), + "c": random_gaussian(OrderedDict({"c": Reals[2], "b": Reals[2]})), + "d": random_gaussian(OrderedDict({"d": Reals[2], "c": Reals[2]})), + "e": random_gaussian(OrderedDict({"d": Reals[2]})), + } + eliminate = frozenset(["a", "b", "c", "d"]) + plates = frozenset() + sample_inputs = {"particle": Bint[num_samples]} + + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a", "b", "c", "d"} + assert actual_samples["a"].output == Reals[2] + assert actual_samples["b"].output == Reals[2] + assert actual_samples["c"].output == Reals[2] + assert actual_samples["d"].output == Reals[2] + assert set(actual_samples["a"].inputs) == {"particle"} + assert set(actual_samples["b"].inputs) == {"particle"} + assert set(actual_samples["c"].inputs) == {"particle"} + assert set(actual_samples["d"].inputs) == {"particle"} + + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + +@pytest.mark.xfail(reason="TODO handle intractable case") +def test_ffbr_intractable_1(): + """ + def model(data): + i_plate = pyro.plate("i", 2, dim=-2) + j_plate = pyro.plate("j", 3, dim=-1) + with i_plate: + a = pyro.sample("a", dist.Normal(0, 1)) + with i_plate: + b = pyro.sample("b", dist.Normal(0, 1)) + with i_plate, j_plate: + pyro.sample("c", dist.Normal(a, b), obs=data) + """ + num_samples = int(1e5) + + factors = { + "a": random_gaussian(OrderedDict({"i": Bint[2], "a": Real})), + "b": random_gaussian(OrderedDict({"j": Bint[2], "b": Real})), + "c": random_gaussian( + OrderedDict({"i": Bint[2], "j": Bint[2], "a": Real, "b": Real}) + ), + } + eliminate = frozenset(["a", "b", "i", "j"]) + plates = frozenset(["i", "j"]) + sample_inputs = {"particle": Bint[num_samples]} + + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a", "b"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert set(actual_samples["a"].inputs) == {"particle", "i"} + assert set(actual_samples["b"].inputs) == {"particle", "j"} + + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + +@pytest.mark.xfail(reason="TODO handle colliders via Lambda") +def test_ffbr_intractable_2(): + """ + def model(data): + with pyro.plate("i", 2): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a.sum(), 1), obs=data) + """ + num_samples = int(1e5) + + factors = { + "a": random_gaussian(OrderedDict({"i": Bint[2], "a": Real})), + "b": random_gaussian(OrderedDict({"a_i": Reals[2]})), + } + eliminate = frozenset(["a", "i"]) + plates = frozenset(["i"]) + sample_inputs = {"particle": Bint[num_samples]} + + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a"} + assert set(actual_samples["a"].inputs) == {"particle", "i"} + + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) diff --git a/test/test_samplers.py b/test/test_samplers.py index 69cf37c0..18d94d24 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +import math from collections import OrderedDict from importlib import import_module @@ -262,6 +263,7 @@ def test_tensor_distribution(event_inputs, batch_inputs, test_grad): def diff_fn(p_data): p = Tensor(p_data, be_inputs) q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key) + q -= math.log(num_samples) mq = p.materialize(q).reduce(ops.logaddexp, "n") mq = mq.align(tuple(p.inputs)) @@ -307,6 +309,7 @@ def test_gaussian_distribution(event_inputs, batch_inputs): rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key) + q -= math.log(num_samples) p_vars = sampled_vars q_vars = sampled_vars | frozenset(["particle"]) # Check zeroth moment. @@ -352,6 +355,7 @@ def test_gaussian_mixture_distribution(batch_inputs, event_inputs): rng_key = None if get_backend() == "torch" else np.array([0, 1], dtype=np.uint32) q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key) + q -= math.log(num_samples) q_marginal = q.reduce(ops.logaddexp, "e") q_marginal = p_marginal.materialize(q_marginal).reduce(ops.logaddexp, "particle") assert isinstance(q_marginal, Tensor) @@ -372,6 +376,7 @@ def test_lognormal_distribution(moment): with MonteCarlo(particle=Bint[num_samples]): with xfail_if_not_implemented(): actual = Integrate(log_measure, probe, frozenset(["x"])) + actual = actual.reduce(ops.mean, "particle") _, (loc_data, scale_data) = align_tensors(loc, scale) samples = backend_dist.LogNormal(loc_data, scale_data).sample((num_samples,)) diff --git a/test/test_tensor.py b/test/test_tensor.py index 07353666..daee33f2 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1383,6 +1383,18 @@ def test_reduction(op, event_shape): ) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("event_shape", [(), (4,), (3, 2)], ids=str) +def test_reduce_reduction(batch_shape, event_shape): + x = Tensor(randn(*batch_shape, 5, *event_shape)) + for name in "abc"[: len(batch_shape)]: + x = x[name] + + assert_close(x["i"].reduce(ops.mean, "i"), x.mean(0)) + assert_close(x["i"].reduce(ops.var, "i"), x.var(0)) + assert_close(x["i"].reduce(ops.std, "i"), x.std(0)) + + @pytest.mark.parametrize( "op", [