Skip to content

Commit

Permalink
Fixing a bunch of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos committed Sep 9, 2024
1 parent fd70221 commit 871713f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 119 deletions.
115 changes: 50 additions & 65 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.
"""
from typing import Callable, NamedTuple, Optional, Protocol, Tuple, Union
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from chex import Numeric
Expand Down Expand Up @@ -64,7 +65,7 @@ class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
scale: Callable[[ArrayLikeTree, Tuple[Tuple[ArrayLikeTree, bool]]], ArrayLikeTree]
scale: Callable[[ArrayLikeTree, ArrayLikeTree, bool], ArrayLikeTree]


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down Expand Up @@ -129,8 +130,8 @@ def gaussian_euclidean(
itself given the values of the momentum along the trajectory.
"""
inv_mass_matrix_sqrt, mass_matrix_sqrt, diag = _format_covariance(
inverse_mass_matrix, get_inv=True
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
inverse_mass_matrix, is_inv=True
)

def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
Expand Down Expand Up @@ -180,32 +181,34 @@ def is_turning(
return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]
) -> Tuple[ArrayLikeTree]:
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Parameters
----------
position
The current position. Not used in this metric.
elements
A tuple of (element, inv) pairs to scale.
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
Elements to scale
invs
Whether to scale the elements by the inverse mass matrix or the mass matrix.
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
Same pytree structure as `elements`.
Returns
-------
scaled_elements
The scaled elements.
"""
scaled_elements = []
for element, inv in elements:
ravelled_element, unravel_fn = ravel_pytree(element)
if inv:
ravelled_element = linear_map(inv_mass_matrix_sqrt, ravelled_element)
else:
ravelled_element = linear_map(mass_matrix_sqrt, ravelled_element)
scaled_elements.append(unravel_fn(ravelled_element))
return tuple(scaled_elements) # type: ignore

ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)
return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)

Expand All @@ -215,7 +218,7 @@ def gaussian_riemannian(
) -> Metric:
def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree:
mass_matrix = mass_matrix_fn(position)
mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, get_inv=False)
mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, is_inv=False)

return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

Expand All @@ -232,10 +235,10 @@ def kinetic_energy(
momentum, _ = ravel_pytree(momentum)
mass_matrix = mass_matrix_fn(position)
sqrt_mass_matrix, inv_sqrt_mass_matrix, diag = _format_covariance(
mass_matrix, get_inv=True
mass_matrix, is_inv=False
)

return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix, diag)
return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix.T, diag)

def is_turning(
momentum_left: ArrayLikeTree,
Expand Down Expand Up @@ -270,76 +273,58 @@ def is_turning(
# return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]
) -> Tuple[ArrayLikeTree]:
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Parameters
----------
position
The current position.
elements
A tuple of (element, inv) pairs to scale.
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
Returns
-------
scaled_elements
The scaled elements.
"""
scaled_elements = []
mass_matrix = mass_matrix_fn(position)
# some small performance improvement: group by inv and only compute the inverse Cholesky if needed

inv_elements = [
(k, element) for k, (element, inv) in enumerate(elements) if inv
]
non_inv_elements = [
(k, element) for k, (element, inv) in enumerate(elements) if not inv
]
argsort = [k for k, _ in non_inv_elements] + [k for k, _ in inv_elements]

mass_matrix_sqrt, inv_sqrt_mass_matrix, diag = _format_covariance(
mass_matrix, get_inv=bool(inv_elements)
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
mass_matrix, is_inv=False
)

for _, element in non_inv_elements:
rav_element, unravel_fn = ravel_pytree(element)
rav_element = linear_map(mass_matrix_sqrt, rav_element)
scaled_elements.append(unravel_fn(rav_element))

if inv_elements:
for _, element in inv_elements:
rav_element, unravel_fn = ravel_pytree(element)
rav_element = linear_map(inv_sqrt_mass_matrix, rav_element)
scaled_elements.append(unravel_fn(rav_element))

scaled_elements = [scaled_elements[k] for k in argsort]

return tuple(scaled_elements) # type: ignore
ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)
return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)


def _format_covariance(cov: Array, get_inv):
def _format_covariance(cov: Array, is_inv):
ndim = jnp.ndim(cov)
if ndim == 1:
cov_sqrt = jnp.sqrt(cov)
inv_cov_sqrt = 1 / cov_sqrt
diag = lambda x: x
if get_inv:
inv_cov_sqrt = jnp.reciprocal(cov_sqrt)
else:
inv_cov_sqrt = None
if is_inv:
inv_cov_sqrt, cov_sqrt = cov_sqrt, inv_cov_sqrt
elif ndim == 2:
cov_sqrt = jscipy.linalg.cholesky(cov, lower=False)
diag = lambda x: jnp.diag(x)
if get_inv:
identity = jnp.identity(cov.shape[0])
inv_cov_sqrt = jscipy.linalg.solve_triangular(
cov_sqrt, identity, lower=False
identity = jnp.identity(cov.shape[0])
if is_inv:
inv_cov_sqrt = jscipy.linalg.cholesky(cov, lower=True)
cov_sqrt = jscipy.linalg.solve_triangular(
inv_cov_sqrt, identity, lower=True, trans=True
)
else:
inv_cov_sqrt = None
cov_sqrt = jscipy.linalg.cholesky(cov, lower=False).T
inv_cov_sqrt = jscipy.linalg.solve_triangular(
cov_sqrt, identity, lower=True, trans=True
)

diag = lambda x: jnp.diag(x)

else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
Expand Down
Loading

0 comments on commit 871713f

Please sign in to comment.