From d59a960f89873667d6190489ff0e975091e57d10 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 7 Aug 2023 10:26:07 +0200 Subject: [PATCH] Do not include seeded_test fixture in exported BaseTestDistributionRandom --- pymc/testing.py | 2 +- tests/distributions/test_continuous.py | 2 +- tests/distributions/test_multivariate.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index 12200db69be..3eb1b7ba819 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -848,7 +848,7 @@ class BaseTestDistributionRandom: repeated_params_shape = 5 random_state = None - def test_distribution(self, seeded_test): + def test_distribution(self): self.validate_tests_list() if self.pymc_dist == pm.Wishart: with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"): diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 0e21e81d2fa..bab4e281c16 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -1825,7 +1825,7 @@ class TestStudentT(BaseTestDistributionRandom): class TestHalfStudentT(BaseTestDistributionRandom): def halfstudentt_rng_fn(self, df, loc, scale, size, rng): - return np.abs(st.t.rvs(df=df, loc=loc, scale=scale, size=size)) + return np.abs(st.t.rvs(df=df, loc=loc, scale=scale, size=size, random_state=rng)) pymc_dist = pm.HalfStudentT pymc_dist_params = {"nu": 5.0, "sigma": 2.0} diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 99707fbe870..752f2502914 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1972,7 +1972,7 @@ class TestKroneckerNormal(BaseTestDistributionRandom): def kronecker_rng_fn(self, size, mu, covs=None, sigma=None, rng=None): cov = pm.math.kronecker(covs[0], covs[1]).eval() cov += sigma**2 * np.identity(cov.shape[0]) - return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size) + return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size, random_state=rng) pymc_dist = pm.KroneckerNormal