diff --git a/pyro/ops/provenance.py b/pyro/ops/provenance.py index 02a515c7e5..5058e66414 100644 --- a/pyro/ops/provenance.py +++ b/pyro/ops/provenance.py @@ -46,14 +46,15 @@ def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs): assert not isinstance(data, ProvenanceTensor) if not provenance: return data - return super().__new__(cls) + ret = data.as_subclass(cls) + ret._t = data # this makes sure that detach_provenance always + # returns the same object. This is important when + # using the tensor as key in a dict, e.g. the global + # param store + return ret def __init__(self, data, provenance=frozenset()): assert isinstance(provenance, frozenset) - if isinstance(data, ProvenanceTensor): - provenance |= data._provenance - data = data._t - self._t = data self._provenance = provenance def __repr__(self): diff --git a/tests/ops/test_provenance.py b/tests/ops/test_provenance.py new file mode 100644 index 0000000000..818d9d83ca --- /dev/null +++ b/tests/ops/test_provenance.py @@ -0,0 +1,45 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pyro.ops.provenance import ProvenanceTensor +from tests.common import assert_equal, requires_cuda + + +@requires_cuda +@pytest.mark.parametrize( + "dtype1", + [ + torch.float16, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ], +) +@pytest.mark.parametrize( + "dtype2", + [ + torch.float16, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ], +) +def test_provenance_tensor(dtype1, dtype2): + device = torch.device("cuda") + x = torch.tensor([1, 2, 3], dtype=dtype1) + y = ProvenanceTensor(x, frozenset(["x"])) + z = torch.as_tensor(y, device=device, dtype=dtype2) + + assert x.shape == y.shape == z.shape + assert_equal(x, z.cpu())