diff --git a/.gitignore b/.gitignore index 7e38e71ec..25b11a123 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,7 @@ ehthumbs.db # Custom stuff *.profraw *.DS_Store +venv_blackjax/ # Apparently both patterns _have_ to be present on my system. # Having both is a nice backup, I guess. diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index a068acee7..5c71df451 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -23,7 +23,7 @@ import blackjax.mcmc.proposal as proposal from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -from blackjax.util import generate_gaussian_noise, pytree_size +from blackjax.util import generate_gaussian_noise __all__ = ["GHMCState", "init", "build_kernel", "ghmc"] @@ -131,7 +131,9 @@ def kernel( """ flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0] - _, kinetic_energy_fn, _ = metrics.gaussian_euclidean(flat_inverse_scale**2) + momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( + flat_inverse_scale**2 + ) symplectic_integrator = integrators.velocity_verlet( logdensity_fn, kinetic_energy_fn @@ -147,8 +149,7 @@ def kernel( key_momentum, key_noise = jax.random.split(rng_key) position, momentum, logdensity, logdensity_grad, slice = state # New momentum is persistent - momentum = update_momentum(key_momentum, state, alpha) - momentum = jax.tree_map(lambda m, s: m / s, momentum, momentum_inverse_scale) + momentum = update_momentum(key_momentum, state, alpha, momentum_generator) # Slice is non-reversible slice = ((slice + 1.0 + delta + noise_fn(key_noise)) % 2) - 1.0 @@ -159,7 +160,7 @@ def kernel( proposal = hmc.flip_momentum(proposal) state = GHMCState( proposal.position, - jax.tree_map(lambda m, s: m * s, proposal.momentum, momentum_inverse_scale), + proposal.momentum, proposal.logdensity, proposal.logdensity_grad, info.acceptance_rate, @@ -170,11 +171,12 @@ def kernel( return kernel -def update_momentum(rng_key, state, alpha): +def update_momentum(rng_key, state, alpha, momentum_generator): """Persistent update of the momentum variable. Performs a persistent update of the momentum, taking as input the previous - momentum, a random number generating key and the parameter alpha. Outputs + momentum, a random number generating key, the parameter alpha and the + momentum generator function. Outputs an updated momentum that is a mixture of the previous momentum a new sample from a Gaussian density (dependent on alpha). The weights of the mixture of these two components are a function of alpha. @@ -182,11 +184,9 @@ def update_momentum(rng_key, state, alpha): """ position, momentum, *_ = state - m_size = pytree_size(momentum) - momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,))) momentum = jax.tree_map( lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha) - + shifted_momentum, + + jnp.sqrt(alpha) * shifted_momentum, momentum, momentum_generator(rng_key, position), )