diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 9016d2a0e..c6b48712a 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -1,5 +1,6 @@ from blackjax._version import __version__ +from .adaptation.chees_adaptation import chees_adaptation from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation @@ -44,6 +45,7 @@ "csgld", "window_adaptation", # mcmc adaptation "meads_adaptation", + "chees_adaptation", "pathfinder_adaptation", "adaptive_tempered_smc", # smc "tempered_smc", diff --git a/blackjax/adaptation/__init__.py b/blackjax/adaptation/__init__.py index edf6bcd13..91a491ed0 100644 --- a/blackjax/adaptation/__init__.py +++ b/blackjax/adaptation/__init__.py @@ -1,3 +1,13 @@ -from . import meads_adaptation, pathfinder_adaptation, window_adaptation +from . import ( + chees_adaptation, + meads_adaptation, + pathfinder_adaptation, + window_adaptation, +) -__all__ = ["meads_adaptation", "window_adaptation", "pathfinder_adaptation"] +__all__ = [ + "chees_adaptation", + "meads_adaptation", + "window_adaptation", + "pathfinder_adaptation", +] diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py new file mode 100644 index 000000000..ff7981cbf --- /dev/null +++ b/blackjax/adaptation/chees_adaptation.py @@ -0,0 +1,423 @@ +"""Public API for ChEES-HMC""" + +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp +import optax +import scipy.stats.qmc as qmc + +import blackjax.mcmc.hmc as hmc +import blackjax.optimizers.dual_averaging as dual_averaging +from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.base import AdaptationAlgorithm +from blackjax.types import Array, ArrayLikeTree, PRNGKey + + +class ChEESState(NamedTuple): + """State of the jittered HMC algorithm. + + The jittered HMC algorithm extends the HMC state by including the + current iteration of the sampler, used to select the pseudo-random + jittering value for the current step. + """ + + state: hmc.HMCState + current_iteration: int + + +class ChEESAdaptationState(NamedTuple): + """State of the ChEES-HMC adaptation scheme. + + step_size + Value of the step_size parameter of the HMC algorithm. + step_size_moving_average + Running moving average of the step_size parameter. + trajectory_length + Value of the num_integration_steps / step_size parameter of + the HMC algorithm. + trajectory_length_moving_average + Running moving average of the num_integration_steps / step_size + parameter. + optim_state + Optax optimizing state for used to maximize the ChEES criterion. + current_iteration + Current iteration of the warmup, used to find the pseudo-random + jittering value. + + """ + + step_size: float + step_size_moving_average: float + trajectory_length: float + trajectory_length_moving_average: float + da_state: dual_averaging.DualAveragingState + optim_state: optax.OptState + current_iteration: int + + +def base(halton_sequence: Array, optim: optax.GradientTransformation): + """Maximizing the Change in the Estimator of the Expected Square criterion + (trajectory length) and dual averaging procedure (step size) for the jittered + Hamiltonian Monte Carlo kernel [1]_. + + This adaptation algorithm tunes the step size and trajectory length, i.e. + number of integration steps / step size, of the jittered HMC algorithm based + on statistics collected from a population of many chains. It maximizes the Change + in the Estimator of the Expected Square (ChEES) criterion to tune the trajectory + length and uses dual averaging targeting an acceptance rate of 0.651 of the harmonic + mean of the chain's acceptance probabilities to tune the step size. + + Returns + ------- + init + Function that initializes the warmup. + update + Function that moves the warmup one step. + + References + ---------- + .. [1]: Hoffman, M., Radul, A., & Sountsov, P. (2021). An adaptive-MCMC scheme + for setting trajectory lengths in Hamiltonian Monte Carlo. 130:3907-3915. + """ + + da_init, da_update, _ = dual_averaging.dual_averaging() + + def compute_parameters( + proposed_positions: ArrayLikeTree, + proposed_momentums: ArrayLikeTree, + initial_positions: ArrayLikeTree, + acceptance_probabilities: Array, + is_divergent: Array, + initial_adaptation_state: ChEESAdaptationState, + ): + """Compute values for the parameters based on statistics collected from + multiple chains. + + Parameters + ---------- + proposed_positions: + A PyTree that contains the position proposed by the HMC algorithm of + every chain (proposal that is accepted or rejected using MH). + proposed_momentums: + A PyTree that contains the momentum variable proposed by the HMC algorithm + of every chain (proposal that is accepted or rejected using MH). + initial_positions: + A PyTree that contains the initial position at the start of the HMC + algorithm of every chain. + acceptance_probabilities: + Metropolis-Hastings acceptance probabilty of proposals of every chain. + initial_adaptation_state: + ChEES adaptation step used to generate proposals and acceptance probabilities. + + Returns + ------- + New values of the step size and trajectory length of the jittered HMC algorithm. + + """ + ( + step_size, + step_size_ma, + trajectory_length, + trajectory_length_ma, + da_state, + optim_state, + current_iteration, + ) = initial_adaptation_state + + harmonic_mean = 1.0 / jnp.mean( + 1.0 / acceptance_probabilities, where=~is_divergent + ) + da_state_ = da_update(da_state, 0.651 - harmonic_mean) + step_size_ = jnp.exp(da_state_.log_x) + new_step_size, new_da_state = jax.lax.cond( + jnp.isfinite(step_size_), + lambda _: (step_size_, da_state_), + lambda _: (step_size, da_state), + None, + ) + new_step_size_ma = 0.9 * step_size_ma + 0.1 * new_step_size + + proposals_mean = jax.tree_util.tree_map( + lambda p: jnp.nanmean(p, axis=0), proposed_positions + ) + initials_mean = jax.tree_util.tree_map( + lambda p: jnp.nanmean(p, axis=0), initial_positions + ) + proposals_centered = jax.tree_util.tree_map( + lambda p, pm: p - pm, proposed_positions, proposals_mean + ) + initials_centered = jax.tree_util.tree_map( + lambda p, pm: p - pm, initial_positions, initials_mean + ) + + proposals_matrix = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])( + proposals_centered + ) + initials_matrix = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])( + initials_centered + ) + momentums_matrix = jax.vmap(lambda m: jax.flatten_util.ravel_pytree(m)[0])( + proposed_momentums + ) + + trajectory_gradients = ( + halton_sequence[current_iteration] + * trajectory_length + * ( + jax.vmap(lambda p: jnp.dot(p, p))(proposals_matrix) + - jax.vmap(lambda p: jnp.dot(p, p))(initials_matrix) + ) + * jax.vmap(lambda p, m: jnp.dot(p, m))(proposals_matrix, momentums_matrix) + ) + trajectory_gradient = jnp.sum( + acceptance_probabilities * trajectory_gradients, where=~is_divergent + ) / jnp.sum(acceptance_probabilities, where=~is_divergent) + + log_trajectory_length = jnp.log(trajectory_length) + updates, optim_state_ = optim.update( + trajectory_gradient, optim_state, log_trajectory_length + ) + log_trajectory_length_ = optax.apply_updates(log_trajectory_length, updates) + new_log_trajectory_length, new_optim_state = jax.lax.cond( + jnp.isfinite( + jax.flatten_util.ravel_pytree(log_trajectory_length_)[0] + ).all(), + lambda _: (log_trajectory_length_, optim_state_), + lambda _: (log_trajectory_length, optim_state), + None, + ) + new_trajectory_length = jnp.exp(new_log_trajectory_length) + new_trajectory_length_ma = ( + 0.9 * trajectory_length_ma + 0.1 * new_trajectory_length + ) + + return ( + new_step_size, + new_step_size_ma, + new_trajectory_length, + new_trajectory_length_ma, + new_da_state, + new_optim_state, + current_iteration + 1, + ) + + def init(step_size: float): + return ChEESAdaptationState( + step_size, 0.0, step_size, 0.0, da_init(step_size), optim.init(step_size), 0 + ) + + def update( + adaptation_state: ChEESAdaptationState, + proposed_positions: ArrayLikeTree, + proposed_momentums: ArrayLikeTree, + initial_positions: ArrayLikeTree, + acceptance_probabilities: Array, + is_divergent: Array, + ): + """Update the adaptation state and parameter values. + + Parameters + ---------- + adaptation_state + The current state of the adaptation algorithm + proposed_positions: + The position proposed by the HMC algorithm of every chain. + proposed_momentums: + The momentum variable proposed by the HMC algorithm of every chain. + initial_positions: + The initial position at the start of the HMC algorithm of every chain. + acceptance_probabilities: + Metropolis-Hastings acceptance probabilty of proposals of every chain. + + Returns + ------- + New adaptation state that contains the step size and trajectory length of the + jittered HMC algorithm. + + """ + parameters = compute_parameters( + proposed_positions, + proposed_momentums, + initial_positions, + acceptance_probabilities, + is_divergent, + adaptation_state, + ) + + return ChEESAdaptationState(*parameters) + + return init, update + + +def chees_adaptation( + logprob_fn: Callable, + num_chains: int, + *, + jitter_seed: int = None, +) -> AdaptationAlgorithm: + """Adapt the step size and trajectory length (number of integration steps / step size) + parameters of the jittered HMC algorthm. + + The jittered HMC algorithm depends on the value of a step size, controlling + the discretization step of the integrator, and a trajectory length, given by the + number of integration steps / step size, jittered by using only a random percentage + of this trajectory length. + + This adaptation algorithm tunes the trajectory length by heuristically maximizing + the Change in the Estimator of the Expected Square (ChEES) criterion over + an ensamble of parallel chains. At equilibrium, the algorithm aims at eliminating + correlations between target dimensions, making the HMC algorithm efficient. + + Jittering requires generating a random sequence of uniform variables in [0, 1]. + However, this adds another source of variance to the sampling procedure, + which may slow adaptation or lead to suboptimal mixing. To alleviate this, + rather than use uniform random noise to jitter the trajectory lengths, we use a + quasi-random Halton sequence, which ensures a more even distribution of trajectory + lengths. + + Examples + -------- + + An HMC adapted kernel can be learned and used with the following code: + + .. code:: + + warmup = blackjax.chees(logprob_fn, num_chains) + key_warmup, key_sample = jax.random.split(rng_key) + optim = optax.adam(learning_rate) + last_states, kernel, parameters = warmup.run( + key_warmup, + positions, #PyTree where each leaf has shape (num_chains, ...) + initial_step_size, + optim, + num_warmup_steps, + max_sampling_steps=max_working_steps_after_warmup, + ) + new_states, info = jax.vmap(kernel)(key_sample, last_states) + + We can extract an `HMCState` from the `ChEESState` output object: + + .. code:: + + hmc_states = new_states.state + + Parameters + ---------- + logprob_fn + The log density probability density function from which we wish to sample. + num_chains + Number of chains used for cross-chain warm-up training. + jitter_seed + Seed used to create a `numpy.random.Generator` instance to scramble the + Halton sequence creted with `scipy.stats.qmc.Halton` used for jittering + the trajectory length. + + Returns + ------- + A function that returns the last cross-chain state, a sampling kernel with the + tuned parameter values, and all the warm-up states for diagnostics. + + """ + + batch_init = jax.vmap(lambda p: ChEESState(hmc.init(p, logprob_fn), 0)) + step_fn = hmc.build_kernel() + + halton_sequence_gn = qmc.Halton(d=1, scramble=True, seed=jitter_seed) + + def run( + rng_key: PRNGKey, + positions: ArrayLikeTree, + step_size: float, + optim, + num_steps: int = 1000, + *, + max_sampling_steps: int = 1000, + ): + num_dim = 0 + for d in jax.tree_util.tree_leaves( + jax.tree_util.tree_map( + lambda p: p.reshape(num_chains, -1).shape[1], positions + ) + ): + num_dim += d + + halton_sequence = jnp.array( + halton_sequence_gn.random(n=num_steps + max_sampling_steps).squeeze() + ) + init, update = base(halton_sequence, optim) + + def one_step(carry, rng_key): + states, adaptation_state = carry + + def kernel(rng_key, state): + num_integration_steps = jnp.ceil( + halton_sequence[state.current_iteration] + * adaptation_state.trajectory_length + / adaptation_state.step_size + ) + new_state, info = step_fn( + rng_key, + state.state, + logprob_fn, + step_size=adaptation_state.step_size, + inverse_mass_matrix=jnp.ones(num_dim), + num_integration_steps=num_integration_steps, + ) + return ( + ChEESState(new_state, state.current_iteration + 1), + info, + ) + + keys = jax.random.split(rng_key, num_chains) + new_states, info = jax.vmap(kernel)(keys, states) + new_adaptation_state = update( + adaptation_state, + info.proposal.state.position, + info.proposal.state.momentum, + states.state.position, + info.acceptance_rate, + info.is_divergent, + ) + + return (new_states, new_adaptation_state), AdaptationInfo( + new_states, + info, + new_adaptation_state, + ) + + init_states = batch_init(positions) + init_adaptation_state = init(step_size) + + keys = jax.random.split(rng_key, num_steps) + (last_states, last_adaptation_state), info = jax.lax.scan( + one_step, (init_states, init_adaptation_state), keys + ) + + parameters = { + "step_size": last_adaptation_state.step_size_moving_average, + "trajectory_length": last_adaptation_state.trajectory_length_moving_average, + } + + def kernel(rng_key, state): + num_integration_steps = jnp.ceil( + halton_sequence[state.current_iteration] + * parameters["trajectory_length"] + / parameters["step_size"] + ) + new_state, info = step_fn( + rng_key, + state.state, + logprob_fn, + step_size=parameters["step_size"], + inverse_mass_matrix=jnp.ones(num_dim), + num_integration_steps=num_integration_steps, + ) + return ( + ChEESState(new_state, state.current_iteration + 1), + info, + ) + + return AdaptationResults(last_states, parameters), kernel, info + + return AdaptationAlgorithm(run) # type: ignore[arg-type] diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 00f25989d..9665e26d9 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -119,10 +119,18 @@ def integrate( lambda step_size: direction * step_size, step_size ) - def one_step(_, state): - return integrator(state, directed_step_size) + def one_step(state_iter): + state, iter = state_iter + state = integrator(state, directed_step_size) + return (state, iter + 1) + + last_state, _ = jax.lax.while_loop( + lambda state_iter: state_iter[1] < num_integration_steps, + one_step, + (initial_state, 0), + ) - return jax.lax.fori_loop(0, num_integration_steps, one_step, initial_state) + return last_state return integrate diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 8772d2a13..a0e920067 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import jax.scipy.stats as stats import numpy as np +import optax from absl.testing import absltest, parameterized import blackjax @@ -223,6 +224,47 @@ def test_meads(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + def test_chees(self): + """Test the ChEES adaptation w/ HMC kernel.""" + rng_key, init_key0, init_key1 = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logposterior_fn = lambda x: logposterior_fn_(**x) + + init_key, warmup_key, inference_key = jax.random.split(rng_key, 3) + + num_chains = 128 + warmup = blackjax.chees( + logposterior_fn, + num_chains=num_chains, + ) + scale_key, coefs_key = jax.random.split(init_key, 2) + log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,)) + coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,)) + initial_positions = {"log_scale": log_scales, "coefs": coefs} + last_states, kernel, _ = warmup.run( + warmup_key, + initial_positions, + step_size=0.001, + optim=optax.adam(learning_rate=0.1), + num_steps=100, + ) + + chain_keys = jax.random.split(inference_key, num_chains) + states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))( + chain_keys, last_states + ) + + coefs_samples = states.state.position["coefs"] + scale_samples = np.exp(states.state.position["log_scale"]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + class SGMCMCTest(chex.TestCase): """Test sampling of a linear regression model."""