Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_dependencies(), add poutine.get_plates() #2933

Merged
merged 4 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ jobs:
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip install --upgrade coveralls
pip freeze
- name: Run unit tests
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage unit --durations 20
- name: Submit to coveralls
run: coveralls --service=github
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
Expand All @@ -116,15 +116,15 @@ jobs:
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip install --upgrade coveralls
pip freeze
- name: Run examples
run: |
CI=1 pytest -vs --cov=pyro --cov-config .coveragerc --stage test_examples --durations 10
grep -l smoke_test tutorial/source/*.ipynb | xargs grep -L 'smoke_test = False' \
| CI=1 xargs pytest -vx --nbval-lax --current-env
- name: Submit to coveralls
run: coveralls --service=github
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
Expand All @@ -150,13 +150,13 @@ jobs:
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip install --upgrade coveralls
pip freeze
- name: Run integration test (batch 1)
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_1 --durations 10
- name: Submit to coveralls
run: coveralls --service=github
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
Expand All @@ -182,13 +182,13 @@ jobs:
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip install --upgrade coveralls
pip freeze
- name: Run integration test (batch 2)
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_2 --durations 10
- name: Submit to coveralls
run: coveralls --service=github
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
Expand All @@ -215,14 +215,14 @@ jobs:
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install -e .[funsor]
pip install coveralls
pip install --upgrade coveralls
pip freeze
- name: Run funsor tests
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage funsor --durations 10
CI=1 pytest -vs --cov=pyro --cov-config .coveragerc --stage test_examples --durations 10 -k funsor
- name: Submit to coveralls
run: coveralls --service=github
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, loc, scale, *, validate_args=None):
super().__init__(self.loc.shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(SkewLogistic, _instance)
new = self._get_checked_instance(Logistic, _instance)
batch_shape = torch.Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
Expand Down
33 changes: 17 additions & 16 deletions pyro/infer/inspect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Dict, Optional
from typing import Callable, Dict, List, Optional

import torch

Expand Down Expand Up @@ -37,13 +37,14 @@ def __init__(self, predicate=lambda msg: True):
super().__init__()

def _pyro_post_sample(self, msg):
if is_sample_site(msg):
if is_sample_site(msg) and msg["value"].dtype.is_floating_point:
if self.predicate(msg):
msg["value"].requires_grad_()
elif not msg["is_observed"] and msg["value"].requires_grad:
msg["value"] = msg["value"].detach()


@torch.enable_grad()
def get_dependencies(
model: Callable,
model_args: Optional[tuple] = None,
Expand Down Expand Up @@ -171,7 +172,7 @@ def model_3():
model_kwargs = {}

def get_sample_sites(predicate=lambda msg: True):
with torch.enable_grad(), torch.random.fork_rng():
with torch.random.fork_rng():
with pyro.validation_enabled(False), RequiresGradMessenger(predicate):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
return [msg for msg in trace.nodes.values() if is_sample_site(msg)]
Expand All @@ -187,15 +188,13 @@ def get_sample_sites(predicate=lambda msg: True):
# First find transitive dependencies among latent and observed sites
prior_dependencies = {n: {n: set()} for n in plates} # no deps yet
for i, downstream in enumerate(sample_sites):
upstreams = [u for u in sample_sites[:i] if not u["is_observed"]]
upstreams = [
u for u in sample_sites[:i] if not u["is_observed"] if u["value"].numel()
]
if not upstreams:
continue
grads = torch.autograd.grad(
downstream["fn"].log_prob(downstream["value"]).sum(),
[u["value"] for u in upstreams],
allow_unused=True,
retain_graph=True,
)
log_prob = downstream["fn"].log_prob(downstream["value"]).sum()
grads = _safe_grad(log_prob, [u["value"] for u in upstreams])
for upstream, grad in zip(upstreams, grads):
if grad is not None:
d = downstream["name"]
Expand All @@ -211,12 +210,8 @@ def get_sample_sites(predicate=lambda msg: True):
sample_sites_ij = get_sample_sites(lambda msg: msg["name"] in names)
d = sample_sites_ij[i]
u = sample_sites_ij[j]
grad = torch.autograd.grad(
d["fn"].log_prob(d["value"]).sum(),
[u["value"]],
allow_unused=True,
retain_graph=True,
)[0]
log_prob = d["fn"].log_prob(d["value"]).sum()
grad = _safe_grad(log_prob, [u["value"]])[0]
if grad is None:
prior_dependencies[d["name"]].pop(u["name"])

Expand Down Expand Up @@ -248,6 +243,12 @@ def get_sample_sites(predicate=lambda msg: True):
}


def _safe_grad(root: torch.Tensor, args: List[torch.Tensor]):
if not root.requires_grad:
return [None] * len(args)
return torch.autograd.grad(root, args, allow_unused=True, retain_graph=True)


__all__ = [
"get_dependencies",
]
14 changes: 13 additions & 1 deletion pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from typing import Dict

from pyro.params.param_store import ( # noqa: F401
_MODULE_NAMESPACE_DIVIDER,
Expand Down Expand Up @@ -286,7 +287,7 @@ def _fn(*args, **kwargs):
return _fn


def _inspect():
def _inspect() -> Dict[str, object]:
"""
EXPERIMENTAL Inspect the Pyro stack.

Expand Down Expand Up @@ -334,3 +335,14 @@ def model():
:rtype: None, bool, or torch.Tensor
"""
return _inspect()["mask"]


def get_plates() -> tuple:
"""
Records the effects of enclosing ``pyro.plate`` contexts.

:returns: A tuple of
:class:`pyro.poutine.indep_messenger.CondIndepStackFrame` objects.
:rtype: tuple
"""
return _inspect()["cond_indep_stack"]
2 changes: 2 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def assert_tensors_equal(a, b, prec=0.0, msg=""):
return
b = b.type_as(a)
b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
if not a.dtype.is_floating_point:
return (a == b).all()
# check that NaNs are in the same locations
nan_mask = a != a
assert torch.equal(nan_mask, b != b), msg
Expand Down
58 changes: 56 additions & 2 deletions tests/infer/test_inspect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import pyro
Expand All @@ -9,7 +10,8 @@
from pyro.infer.inspect import get_dependencies


def test_get_dependencies():
@pytest.mark.parametrize("grad_enabled", [True, False])
def test_get_dependencies(grad_enabled):
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", NonreparameterizedNormal(a, 0))
Expand All @@ -30,7 +32,8 @@ def model(data):
return [a, b, c, d, e, f, g, h, i, j, k]

data = torch.randn(3)
actual = get_dependencies(model, (data,))
with torch.set_grad_enabled(grad_enabled):
actual = get_dependencies(model, (data,))
_ = set()
expected = {
"prior_dependencies": {
Expand Down Expand Up @@ -118,6 +121,57 @@ def model_3():
assert actual == expected


def test_factor():
def model():
a = pyro.sample("a", dist.Normal(0, 1))
pyro.factor("b", torch.tensor(0.0))
pyro.factor("c", a)

actual = get_dependencies(model)
expected = {
"prior_dependencies": {
"a": {"a": set()},
"b": {"b": set()},
"c": {"c": set(), "a": set()},
},
"posterior_dependencies": {
"a": {"a": set(), "c": set()},
},
}
assert actual == expected


def test_discrete_obs():
def model():
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(a[..., None], torch.ones(3)).to_event(1))
c = pyro.sample(
"c", dist.MultivariateNormal(torch.zeros(3) + a[..., None], torch.eye(3))
)
with pyro.plate("i", 2):
d = pyro.sample("d", dist.Dirichlet((b + c).exp()))
pyro.sample("e", dist.Categorical(logits=d), obs=torch.tensor([0, 0]))
return a, b, c, d

actual = get_dependencies(model)
expected = {
"prior_dependencies": {
"a": {"a": set()},
"b": {"a": set(), "b": set()},
"c": {"a": set(), "c": set()},
"d": {"b": set(), "c": set(), "d": set()},
"e": {"d": set(), "e": set()},
},
"posterior_dependencies": {
"a": {"a": set(), "b": set(), "c": set()},
"b": {"b": set(), "c": set(), "d": set()},
"c": {"c": set(), "d": set()},
"d": {"d": set(), "e": set()},
},
}
assert actual == expected


def test_plate_coupling():
# x x
# ||
Expand Down
39 changes: 39 additions & 0 deletions tests/poutine/test_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.poutine as poutine
from pyro.poutine.runtime import get_mask, get_plates
from tests.common import assert_equal


def test_get_mask():
assert get_mask() is None

with poutine.mask(mask=True):
assert get_mask() is True
with poutine.mask(mask=False):
assert get_mask() is False

with pyro.plate("i", 2, dim=-1):
mask1 = torch.tensor([False, True, True])
mask2 = torch.tensor([True, True, False])
with poutine.mask(mask=mask1):
assert_equal(get_mask(), mask1)
with poutine.mask(mask=mask2):
assert_equal(get_mask(), mask1 & mask2)


def test_get_plates():
def get_plate_names():
plates = get_plates()
assert isinstance(plates, tuple)
return {f.name for f in plates}

assert get_plate_names() == set()
with pyro.plate("foo", 5):
assert get_plate_names() == {"foo"}
with pyro.plate("bar", 3):
assert get_plate_names() == {"foo", "bar"}