diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index 9d7a0abee..68b550e1e 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -270,14 +270,18 @@ def kernel( state: RWState, logdensity_fn: Callable, proposal_distribution: Callable, + proposal_logdensity_fn: Optional[Callable] = None, ) -> tuple[RWState, RWInfo]: """ - Parameters ---------- proposal_distribution A function that, given a PRNGKey, is able to produce a sample in the same domain of the target distribution. + proposal_logdensity_fn: + For non-symmetric proposals, a function that returns the log-density + to obtain a given proposal knowing the current state. If it is not + provided we assume the proposal is symmetric. """ def proposal_generator(rng_key: PRNGKey, position: ArrayTree): @@ -285,7 +289,9 @@ def proposal_generator(rng_key: PRNGKey, position: ArrayTree): return proposal_distribution(rng_key) inner_kernel = build_rmh() - return inner_kernel(rng_key, state, logdensity_fn, proposal_generator) + return inner_kernel( + rng_key, state, logdensity_fn, proposal_generator, proposal_logdensity_fn + ) return kernel @@ -318,7 +324,10 @@ class irmh: proposal_distribution A Callable that takes a random number generator and produces a new proposal. The proposal is independent of the sampler's current state. - + proposal_logdensity_fn: + For non-symmetric proposals, a function that returns the log-density + to obtain a given proposal knowing the current state. If it is not + provided we assume the proposal is symmetric. Returns ------- A ``SamplingAlgorithm``. @@ -332,6 +341,7 @@ def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, proposal_distribution: Callable, + proposal_logdensity_fn: Optional[Callable] = None, ) -> SamplingAlgorithm: kernel = cls.build_kernel() @@ -339,7 +349,13 @@ def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, proposal_distribution) + return kernel( + rng_key, + state, + logdensity_fn, + proposal_distribution, + proposal_logdensity_fn, + ) return SamplingAlgorithm(init_fn, step_fn) diff --git a/tests/mcmc/test_random_walk_without_chex.py b/tests/mcmc/test_random_walk_without_chex.py index 8bbcd578e..e5ce69dcc 100644 --- a/tests/mcmc/test_random_walk_without_chex.py +++ b/tests/mcmc/test_random_walk_without_chex.py @@ -49,31 +49,58 @@ def test_logdensity_accepts(position): class IRMHTest(unittest.TestCase): + def proposal_distribution(self, key): + return jnp.array([10.0]) + + def logdensity_accepts(self, position): + """ + a logdensity that gets maximized after the step + """ + return 0.0 if all(position - 10.0 < 1e-10) else 0.5 + def test_proposal_is_independent_of_position(self): - """New position does not depend on previous""" + """New position does not depend on previous position""" rng_key = jax.random.key(0) initial_position = jnp.array([50.0]) other_position = jnp.array([15000.0]) - def proposal_distribution(key): - return jnp.array([10.0]) - - def test_logdensity_accepts(position): - """ - a logdensity that gets maximized after the step - """ - return 0.0 if all(position - 10.0 < 1e-10) else 0.5 - step = build_irmh() for previous_position in [initial_position, other_position]: - new_state, _ = step( + new_state, state_info = step( rng_key, RWState(position=previous_position, logdensity=1.0), - test_logdensity_accepts, - proposal_distribution, + self.logdensity_accepts, + self.proposal_distribution, ) np.testing.assert_allclose(new_state.position, jnp.array([10.0])) + np.testing.assert_allclose(state_info.acceptance_rate, 0.367879, rtol=1e-5) + + def test_non_symmetric_proposal(self): + """ + Given that proposal_logdensity_fn is included, + thus the proposal is non-symmetric. + When computing the acceptance of the proposed state + Then proposal_logdensity_fn value is taken into account + """ + rng_key = jax.random.key(0) + initial_position = jnp.array([50.0]) + + def test_proposal_logdensity(new_state, prev_state): + return 0.1 if all(new_state.position - 10 < 1e-10) else 0.5 + + step = build_irmh() + + for previous_position in [initial_position]: + _, state_info = step( + rng_key, + RWState(position=previous_position, logdensity=1.0), + self.logdensity_accepts, + self.proposal_distribution, + test_proposal_logdensity, + ) + + np.testing.assert_allclose(state_info.acceptance_rate, 0.246597) class RMHProposalTest(unittest.TestCase):