Skip to content

Commit

Permalink
names fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab committed Oct 30, 2023
1 parent ba86244 commit 8b540bc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def base(halton_sequence: Array, optim: optax.GradientTransformation):
Function that initializes the warmup.
update
Function that moves the warmup one step.
"""

da_init, da_update, _ = dual_averaging.dual_averaging()
Expand Down Expand Up @@ -279,7 +279,7 @@ def chees_adaptation(
.. code::
warmup = blackjax.chees(logprob_fn, num_chains)
warmup = blackjax.chees_adaptation(logprob_fn, num_chains)
key_warmup, key_sample = jax.random.split(rng_key)
optim = optax.adam(learning_rate)
last_states, kernel, parameters = warmup.run(
Expand Down
4 changes: 2 additions & 2 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,15 @@ def test_chees(self):
init_key, warmup_key, inference_key = jax.random.split(rng_key, 3)

num_chains = 128
warmup = blackjax.chees(
warmup = blackjax.chees_adaptation(
logposterior_fn,
num_chains=num_chains,
)
scale_key, coefs_key = jax.random.split(init_key, 2)
log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,))
coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,))
initial_positions = {"log_scale": log_scales, "coefs": coefs}
last_states, kernel, _ = warmup.run(
(last_states, _), kernel, _ = warmup.run(
warmup_key,
initial_positions,
step_size=0.001,
Expand Down

0 comments on commit 8b540bc

Please sign in to comment.