From a61af36ab77dc9926ce2b54cdb3d25f6ecb05dd4 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Sun, 1 Sep 2024 16:01:02 -0500 Subject: [PATCH] Add missing mass matrix in missing tests. --- tests/mcmc/test_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 78421ba83..20348f579 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -889,7 +889,7 @@ def test_mala(self): @chex.all_variants(with_pmap=False) def test_barker(self): inference_algorithm = blackjax.barker_proposal( - self.normal_logprob, step_size=1.5 + self.normal_logprob, step_size=1.5, inverse_mass_matrix=jnp.eye(1) ) initial_state = inference_algorithm.init(jnp.array(1.0)) self.univariate_normal_test_case( @@ -926,7 +926,7 @@ def test_barker(self): }, { "algorithm": blackjax.barker_proposal, - "parameters": {"step_size": 0.5}, + "parameters": {"step_size": 0.5, "inverse_mass_matrix": jnp.eye(2)}, "is_mass_matrix_diagonal": None, }, ]