Skip to content

Commit

Permalink
return tuning steps
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jan 19, 2025
1 parent fd7f7fa commit 98d4873
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
46 changes: 35 additions & 11 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def adjusted_mclmc_find_L_and_step_size(
total_num_tuning_integrator_steps = 0
for i in range(num_windows):
window_key = jax.random.fold_in(part1_key, i)
(state, params, eigenvector, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation(
(
state,
params,
eigenvector,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -91,25 +96,38 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, window_key)
)(
state, params, num_steps, window_key
)
total_num_tuning_integrator_steps += num_tuning_integrator_steps

if frac_tune3 != 0:
for i in range(num_windows):
part2_key = jax.random.fold_in(part2_key, i)
part2_key1, part2_key2 = jax.random.split(part2_key, 2)

state, params, num_tuning_integrator_steps = adjusted_mclmc_make_adaptation_L(
(
state,
params,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_adaptation_L(
mclmc_kernel,
frac=frac_tune3,
Lfactor=0.5,
max=max,
eigenvector=eigenvector,
)(state, params, num_steps, part2_key1)
)(
state, params, num_steps, part2_key1
)

total_num_tuning_integrator_steps += num_tuning_integrator_steps

(state, params, _, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation(
(
state,
params,
_,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -119,7 +137,9 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, part2_key2)
)(
state, params, num_steps, part2_key2
)

total_num_tuning_integrator_steps += num_tuning_integrator_steps

Expand Down Expand Up @@ -355,11 +375,15 @@ def step(state, key):
# number of effective samples per 1 actual sample
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps

return state, params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
), info.num_integration_steps.sum()
return (
state,
params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
),
info.num_integration_steps.sum(),
)

return adaptation_L

Expand Down
6 changes: 1 addition & 5 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,7 @@ def get_inverse_mass_matrix():
inverse_mass_matrix=inverse_mass_matrix,
)

(
_,
blackjax_mclmc_sampler_params,
_
) = blackjax.mclmc_find_L_and_step_size(
(_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
Expand Down

0 comments on commit 98d4873

Please sign in to comment.