diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py index 8c71bcac00..c16489610f 100644 --- a/tests/sampling/test_parallel.py +++ b/tests/sampling/test_parallel.py @@ -228,3 +228,23 @@ def logp(x, mu): with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + + +@pytest.mark.parametrize("cores", (1, 2)) +def test_sampling_with_random_generator_matches(cores): + # Regression test for https://github.com/pymc-devs/pymc/issues/7612 + kwargs = { + "chains": 2, + "cores": cores, + "tune": 10, + "draws": 10, + "compute_convergence_checks": False, + "progress_bar": False, + } + with pm.Model() as m: + x = pm.Normal("x") + + post1 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior + post2 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior + + assert post1.equals(post2), (post1["x"].mean().item(), post2["x"].mean().item())