diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 9016d2a0e..be4ee924b 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -1,5 +1,6 @@ from blackjax._version import __version__ +from .adaptation.chees_adaptation import chees_adaptation from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation @@ -7,7 +8,7 @@ from .diagnostics import potential_scale_reduction as rhat from .mcmc.elliptical_slice import elliptical_slice from .mcmc.ghmc import ghmc -from .mcmc.hmc import hmc +from .mcmc.hmc import dynamic_hmc, hmc from .mcmc.mala import mala from .mcmc.marginal_latent_gaussian import mgrad_gaussian from .mcmc.nuts import nuts @@ -29,6 +30,7 @@ "dual_averaging", # optimizers "lbfgs", "hmc", # mcmc + "dynamic_hmc", "mala", "mgrad_gaussian", "nuts", @@ -44,6 +46,7 @@ "csgld", "window_adaptation", # mcmc adaptation "meads_adaptation", + "chees_adaptation", "pathfinder_adaptation", "adaptive_tempered_smc", # smc "tempered_smc", diff --git a/blackjax/adaptation/__init__.py b/blackjax/adaptation/__init__.py index edf6bcd13..91a491ed0 100644 --- a/blackjax/adaptation/__init__.py +++ b/blackjax/adaptation/__init__.py @@ -1,3 +1,13 @@ -from . import meads_adaptation, pathfinder_adaptation, window_adaptation +from . import ( + chees_adaptation, + meads_adaptation, + pathfinder_adaptation, + window_adaptation, +) -__all__ = ["meads_adaptation", "window_adaptation", "pathfinder_adaptation"] +__all__ = [ + "chees_adaptation", + "meads_adaptation", + "window_adaptation", + "pathfinder_adaptation", +] diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py new file mode 100644 index 000000000..a43a0d9cd --- /dev/null +++ b/blackjax/adaptation/chees_adaptation.py @@ -0,0 +1,426 @@ +"""Public API for ChEES-HMC""" + +from typing import Callable, NamedTuple, Optional + +import jax +import jax.numpy as jnp +import numpy as np +import optax + +import blackjax.mcmc.hmc as hmc +import blackjax.optimizers.dual_averaging as dual_averaging +from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.base import AdaptationAlgorithm +from blackjax.types import Array, ArrayLikeTree, PRNGKey + + +class ChEESAdaptationState(NamedTuple): + """State of the ChEES-HMC adaptation scheme. + + step_size + Value of the step_size parameter of the HMC algorithm. + step_size_moving_average + Running moving average of the step_size parameter. + trajectory_length + Value of the num_integration_steps / step_size parameter of + the HMC algorithm. + trajectory_length_moving_average + Running moving average of the num_integration_steps / step_size + parameter. + optim_state + Optax optimizing state for used to maximize the ChEES criterion. + random_generator_arg + Utility array for generating a pseudo or quasi-random sequence of + numbers. + + """ + + step_size: float + step_size_moving_average: float + trajectory_length: float + trajectory_length_moving_average: float + da_state: dual_averaging.DualAveragingState + optim_state: optax.OptState + random_generator_arg: Array + + +def base( + jitter_generator: Callable, + next_random_arg_fn: Callable, + optim: optax.GradientTransformation, +): + """Maximizing the Change in the Estimator of the Expected Square criterion + (trajectory length) and dual averaging procedure (step size) for the jittered + Hamiltonian Monte Carlo kernel :cite:p:`hoffman2021adaptive`. + + This adaptation algorithm tunes the step size and trajectory length, i.e. + number of integration steps / step size, of the jittered HMC algorithm based + on statistics collected from a population of many chains. It maximizes the Change + in the Estimator of the Expected Square (ChEES) criterion to tune the trajectory + length and uses dual averaging targeting an acceptance rate of 0.651 of the harmonic + mean of the chain's acceptance probabilities to tune the step size. + + Parameters + ---------- + jitter_generator + Optional function that generates a value in [0, 1] used to jitter the trajectory + lengths given a PRNGKey, used to propose the number of integration steps. If None, + then a quasi-random Halton is used to jitter the trajectory length. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + optim + Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol. + + Returns + ------- + init + Function that initializes the warmup. + update + Function that moves the warmup one step. + + """ + + da_init, da_update, _ = dual_averaging.dual_averaging() + + def compute_parameters( + proposed_positions: ArrayLikeTree, + proposed_momentums: ArrayLikeTree, + initial_positions: ArrayLikeTree, + acceptance_probabilities: Array, + is_divergent: Array, + initial_adaptation_state: ChEESAdaptationState, + ): + """Compute values for the parameters based on statistics collected from + multiple chains. + + Parameters + ---------- + proposed_positions: + A PyTree that contains the position proposed by the HMC algorithm of + every chain (proposal that is accepted or rejected using MH). + proposed_momentums: + A PyTree that contains the momentum variable proposed by the HMC algorithm + of every chain (proposal that is accepted or rejected using MH). + initial_positions: + A PyTree that contains the initial position at the start of the HMC + algorithm of every chain. + acceptance_probabilities: + Metropolis-Hastings acceptance probabilty of proposals of every chain. + initial_adaptation_state: + ChEES adaptation step used to generate proposals and acceptance probabilities. + + Returns + ------- + New values of the step size and trajectory length of the jittered HMC algorithm. + + """ + ( + step_size, + step_size_ma, + trajectory_length, + trajectory_length_ma, + da_state, + optim_state, + random_generator_arg, + ) = initial_adaptation_state + + harmonic_mean = 1.0 / jnp.mean( + 1.0 / acceptance_probabilities, where=~is_divergent + ) + da_state_ = da_update(da_state, 0.651 - harmonic_mean) + step_size_ = jnp.exp(da_state_.log_x) + new_step_size, new_da_state = jax.lax.cond( + jnp.isfinite(step_size_), + lambda _: (step_size_, da_state_), + lambda _: (step_size, da_state), + None, + ) + new_step_size_ma = 0.9 * step_size_ma + 0.1 * new_step_size + + proposals_mean = jax.tree_util.tree_map( + lambda p: jnp.nanmean(p, axis=0), proposed_positions + ) + initials_mean = jax.tree_util.tree_map( + lambda p: jnp.nanmean(p, axis=0), initial_positions + ) + proposals_centered = jax.tree_util.tree_map( + lambda p, pm: p - pm, proposed_positions, proposals_mean + ) + initials_centered = jax.tree_util.tree_map( + lambda p, pm: p - pm, initial_positions, initials_mean + ) + + proposals_matrix = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])( + proposals_centered + ) + initials_matrix = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])( + initials_centered + ) + momentums_matrix = jax.vmap(lambda m: jax.flatten_util.ravel_pytree(m)[0])( + proposed_momentums + ) + + trajectory_gradients = ( + jitter_generator(random_generator_arg) + * trajectory_length + * ( + jax.vmap(lambda p: jnp.dot(p, p))(proposals_matrix) + - jax.vmap(lambda p: jnp.dot(p, p))(initials_matrix) + ) + * jax.vmap(lambda p, m: jnp.dot(p, m))(proposals_matrix, momentums_matrix) + ) + trajectory_gradient = jnp.sum( + acceptance_probabilities * trajectory_gradients, where=~is_divergent + ) / jnp.sum(acceptance_probabilities, where=~is_divergent) + + log_trajectory_length = jnp.log(trajectory_length) + updates, optim_state_ = optim.update( + trajectory_gradient, optim_state, log_trajectory_length + ) + log_trajectory_length_ = optax.apply_updates(log_trajectory_length, updates) + new_log_trajectory_length, new_optim_state = jax.lax.cond( + jnp.isfinite( + jax.flatten_util.ravel_pytree(log_trajectory_length_)[0] + ).all(), + lambda _: (log_trajectory_length_, optim_state_), + lambda _: (log_trajectory_length, optim_state), + None, + ) + new_trajectory_length = jnp.exp(new_log_trajectory_length) + new_trajectory_length_ma = ( + 0.9 * trajectory_length_ma + 0.1 * new_trajectory_length + ) + + return ( + new_step_size, + new_step_size_ma, + new_trajectory_length, + new_trajectory_length_ma, + new_da_state, + new_optim_state, + next_random_arg_fn(random_generator_arg), + ) + + def init(random_generator_arg: Array, step_size: float): + return ChEESAdaptationState( + step_size, + 0.0, + step_size, + 0.0, + da_init(step_size), + optim.init(step_size), + random_generator_arg, + ) + + def update( + adaptation_state: ChEESAdaptationState, + proposed_positions: ArrayLikeTree, + proposed_momentums: ArrayLikeTree, + initial_positions: ArrayLikeTree, + acceptance_probabilities: Array, + is_divergent: Array, + ): + """Update the adaptation state and parameter values. + + Parameters + ---------- + adaptation_state + The current state of the adaptation algorithm + proposed_positions: + The position proposed by the HMC algorithm of every chain. + proposed_momentums: + The momentum variable proposed by the HMC algorithm of every chain. + initial_positions: + The initial position at the start of the HMC algorithm of every chain. + acceptance_probabilities: + Metropolis-Hastings acceptance probabilty of proposals of every chain. + + Returns + ------- + New adaptation state that contains the step size and trajectory length of the + jittered HMC algorithm. + + """ + parameters = compute_parameters( + proposed_positions, + proposed_momentums, + initial_positions, + acceptance_probabilities, + is_divergent, + adaptation_state, + ) + + return ChEESAdaptationState(*parameters) + + return init, update + + +def chees_adaptation( + logprob_fn: Callable, + num_chains: int, + *, + jitter_generator: Optional[Callable] = None, +) -> AdaptationAlgorithm: + """Adapt the step size and trajectory length (number of integration steps / step size) + parameters of the jittered HMC algorthm. + + The jittered HMC algorithm depends on the value of a step size, controlling + the discretization step of the integrator, and a trajectory length, given by the + number of integration steps / step size, jittered by using only a random percentage + of this trajectory length. + + This adaptation algorithm tunes the trajectory length by heuristically maximizing + the Change in the Estimator of the Expected Square (ChEES) criterion over + an ensamble of parallel chains. At equilibrium, the algorithm aims at eliminating + correlations between target dimensions, making the HMC algorithm efficient. + + Jittering requires generating a random sequence of uniform variables in [0, 1]. + However, this adds another source of variance to the sampling procedure, + which may slow adaptation or lead to suboptimal mixing. To alleviate this, + rather than use uniform random noise to jitter the trajectory lengths, we use a + quasi-random Halton sequence, which ensures a more even distribution of trajectory + lengths. + + Examples + -------- + + An HMC adapted kernel can be learned and used with the following code: + + .. code:: + + warmup = blackjax.chees_adaptation(logprob_fn, num_chains) + key_warmup, key_sample = jax.random.split(rng_key) + optim = optax.adam(learning_rate) + (last_states, parameters), _ = warmup.run( + key_warmup, + positions, #PyTree where each leaf has shape (num_chains, ...) + initial_step_size, + optim, + num_warmup_steps, + ) + kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step + new_states, info = jax.vmap(kernel)(key_sample, last_states) + + Parameters + ---------- + logprob_fn + The log density probability density function from which we wish to sample. + num_chains + Number of chains used for cross-chain warm-up training. + jitter_generator + Optional function that generates a value in [0, 1] used to jitter the trajectory + lengths given a PRNGKey, used to propose the number of integration steps. If None, + then a quasi-random Halton is used to jitter the trajectory length. + + Returns + ------- + A function that returns the last cross-chain state, a sampling kernel with the + tuned parameter values, and all the warm-up states for diagnostics. + + """ + + def run( + rng_key: PRNGKey, + positions: ArrayLikeTree, + step_size: float, + optim: optax.GradientTransformation, + num_steps: int = 1000, + *, + max_sampling_steps: int = 1000, + ): + num_dim = 0 + for d in jax.tree_util.tree_leaves( + jax.tree_util.tree_map( + lambda p: p.reshape(num_chains, -1).shape[1], positions + ) + ): + num_dim += d + + key_init, key_step = jax.random.split(rng_key) + + if jitter_generator is not None: + jitter_gn = jitter_generator + next_random_arg_fn = lambda key: jax.random.split(key)[1] + init_random_arg = key_init + else: + jitter_gn = lambda i: _halton_sequence( + i, np.ceil(np.log2(num_steps + max_sampling_steps)) + ) + next_random_arg_fn = lambda i: i + 1 + init_random_arg = 0 + + def integration_steps_fn(random_generator_arg, trajectory_length_adjusted): + return jnp.ceil( + jitter_gn(random_generator_arg) * trajectory_length_adjusted + ) + + step_fn = hmc.build_dynamic_kernel( + next_random_arg_fn=next_random_arg_fn, + integration_steps_fn=integration_steps_fn, + ) + + init, update = base(jitter_gn, next_random_arg_fn, optim) + + def one_step(carry, rng_key): + states, adaptation_state = carry + + def one_step(rng_key, state): + return step_fn( + rng_key, + state, + logprob_fn, + step_size=adaptation_state.step_size, + inverse_mass_matrix=jnp.ones(num_dim), + trajectory_length_adjusted=adaptation_state.trajectory_length + / adaptation_state.step_size, + ) + + keys = jax.random.split(rng_key, num_chains) + new_states, info = jax.vmap(one_step)(keys, states) + new_adaptation_state = update( + adaptation_state, + info.proposal.state.position, + info.proposal.state.momentum, + states.position, + info.acceptance_rate, + info.is_divergent, + ) + + return (new_states, new_adaptation_state), AdaptationInfo( + new_states, + info, + new_adaptation_state, + ) + + batch_init = jax.vmap( + lambda p: hmc.init_dynamic(p, logprob_fn, init_random_arg) + ) + init_states = batch_init(positions) + init_adaptation_state = init(init_random_arg, step_size) + + keys_step = jax.random.split(key_step, num_steps) + (last_states, last_adaptation_state), info = jax.lax.scan( + one_step, (init_states, init_adaptation_state), keys_step + ) + + trajectory_length_adjusted = ( + last_adaptation_state.trajectory_length_moving_average + / last_adaptation_state.step_size_moving_average + ) + parameters = { + "step_size": last_adaptation_state.step_size_moving_average, + "inverse_mass_matrix": jnp.ones(num_dim), + "next_random_arg_fn": next_random_arg_fn, + "integration_steps_fn": lambda arg: integration_steps_fn( + arg, trajectory_length_adjusted + ), + } + + return AdaptationResults(last_states, parameters), info + + return AdaptationAlgorithm(run) # type: ignore[arg-type] + + +def _halton_sequence(i, max_bits=10): + bit_masks = 2 ** jnp.arange(max_bits, dtype=i.dtype) + return jnp.einsum("i,i->", jnp.mod((i + 1) // bit_masks, 2), 0.5 / bit_masks) diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 826844ade..8aaf272bf 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -33,6 +33,7 @@ "build_kernel", "build_dynamic_kernel", "hmc", + "dynamic_hmc", ] @@ -210,9 +211,12 @@ def kernel( logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, + **integration_steps_kwargs, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" - num_integration_steps = integration_steps_fn(state.random_generator_arg) + num_integration_steps = integration_steps_fn( + state.random_generator_arg, **integration_steps_kwargs + ) hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad) hmc_proposal, info = hmc_base( rng_key, @@ -220,7 +224,7 @@ def kernel( logdensity_fn, step_size, inverse_mass_matrix, - num_integration_steps, + jax.numpy.asarray(num_integration_steps, dtype=int), ) next_random_arg = next_random_arg_fn(state.random_generator_arg) return ( @@ -330,6 +334,69 @@ def step_fn(rng_key: PRNGKey, state): return SamplingAlgorithm(init_fn, step_fn) +class dynamic_hmc: + """Implements the (basic) user interface for the dynamic HMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + inverse_mass_matrix + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + init = staticmethod(init_dynamic) + build_kernel = staticmethod(build_dynamic_kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + ) -> SamplingAlgorithm: + kernel = cls.build_kernel( + integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn + ) + + def init_fn(position: ArrayLikeTree, random_generator_arg: Array): + return cls.init(position, logdensity_fn, random_generator_arg) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + ) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + + def hmc_proposal( integrator: Callable, kinetic_energy: Callable, diff --git a/docs/refs.bib b/docs/refs.bib index 378c451ff..f5015ccb9 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -351,3 +351,12 @@ @article{liu2016stein volume={29}, year={2016} } + +@inproceedings{hoffman2021adaptive, + title={An adaptive-mcmc scheme for setting trajectory lengths in hamiltonian monte carlo}, + author={Hoffman, Matthew and Radul, Alexey and Sountsov, Pavel}, + booktitle={International Conference on Artificial Intelligence and Statistics}, + pages={3907--3915}, + year={2021}, + organization={PMLR} +} diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 8772d2a13..7770b55a1 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import jax.scipy.stats as stats import numpy as np +import optax from absl.testing import absltest, parameterized import blackjax @@ -223,6 +224,48 @@ def test_meads(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + @parameterized.parameters([None, jax.random.uniform]) + def test_chees(self, jitter_generator): + """Test the ChEES adaptation w/ HMC kernel.""" + rng_key, init_key0, init_key1 = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logposterior_fn = lambda x: logposterior_fn_(**x) + + init_key, warmup_key, inference_key = jax.random.split(rng_key, 3) + + num_chains = 128 + warmup = blackjax.chees_adaptation( + logposterior_fn, num_chains=num_chains, jitter_generator=jitter_generator + ) + scale_key, coefs_key = jax.random.split(init_key, 2) + log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,)) + coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,)) + initial_positions = {"log_scale": log_scales, "coefs": coefs} + (last_states, parameters), _ = warmup.run( + warmup_key, + initial_positions, + step_size=0.001, + optim=optax.adam(learning_rate=0.1), + num_steps=1000, + ) + kernel = blackjax.dynamic_hmc(logposterior_fn, **parameters).step + + chain_keys = jax.random.split(inference_key, num_chains) + states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))( + chain_keys, last_states + ) + + coefs_samples = states.position["coefs"] + scale_samples = np.exp(states.position["log_scale"]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + class SGMCMCTest(chex.TestCase): """Test sampling of a linear regression model."""