Skip to content

Commit

Permalink
geometric weighted moving average and tests for convergence of step s…
Browse files Browse the repository at this point in the history
…ize and number of steps
  • Loading branch information
albcab committed Nov 2, 2023
1 parent ce46f3b commit b1f3419
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 29 deletions.
90 changes: 61 additions & 29 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,48 @@
from blackjax.util import pytree_size

# optimal tuning for HMC, see https://arxiv.org/abs/1001.4460
TARGET_ACCEPTANCE_RATE = 0.651
OPTIMAL_TARGET_ACCEPTANCE_RATE = 0.651


class ChEESAdaptationState(NamedTuple):
"""State of the ChEES-HMC adaptation scheme.
step_size
Value of the step_size parameter of the HMC algorithm.
step_size_moving_average
Running moving average of the step_size parameter.
log_step_size_moving_average
Running moving average of the log step_size parameter.
trajectory_length
Value of the num_integration_steps / step_size parameter of
the HMC algorithm.
trajectory_length_moving_average
Running moving average of the num_integration_steps / step_size
log_trajectory_length_moving_average
Running moving average of the log num_integration_steps / step_size
parameter.
optim_state
Optax optimizing state for used to maximize the ChEES criterion.
random_generator_arg
Utility array for generating a pseudo or quasi-random sequence of
numbers.
step
Current iteration number.
"""

step_size: float
step_size_moving_average: float
log_step_size_moving_average: float
trajectory_length: float
trajectory_length_moving_average: float
log_trajectory_length_moving_average: float
da_state: dual_averaging.DualAveragingState
optim_state: optax.OptState
random_generator_arg: Array
step: int


def base(
jitter_generator: Callable,
next_random_arg_fn: Callable,
optim: optax.GradientTransformation,
target_acceptance_rate: float,
decay_rate: float,
) -> Tuple[Callable, Callable]:
"""Maximizing the Change in the Estimator of the Expected Square criterion
(trajectory length) and dual averaging procedure (step size) for the jittered
Expand All @@ -75,6 +80,12 @@ def base(
Function that generates the next `random_generator_arg` from its previous value.
optim
Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol.
target_acceptance_rate
Average acceptance rate to target with dual averaging.
decay_rate
Float representing how much to favor recent iterations over earlier ones in the optimization
of step size and trajectory length.
Returns
-------
Expand Down Expand Up @@ -121,26 +132,30 @@ def compute_parameters(
"""
(
step_size,
step_size_ma,
log_step_size_ma,
trajectory_length,
trajectory_length_ma,
log_trajectory_length_ma,
da_state,
optim_state,
random_generator_arg,
step,
) = initial_adaptation_state

harmonic_mean = 1.0 / jnp.mean(
1.0 / acceptance_probabilities, where=~is_divergent
)
da_state_ = da_update(da_state, TARGET_ACCEPTANCE_RATE - harmonic_mean)
da_state_ = da_update(da_state, target_acceptance_rate - harmonic_mean)
step_size_ = jnp.exp(da_state_.log_x)
new_step_size, new_da_state = jax.lax.cond(
new_step_size, new_da_state, new_log_step_size = jax.lax.cond(
jnp.isfinite(step_size_),
lambda _: (step_size_, da_state_),
lambda _: (step_size, da_state),
lambda _: (step_size_, da_state_, da_state_.log_x),
lambda _: (step_size, da_state, da_state.log_x),
None,
)
new_step_size_ma = 0.9 * step_size_ma + 0.1 * new_step_size
update_weight = step ** (-decay_rate)
new_log_step_size_ma = (
1.0 - update_weight
) * log_step_size_ma + update_weight * new_log_step_size

proposals_mean = jax.tree_util.tree_map(
lambda p: jnp.nanmean(p, axis=0), proposed_positions
Expand Down Expand Up @@ -186,30 +201,32 @@ def compute_parameters(
lambda _: (log_trajectory_length, optim_state),
None,
)
new_trajectory_length = jnp.exp(new_log_trajectory_length)
new_trajectory_length_ma = (
0.9 * trajectory_length_ma + 0.1 * new_trajectory_length
)
new_log_trajectory_length_ma = (
1.0 - update_weight
) * log_trajectory_length_ma + update_weight * new_log_trajectory_length
new_trajectory_length = jnp.exp(new_log_trajectory_length_ma)

return ChEESAdaptationState(
new_step_size,
new_step_size_ma,
new_log_step_size_ma,
new_trajectory_length,
new_trajectory_length_ma,
new_log_trajectory_length_ma,
new_da_state,
new_optim_state,
next_random_arg_fn(random_generator_arg),
step + 1,
)

def init(random_generator_arg: Array, step_size: float):
return ChEESAdaptationState(
step_size=step_size,
step_size_moving_average=0.0,
log_step_size_moving_average=0.0,
trajectory_length=step_size,
trajectory_length_moving_average=0.0,
log_trajectory_length_moving_average=0.0,
da_state=da_init(step_size),
optim_state=optim.init(step_size),
random_generator_arg=random_generator_arg,
step=1,
)

def update(
Expand Down Expand Up @@ -260,6 +277,9 @@ def chees_adaptation(
num_chains: int,
*,
jitter_generator: Optional[Callable] = None,
jitter_amount: float = 1.0,
target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE,
decay_rate: float = 0.5,
) -> AdaptationAlgorithm:
"""Adapt the step size and trajectory length (number of integration steps / step size)
parameters of the jittered HMC algorthm.
Expand Down Expand Up @@ -311,6 +331,14 @@ def chees_adaptation(
Optional function that generates a value in [0, 1] used to jitter the trajectory
lengths given a PRNGKey, used to propose the number of integration steps. If None,
then a quasi-random Halton is used to jitter the trajectory length.
jitter_value
A percentage in [0, 1] representing how much of the calculated trajectory should be jitted.
target_acceptance_rate
Average acceptance rate to target with dual averaging. Defaults to optimal tuning for HMC.
decay_rate
Float representing how much to favor recent iterations over earlier ones in the optimization
of step size and trajectory length. A value of 1 gives equal weight to all history. A value
of 0 gives weight only to the most recent iteration.
Returns
-------
Expand Down Expand Up @@ -338,13 +366,15 @@ def run(
key_init, key_step = jax.random.split(rng_key)

if jitter_generator is not None:
jitter_gn = jitter_generator
jitter_gn = lambda key: jitter_generator(key) * jitter_amount + (
1.0 - jitter_amount
)
next_random_arg_fn = lambda key: jax.random.split(key)[1]
init_random_arg = key_init
else:
jitter_gn = lambda i: _halton_sequence(
i, np.ceil(np.log2(num_steps + max_sampling_steps))
)
) * jitter_amount + (1.0 - jitter_amount)
next_random_arg_fn = lambda i: i + 1
init_random_arg = 0

Expand All @@ -359,7 +389,9 @@ def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
integration_steps_fn=integration_steps_fn,
)

init, update = base(jitter_gn, next_random_arg_fn, optim)
init, update = base(
jitter_gn, next_random_arg_fn, optim, target_acceptance_rate, decay_rate
)

def one_step(carry, rng_key):
states, adaptation_state = carry
Expand Down Expand Up @@ -403,12 +435,12 @@ def one_step(carry, rng_key):
one_step, (init_states, init_adaptation_state), keys_step
)

trajectory_length_adjusted = (
last_adaptation_state.trajectory_length_moving_average
/ last_adaptation_state.step_size_moving_average
trajectory_length_adjusted = jnp.exp(
last_adaptation_state.log_trajectory_length_moving_average
- last_adaptation_state.log_step_size_moving_average
)
parameters = {
"step_size": last_adaptation_state.step_size_moving_average,
"step_size": jnp.exp(last_adaptation_state.log_step_size_moving_average),
"inverse_mass_matrix": jnp.ones(num_dim),
"next_random_arg_fn": next_random_arg_fn,
"integration_steps_fn": lambda arg: integration_steps_fn(
Expand Down
44 changes: 44 additions & 0 deletions tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pytest

import blackjax
from blackjax.adaptation import window_adaptation


Expand All @@ -27,3 +31,43 @@ def test_adaptation_schedule(num_steps, expected_schedule):
adaptation_schedule = window_adaptation.build_schedule(num_steps)
assert num_steps == len(adaptation_schedule)
assert np.array_equal(adaptation_schedule, expected_schedule)


def test_chees_adaptation():
logprob_fn = lambda x: jax.scipy.stats.norm.logpdf(
x, loc=0.0, scale=jnp.array([1.0, 10.0])
).sum()

num_burnin_steps = 1000
num_results = 500
num_chains = 16
step_size = 0.1

init_key, warmup_key, inference_key = jax.random.split(jax.random.PRNGKey(0), 3)

warmup = blackjax.chees_adaptation(
logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
)

initial_positions = jax.random.normal(init_key, (num_chains, 2))
(last_states, parameters), warmup_info = warmup.run(
warmup_key,
initial_positions,
step_size=step_size,
optim=optax.adamw(learning_rate=0.5),
num_steps=num_burnin_steps,
)
kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step

def one_step(states, rng_key):
keys = jax.random.split(rng_key, num_chains)
states, infos = jax.vmap(kernel)(keys, states)
return states, infos

keys = jax.random.split(inference_key, num_results)
_, infos = jax.lax.scan(one_step, last_states, keys)

harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1)
np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1)
np.testing.assert_allclose(infos.num_integration_steps.mean(), 15.0, rtol=3e-1)

0 comments on commit b1f3419

Please sign in to comment.