Skip to content

Commit

Permalink
Add a compiler and tracer, each creating OpPrograms (#557)
Browse files Browse the repository at this point in the history
* Add a FunsorProgram for running backend computations

* Support program constants

* Rename vm -> compiler

* Start to implement lowering stage

* Add docs

* Reduce recursion

* Support Tuple

* Add a .as_code() method

* Xfail test

* Sketch tracer

* Fix docs

* Fix test import

* Fix a test

* Add tests, mark some tests xfail

* lint

* Ignore irrelevant ops
  • Loading branch information
fritzo authored Oct 4, 2021
1 parent ca1557b commit dbb2989
Show file tree
Hide file tree
Showing 12 changed files with 649 additions and 2 deletions.
18 changes: 18 additions & 0 deletions docs/source/compiler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Compiler & Tracer
-----------------

.. automodule:: funsor.compiler
:members:
:show-inheritance:
:member-order: bysource

.. automodule:: funsor.ops.tracer
:members:
:show-inheritance:
:member-order: bysource

.. automodule:: funsor.ops.program
:members:
:show-inheritance:
:member-order: bysource
:special-members: __call__
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Funsor is a tensor-like library for functions and distributions
distributions
minipyro
einsum
compiler

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
affine,
approximations,
cnf,
compiler,
constant,
delta,
distribution,
Expand Down Expand Up @@ -78,6 +79,7 @@
"backward",
"bint",
"cnf",
"compiler",
"constant",
"delta",
"distribution",
Expand Down
142 changes: 142 additions & 0 deletions funsor/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools

import funsor

from .cnf import Contraction
from .ops.program import OpProgram, make_tuple
from .tensor import Tensor
from .terms import Binary, Funsor, Number, Tuple, Unary, Variable


def compile_funsor(expr: Funsor) -> OpProgram:
"""
Compiles a symbolic :class:`~funsor.terms.Funsor` to an
:class:`~funsor.ops.program.OpProgram` that runs on backend values.
Example::
# Create a lazy expression.
a = Variable("a", Reals[3, 3])
b = Variable("b", Reals[3])
x = Variable("x", Reals[3])
expr = a @ x + b
# Evaluate via Funsor substitution.
data = dict(a=randn(3, 3), b=randn(3), x=randn(3))
expected = expr(**data).data
# Alternatively evaluate via a program.
program = compile_funsor(expr)
actual = program(**data)
assert (acutal == expected).all()
:param Funsor expr: A funsor expression to evaluate.
:returns: An op program.
:rtype: ~funsor.ops.program.OpProgram
"""
assert isinstance(expr, Funsor)

# Lower and convert to A-normal form.
lowered_expr = lower(expr)
anf = list(funsor.interpreter.anf(lowered_expr))
ids = {}

# Collect constants (leaves).
constants = []
for f in anf:
if isinstance(f, (Number, Tensor)):
ids[f] = len(ids)
constants.append(f.data)

# Collect input variables (leaves).
inputs = []
for k, d in expr.inputs.items():
f = Variable(k, d)
ids[f] = len(ids)
inputs.append(k)

# Collect operations to be computed (internal nodes).
operations = []
for f in anf:
if f in ids:
continue # constant or free variable
ids[f] = len(ids)
if isinstance(f, Unary):
arg_ids = (ids[f.arg],)
operations.append((f.op, arg_ids))
elif isinstance(f, Binary):
arg_ids = (ids[f.lhs], ids[f.rhs])
operations.append((f.op, arg_ids))
elif isinstance(f, Tuple):
arg_ids = tuple(ids[arg] for arg in f.args)
operations.append((make_tuple, arg_ids))
elif isinstance(f, tuple):
continue # Skip from Tuple directly to its elements.
else:
raise NotImplementedError(type(f).__name__)

return OpProgram(constants, inputs, operations)


def lower(expr: Funsor) -> Funsor:
"""
Lower a funsor expression:
- eliminate bound variables
- convert Contraction to Binary
:param Funsor expr: An arbitrary funsor expression.
:returns: A lowered funsor expression.
:rtype: Funsor
"""
# FIXME should this be lazy? What about Lambda?
with funsor.interpretations.reflect:
return _lower(expr)


@functools.singledispatch
def _lower(x):
raise NotImplementedError(type(x).__name__)


@_lower.register(Number)
@_lower.register(Tensor)
@_lower.register(Variable)
def _lower_atom(x):
return x


@_lower.register(Tuple)
def _lower_tuple(x):
args = tuple(_lower(arg) for arg in x.args)
return Tuple(args)


@_lower.register(Unary)
def _lower_unary(x):
arg = _lower(x.arg)
return Unary(x.op, arg)


@_lower.register(Binary)
def _lower_binary(x):
lhs = _lower(x.lhs)
rhs = _lower(x.rhs)
return Binary(x.op, lhs, rhs)


@_lower.register(Contraction)
def _lower_contraction(x):
if x.reduced_vars:
raise NotImplementedError("TODO")

terms = [_lower(term) for term in x.terms]
bin_op = functools.partial(Binary, x.bin_op)
return functools.reduce(bin_op, terms)


__all__ = [
"lower",
]
4 changes: 3 additions & 1 deletion funsor/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from . import array, builtin, op
from . import array, builtin, op, program, tracer
from .array import *
from .builtin import *
from .op import *
from .program import *
from .tracer import *
35 changes: 34 additions & 1 deletion funsor/ops/op.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import contextlib
import functools
import inspect
import math
import operator
import weakref
from collections import OrderedDict

from funsor.registry import PartialDispatcher
from funsor.util import methodof
Expand Down Expand Up @@ -60,6 +62,23 @@ def __call__(self, *args, **kwargs):
return self.fn(arg, *args, **kwargs)


_TRACE = None
_TRACE_FILTER_ARGS = None


@contextlib.contextmanager
def trace_ops(filter_args):
global _TRACE, _TRACE_FILTER_ARGS
assert _TRACE is None, "not reentrant"
try:
_TRACE = OrderedDict()
_TRACE_FILTER_ARGS = filter_args
yield _TRACE
finally:
_TRACE = None
_TRACE_FILTER_ARGS = None


class OpMeta(type):
"""
Metaclass for :class:`Op` classes.
Expand Down Expand Up @@ -159,6 +178,9 @@ def __str__(self):
return self.__name__

def __call__(self, *args, **kwargs):
global _TRACE, _TRACE_FILTER_ARGS
raw_args = args

# Normalize args, kwargs.
cls = type(self)
bound = cls.signature.bind_partial(*args, **kwargs)
Expand All @@ -170,7 +192,18 @@ def __call__(self, *args, **kwargs):

# Dispatch.
fn = cls.dispatcher.partial_call(*args[: cls.arity])
return fn(*args, **kwargs)
if _TRACE is None or not _TRACE_FILTER_ARGS(raw_args):
result = fn(*args, **kwargs)
else:
# Trace this op but avoid tracing internal ops.
try:
trace, _TRACE = _TRACE, None
result = fn(*args, **kwargs)
trace.setdefault(id(result), (result, self, raw_args))
finally:
_TRACE = trace

return result

def register(self, *pattern):
if len(pattern) != self.arity:
Expand Down
103 changes: 103 additions & 0 deletions funsor/ops/program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from funsor.util import get_backend, set_backend


class OpProgram:
"""
Backend program for evaluating a symbolic funsor expression.
Programs depend on the funsor library only via ``funsor.ops`` and op
registrations; program evaluation does not involve funsor interpretation or
rewriting. Programs can be pickled and unpickled.
:param iterable expr: A list of built-in constants (leaves).
:param iterable inputs: A list of string names of program inputs (leaves).
:param iterable operations: A list of program operations defining
non-leaf nodes in the program dag. Each operations is a tuple ``(op,
arg_ids)`` where op is a funsor op and ``arg_ids`` is a tuple of
positions of values, starting from zero and counting: constants,
inputs, and operation outputs.
"""

def __init__(self, constants, inputs, operations):
super().__init__()
self.constants = tuple(constants)
self.inputs = tuple(inputs)
self.operations = tuple(operations)
self.backend = get_backend()

def __call__(self, **kwargs):
set_backend(self.backend)

# Initialize environment with constants.
env = list(self.constants)

# Read inputs from kwargs.
for name in self.inputs:
value = kwargs.pop(name, None)
if value is None:
raise ValueError(f"Missing kwarg: {repr(name)}")
env.append(value)
if kwargs:
raise ValueError(f"Unrecognized kwargs: {set(kwargs)}")

# Sequentially compute ops.
for op, arg_ids in self.operations:
args = tuple(env[i] for i in arg_ids)
value = op(*args)
env.append(value)

result = env[-1]
return result

def as_code(self, name="program"):
"""
Returns Python code text defining a straight-line function equivalent
to this program.
:param str name: Optional name for the function, defaults to "program".
:returns: A string defining a python function equivalent to this program.
:rtype: str
"""
lines = [
"# Automatically generated by funsor.compiler.FunsorProgram.as_code().",
"def {}({}):".format(name, ", ".join(self.inputs)),
" from funsor import set_backend, ops",
f" set_backend({repr(self.backend)})",
]
start = len(lines)

def let(body):
i = len(lines) - start
lines.append(f" v{i} = {body}")

for c in self.constants:
let(c)
for name in self.inputs:
let(name)
for op, arg_ids in self.operations:
op = _print_op(op)
args = ", ".join(f"v{arg_id}" for arg_id in arg_ids)
let(f"{op}({args},)")
lines.append(f" return v{len(lines) - start - 1}")
return "\n".join(lines)


def make_tuple(*args):
return args


def _print_op(op):
if op is make_tuple:
return ""
if op.defaults and op.defaults != type(op)().defaults:
args = ", ".join(map(str, op.defaults.values()))
return f"ops.{type(op).__name__}({args})"
return repr(op)


__all__ = [
"OpProgram",
]
Loading

0 comments on commit dbb2989

Please sign in to comment.