Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recipes.forward_filter_backward_rsample() #549

Merged
merged 21 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Funsor is a tensor-like library for functions and distributions
:maxdepth: 2
:caption: Interfaces:

recipes
pyro
distributions
minipyro
Expand Down
4 changes: 4 additions & 0 deletions docs/source/recipes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: funsor.recipes
:members:
:show-inheritance:
:member-order: bysource
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
joint,
montecarlo,
ops,
recipes,
sum_product,
terms,
testing,
Expand Down Expand Up @@ -96,6 +97,7 @@
"pretty",
"quote",
"reals",
"recipes",
"reinterpret",
"set_backend",
# 'minipyro', # TODO: enable when minipyro is backend-agnostic
Expand Down
20 changes: 15 additions & 5 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __enter__(self):
self._old_interpretation = interpreter.get_interpretation()
return super().__enter__()

def adjoint(self, sum_op, bin_op, root, targets=None):
def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset()):
# TODO Replace this with root + Constant(...) after #548 merges.
root_vars = root.input_vars | batch_vars

zero = to_funsor(ops.UNITS[sum_op])
one = to_funsor(ops.UNITS[bin_op])
Expand Down Expand Up @@ -115,7 +117,9 @@ def adjoint(self, sum_op, bin_op, root, targets=None):

in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
for v, adjv in in_adjs:
agg_vars = adjv.input_vars - v.input_vars - root.input_vars
# Marginalize out message variables that don't appear in recipients.
agg_vars = adjv.input_vars - v.input_vars - root_vars
assert "particle" not in {var.name for var in agg_vars} # DEBUG FIXME
old_value = adjoint_values[v]
adjoint_values[v] = sum_op(old_value, adjv.reduce(sum_op, agg_vars))

Expand All @@ -129,11 +133,17 @@ def adjoint(self, sum_op, bin_op, root, targets=None):
return {target: result[target] for target in targets}


def adjoint(sum_op, bin_op, expr):
def forward_backward(sum_op, bin_op, expr, *, batch_vars=frozenset()):
with AdjointTape() as tape:
# TODO fix traversal order in AdjointTape instead of using stack_reinterpret
root = stack_reinterpret(expr)
return tape.adjoint(sum_op, bin_op, root)
forward = stack_reinterpret(expr)
backward = tape.adjoint(sum_op, bin_op, forward, batch_vars=batch_vars)
return forward, backward


def adjoint(sum_op, bin_op, expr):
forward, backward = forward_backward(sum_op, bin_op, expr)
return backward


# logaddexp/add
Expand Down
2 changes: 1 addition & 1 deletion funsor/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def laplace_approximate_logaddexp(op, model, guide, approx_vars):
################################################################################
# Computations.
# TODO Consider either making these Funsor methods or making .sample() and
# .unscaled_sample() singledispatch functions.
# ._sample() singledispatch functions.


@singledispatch
Expand Down
43 changes: 17 additions & 26 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,31 @@ def __str__(self):
)
return super().__str__()

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
def _sample(self, sampled_vars, sample_inputs, rng_key):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self

if self.red_op in (ops.null, ops.logaddexp):
if self.bin_op in (ops.null, ops.logaddexp):
if rng_key is not None and get_backend() == "jax":
import jax
if rng_key is not None and get_backend() == "jax":
import jax

rng_keys = jax.random.split(rng_key, len(self.terms))
else:
rng_keys = [None] * len(self.terms)
rng_keys = jax.random.split(rng_key, len(self.terms))
else:
rng_keys = [None] * len(self.terms)

if self.bin_op in (ops.null, ops.logaddexp):
# Design choice: we sample over logaddexp reductions, but leave logaddexp
# binary choices symbolic.
terms = [
term.unscaled_sample(
sampled_vars.intersection(term.inputs), sample_inputs
term._sample(
sampled_vars.intersection(term.inputs), sample_inputs, rng_key
)
for term, rng_key in zip(self.terms, rng_keys)
]
return Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)

if self.bin_op is ops.add:
if rng_key is not None and get_backend() == "jax":
import jax

rng_keys = jax.random.split(rng_key)
else:
rng_keys = [None] * 2

# Sample variables greedily in order of the terms in which they appear.
for term in self.terms:
greedy_vars = sampled_vars.intersection(term.inputs)
Expand All @@ -146,9 +139,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
).append(term)
if len(greedy_terms) == 1:
term = greedy_terms[0]
terms.append(
term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0])
)
terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0]))
result = Contraction(
self.red_op, self.bin_op, self.reduced_vars, *terms
)
Expand All @@ -161,9 +152,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
term = discrete + gaussian.log_normalizer
terms.append(gaussian)
terms.append(-gaussian.log_normalizer)
terms.append(
term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0])
)
terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0]))
result = Contraction(
self.red_op, self.bin_op, self.reduced_vars, *terms
)
Expand All @@ -173,10 +162,12 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
for term in greedy_terms
):
sampled_terms = [
term.unscaled_sample(
greedy_vars.intersection(term.value.inputs), sample_inputs
term._sample(
greedy_vars.intersection(term.value.inputs),
sample_inputs,
rng_key,
)
for term in greedy_terms
for term, rng_key in zip(greedy_terms, rng_keys)
if isinstance(term, funsor.distribution.Distribution)
and not greedy_vars.isdisjoint(term.value.inputs)
]
Expand All @@ -192,7 +183,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
", ".join(str(type(t)) for t in greedy_terms)
)
)
return result.unscaled_sample(
return result._sample(
sampled_vars - greedy_vars, sample_inputs, rng_keys[1]
)

Expand Down
2 changes: 1 addition & 1 deletion funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
def _sample(self, sampled_vars, sample_inputs, rng_key):
return self


Expand Down
2 changes: 1 addition & 1 deletion funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def eager_log_prob(cls, *params):
inputs.update(x.inputs)
return log_prob.align(tuple(inputs))

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
def _sample(self, sampled_vars, sample_inputs, rng_key):

# note this should handle transforms correctly via distribution_to_data
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
Expand Down
2 changes: 1 addition & 1 deletion funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
def _sample(self, sampled_vars, sample_inputs, rng_key):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
Expand Down
46 changes: 39 additions & 7 deletions funsor/montecarlo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
from collections import OrderedDict

from funsor.cnf import Contraction
from funsor.delta import Delta
from funsor.integrate import Integrate
from funsor.interpretations import StatefulInterpretation
from funsor.terms import Approximate, Funsor
from funsor.tensor import Tensor
from funsor.terms import Approximate, Funsor, Number
from funsor.util import get_backend

from . import ops
Expand Down Expand Up @@ -36,9 +40,7 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars):
sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options)
if sample is log_measure:
return None # cannot progress
reduced_vars |= frozenset(
v for v in sample.input_vars if v.name in state.sample_inputs
)

return Integrate(sample, integrand, reduced_vars)


Expand All @@ -53,13 +55,43 @@ def monte_carlo_approximate(state, op, model, guide, approx_vars):
sample = guide.sample(approx_vars, state.sample_inputs, **sample_options)
if sample is guide:
return model # cannot progress
reduced_vars = frozenset(
v for v in sample.input_vars if v.name in state.sample_inputs
result = sample + model - guide

return result


@functools.singledispatch
def extract_samples(discrete_density):
"""
Extract sample values out of a funsor Delta, possibly scaled by Tensors.
This is useful for extracting sample tensors from a Monte Carlo
computation.
"""
raise ValueError(
f"Could not extract support from {type(discrete_density).__name__}"
)
result = (sample + model - guide).reduce(op, reduced_vars)


@extract_samples.register(Delta)
def _extract_samples_delta(discrete_density):
return {name: point for name, (point, log_density) in discrete_density.terms}


@extract_samples.register(Contraction)
def _extract_samples_contraction(discrete_density):
assert not discrete_density.reduced_vars
result = {}
for term in discrete_density.terms:
result.update(extract_samples(term))
return result


@extract_samples.register(Number)
@extract_samples.register(Tensor)
def _extract_samples_scale(discrete_density):
return {}


__all__ = [
"MonteCarlo",
]
84 changes: 84 additions & 0 deletions funsor/recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
Recipes using Funsor
--------------------
This module provides a number of high-level algorithms using Funsor.

"""

from typing import Dict, FrozenSet

import funsor # Let's use fully qualified names in this file.


def forward_filter_backward_rsample(
factors: Dict[str, funsor.Funsor],
eliminate: FrozenSet[str],
plates: FrozenSet[str],
sample_inputs: Dict[str, funsor.domains.Domain] = {},
rng_key=None,
):
"""
A forward-filter backward-batched-reparametrized-sample algorithm for use
in variational inference. The motivating use case is performing Gaussian
tensor variable elimination over structured variational posteriors.

:param dict factors: A dictionary mapping sample site name to a Funsor
factor created at that sample site.
:param frozenset: A set of names of latent variables to marginalize and
plates to aggregate.
:param plates: A set of names of plates to aggregate.
:param dict sample_inputs: An optional dict of enclosing sample indices
over which samples will be drawn in batch.
:param rng_key: A random number key for the JAX backend.
:returns: A pair ``samples:Dict[str, Tensor], log_prob: Tensor`` of samples
and log density evaluated at each of those samples. If ``sample_inputs``
is nonempty, both outputs will be batched.
:rtype: tuple
"""
assert isinstance(factors, dict)
assert all(isinstance(k, str) for k in factors)
assert all(isinstance(v, funsor.Funsor) for v in factors.values())
assert isinstance(eliminate, frozenset)
assert all(isinstance(v, str) for v in eliminate)
assert isinstance(plates, frozenset)
assert all(isinstance(v, str) for v in plates)
assert isinstance(sample_inputs, dict)
assert all(isinstance(k, str) for k in sample_inputs)
assert all(isinstance(v, funsor.domains.Domain) for v in sample_inputs.values())

# Perform tensor variable elimination.
with funsor.interpretations.reflect:
log_Z = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
list(factors.values()),
eliminate,
plates,
)
log_Z = funsor.optimizer.apply_optimizer(log_Z)
batch_vars = frozenset(funsor.Variable(k, v) for k, v in sample_inputs.items())
with funsor.montecarlo.MonteCarlo(**sample_inputs, rng_key=rng_key):
log_Z, marginals = funsor.adjoint.forward_backward(
funsor.ops.logaddexp, funsor.ops.add, log_Z, batch_vars=batch_vars
)

# Extract sample tensors.
samples = {}
for name, factor in factors.items():
if name in eliminate:
samples.update(funsor.montecarlo.extract_samples(marginals[factor]))
assert frozenset(samples) == eliminate - plates

# Compute log density at each sample.
log_prob = -log_Z
for f in factors.values():
term = f(**samples)
plates = eliminate.intersection(term.inputs)
term = term.reduce(funsor.ops.add, plates)
log_prob += term
assert set(log_prob.inputs) == set(sample_inputs)

return samples, log_prob
2 changes: 1 addition & 1 deletion funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def eager_reduce(self, op, reduced_vars):
return Tensor(data, inputs, dtype)
return super(Tensor, self).eager_reduce(op, reduced_vars)

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
def _sample(self, sampled_vars, sample_inputs, rng_key):
assert self.output == Real
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
Expand Down
Loading