From cec95ad48d7bd26311ae212f683974f4514480df Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 20 Sep 2021 21:22:21 -0400 Subject: [PATCH 1/4] Fix get_inspect() --- pyro/infer/inspect.py | 33 +++++++++++---------- tests/infer/test_inspect.py | 59 +++++++++++++++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 18 deletions(-) diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index ccb642ebe2..f8a91f65d7 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Optional +from typing import Callable, Dict, List, Optional import torch @@ -37,13 +37,14 @@ def __init__(self, predicate=lambda msg: True): super().__init__() def _pyro_post_sample(self, msg): - if is_sample_site(msg): + if is_sample_site(msg) and msg["value"].dtype.is_floating_point: if self.predicate(msg): msg["value"].requires_grad_() elif not msg["is_observed"] and msg["value"].requires_grad: msg["value"] = msg["value"].detach() +@torch.enable_grad() def get_dependencies( model: Callable, model_args: Optional[tuple] = None, @@ -171,7 +172,7 @@ def model_3(): model_kwargs = {} def get_sample_sites(predicate=lambda msg: True): - with torch.enable_grad(), torch.random.fork_rng(): + with torch.random.fork_rng(): with pyro.validation_enabled(False), RequiresGradMessenger(predicate): trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) return [msg for msg in trace.nodes.values() if is_sample_site(msg)] @@ -187,15 +188,13 @@ def get_sample_sites(predicate=lambda msg: True): # First find transitive dependencies among latent and observed sites prior_dependencies = {n: {n: set()} for n in plates} # no deps yet for i, downstream in enumerate(sample_sites): - upstreams = [u for u in sample_sites[:i] if not u["is_observed"]] + upstreams = [ + u for u in sample_sites[:i] if not u["is_observed"] if u["value"].numel() + ] if not upstreams: continue - grads = torch.autograd.grad( - downstream["fn"].log_prob(downstream["value"]).sum(), - [u["value"] for u in upstreams], - allow_unused=True, - retain_graph=True, - ) + log_prob = downstream["fn"].log_prob(downstream["value"]).sum() + grads = _safe_grad(log_prob, [u["value"] for u in upstreams]) for upstream, grad in zip(upstreams, grads): if grad is not None: d = downstream["name"] @@ -211,12 +210,8 @@ def get_sample_sites(predicate=lambda msg: True): sample_sites_ij = get_sample_sites(lambda msg: msg["name"] in names) d = sample_sites_ij[i] u = sample_sites_ij[j] - grad = torch.autograd.grad( - d["fn"].log_prob(d["value"]).sum(), - [u["value"]], - allow_unused=True, - retain_graph=True, - )[0] + log_prob = d["fn"].log_prob(d["value"]).sum() + grad = _safe_grad(log_prob, [u["value"]])[0] if grad is None: prior_dependencies[d["name"]].pop(u["name"]) @@ -248,6 +243,12 @@ def get_sample_sites(predicate=lambda msg: True): } +def _safe_grad(root: torch.Tensor, args: List[torch.Tensor]): + if not root.requires_grad: + return [None] * len(args) + return torch.autograd.grad(root, args, allow_unused=True, retain_graph=True) + + __all__ = [ "get_dependencies", ] diff --git a/tests/infer/test_inspect.py b/tests/infer/test_inspect.py index a0916ef3ac..eb4b82ac77 100644 --- a/tests/infer/test_inspect.py +++ b/tests/infer/test_inspect.py @@ -8,8 +8,11 @@ from pyro.distributions.testing.fakes import NonreparameterizedNormal from pyro.infer.inspect import get_dependencies +import pytest -def test_get_dependencies(): + +@pytest.mark.parametrize("grad_enabled", [True, False]) +def test_get_dependencies(grad_enabled): def model(data): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", NonreparameterizedNormal(a, 0)) @@ -30,7 +33,8 @@ def model(data): return [a, b, c, d, e, f, g, h, i, j, k] data = torch.randn(3) - actual = get_dependencies(model, (data,)) + with torch.set_grad_enabled(grad_enabled): + actual = get_dependencies(model, (data,)) _ = set() expected = { "prior_dependencies": { @@ -118,6 +122,57 @@ def model_3(): assert actual == expected +def test_factor(): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.factor("b", torch.tensor(0.0)) + pyro.factor("c", a) + + actual = get_dependencies(model) + expected = { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"b": set()}, + "c": {"c": set(), "a": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "c": set()}, + }, + } + assert actual == expected + + +def test_discrete_obs(): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a[..., None], torch.ones(3)).to_event(1)) + c = pyro.sample( + "c", dist.MultivariateNormal(torch.zeros(3) + a[..., None], torch.eye(3)) + ) + with pyro.plate("i", 2): + d = pyro.sample("d", dist.Dirichlet((b + c).exp())) + pyro.sample("e", dist.Categorical(logits=d), obs=torch.tensor([0, 0])) + return a, b, c, d + + actual = get_dependencies(model) + expected = { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"a": set(), "b": set()}, + "c": {"a": set(), "c": set()}, + "d": {"b": set(), "c": set(), "d": set()}, + "e": {"d": set(), "e": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "b": set(), "c": set()}, + "b": {"b": set(), "c": set(), "d": set()}, + "c": {"c": set(), "d": set()}, + "d": {"d": set(), "e": set()}, + }, + } + assert actual == expected + + def test_plate_coupling(): # x x # || From 2d03add978f54c93009e1c01519612e32fb8eca2 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 20 Sep 2021 21:27:28 -0400 Subject: [PATCH 2/4] Add more tests --- pyro/distributions/logistic.py | 2 +- pyro/poutine/runtime.py | 14 +++++++++++- tests/common.py | 2 ++ tests/poutine/test_runtime.py | 39 ++++++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 tests/poutine/test_runtime.py diff --git a/pyro/distributions/logistic.py b/pyro/distributions/logistic.py index 354aaef933..77be18f700 100644 --- a/pyro/distributions/logistic.py +++ b/pyro/distributions/logistic.py @@ -42,7 +42,7 @@ def __init__(self, loc, scale, *, validate_args=None): super().__init__(self.loc.shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): - new = self._get_checked_instance(SkewLogistic, _instance) + new = self._get_checked_instance(Logistic, _instance) batch_shape = torch.Size(batch_shape) new.loc = self.loc.expand(batch_shape) new.scale = self.scale.expand(batch_shape) diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index e5a980c895..59b27c8911 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +from typing import Dict from pyro.params.param_store import ( # noqa: F401 _MODULE_NAMESPACE_DIVIDER, @@ -286,7 +287,7 @@ def _fn(*args, **kwargs): return _fn -def _inspect(): +def _inspect() -> Dict[str, object]: """ EXPERIMENTAL Inspect the Pyro stack. @@ -334,3 +335,14 @@ def model(): :rtype: None, bool, or torch.Tensor """ return _inspect()["mask"] + + +def get_plates() -> tuple: + """ + Records the effects of enclosing ``pyro.plate`` contexts. + + :returns: A tuple of + :class:`pyro.poutine.indep_messenger.CondIndepStackFrame` objects. + :rtype: tuple + """ + return _inspect()["cond_indep_stack"] diff --git a/tests/common.py b/tests/common.py index 100bc83cf6..28708ba8b4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -145,6 +145,8 @@ def assert_tensors_equal(a, b, prec=0.0, msg=""): return b = b.type_as(a) b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu() + if not a.dtype.is_floating_point: + return (a == b).all() # check that NaNs are in the same locations nan_mask = a != a assert torch.equal(nan_mask, b != b), msg diff --git a/tests/poutine/test_runtime.py b/tests/poutine/test_runtime.py new file mode 100644 index 0000000000..1cd4395287 --- /dev/null +++ b/tests/poutine/test_runtime.py @@ -0,0 +1,39 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import pyro +import pyro.poutine as poutine +from pyro.poutine.runtime import get_mask, get_plates +from tests.common import assert_equal + + +def test_get_mask(): + assert get_mask() is None + + with poutine.mask(mask=True): + assert get_mask() is True + with poutine.mask(mask=False): + assert get_mask() is False + + with pyro.plate("i", 2, dim=-1): + mask1 = torch.tensor([False, True, True]) + mask2 = torch.tensor([True, True, False]) + with poutine.mask(mask=mask1): + assert_equal(get_mask(), mask1) + with poutine.mask(mask=mask2): + assert_equal(get_mask(), mask1 & mask2) + + +def test_get_plates(): + def get_plate_names(): + plates = get_plates() + assert isinstance(plates, tuple) + return {f.name for f in plates} + + assert get_plate_names() == set() + with pyro.plate("foo", 5): + assert get_plate_names() == {"foo"} + with pyro.plate("bar", 3): + assert get_plate_names() == {"foo", "bar"} From d7d4edf6ae82387fa5e36ba77cb72c2abea03612 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 20 Sep 2021 21:28:03 -0400 Subject: [PATCH 3/4] lint --- tests/infer/test_inspect.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/infer/test_inspect.py b/tests/infer/test_inspect.py index eb4b82ac77..2f600c28a2 100644 --- a/tests/infer/test_inspect.py +++ b/tests/infer/test_inspect.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch import pyro @@ -8,8 +9,6 @@ from pyro.distributions.testing.fakes import NonreparameterizedNormal from pyro.infer.inspect import get_dependencies -import pytest - @pytest.mark.parametrize("grad_enabled", [True, False]) def test_get_dependencies(grad_enabled): From 3c8b694e1d2d41903165d5182ab85efff0797edb Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 21 Sep 2021 07:36:39 -0400 Subject: [PATCH 4/4] Ignore failure of coveralls stage --- .github/workflows/ci.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1d194d574a..ed7254553c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -84,13 +84,13 @@ jobs: pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] - pip install coveralls + pip install --upgrade coveralls pip freeze - name: Run unit tests run: | pytest -vs --cov=pyro --cov-config .coveragerc --stage unit --durations 20 - name: Submit to coveralls - run: coveralls --service=github + run: coveralls --service=github || true env: COVERALLS_PARALLEL: true COVERALLS_FLAG_NAME: ${{ matrix.test-name }} @@ -116,7 +116,7 @@ jobs: pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] - pip install coveralls + pip install --upgrade coveralls pip freeze - name: Run examples run: | @@ -124,7 +124,7 @@ jobs: grep -l smoke_test tutorial/source/*.ipynb | xargs grep -L 'smoke_test = False' \ | CI=1 xargs pytest -vx --nbval-lax --current-env - name: Submit to coveralls - run: coveralls --service=github + run: coveralls --service=github || true env: COVERALLS_PARALLEL: true COVERALLS_FLAG_NAME: ${{ matrix.test-name }} @@ -150,13 +150,13 @@ jobs: pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] - pip install coveralls + pip install --upgrade coveralls pip freeze - name: Run integration test (batch 1) run: | pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_1 --durations 10 - name: Submit to coveralls - run: coveralls --service=github + run: coveralls --service=github || true env: COVERALLS_PARALLEL: true COVERALLS_FLAG_NAME: ${{ matrix.test-name }} @@ -182,13 +182,13 @@ jobs: pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] - pip install coveralls + pip install --upgrade coveralls pip freeze - name: Run integration test (batch 2) run: | pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_2 --durations 10 - name: Submit to coveralls - run: coveralls --service=github + run: coveralls --service=github || true env: COVERALLS_PARALLEL: true COVERALLS_FLAG_NAME: ${{ matrix.test-name }} @@ -215,14 +215,14 @@ jobs: pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] pip install -e .[funsor] - pip install coveralls + pip install --upgrade coveralls pip freeze - name: Run funsor tests run: | pytest -vs --cov=pyro --cov-config .coveragerc --stage funsor --durations 10 CI=1 pytest -vs --cov=pyro --cov-config .coveragerc --stage test_examples --durations 10 -k funsor - name: Submit to coveralls - run: coveralls --service=github + run: coveralls --service=github || true env: COVERALLS_PARALLEL: true COVERALLS_FLAG_NAME: ${{ matrix.test-name }}