Skip to content

Commit

Permalink
ghmc dev (#585)
Browse files Browse the repository at this point in the history
* Update .gitignore

* Remove redundant steps in ghmc

* Update ghmc.py

* Update ghmc.py
  • Loading branch information
williwilliams3 authored and junpenglao committed Mar 12, 2024
1 parent a8ccc00 commit 80e642d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ ehthumbs.db
# Custom stuff
*.profraw
*.DS_Store
venv_blackjax/

# Apparently both patterns _have_ to be present on my system.
# Having both is a nice backup, I guess.
Expand Down
20 changes: 10 additions & 10 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import blackjax.mcmc.proposal as proposal
from blackjax.base import SamplingAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise, pytree_size
from blackjax.util import generate_gaussian_noise

__all__ = ["GHMCState", "init", "build_kernel", "ghmc"]

Expand Down Expand Up @@ -131,7 +131,9 @@ def kernel(
"""

flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0]
_, kinetic_energy_fn, _ = metrics.gaussian_euclidean(flat_inverse_scale**2)
momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
flat_inverse_scale**2
)

symplectic_integrator = integrators.velocity_verlet(
logdensity_fn, kinetic_energy_fn
Expand All @@ -147,8 +149,7 @@ def kernel(
key_momentum, key_noise = jax.random.split(rng_key)
position, momentum, logdensity, logdensity_grad, slice = state
# New momentum is persistent
momentum = update_momentum(key_momentum, state, alpha)
momentum = jax.tree_map(lambda m, s: m / s, momentum, momentum_inverse_scale)
momentum = update_momentum(key_momentum, state, alpha, momentum_generator)
# Slice is non-reversible
slice = ((slice + 1.0 + delta + noise_fn(key_noise)) % 2) - 1.0

Expand All @@ -159,7 +160,7 @@ def kernel(
proposal = hmc.flip_momentum(proposal)
state = GHMCState(
proposal.position,
jax.tree_map(lambda m, s: m * s, proposal.momentum, momentum_inverse_scale),
proposal.momentum,
proposal.logdensity,
proposal.logdensity_grad,
info.acceptance_rate,
Expand All @@ -170,23 +171,22 @@ def kernel(
return kernel


def update_momentum(rng_key, state, alpha):
def update_momentum(rng_key, state, alpha, momentum_generator):
"""Persistent update of the momentum variable.
Performs a persistent update of the momentum, taking as input the previous
momentum, a random number generating key and the parameter alpha. Outputs
momentum, a random number generating key, the parameter alpha and the
momentum generator function. Outputs
an updated momentum that is a mixture of the previous momentum a new sample
from a Gaussian density (dependent on alpha). The weights of the mixture of
these two components are a function of alpha.
"""
position, momentum, *_ = state

m_size = pytree_size(momentum)
momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,)))
momentum = jax.tree_map(
lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha)
+ shifted_momentum,
+ jnp.sqrt(alpha) * shifted_momentum,
momentum,
momentum_generator(rng_key, position),
)
Expand Down

0 comments on commit 80e642d

Please sign in to comment.