Skip to content

Commit

Permalink
add static adjusted mclmc
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jan 15, 2025
1 parent 8eed424 commit 9dd6bdb
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 52 deletions.
10 changes: 4 additions & 6 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def adjusted_mclmc_find_L_and_step_size(
dim = pytree_size(state.position)
if params is None:
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,))
)

part1_key, part2_key = jax.random.split(rng_key, 2)
Expand Down Expand Up @@ -152,7 +152,7 @@ def step(iteration_state, weight_and_key):
state=previous_state,
avg_num_integration_steps=avg_num_integration_steps,
step_size=params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
inverse_mass_matrix=params.inverse_mass_matrix,
)

# step updating
Expand Down Expand Up @@ -283,9 +283,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L=params.L * change, step_size=params.step_size * change
)
if diagonal_preconditioning:
params = params._replace(
sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim)
)
params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim))

initial_da, update_da, final_da = dual_averaging_adaptation(target=target)
(
Expand Down Expand Up @@ -323,7 +321,7 @@ def step(state, key):
state=state,
step_size=params.step_size,
avg_num_integration_steps=params.L / params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
inverse_mass_matrix=params.inverse_mass_matrix,
)
return next_state, next_state.position

Expand Down
22 changes: 11 additions & 11 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
sqrt_diag_cov
inverse_mass_matrix
A matrix used for preconditioning.
"""

L: float
step_size: float
sqrt_diag_cov: float
inverse_mass_matrix: float


def mclmc_find_L_and_step_size(
Expand Down Expand Up @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size(
Example
-------
.. code::
kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel(
kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
sqrt_diag_cov=sqrt_diag_cov,
inverse_mass_matrix=inverse_mass_matrix,
)
(
Expand All @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size(
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

Expand All @@ -123,7 +123,7 @@ def mclmc_find_L_and_step_size(

if frac_tune3 != 0:
state, params = make_adaptation_L(
mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params
Expand Down Expand Up @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
rng_key, nan_key = jax.random.split(rng_key)

# dynamics
next_state, info = kernel(params.sqrt_diag_cov)(
next_state, info = kernel(params.inverse_mass_matrix)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
inverse_mass_matrix = params.inverse_mass_matrix
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

if diagonal_preconditioning:
sqrt_diag_cov = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov=sqrt_diag_cov)
inverse_mass_matrix = variances
params = params._replace(inverse_mass_matrix=inverse_mass_matrix)
L = jnp.sqrt(dim)

# readjust the stepsize
Expand All @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)
return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix)

return L_step_size_adaptation

Expand Down
10 changes: 6 additions & 4 deletions blackjax/mcmc/adjusted_mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def build_kernel(
num_integration_steps: int,
integrator: Callable = integrators.isokinetic_mclachlan,
divergence_threshold: float = 1000,
sqrt_diag_cov=1.0,
inverse_mass_matrix=1.0,
):
"""Build an MHMCHMC kernel where the number of integration steps is chosen randomly.
Expand Down Expand Up @@ -76,7 +76,9 @@ def kernel(
momentum = generate_unit_vector(key_momentum, state.position)
proposal, info, _ = adjusted_mclmc_proposal(
integrator=integrators.with_isokinetic_maruyama(
integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov)
integrator(
logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix
)
),
step_size=step_size,
L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size),
Expand Down Expand Up @@ -105,7 +107,7 @@ def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
L_proposal_factor: float = jnp.inf,
sqrt_diag_cov=1.0,
inverse_mass_matrix=1.0,
*,
divergence_threshold: int = 1000,
integrator: Callable = integrators.isokinetic_mclachlan,
Expand Down Expand Up @@ -140,7 +142,7 @@ def as_top_level_api(
kernel = build_kernel(
num_integration_steps,
integrator=integrator,
sqrt_diag_cov=sqrt_diag_cov,
inverse_mass_matrix=inverse_mass_matrix,
divergence_threshold=divergence_threshold,
)

Expand Down
10 changes: 6 additions & 4 deletions blackjax/mcmc/adjusted_mclmc_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def build_kernel(
integrator: Callable = integrators.isokinetic_mclachlan,
divergence_threshold: float = 1000,
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1],
sqrt_diag_cov=1.0,
inverse_mass_matrix=1.0,
):
"""Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly.
Expand Down Expand Up @@ -76,7 +76,9 @@ def kernel(
momentum = generate_unit_vector(key_momentum, state.position)
proposal, info, _ = adjusted_mclmc_proposal(
integrator=integrators.with_isokinetic_maruyama(
integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov)
integrator(
logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix
)
),
step_size=step_size,
L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size),
Expand Down Expand Up @@ -106,7 +108,7 @@ def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
L_proposal_factor: float = jnp.inf,
sqrt_diag_cov=1.0,
inverse_mass_matrix=1.0,
*,
divergence_threshold: int = 1000,
integrator: Callable = integrators.isokinetic_mclachlan,
Expand Down Expand Up @@ -143,7 +145,7 @@ def as_top_level_api(
integration_steps_fn=integration_steps_fn,
integrator=integrator,
next_random_arg_fn=next_random_arg_fn,
sqrt_diag_cov=sqrt_diag_cov,
inverse_mass_matrix=inverse_mass_matrix,
divergence_threshold=divergence_threshold,
)

Expand Down
12 changes: 7 additions & 5 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ def _normalized_flatten_array(x, tol=1e-13):
return jnp.where(norm > tol, x / norm, x), norm


def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0):
def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0):
sqrt_inverse_mass_matrix = jax.tree_util.tree_map(jnp.sqrt, inverse_mass_matrix)

def update(
momentum: ArrayTree,
logdensity_grad: ArrayTree,
Expand All @@ -330,7 +332,7 @@ def update(

logdensity_grad = logdensity_grad
flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_grads = flatten_grads * sqrt_diag_cov
flatten_grads = flatten_grads * sqrt_inverse_mass_matrix
flatten_momentum, _ = ravel_pytree(momentum)
dims = flatten_momentum.shape[0]
normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads)
Expand All @@ -342,7 +344,7 @@ def update(
+ 2 * zeta * flatten_momentum
)
new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw)
gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov)
gr = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix)
next_momentum = unravel_fn(new_momentum_normalized)
kinetic_energy_change = (
delta
Expand Down Expand Up @@ -374,11 +376,11 @@ def format_isokinetic_state_output(

def generate_isokinetic_integrator(coefficients):
def isokinetic_integrator(
logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0
logdensity_fn: Callable, inverse_mass_matrix: ArrayTree = 1.0
) -> GeneralIntegrator:
position_update_fn = euclidean_position_update_fn(logdensity_fn)
one_step = generalized_two_stage_integrator(
esh_dynamics_momentum_update_one_step(sqrt_diag_cov),
esh_dynamics_momentum_update_one_step(inverse_mass_matrix),
position_update_fn,
coefficients,
format_output_fn=format_isokinetic_state_output,
Expand Down
8 changes: 4 additions & 4 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key):
)


def build_kernel(logdensity_fn, sqrt_diag_cov, integrator):
def build_kernel(logdensity_fn, inverse_mass_matrix, integrator):
"""Build a HMC kernel.
Parameters
Expand All @@ -81,7 +81,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator):
"""

step = with_isokinetic_maruyama(
integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov)
integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix)
)

def kernel(
Expand All @@ -107,7 +107,7 @@ def as_top_level_api(
L,
step_size,
integrator=isokinetic_mclachlan,
sqrt_diag_cov=1.0,
inverse_mass_matrix=1.0,
) -> SamplingAlgorithm:
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
Expand Down Expand Up @@ -155,7 +155,7 @@ def as_top_level_api(
A ``SamplingAlgorithm``.
"""

kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator)
kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator)

def init_fn(position: ArrayLike, rng_key: PRNGKey):
return init(position, logdensity_fn, rng_key)
Expand Down
4 changes: 2 additions & 2 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_esh_momentum_update(self, dims):

# Efficient implementation
update_stable = self.variant(
esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0)
esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0)
)
next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0)
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)
Expand All @@ -263,7 +263,7 @@ def test_isokinetic_velocity_verlet(self):
next_state, kinetic_energy_change = step(initial_state, step_size)

# explicit integration
op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0)
op1 = esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0)
op2 = integrators.euclidean_position_update_fn(logdensity_fn)
position, momentum, _, logdensity_grad = initial_state
momentum, kinetic_grad, kinetic_energy_change0 = op1(
Expand Down
Loading

0 comments on commit 9dd6bdb

Please sign in to comment.