Skip to content

Commit

Permalink
Dynamic HMC (#580)
Browse files Browse the repository at this point in the history
* dynamic HMC kernel builder and state

* tests and docstring
  • Loading branch information
albcab authored Oct 30, 2023
1 parent 29dc2eb commit f5c0822
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 1 deletion.
93 changes: 92 additions & 1 deletion blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@
from blackjax.mcmc.trajectory import hmc_energy
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["HMCState", "HMCInfo", "init", "build_kernel", "hmc"]
__all__ = [
"HMCState",
"DynamicHMCState",
"HMCInfo",
"init",
"init_dynamic",
"build_kernel",
"build_dynamic_kernel",
"hmc",
]


class HMCState(NamedTuple):
Expand All @@ -41,6 +50,20 @@ class HMCState(NamedTuple):
logdensity_grad: ArrayTree


class DynamicHMCState(NamedTuple):
"""State of the dynamic HMC algorithm.
Adds a utility array for generating a pseudo or quasi-random sequence of
number of integration steps.
"""

position: ArrayTree
logdensity: float
logdensity_grad: ArrayTree
random_generator_arg: Array


class HMCInfo(NamedTuple):
"""Additional information on the HMC transition.
Expand Down Expand Up @@ -84,6 +107,13 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable):
return HMCState(position, logdensity, logdensity_grad)


def init_dynamic(
position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array
):
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg)


def build_kernel(
integrator: Callable = integrators.velocity_verlet,
divergence_threshold: float = 1000,
Expand Down Expand Up @@ -145,6 +175,67 @@ def kernel(
return kernel


def build_dynamic_kernel(
integrator: Callable = integrators.velocity_verlet,
divergence_threshold: float = 1000,
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1],
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10),
):
"""Build a Dynamic HMC kernel where the number of integration steps is chosen randomly.
Parameters
----------
integrator
The symplectic integrator to use to integrate the Hamiltonian dynamics.
divergence_threshold
Value of the difference in energy above which we consider that the transition is divergent.
next_random_arg_fn
Function that generates the next `random_generator_arg` from its previous value.
integration_steps_fn
Function that generates the next pseudo or quasi-random number of integration steps in the
sequence, given the current `random_generator_arg`.
Returns
-------
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
"""
hmc_base = build_kernel(integrator, divergence_threshold)

def kernel(
rng_key: PRNGKey,
state: DynamicHMCState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
) -> tuple[DynamicHMCState, HMCInfo]:
"""Generate a new sample with the HMC kernel."""
num_integration_steps = integration_steps_fn(state.random_generator_arg)
hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad)
hmc_proposal, info = hmc_base(
rng_key,
hmc_state,
logdensity_fn,
step_size,
inverse_mass_matrix,
num_integration_steps,
)
next_random_arg = next_random_arg_fn(state.random_generator_arg)
return (
DynamicHMCState(
hmc_proposal.position,
hmc_proposal.logdensity,
hmc_proposal.logdensity_grad,
next_random_arg,
),
info,
)

return kernel


class hmc:
"""Implements the (basic) user interface for the HMC kernel.
Expand Down
31 changes: 31 additions & 0 deletions tests/mcmc/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from absl.testing import absltest, parameterized

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.proposal as proposal
Expand Down Expand Up @@ -292,6 +293,36 @@ def test_static_integration_variable_num_steps(self):
scan_state,
)

def test_dynamic_hmc_integration_steps(self):
rng_key = jax.random.key(0)
num_step_key, sample_key = jax.random.split(rng_key)
initial_position = jnp.array(3.0)
parameters = {"step_size": 3.9, "inverse_mass_matrix": jnp.array([1.0])}

unique_integration_steps = jnp.asarray([5, 10, 20])
unique_probs = jnp.asarray([0.1, 0.8, 0.1])
num_step_fn = lambda key: jax.random.choice(
key, unique_integration_steps, p=unique_probs
)
kernel_factory = hmc.build_dynamic_kernel(integration_steps_fn=num_step_fn)

logprob = jax.scipy.stats.norm.logpdf
hmc_kernel = lambda key, state: kernel_factory(
key, state, logprob, **parameters
)
init_state = hmc.init_dynamic(initial_position, logprob, num_step_key)

def one_step(state, rng_key):
state, info = hmc_kernel(rng_key, state)
return state, info

num_iter = 1000
keys = jax.random.split(sample_key, num_iter)
_, infos = jax.lax.scan(one_step, init_state, keys)
_, unique_counts = np.unique(infos.num_integration_steps, return_counts=True)

np.testing.assert_allclose(unique_counts / num_iter, unique_probs, rtol=1e-1)


if __name__ == "__main__":
absltest.main()

0 comments on commit f5c0822

Please sign in to comment.