Skip to content

Commit

Permalink
Adjusts for the fact that IRMH proposal might not be symmetric (#581)
Browse files Browse the repository at this point in the history
* Adjusting for the fact that the IRMH proposal might not be symmetric

* Fixing type annotation

* Adding proposal_logdensity_fn to the interface exposed to users, with default parameter

* Applying pre-commit to all files

* Adding minimal test
  • Loading branch information
ciguaran authored Oct 31, 2023
1 parent f5c0822 commit f5a2a12
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
24 changes: 20 additions & 4 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,22 +270,28 @@ def kernel(
state: RWState,
logdensity_fn: Callable,
proposal_distribution: Callable,
proposal_logdensity_fn: Optional[Callable] = None,
) -> tuple[RWState, RWInfo]:
"""
Parameters
----------
proposal_distribution
A function that, given a PRNGKey, is able to produce a sample in the same
domain of the target distribution.
proposal_logdensity_fn:
For non-symmetric proposals, a function that returns the log-density
to obtain a given proposal knowing the current state. If it is not
provided we assume the proposal is symmetric.
"""

def proposal_generator(rng_key: PRNGKey, position: ArrayTree):
del position
return proposal_distribution(rng_key)

inner_kernel = build_rmh()
return inner_kernel(rng_key, state, logdensity_fn, proposal_generator)
return inner_kernel(
rng_key, state, logdensity_fn, proposal_generator, proposal_logdensity_fn
)

return kernel

Expand Down Expand Up @@ -318,7 +324,10 @@ class irmh:
proposal_distribution
A Callable that takes a random number generator and produces a new proposal. The
proposal is independent of the sampler's current state.
proposal_logdensity_fn:
For non-symmetric proposals, a function that returns the log-density
to obtain a given proposal knowing the current state. If it is not
provided we assume the proposal is symmetric.
Returns
-------
A ``SamplingAlgorithm``.
Expand All @@ -332,14 +341,21 @@ def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
proposal_distribution: Callable,
proposal_logdensity_fn: Optional[Callable] = None,
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, proposal_distribution)
return kernel(
rng_key,
state,
logdensity_fn,
proposal_distribution,
proposal_logdensity_fn,
)

return SamplingAlgorithm(init_fn, step_fn)

Expand Down
53 changes: 40 additions & 13 deletions tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,58 @@ def test_logdensity_accepts(position):


class IRMHTest(unittest.TestCase):
def proposal_distribution(self, key):
return jnp.array([10.0])

def logdensity_accepts(self, position):
"""
a logdensity that gets maximized after the step
"""
return 0.0 if all(position - 10.0 < 1e-10) else 0.5

def test_proposal_is_independent_of_position(self):
"""New position does not depend on previous"""
"""New position does not depend on previous position"""
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])
other_position = jnp.array([15000.0])

def proposal_distribution(key):
return jnp.array([10.0])

def test_logdensity_accepts(position):
"""
a logdensity that gets maximized after the step
"""
return 0.0 if all(position - 10.0 < 1e-10) else 0.5

step = build_irmh()

for previous_position in [initial_position, other_position]:
new_state, _ = step(
new_state, state_info = step(
rng_key,
RWState(position=previous_position, logdensity=1.0),
test_logdensity_accepts,
proposal_distribution,
self.logdensity_accepts,
self.proposal_distribution,
)
np.testing.assert_allclose(new_state.position, jnp.array([10.0]))
np.testing.assert_allclose(state_info.acceptance_rate, 0.367879, rtol=1e-5)

def test_non_symmetric_proposal(self):
"""
Given that proposal_logdensity_fn is included,
thus the proposal is non-symmetric.
When computing the acceptance of the proposed state
Then proposal_logdensity_fn value is taken into account
"""
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])

def test_proposal_logdensity(new_state, prev_state):
return 0.1 if all(new_state.position - 10 < 1e-10) else 0.5

step = build_irmh()

for previous_position in [initial_position]:
_, state_info = step(
rng_key,
RWState(position=previous_position, logdensity=1.0),
self.logdensity_accepts,
self.proposal_distribution,
test_proposal_logdensity,
)

np.testing.assert_allclose(state_info.acceptance_rate, 0.246597)


class RMHProposalTest(unittest.TestCase):
Expand Down

0 comments on commit f5a2a12

Please sign in to comment.