Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/blackjax-devs/blackjax into…
Browse files Browse the repository at this point in the history
… adjusted_mclmc
  • Loading branch information
reubenharry committed Dec 27, 2024
2 parents 677dea7 + 65ae00e commit 4e2e091
Show file tree
Hide file tree
Showing 16 changed files with 626 additions and 158 deletions.
4 changes: 3 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import partial_posteriors_path as _partial_posteriors_smc
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
Expand Down Expand Up @@ -122,8 +123,9 @@ def generate_top_level_api_from(module):
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)

smc_family = [tempered_smc, adaptive_tempered_smc]
smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
Expand Down
146 changes: 79 additions & 67 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.scipy import stats
from jax.tree_util import tree_leaves, tree_map

import blackjax.mcmc.metrics as metrics
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.metrics import Metric
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey
from blackjax.util import generate_gaussian_noise

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]

Expand Down Expand Up @@ -81,44 +83,70 @@ def build_kernel():
"""

def _compute_acceptance_probability(
state: BarkerState,
proposal: BarkerState,
) -> float:
state: BarkerState, proposal: BarkerState, metric: Metric
) -> Numeric:
"""Compute the acceptance probability of the Barker's proposal kernel."""

def ratio_proposal_nd(y, x, log_y, log_x):
num = -_log1pexp(-log_y * (x - y))
den = -_log1pexp(-log_x * (y - x))
x = state.position
y = proposal.position
log_x = state.logdensity_grad
log_y = proposal.logdensity_grad

return jnp.sum(num - den)
y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x)
x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x)
z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True)
z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True)

ratios_proposals = tree_map(
ratio_proposal_nd,
proposal.position,
state.position,
proposal.logdensity_grad,
state.logdensity_grad,
c_x_to_y = metric.scale(x, log_x, inv=False, trans=True)
c_y_to_x = metric.scale(y, log_y, inv=False, trans=True)

z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y)
z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x)

c_x_to_y_flat, _ = ravel_pytree(c_x_to_y)
c_y_to_x_flat, _ = ravel_pytree(c_y_to_x)

num = metric.kinetic_energy(x_minus_y, y) - _log1pexp(
-z_tilde_y_to_x_flat * c_y_to_x_flat
)
ratio_proposal = sum(tree_leaves(ratios_proposals))
denom = metric.kinetic_energy(y_minus_x, x) - _log1pexp(
-z_tilde_x_to_y_flat * c_x_to_y_flat
)

ratio_proposal = jnp.sum(num - denom)

return proposal.logdensity - state.logdensity + ratio_proposal

def kernel(
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
rng_key: PRNGKey,
state: BarkerState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: metrics.MetricTypes | None = None,
) -> tuple[BarkerState, BarkerInfo]:
"""Generate a new sample with the MALA kernel."""
"""Generate a new sample with the Barker kernel."""
if inverse_mass_matrix is None:
p, _ = ravel_pytree(state.position)
(m,) = p.shape
inverse_mass_matrix = jnp.ones((m,))
metric = metrics.default_metric(inverse_mass_matrix)
grad_fn = jax.value_and_grad(logdensity_fn)

key_sample, key_rmh = jax.random.split(rng_key)

proposed_pos = _barker_sample(
key_sample, state.position, state.logdensity_grad, step_size
key_sample,
state.position,
state.logdensity_grad,
step_size,
metric,
)

proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos)
proposed_state = BarkerState(
proposed_pos, proposed_logdensity, proposed_logdensity_grad
)

log_p_accept = _compute_acceptance_probability(state, proposed_state)
log_p_accept = _compute_acceptance_probability(state, proposed_state, metric)
accepted_state, info = static_binomial_sampling(
key_rmh, log_p_accept, state, proposed_state
)
Expand All @@ -131,6 +159,7 @@ def kernel(
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: metrics.MetricTypes | None = None,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
Gaussian base kernel.
Expand Down Expand Up @@ -174,7 +203,9 @@ def as_top_level_api(
logdensity_fn
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
The value of the step_size correspnoding to the global scale of the proposal distribution.
inverse_mass_matrix
The inverse mass matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`).
Returns
-------
Expand All @@ -189,74 +220,55 @@ def init_fn(position: ArrayLikeTree, rng_key=None):
return init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)
return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix)

return SamplingAlgorithm(init_fn, step_fn)


def _barker_sample_nd(key, mean, a, scale):
"""
Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function:
.. math::
p(x; \\mu, a, \\sigma) = 2 \frac{N(x; \\mu, \\sigma^2)}{1 + \\exp(-a (x - \\mu)}
def _generate_bernoulli(
rng_key: PRNGKey, position: ArrayLikeTree, p: ArrayLikeTree
) -> ArrayTree:
pos, unravel_fn = ravel_pytree(position)
p_flat, _ = ravel_pytree(p)
sample = jax.random.bernoulli(rng_key, p=p_flat, shape=pos.shape)
return unravel_fn(sample)

where :math:`N(x; \\mu, \\sigma^2)` is the normal distribution with mean :math:`\\mu` and standard deviation :math:`\\sigma`.
The multivariate Barker's proposal distribution is the product of one-dimensional Barker's proposal distributions.

def _barker_sample(key, mean, a, scale, metric):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.
Parameters
----------
key
A PRNG key.
mean
The mean of the normal distribution, an Array. This corresponds to :math:`\\mu` in the equation above.
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
a
The parameter :math:`a` in the equation above, an Array. This is a skewness parameter.
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above.
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
It encodes the step size of the proposal.
Returns
-------
A sample from the Barker's multidimensional proposal distribution.
metric
A `metrics.MetricTypes` object encoding the mass matrix information.
"""

key1, key2 = jax.random.split(key)
z = scale * jax.random.normal(key1, shape=mean.shape)

z = generate_gaussian_noise(key1, mean, sigma=scale)
c = metric.scale(mean, a, inv=False, trans=True)

# Sample b=1 with probability p and 0 with probability 1 - p where
# p = 1 / (1 + exp(-a * (z - mean)))
log_p = -_log1pexp(-a * z)
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape)

# return mean + z if b == 1 else mean - z
return mean + b * z - (1 - b) * z

log_p = jax.tree_util.tree_map(lambda x, y: -_log1pexp(-x * y), c, z)
p = jax.tree_util.tree_map(lambda x: jnp.exp(x), log_p)
b = _generate_bernoulli(key2, mean, p=p)

def _barker_sample(key, mean, a, scale):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.
Parameters
----------
key
A PRNG key.
mean
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
a
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above.
It encodes the step size of the proposal.
"""
bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z)

flat_mean, unravel_fn = ravel_pytree(mean)
flat_a, _ = ravel_pytree(a)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale)
return unravel_fn(flat_sample)
return jax.tree_util.tree_map(
lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False)
)


def _log1pexp(a):
Expand Down
55 changes: 39 additions & 16 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -62,7 +61,12 @@ def __call__(

class Scale(Protocol):
def __call__(
self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
self,
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
...

Expand Down Expand Up @@ -187,7 +191,11 @@ def is_turning(
return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand All @@ -197,10 +205,11 @@ def scale(
The current position. Not used in this metric.
elements
Elements to scale
invs
inv
Whether to scale the elements by the inverse mass matrix or the mass matrix.
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
Same pytree structure as `elements`.
trans
whether to transpose mass matrix when scaling
Returns
-------
Expand All @@ -209,11 +218,16 @@ def scale(
"""

ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

if inv:
left_hand_side_matrix = inv_mass_matrix_sqrt
else:
left_hand_side_matrix = mass_matrix_sqrt
if trans:
left_hand_side_matrix = left_hand_side_matrix.T

scaled = linear_map(left_hand_side_matrix, ravelled_element)

return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)
Expand Down Expand Up @@ -279,7 +293,11 @@ def is_turning(
# return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand All @@ -298,11 +316,16 @@ def scale(
mass_matrix, is_inv=False
)
ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

if inv:
left_hand_side_matrix = inv_mass_matrix_sqrt
else:
left_hand_side_matrix = mass_matrix_sqrt
if trans:
left_hand_side_matrix = left_hand_side_matrix.T

scaled = linear_map(left_hand_side_matrix, ravelled_element)

return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)
Expand Down
1 change: 1 addition & 0 deletions blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
"tempered",
"inner_kernel_tuning",
"extend_params",
"partial_posteriors_path",
]
28 changes: 28 additions & 0 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,31 @@ def extend_params(params):
"""

return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles
Loading

0 comments on commit 4e2e091

Please sign in to comment.