diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py index 499c64561..1edbe9e58 100644 --- a/blackjax/adaptation/chees_adaptation.py +++ b/blackjax/adaptation/chees_adaptation.py @@ -16,7 +16,7 @@ from blackjax.util import pytree_size # optimal tuning for HMC, see https://arxiv.org/abs/1001.4460 -TARGET_ACCEPTANCE_RATE = 0.651 +OPTIMAL_TARGET_ACCEPTANCE_RATE = 0.651 class ChEESAdaptationState(NamedTuple): @@ -24,35 +24,40 @@ class ChEESAdaptationState(NamedTuple): step_size Value of the step_size parameter of the HMC algorithm. - step_size_moving_average - Running moving average of the step_size parameter. + log_step_size_moving_average + Running moving average of the log step_size parameter. trajectory_length Value of the num_integration_steps / step_size parameter of the HMC algorithm. - trajectory_length_moving_average - Running moving average of the num_integration_steps / step_size + log_trajectory_length_moving_average + Running moving average of the log num_integration_steps / step_size parameter. optim_state Optax optimizing state for used to maximize the ChEES criterion. random_generator_arg Utility array for generating a pseudo or quasi-random sequence of numbers. + step + Current iteration number. """ step_size: float - step_size_moving_average: float + log_step_size_moving_average: float trajectory_length: float - trajectory_length_moving_average: float + log_trajectory_length_moving_average: float da_state: dual_averaging.DualAveragingState optim_state: optax.OptState random_generator_arg: Array + step: int def base( jitter_generator: Callable, next_random_arg_fn: Callable, optim: optax.GradientTransformation, + target_acceptance_rate: float, + decay_rate: float, ) -> Tuple[Callable, Callable]: """Maximizing the Change in the Estimator of the Expected Square criterion (trajectory length) and dual averaging procedure (step size) for the jittered @@ -75,6 +80,12 @@ def base( Function that generates the next `random_generator_arg` from its previous value. optim Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol. + target_acceptance_rate + Average acceptance rate to target with dual averaging. + decay_rate + Float representing how much to favor recent iterations over earlier ones in the optimization + of step size and trajectory length. + Returns ------- @@ -121,26 +132,30 @@ def compute_parameters( """ ( step_size, - step_size_ma, + log_step_size_ma, trajectory_length, - trajectory_length_ma, + log_trajectory_length_ma, da_state, optim_state, random_generator_arg, + step, ) = initial_adaptation_state harmonic_mean = 1.0 / jnp.mean( 1.0 / acceptance_probabilities, where=~is_divergent ) - da_state_ = da_update(da_state, TARGET_ACCEPTANCE_RATE - harmonic_mean) + da_state_ = da_update(da_state, target_acceptance_rate - harmonic_mean) step_size_ = jnp.exp(da_state_.log_x) - new_step_size, new_da_state = jax.lax.cond( + new_step_size, new_da_state, new_log_step_size = jax.lax.cond( jnp.isfinite(step_size_), - lambda _: (step_size_, da_state_), - lambda _: (step_size, da_state), + lambda _: (step_size_, da_state_, da_state_.log_x), + lambda _: (step_size, da_state, da_state.log_x), None, ) - new_step_size_ma = 0.9 * step_size_ma + 0.1 * new_step_size + update_weight = step ** (-decay_rate) + new_log_step_size_ma = ( + 1.0 - update_weight + ) * log_step_size_ma + update_weight * new_log_step_size proposals_mean = jax.tree_util.tree_map( lambda p: jnp.nanmean(p, axis=0), proposed_positions @@ -186,30 +201,32 @@ def compute_parameters( lambda _: (log_trajectory_length, optim_state), None, ) - new_trajectory_length = jnp.exp(new_log_trajectory_length) - new_trajectory_length_ma = ( - 0.9 * trajectory_length_ma + 0.1 * new_trajectory_length - ) + new_log_trajectory_length_ma = ( + 1.0 - update_weight + ) * log_trajectory_length_ma + update_weight * new_log_trajectory_length + new_trajectory_length = jnp.exp(new_log_trajectory_length_ma) return ChEESAdaptationState( new_step_size, - new_step_size_ma, + new_log_step_size_ma, new_trajectory_length, - new_trajectory_length_ma, + new_log_trajectory_length_ma, new_da_state, new_optim_state, next_random_arg_fn(random_generator_arg), + step + 1, ) def init(random_generator_arg: Array, step_size: float): return ChEESAdaptationState( step_size=step_size, - step_size_moving_average=0.0, + log_step_size_moving_average=0.0, trajectory_length=step_size, - trajectory_length_moving_average=0.0, + log_trajectory_length_moving_average=0.0, da_state=da_init(step_size), optim_state=optim.init(step_size), random_generator_arg=random_generator_arg, + step=1, ) def update( @@ -260,6 +277,9 @@ def chees_adaptation( num_chains: int, *, jitter_generator: Optional[Callable] = None, + jitter_amount: float = 1.0, + target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE, + decay_rate: float = 0.5, ) -> AdaptationAlgorithm: """Adapt the step size and trajectory length (number of integration steps / step size) parameters of the jittered HMC algorthm. @@ -311,6 +331,14 @@ def chees_adaptation( Optional function that generates a value in [0, 1] used to jitter the trajectory lengths given a PRNGKey, used to propose the number of integration steps. If None, then a quasi-random Halton is used to jitter the trajectory length. + jitter_value + A percentage in [0, 1] representing how much of the calculated trajectory should be jitted. + target_acceptance_rate + Average acceptance rate to target with dual averaging. Defaults to optimal tuning for HMC. + decay_rate + Float representing how much to favor recent iterations over earlier ones in the optimization + of step size and trajectory length. A value of 1 gives equal weight to all history. A value + of 0 gives weight only to the most recent iteration. Returns ------- @@ -338,13 +366,15 @@ def run( key_init, key_step = jax.random.split(rng_key) if jitter_generator is not None: - jitter_gn = jitter_generator + jitter_gn = lambda key: jitter_generator(key) * jitter_amount + ( + 1.0 - jitter_amount + ) next_random_arg_fn = lambda key: jax.random.split(key)[1] init_random_arg = key_init else: jitter_gn = lambda i: _halton_sequence( i, np.ceil(np.log2(num_steps + max_sampling_steps)) - ) + ) * jitter_amount + (1.0 - jitter_amount) next_random_arg_fn = lambda i: i + 1 init_random_arg = 0 @@ -359,7 +389,9 @@ def integration_steps_fn(random_generator_arg, trajectory_length_adjusted): integration_steps_fn=integration_steps_fn, ) - init, update = base(jitter_gn, next_random_arg_fn, optim) + init, update = base( + jitter_gn, next_random_arg_fn, optim, target_acceptance_rate, decay_rate + ) def one_step(carry, rng_key): states, adaptation_state = carry @@ -403,12 +435,12 @@ def one_step(carry, rng_key): one_step, (init_states, init_adaptation_state), keys_step ) - trajectory_length_adjusted = ( - last_adaptation_state.trajectory_length_moving_average - / last_adaptation_state.step_size_moving_average + trajectory_length_adjusted = jnp.exp( + last_adaptation_state.log_trajectory_length_moving_average + - last_adaptation_state.log_step_size_moving_average ) parameters = { - "step_size": last_adaptation_state.step_size_moving_average, + "step_size": jnp.exp(last_adaptation_state.log_step_size_moving_average), "inverse_mass_matrix": jnp.ones(num_dim), "next_random_arg_fn": next_random_arg_fn, "integration_steps_fn": lambda arg: integration_steps_fn( diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 67e0ea0b6..1b95b0115 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -1,6 +1,10 @@ +import jax +import jax.numpy as jnp import numpy as np +import optax import pytest +import blackjax from blackjax.adaptation import window_adaptation @@ -27,3 +31,43 @@ def test_adaptation_schedule(num_steps, expected_schedule): adaptation_schedule = window_adaptation.build_schedule(num_steps) assert num_steps == len(adaptation_schedule) assert np.array_equal(adaptation_schedule, expected_schedule) + + +def test_chees_adaptation(): + logprob_fn = lambda x: jax.scipy.stats.norm.logpdf( + x, loc=0.0, scale=jnp.array([1.0, 10.0]) + ).sum() + + num_burnin_steps = 1000 + num_results = 500 + num_chains = 16 + step_size = 0.1 + + init_key, warmup_key, inference_key = jax.random.split(jax.random.PRNGKey(0), 3) + + warmup = blackjax.chees_adaptation( + logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75 + ) + + initial_positions = jax.random.normal(init_key, (num_chains, 2)) + (last_states, parameters), warmup_info = warmup.run( + warmup_key, + initial_positions, + step_size=step_size, + optim=optax.adamw(learning_rate=0.5), + num_steps=num_burnin_steps, + ) + kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step + + def one_step(states, rng_key): + keys = jax.random.split(rng_key, num_chains) + states, infos = jax.vmap(kernel)(keys, states) + return states, infos + + keys = jax.random.split(inference_key, num_results) + _, infos = jax.lax.scan(one_step, last_states, keys) + + harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) + np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1) + np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1) + np.testing.assert_allclose(infos.num_integration_steps.mean(), 15.0, rtol=3e-1)