From dbb29891b94e834d38776f6f62f08752722f37be Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 4 Oct 2021 12:24:11 -0400 Subject: [PATCH] Add a compiler and tracer, each creating OpPrograms (#557) * Add a FunsorProgram for running backend computations * Support program constants * Rename vm -> compiler * Start to implement lowering stage * Add docs * Reduce recursion * Support Tuple * Add a .as_code() method * Xfail test * Sketch tracer * Fix docs * Fix test import * Fix a test * Add tests, mark some tests xfail * lint * Ignore irrelevant ops --- docs/source/compiler.rst | 18 +++++ docs/source/index.rst | 1 + funsor/__init__.py | 2 + funsor/compiler.py | 142 +++++++++++++++++++++++++++++++++++++++ funsor/ops/__init__.py | 4 +- funsor/ops/op.py | 35 +++++++++- funsor/ops/program.py | 103 ++++++++++++++++++++++++++++ funsor/ops/tracer.py | 116 ++++++++++++++++++++++++++++++++ funsor/testing.py | 5 ++ funsor/util.py | 2 + test/test_compiler.py | 94 ++++++++++++++++++++++++++ test/test_tracer.py | 129 +++++++++++++++++++++++++++++++++++ 12 files changed, 649 insertions(+), 2 deletions(-) create mode 100644 docs/source/compiler.rst create mode 100644 funsor/compiler.py create mode 100644 funsor/ops/program.py create mode 100644 funsor/ops/tracer.py create mode 100644 test/test_compiler.py create mode 100644 test/test_tracer.py diff --git a/docs/source/compiler.rst b/docs/source/compiler.rst new file mode 100644 index 000000000..f8b214111 --- /dev/null +++ b/docs/source/compiler.rst @@ -0,0 +1,18 @@ +Compiler & Tracer +----------------- + +.. automodule:: funsor.compiler + :members: + :show-inheritance: + :member-order: bysource + +.. automodule:: funsor.ops.tracer + :members: + :show-inheritance: + :member-order: bysource + +.. automodule:: funsor.ops.program + :members: + :show-inheritance: + :member-order: bysource + :special-members: __call__ diff --git a/docs/source/index.rst b/docs/source/index.rst index 6d1f63b5d..048b15c23 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,6 +33,7 @@ Funsor is a tensor-like library for functions and distributions distributions minipyro einsum + compiler .. toctree:: :maxdepth: 1 diff --git a/funsor/__init__.py b/funsor/__init__.py index c8e04e084..dbed289af 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -31,6 +31,7 @@ affine, approximations, cnf, + compiler, constant, delta, distribution, @@ -78,6 +79,7 @@ "backward", "bint", "cnf", + "compiler", "constant", "delta", "distribution", diff --git a/funsor/compiler.py b/funsor/compiler.py new file mode 100644 index 000000000..e045e41ac --- /dev/null +++ b/funsor/compiler.py @@ -0,0 +1,142 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import functools + +import funsor + +from .cnf import Contraction +from .ops.program import OpProgram, make_tuple +from .tensor import Tensor +from .terms import Binary, Funsor, Number, Tuple, Unary, Variable + + +def compile_funsor(expr: Funsor) -> OpProgram: + """ + Compiles a symbolic :class:`~funsor.terms.Funsor` to an + :class:`~funsor.ops.program.OpProgram` that runs on backend values. + + Example:: + + # Create a lazy expression. + a = Variable("a", Reals[3, 3]) + b = Variable("b", Reals[3]) + x = Variable("x", Reals[3]) + expr = a @ x + b + + # Evaluate via Funsor substitution. + data = dict(a=randn(3, 3), b=randn(3), x=randn(3)) + expected = expr(**data).data + + # Alternatively evaluate via a program. + program = compile_funsor(expr) + actual = program(**data) + assert (acutal == expected).all() + + :param Funsor expr: A funsor expression to evaluate. + :returns: An op program. + :rtype: ~funsor.ops.program.OpProgram + """ + assert isinstance(expr, Funsor) + + # Lower and convert to A-normal form. + lowered_expr = lower(expr) + anf = list(funsor.interpreter.anf(lowered_expr)) + ids = {} + + # Collect constants (leaves). + constants = [] + for f in anf: + if isinstance(f, (Number, Tensor)): + ids[f] = len(ids) + constants.append(f.data) + + # Collect input variables (leaves). + inputs = [] + for k, d in expr.inputs.items(): + f = Variable(k, d) + ids[f] = len(ids) + inputs.append(k) + + # Collect operations to be computed (internal nodes). + operations = [] + for f in anf: + if f in ids: + continue # constant or free variable + ids[f] = len(ids) + if isinstance(f, Unary): + arg_ids = (ids[f.arg],) + operations.append((f.op, arg_ids)) + elif isinstance(f, Binary): + arg_ids = (ids[f.lhs], ids[f.rhs]) + operations.append((f.op, arg_ids)) + elif isinstance(f, Tuple): + arg_ids = tuple(ids[arg] for arg in f.args) + operations.append((make_tuple, arg_ids)) + elif isinstance(f, tuple): + continue # Skip from Tuple directly to its elements. + else: + raise NotImplementedError(type(f).__name__) + + return OpProgram(constants, inputs, operations) + + +def lower(expr: Funsor) -> Funsor: + """ + Lower a funsor expression: + - eliminate bound variables + - convert Contraction to Binary + + :param Funsor expr: An arbitrary funsor expression. + :returns: A lowered funsor expression. + :rtype: Funsor + """ + # FIXME should this be lazy? What about Lambda? + with funsor.interpretations.reflect: + return _lower(expr) + + +@functools.singledispatch +def _lower(x): + raise NotImplementedError(type(x).__name__) + + +@_lower.register(Number) +@_lower.register(Tensor) +@_lower.register(Variable) +def _lower_atom(x): + return x + + +@_lower.register(Tuple) +def _lower_tuple(x): + args = tuple(_lower(arg) for arg in x.args) + return Tuple(args) + + +@_lower.register(Unary) +def _lower_unary(x): + arg = _lower(x.arg) + return Unary(x.op, arg) + + +@_lower.register(Binary) +def _lower_binary(x): + lhs = _lower(x.lhs) + rhs = _lower(x.rhs) + return Binary(x.op, lhs, rhs) + + +@_lower.register(Contraction) +def _lower_contraction(x): + if x.reduced_vars: + raise NotImplementedError("TODO") + + terms = [_lower(term) for term in x.terms] + bin_op = functools.partial(Binary, x.bin_op) + return functools.reduce(bin_op, terms) + + +__all__ = [ + "lower", +] diff --git a/funsor/ops/__init__.py b/funsor/ops/__init__.py index b4479fe0b..f2b62fd39 100644 --- a/funsor/ops/__init__.py +++ b/funsor/ops/__init__.py @@ -1,7 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from . import array, builtin, op +from . import array, builtin, op, program, tracer from .array import * from .builtin import * from .op import * +from .program import * +from .tracer import * diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 2ed10d75f..ed963206b 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -1,11 +1,13 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import contextlib import functools import inspect import math import operator import weakref +from collections import OrderedDict from funsor.registry import PartialDispatcher from funsor.util import methodof @@ -60,6 +62,23 @@ def __call__(self, *args, **kwargs): return self.fn(arg, *args, **kwargs) +_TRACE = None +_TRACE_FILTER_ARGS = None + + +@contextlib.contextmanager +def trace_ops(filter_args): + global _TRACE, _TRACE_FILTER_ARGS + assert _TRACE is None, "not reentrant" + try: + _TRACE = OrderedDict() + _TRACE_FILTER_ARGS = filter_args + yield _TRACE + finally: + _TRACE = None + _TRACE_FILTER_ARGS = None + + class OpMeta(type): """ Metaclass for :class:`Op` classes. @@ -159,6 +178,9 @@ def __str__(self): return self.__name__ def __call__(self, *args, **kwargs): + global _TRACE, _TRACE_FILTER_ARGS + raw_args = args + # Normalize args, kwargs. cls = type(self) bound = cls.signature.bind_partial(*args, **kwargs) @@ -170,7 +192,18 @@ def __call__(self, *args, **kwargs): # Dispatch. fn = cls.dispatcher.partial_call(*args[: cls.arity]) - return fn(*args, **kwargs) + if _TRACE is None or not _TRACE_FILTER_ARGS(raw_args): + result = fn(*args, **kwargs) + else: + # Trace this op but avoid tracing internal ops. + try: + trace, _TRACE = _TRACE, None + result = fn(*args, **kwargs) + trace.setdefault(id(result), (result, self, raw_args)) + finally: + _TRACE = trace + + return result def register(self, *pattern): if len(pattern) != self.arity: diff --git a/funsor/ops/program.py b/funsor/ops/program.py new file mode 100644 index 000000000..659ed328e --- /dev/null +++ b/funsor/ops/program.py @@ -0,0 +1,103 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from funsor.util import get_backend, set_backend + + +class OpProgram: + """ + Backend program for evaluating a symbolic funsor expression. + + Programs depend on the funsor library only via ``funsor.ops`` and op + registrations; program evaluation does not involve funsor interpretation or + rewriting. Programs can be pickled and unpickled. + + :param iterable expr: A list of built-in constants (leaves). + :param iterable inputs: A list of string names of program inputs (leaves). + :param iterable operations: A list of program operations defining + non-leaf nodes in the program dag. Each operations is a tuple ``(op, + arg_ids)`` where op is a funsor op and ``arg_ids`` is a tuple of + positions of values, starting from zero and counting: constants, + inputs, and operation outputs. + """ + + def __init__(self, constants, inputs, operations): + super().__init__() + self.constants = tuple(constants) + self.inputs = tuple(inputs) + self.operations = tuple(operations) + self.backend = get_backend() + + def __call__(self, **kwargs): + set_backend(self.backend) + + # Initialize environment with constants. + env = list(self.constants) + + # Read inputs from kwargs. + for name in self.inputs: + value = kwargs.pop(name, None) + if value is None: + raise ValueError(f"Missing kwarg: {repr(name)}") + env.append(value) + if kwargs: + raise ValueError(f"Unrecognized kwargs: {set(kwargs)}") + + # Sequentially compute ops. + for op, arg_ids in self.operations: + args = tuple(env[i] for i in arg_ids) + value = op(*args) + env.append(value) + + result = env[-1] + return result + + def as_code(self, name="program"): + """ + Returns Python code text defining a straight-line function equivalent + to this program. + + :param str name: Optional name for the function, defaults to "program". + :returns: A string defining a python function equivalent to this program. + :rtype: str + """ + lines = [ + "# Automatically generated by funsor.compiler.FunsorProgram.as_code().", + "def {}({}):".format(name, ", ".join(self.inputs)), + " from funsor import set_backend, ops", + f" set_backend({repr(self.backend)})", + ] + start = len(lines) + + def let(body): + i = len(lines) - start + lines.append(f" v{i} = {body}") + + for c in self.constants: + let(c) + for name in self.inputs: + let(name) + for op, arg_ids in self.operations: + op = _print_op(op) + args = ", ".join(f"v{arg_id}" for arg_id in arg_ids) + let(f"{op}({args},)") + lines.append(f" return v{len(lines) - start - 1}") + return "\n".join(lines) + + +def make_tuple(*args): + return args + + +def _print_op(op): + if op is make_tuple: + return "" + if op.defaults and op.defaults != type(op)().defaults: + args = ", ".join(map(str, op.defaults.values())) + return f"ops.{type(op).__name__}({args})" + return repr(op) + + +__all__ = [ + "OpProgram", +] diff --git a/funsor/ops/tracer.py b/funsor/ops/tracer.py new file mode 100644 index 000000000..03127655e --- /dev/null +++ b/funsor/ops/tracer.py @@ -0,0 +1,116 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict +from functools import singledispatch + +from .array import is_numeric_array +from .op import trace_ops +from .program import OpProgram + + +def _debug(x): + return f"{type(x).__module__.split('.')[0]}.{type(x).__name__}({hex(id(x))[2:]})" + + +def trace_function(fn, kwargs: dict, *, allow_constants=False): + """ + Traces function to an :class:`~funsor.ops.program.OpProgram` that runs on + backend values. + + Example:: + + # Create a function involving ops. + def fn(a, b, x): + return ops.add(ops.matmul(a, x), b) + + # Evaluate via Funsor substitution. + data = dict(a=randn(3, 3), b=randn(3), x=randn(3)) + expected = fn(**data) + + # Alternatively evaluate via a program. + program = trace_function(expr, data) + actual = program(**data) + assert (acutal == expected).all() + + :param Funsor expr: A funsor expression to evaluate. + :returns: An op program. + :rtype: ~funsor.ops.program.OpProgram + """ + # Extract kwargs. + assert isinstance(kwargs, dict) + assert all(is_variable(v) for v in kwargs.values()) + kwarg_ids = {id(v) for v in kwargs.values()} + assert len(kwarg_ids) == len(kwargs), "repeated inputs" + + # Trace the function. + with trace_ops(is_variable) as trace: + root = fn(**kwargs) + assert is_variable(root) + + # Extract relevant portion of trace. + dag = OrderedDict({id(root): (root, None, None)}) + for result, op, args in reversed(trace.values()): # backward + if id(result) not in dag or not is_variable(result): + continue # not needed + for arg in args: + dag.setdefault(id(arg), (arg, None, None)) + dag[id(result)] = result, op, args + anf = list(reversed(dag.values())) # forward + + # Collect constants (leaves). + ids = {} + constants = [] + for result, op, args in anf: + if op is None and id(result) not in kwarg_ids: + ids[id(result)] = len(ids) + constants.append(result) + if not allow_constants and is_variable(result): + raise ValueError(f"Found constant: {repr(result)}") + + # Collect inputs (leaves). + inputs = [] + for name, value in kwargs.items(): + ids[id(value)] = len(ids) + inputs.append(name) + + # Collect operations to be computed (internal nodes). + operations = [] + for result, op, args in anf: + if id(result) in ids: + continue # constant or free variable + assert op is not None + ids[id(result)] = len(ids) + arg_ids = tuple(ids[id(arg)] for arg in args) + operations.append((op, arg_ids)) + + return OpProgram(constants, inputs, operations) + + +@singledispatch +def is_variable(x): + """ + An object is variable if it is either backend arrays or is a nested tuple + containing at least one backend array. + """ + return is_numeric_array(x) + + +@is_variable.register(int) +def _is_variable_int(x): + return type(x) is not int # allow numpy types + + +@is_variable.register(float) +def _is_variable_float(x): + return type(x) is not float # allow numpy types + + +@is_variable.register(tuple) +def _is_variable_tuple(x): + return any(map(is_variable, x)) + + +__all__ = [ + "trace_function", +] diff --git a/funsor/testing.py b/funsor/testing.py index acedea8f4..beea50a6f 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -197,6 +197,11 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): assert set(actual) == set(expected) for k, actual_v in actual.items(): assert_close(actual_v, expected[k], atol=atol, rtol=rtol) + elif isinstance(actual, tuple): + assert isinstance(expected, tuple) + assert len(actual) == len(expected) + for actual_v, expected_v in zip(actual, expected): + assert_close(actual_v, expected_v, atol=atol, rtol=rtol) else: raise ValueError("cannot compare objects of type {}".format(type(actual))) diff --git a/funsor/util.py b/funsor/util.py index a14ffb44d..043b2a808 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -167,6 +167,8 @@ def set_backend(backend): :param str backend: either "numpy", "torch", or "jax". """ global _FUNSOR_BACKEND, _JAX_LOADED + if _FUNSOR_BACKEND == backend: + return if backend == "numpy": if _JAX_LOADED: diff --git a/test/test_compiler.py b/test/test_compiler.py new file mode 100644 index 000000000..373029e58 --- /dev/null +++ b/test/test_compiler.py @@ -0,0 +1,94 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import functools + +import pytest + +import funsor.ops as ops +from funsor.compiler import compile_funsor +from funsor.domains import Real, Reals +from funsor.interpretations import reflect +from funsor.optimizer import apply_optimizer +from funsor.sum_product import sum_product +from funsor.tensor import Tensor +from funsor.terms import Number, Tuple, Variable +from funsor.testing import assert_close, randn + + +@functools.singledispatch +def extract_data(x): + raise TypeError(type(x).__name__) + + +@extract_data.register(Number) +@extract_data.register(Tensor) +def _(x): + return x.data + + +@extract_data.register(Tuple) +def _(x): + return tuple(extract_data(arg) for arg in x.args) + + +def check_compiler(expr): + # Create a random substitution. + subs = {k: randn(d.shape) for k, d in expr.inputs.items()} + expected = expr(**subs) + expected_data = extract_data(expected) + + # Execute a funsor program. + program = compile_funsor(expr) + actual = program(**subs) + assert_close(actual, expected_data) + + # Execute a printed program. + code = program.as_code(name="program2") + print(code) + env = {} + exec(code, None, env) + actual = env["program2"](**subs) + assert_close(actual, expected_data) + + +def test_lowered_1(): + x = Variable("x", Reals[3]) + check_compiler(x) + + +def test_lowered_2(): + x = Variable("x", Reals[3]) + y = x * x + check_compiler(y) + + +def test_lowered_3(): + x = Variable("x", Reals[3]) + y = 1 + x * x + z = y[0] * y[1] + y[2] + check_compiler(z) + + +def test_lowered_4(): + x = Variable("x", Real) + y = Variable("y", Real) + z = Tuple((Number(1), x, y, x * y)) + check_compiler(z) + + +@pytest.mark.xfail(reason="Bound variable lowering is not yet supported") +def test_sum_product(): + factors = [ + Variable("f", Reals[5])["x"], + Variable("g", Reals[5, 4])["x", "y"], + Variable("h", Reals[4, 3, 2])["y", "z", "i"], + ] + eliminate = frozenset({"x", "y", "z", "i"}) + plates = frozenset({"i"}) + with reflect: + expr = sum_product(ops.logaddexp, ops.add, factors, eliminate, plates) + expr = apply_optimizer(expr) + assert set(expr.inputs) == {"f", "g", "h"} + + check_compiler(expr) diff --git a/test/test_tracer.py b/test/test_tracer.py new file mode 100644 index 000000000..992ce8928 --- /dev/null +++ b/test/test_tracer.py @@ -0,0 +1,129 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict + +import pytest + +import funsor.ops as ops +from funsor.domains import Real, Reals +from funsor.gaussian import Gaussian +from funsor.interpretations import reflect +from funsor.interpreter import reinterpret +from funsor.op_factory import make_op +from funsor.ops.tracer import trace_function +from funsor.optimizer import apply_optimizer +from funsor.sum_product import sum_product +from funsor.tensor import Tensor +from funsor.terms import to_data, to_funsor +from funsor.testing import assert_close, randn + + +def check_tracer(fn, data): + expected = fn(**data) + + traced_fn = trace_function(fn, data) + actual = traced_fn(**data) + assert_close(actual, expected) + + +def test_id(): + def fn(x): + return x + + data = dict(x=randn(2, 3)) + check_tracer(fn, data) + + +def test_chain(): + def fn(x): + for i in range(10): + x = ops.mul(x, x) + return x + + data = dict(x=randn(4)) + check_tracer(fn, data) + + +def test_wrapped_op(): + @make_op + def wrapper(x: Real) -> Real: + # This should not be traced. + return ops.add(x, x) + + def fn(x): + y = ops.add(x, 1) + z = wrapper(y) + return ops.add(z, 1) + + data = dict(x=randn(())) + check_tracer(fn, data) + + +def test_use_funsor_interally_1(): + def fn(x, y, z): + # Convert backend arrays -> to funsors. + x = to_funsor(x) + y = to_funsor(y) + z = to_funsor(z) + + # Operate on funsors. + result = x @ y + z + 1 + + # Convert funsors -> to backend array. + return to_data(result) + + data = dict(x=randn(2, 3), y=randn(3, 4), z=randn(2, 4)) + check_tracer(fn, data) + + +@pytest.mark.xfail(reason="TODO Gaussian directly uses backend, bypassing ops") +def test_use_funsor_interally_2(): + def gaussian_log_prob(info_vec, precision, value): + # Convert backend arrays -> to funsors. + g = Gaussian(info_vec, precision, OrderedDict(x=Reals[3])) + value = to_funsor(value) + + # Operate on funsors. + log_prob = g(x=value) - g.reduce(ops.logaddexp) + + # Convert funsors -> to backend array. + return to_data(log_prob) + + p = randn(3, 3) + precision = p @ p.T + data = dict(info_vec=randn(3), precision=precision, value=randn(3)) + check_tracer(gaussian_log_prob, data) + + +@pytest.mark.xfail(reason="TODO support tuples for multiple outputs") +def test_tuple(): + def fn(x, y): + return (1, x, y, ops.mul(x, y)) + + data = dict(x=randn(3), y=randn(2, 1)) + check_tracer(fn, data) + + +@pytest.mark.xfail( + reason="funsor.cnf._eager_contract_tensors directly calls opt_einsum, bypassing ops" +) +def test_sum_product(): + def fn(f, g, h): + # This function only uses Funsors internally. + factors = [Tensor(f)["x"], Tensor(g)["x", "y"], Tensor(h)["y", "z", "i"]] + eliminate = frozenset({"x", "y", "z", "i"}) + plates = frozenset({"i"}) + with reflect: + expr = sum_product(ops.logaddexp, ops.add, factors, eliminate, plates) + expr = apply_optimizer(expr) + print(f"DEBUG\n{expr.pretty()}") + expr = reinterpret(expr) + return expr.data + + data = dict(f=randn(5), g=randn(5, 4), h=randn(4, 3, 2)) + expected = fn(**data) + + traced_fn = trace_function(fn, data) + actual = traced_fn(**data) + assert_close(actual, expected)