-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a compiler and tracer, each creating OpPrograms (#557)
* Add a FunsorProgram for running backend computations * Support program constants * Rename vm -> compiler * Start to implement lowering stage * Add docs * Reduce recursion * Support Tuple * Add a .as_code() method * Xfail test * Sketch tracer * Fix docs * Fix test import * Fix a test * Add tests, mark some tests xfail * lint * Ignore irrelevant ops
- Loading branch information
Showing
12 changed files
with
649 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
Compiler & Tracer | ||
----------------- | ||
|
||
.. automodule:: funsor.compiler | ||
:members: | ||
:show-inheritance: | ||
:member-order: bysource | ||
|
||
.. automodule:: funsor.ops.tracer | ||
:members: | ||
:show-inheritance: | ||
:member-order: bysource | ||
|
||
.. automodule:: funsor.ops.program | ||
:members: | ||
:show-inheritance: | ||
:member-order: bysource | ||
:special-members: __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import functools | ||
|
||
import funsor | ||
|
||
from .cnf import Contraction | ||
from .ops.program import OpProgram, make_tuple | ||
from .tensor import Tensor | ||
from .terms import Binary, Funsor, Number, Tuple, Unary, Variable | ||
|
||
|
||
def compile_funsor(expr: Funsor) -> OpProgram: | ||
""" | ||
Compiles a symbolic :class:`~funsor.terms.Funsor` to an | ||
:class:`~funsor.ops.program.OpProgram` that runs on backend values. | ||
Example:: | ||
# Create a lazy expression. | ||
a = Variable("a", Reals[3, 3]) | ||
b = Variable("b", Reals[3]) | ||
x = Variable("x", Reals[3]) | ||
expr = a @ x + b | ||
# Evaluate via Funsor substitution. | ||
data = dict(a=randn(3, 3), b=randn(3), x=randn(3)) | ||
expected = expr(**data).data | ||
# Alternatively evaluate via a program. | ||
program = compile_funsor(expr) | ||
actual = program(**data) | ||
assert (acutal == expected).all() | ||
:param Funsor expr: A funsor expression to evaluate. | ||
:returns: An op program. | ||
:rtype: ~funsor.ops.program.OpProgram | ||
""" | ||
assert isinstance(expr, Funsor) | ||
|
||
# Lower and convert to A-normal form. | ||
lowered_expr = lower(expr) | ||
anf = list(funsor.interpreter.anf(lowered_expr)) | ||
ids = {} | ||
|
||
# Collect constants (leaves). | ||
constants = [] | ||
for f in anf: | ||
if isinstance(f, (Number, Tensor)): | ||
ids[f] = len(ids) | ||
constants.append(f.data) | ||
|
||
# Collect input variables (leaves). | ||
inputs = [] | ||
for k, d in expr.inputs.items(): | ||
f = Variable(k, d) | ||
ids[f] = len(ids) | ||
inputs.append(k) | ||
|
||
# Collect operations to be computed (internal nodes). | ||
operations = [] | ||
for f in anf: | ||
if f in ids: | ||
continue # constant or free variable | ||
ids[f] = len(ids) | ||
if isinstance(f, Unary): | ||
arg_ids = (ids[f.arg],) | ||
operations.append((f.op, arg_ids)) | ||
elif isinstance(f, Binary): | ||
arg_ids = (ids[f.lhs], ids[f.rhs]) | ||
operations.append((f.op, arg_ids)) | ||
elif isinstance(f, Tuple): | ||
arg_ids = tuple(ids[arg] for arg in f.args) | ||
operations.append((make_tuple, arg_ids)) | ||
elif isinstance(f, tuple): | ||
continue # Skip from Tuple directly to its elements. | ||
else: | ||
raise NotImplementedError(type(f).__name__) | ||
|
||
return OpProgram(constants, inputs, operations) | ||
|
||
|
||
def lower(expr: Funsor) -> Funsor: | ||
""" | ||
Lower a funsor expression: | ||
- eliminate bound variables | ||
- convert Contraction to Binary | ||
:param Funsor expr: An arbitrary funsor expression. | ||
:returns: A lowered funsor expression. | ||
:rtype: Funsor | ||
""" | ||
# FIXME should this be lazy? What about Lambda? | ||
with funsor.interpretations.reflect: | ||
return _lower(expr) | ||
|
||
|
||
@functools.singledispatch | ||
def _lower(x): | ||
raise NotImplementedError(type(x).__name__) | ||
|
||
|
||
@_lower.register(Number) | ||
@_lower.register(Tensor) | ||
@_lower.register(Variable) | ||
def _lower_atom(x): | ||
return x | ||
|
||
|
||
@_lower.register(Tuple) | ||
def _lower_tuple(x): | ||
args = tuple(_lower(arg) for arg in x.args) | ||
return Tuple(args) | ||
|
||
|
||
@_lower.register(Unary) | ||
def _lower_unary(x): | ||
arg = _lower(x.arg) | ||
return Unary(x.op, arg) | ||
|
||
|
||
@_lower.register(Binary) | ||
def _lower_binary(x): | ||
lhs = _lower(x.lhs) | ||
rhs = _lower(x.rhs) | ||
return Binary(x.op, lhs, rhs) | ||
|
||
|
||
@_lower.register(Contraction) | ||
def _lower_contraction(x): | ||
if x.reduced_vars: | ||
raise NotImplementedError("TODO") | ||
|
||
terms = [_lower(term) for term in x.terms] | ||
bin_op = functools.partial(Binary, x.bin_op) | ||
return functools.reduce(bin_op, terms) | ||
|
||
|
||
__all__ = [ | ||
"lower", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from . import array, builtin, op | ||
from . import array, builtin, op, program, tracer | ||
from .array import * | ||
from .builtin import * | ||
from .op import * | ||
from .program import * | ||
from .tracer import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from funsor.util import get_backend, set_backend | ||
|
||
|
||
class OpProgram: | ||
""" | ||
Backend program for evaluating a symbolic funsor expression. | ||
Programs depend on the funsor library only via ``funsor.ops`` and op | ||
registrations; program evaluation does not involve funsor interpretation or | ||
rewriting. Programs can be pickled and unpickled. | ||
:param iterable expr: A list of built-in constants (leaves). | ||
:param iterable inputs: A list of string names of program inputs (leaves). | ||
:param iterable operations: A list of program operations defining | ||
non-leaf nodes in the program dag. Each operations is a tuple ``(op, | ||
arg_ids)`` where op is a funsor op and ``arg_ids`` is a tuple of | ||
positions of values, starting from zero and counting: constants, | ||
inputs, and operation outputs. | ||
""" | ||
|
||
def __init__(self, constants, inputs, operations): | ||
super().__init__() | ||
self.constants = tuple(constants) | ||
self.inputs = tuple(inputs) | ||
self.operations = tuple(operations) | ||
self.backend = get_backend() | ||
|
||
def __call__(self, **kwargs): | ||
set_backend(self.backend) | ||
|
||
# Initialize environment with constants. | ||
env = list(self.constants) | ||
|
||
# Read inputs from kwargs. | ||
for name in self.inputs: | ||
value = kwargs.pop(name, None) | ||
if value is None: | ||
raise ValueError(f"Missing kwarg: {repr(name)}") | ||
env.append(value) | ||
if kwargs: | ||
raise ValueError(f"Unrecognized kwargs: {set(kwargs)}") | ||
|
||
# Sequentially compute ops. | ||
for op, arg_ids in self.operations: | ||
args = tuple(env[i] for i in arg_ids) | ||
value = op(*args) | ||
env.append(value) | ||
|
||
result = env[-1] | ||
return result | ||
|
||
def as_code(self, name="program"): | ||
""" | ||
Returns Python code text defining a straight-line function equivalent | ||
to this program. | ||
:param str name: Optional name for the function, defaults to "program". | ||
:returns: A string defining a python function equivalent to this program. | ||
:rtype: str | ||
""" | ||
lines = [ | ||
"# Automatically generated by funsor.compiler.FunsorProgram.as_code().", | ||
"def {}({}):".format(name, ", ".join(self.inputs)), | ||
" from funsor import set_backend, ops", | ||
f" set_backend({repr(self.backend)})", | ||
] | ||
start = len(lines) | ||
|
||
def let(body): | ||
i = len(lines) - start | ||
lines.append(f" v{i} = {body}") | ||
|
||
for c in self.constants: | ||
let(c) | ||
for name in self.inputs: | ||
let(name) | ||
for op, arg_ids in self.operations: | ||
op = _print_op(op) | ||
args = ", ".join(f"v{arg_id}" for arg_id in arg_ids) | ||
let(f"{op}({args},)") | ||
lines.append(f" return v{len(lines) - start - 1}") | ||
return "\n".join(lines) | ||
|
||
|
||
def make_tuple(*args): | ||
return args | ||
|
||
|
||
def _print_op(op): | ||
if op is make_tuple: | ||
return "" | ||
if op.defaults and op.defaults != type(op)().defaults: | ||
args = ", ".join(map(str, op.defaults.values())) | ||
return f"ops.{type(op).__name__}({args})" | ||
return repr(op) | ||
|
||
|
||
__all__ = [ | ||
"OpProgram", | ||
] |
Oops, something went wrong.