From 554979d17454810acdf66db6bf8312b9d1e76820 Mon Sep 17 00:00:00 2001 From: Bernardo Williams Date: Thu, 9 Nov 2023 23:00:36 +0200 Subject: [PATCH 1/4] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7e38e71ec..25b11a123 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,7 @@ ehthumbs.db # Custom stuff *.profraw *.DS_Store +venv_blackjax/ # Apparently both patterns _have_ to be present on my system. # Having both is a nice backup, I guess. From a53ef06a9bff6132e8a48f456f141f3fcc7a7e8c Mon Sep 17 00:00:00 2001 From: Bernardo Williams Date: Thu, 9 Nov 2023 23:10:25 +0200 Subject: [PATCH 2/4] Remove redundant steps in ghmc --- blackjax/mcmc/ghmc.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index a068acee7..4a2b5f62b 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -131,7 +131,7 @@ def kernel( """ flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0] - _, kinetic_energy_fn, _ = metrics.gaussian_euclidean(flat_inverse_scale**2) + momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(flat_inverse_scale**2) symplectic_integrator = integrators.velocity_verlet( logdensity_fn, kinetic_energy_fn @@ -147,8 +147,7 @@ def kernel( key_momentum, key_noise = jax.random.split(rng_key) position, momentum, logdensity, logdensity_grad, slice = state # New momentum is persistent - momentum = update_momentum(key_momentum, state, alpha) - momentum = jax.tree_map(lambda m, s: m / s, momentum, momentum_inverse_scale) + momentum = update_momentum(key_momentum, state, alpha, momentum_generator) # Slice is non-reversible slice = ((slice + 1.0 + delta + noise_fn(key_noise)) % 2) - 1.0 @@ -159,7 +158,7 @@ def kernel( proposal = hmc.flip_momentum(proposal) state = GHMCState( proposal.position, - jax.tree_map(lambda m, s: m * s, proposal.momentum, momentum_inverse_scale), + proposal.momentum, proposal.logdensity, proposal.logdensity_grad, info.acceptance_rate, @@ -170,7 +169,7 @@ def kernel( return kernel -def update_momentum(rng_key, state, alpha): +def update_momentum(rng_key, state, alpha, momentum_generator): """Persistent update of the momentum variable. Performs a persistent update of the momentum, taking as input the previous @@ -182,11 +181,9 @@ def update_momentum(rng_key, state, alpha): """ position, momentum, *_ = state - m_size = pytree_size(momentum) - momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,))) momentum = jax.tree_map( lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha) - + shifted_momentum, + + jnp.sqrt(alpha)*shifted_momentum, momentum, momentum_generator(rng_key, position), ) From c3846aa85aa507637297d5e224ea35aeb2447916 Mon Sep 17 00:00:00 2001 From: Bernardo Williams Date: Fri, 10 Nov 2023 12:03:53 +0200 Subject: [PATCH 3/4] Update ghmc.py --- blackjax/mcmc/ghmc.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 4a2b5f62b..ce31b0d25 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -131,7 +131,9 @@ def kernel( """ flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0] - momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(flat_inverse_scale**2) + momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( + flat_inverse_scale**2 + ) symplectic_integrator = integrators.velocity_verlet( logdensity_fn, kinetic_energy_fn @@ -173,7 +175,8 @@ def update_momentum(rng_key, state, alpha, momentum_generator): """Persistent update of the momentum variable. Performs a persistent update of the momentum, taking as input the previous - momentum, a random number generating key and the parameter alpha. Outputs + momentum, a random number generating key, the parameter alpha and the + momentum generator function. Outputs an updated momentum that is a mixture of the previous momentum a new sample from a Gaussian density (dependent on alpha). The weights of the mixture of these two components are a function of alpha. @@ -183,7 +186,7 @@ def update_momentum(rng_key, state, alpha, momentum_generator): momentum = jax.tree_map( lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha) - + jnp.sqrt(alpha)*shifted_momentum, + + jnp.sqrt(alpha) * shifted_momentum, momentum, momentum_generator(rng_key, position), ) From 042d484a16224b5369d3cac777f0bfd867ef9dfd Mon Sep 17 00:00:00 2001 From: Bernardo Williams Date: Fri, 10 Nov 2023 12:37:12 +0200 Subject: [PATCH 4/4] Update ghmc.py --- blackjax/mcmc/ghmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index ce31b0d25..5c71df451 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -23,7 +23,7 @@ import blackjax.mcmc.proposal as proposal from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -from blackjax.util import generate_gaussian_noise, pytree_size +from blackjax.util import generate_gaussian_noise __all__ = ["GHMCState", "init", "build_kernel", "ghmc"]