Skip to content

Commit

Permalink
revision
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab committed Nov 1, 2023
1 parent 072c245 commit 79231e7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 46 deletions.
88 changes: 44 additions & 44 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Public API for ChEES-HMC"""

from typing import Callable, NamedTuple, Optional
from functools import partial
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -12,6 +13,10 @@
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size

# optimal tuning for HMC, see https://arxiv.org/abs/1001.4460
TARGET_ACCEPTANCE_RATE = 0.651


class ChEESAdaptationState(NamedTuple):
Expand Down Expand Up @@ -48,7 +53,7 @@ def base(
jitter_generator: Callable,
next_random_arg_fn: Callable,
optim: optax.GradientTransformation,
):
) -> Tuple[Callable, Callable]:
"""Maximizing the Change in the Estimator of the Expected Square criterion
(trajectory length) and dual averaging procedure (step size) for the jittered
Hamiltonian Monte Carlo kernel :cite:p:`hoffman2021adaptive`.
Expand Down Expand Up @@ -89,7 +94,7 @@ def compute_parameters(
acceptance_probabilities: Array,
is_divergent: Array,
initial_adaptation_state: ChEESAdaptationState,
):
) -> ChEESAdaptationState:
"""Compute values for the parameters based on statistics collected from
multiple chains.
Expand Down Expand Up @@ -127,7 +132,7 @@ def compute_parameters(
harmonic_mean = 1.0 / jnp.mean(
1.0 / acceptance_probabilities, where=~is_divergent
)
da_state_ = da_update(da_state, 0.651 - harmonic_mean)
da_state_ = da_update(da_state, TARGET_ACCEPTANCE_RATE - harmonic_mean)
step_size_ = jnp.exp(da_state_.log_x)
new_step_size, new_da_state = jax.lax.cond(
jnp.isfinite(step_size_),
Expand All @@ -150,15 +155,10 @@ def compute_parameters(
lambda p, pm: p - pm, initial_positions, initials_mean
)

proposals_matrix = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])(
proposals_centered
)
initials_matrix = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])(
initials_centered
)
momentums_matrix = jax.vmap(lambda m: jax.flatten_util.ravel_pytree(m)[0])(
proposed_momentums
)
vmap_flatten_op = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])
proposals_matrix = vmap_flatten_op(proposals_centered)
initials_matrix = vmap_flatten_op(initials_centered)
momentums_matrix = vmap_flatten_op(proposed_momentums)

trajectory_gradients = (
jitter_generator(random_generator_arg)
Expand Down Expand Up @@ -191,7 +191,7 @@ def compute_parameters(
0.9 * trajectory_length_ma + 0.1 * new_trajectory_length
)

return (
return ChEESAdaptationState(
new_step_size,
new_step_size_ma,
new_trajectory_length,
Expand All @@ -203,13 +203,13 @@ def compute_parameters(

def init(random_generator_arg: Array, step_size: float):
return ChEESAdaptationState(
step_size,
0.0,
step_size,
0.0,
da_init(step_size),
optim.init(step_size),
random_generator_arg,
step_size=step_size,
step_size_moving_average=0.0,
trajectory_length=step_size,
trajectory_length_moving_average=0.0,
da_state=da_init(step_size),
optim_state=optim.init(step_size),
random_generator_arg=random_generator_arg,
)

def update(
Expand Down Expand Up @@ -241,7 +241,7 @@ def update(
jittered HMC algorithm.
"""
parameters = compute_parameters(
new_state = compute_parameters(
proposed_positions,
proposed_momentums,
initial_positions,
Expand All @@ -250,7 +250,7 @@ def update(
adaptation_state,
)

return ChEESAdaptationState(*parameters)
return new_state

return init, update

Expand Down Expand Up @@ -328,13 +328,12 @@ def run(
*,
max_sampling_steps: int = 1000,
):
num_dim = 0
for d in jax.tree_util.tree_leaves(
jax.tree_util.tree_map(
lambda p: p.reshape(num_chains, -1).shape[1], positions
)
):
num_dim += d
assert all(
jax.tree_util.tree_flatten(
jax.tree_util.tree_map(lambda p: p.shape[0] == num_chains, positions)
)[0]
), "initial `positions` leading dimension must be equal to the `num_chains`"
num_dim = pytree_size(positions) // num_chains

key_init, key_step = jax.random.split(rng_key)

Expand All @@ -350,8 +349,9 @@ def run(
init_random_arg = 0

def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
return jnp.ceil(
jitter_gn(random_generator_arg) * trajectory_length_adjusted
return jnp.asarray(
jnp.ceil(jitter_gn(random_generator_arg) * trajectory_length_adjusted),
dtype=int,
)

step_fn = hmc.build_dynamic_kernel(
Expand All @@ -364,19 +364,19 @@ def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
def one_step(carry, rng_key):
states, adaptation_state = carry

def one_step(rng_key, state):
return step_fn(
rng_key,
state,
logprob_fn,
step_size=adaptation_state.step_size,
inverse_mass_matrix=jnp.ones(num_dim),
trajectory_length_adjusted=adaptation_state.trajectory_length
/ adaptation_state.step_size,
)

keys = jax.random.split(rng_key, num_chains)
new_states, info = jax.vmap(one_step)(keys, states)
_step_fn = partial(
step_fn,
trajectory_length_adjusted=adaptation_state.trajectory_length
/ adaptation_state.step_size,
)
new_states, info = jax.vmap(_step_fn, (0, 0, None, None, None))(
keys,
states,
logprob_fn,
adaptation_state.step_size,
jnp.ones(num_dim),
)
new_adaptation_state = update(
adaptation_state,
info.proposal.state.position,
Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def build_dynamic_kernel(
Function that generates the next `random_generator_arg` from its previous value.
integration_steps_fn
Function that generates the next pseudo or quasi-random number of integration steps in the
sequence, given the current `random_generator_arg`.
sequence, given the current `random_generator_arg`. Needs to return an `int`.
Returns
-------
Expand Down Expand Up @@ -224,7 +224,7 @@ def kernel(
logdensity_fn,
step_size,
inverse_mass_matrix,
jax.numpy.asarray(num_integration_steps, dtype=int),
num_integration_steps,
)
next_random_arg = next_random_arg_fn(state.random_generator_arg)
return (
Expand Down

0 comments on commit 79231e7

Please sign in to comment.