From eca35abc1fed16c3d1174482b0cbf16e084c72ae Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Wed, 19 Jun 2024 22:18:50 -0700 Subject: [PATCH 01/13] convert to bit twiddling (#696) --- blackjax/mcmc/termination.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/blackjax/mcmc/termination.py b/blackjax/mcmc/termination.py index 24e17c3a5..eb1276da3 100644 --- a/blackjax/mcmc/termination.py +++ b/blackjax/mcmc/termination.py @@ -64,16 +64,10 @@ def _leaf_idx_to_ckpt_idxs(n): """Find the checkpoint id from a step number.""" # computes the number of non-zero bits except the last bit # e.g. 6 -> 2, 7 -> 2, 13 -> 2 - _, idx_max = jax.lax.while_loop( - lambda nc: nc[0] > 0, - lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), - (n >> 1, 0), - ) + idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32) # computes the number of contiguous last non-zero bits # e.g. 6 -> 0, 7 -> 3, 13 -> 1 - _, num_subtrees = jax.lax.while_loop( - lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0) - ) + num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32) idx_min = idx_max - num_subtrees + 1 return idx_min, idx_max From 5764a2b4aff803dff45d903370cc70f2000aa7ef Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 24 Jun 2024 07:01:50 +0200 Subject: [PATCH 02/13] Remove nightly release (#699) --- .github/workflows/nightly.yml | 48 ----------------------------------- README.md | 6 ----- docs/index.md | 6 ----- 3 files changed, 60 deletions(-) delete mode 100644 .github/workflows/nightly.yml diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml deleted file mode 100644 index 6472e4421..000000000 --- a/.github/workflows/nightly.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: Nightly - -on: - push: - branches: [main] - -jobs: - build_and_publish: - name: Build and publish on PyPi - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - uses: actions/setup-python@v4 - with: - python-version: 3.11 - - name: Update pyproject.toml - # Taken from https://github.com/aesara-devs/aesara/pull/1375 - run: | - curl -sSLf https://github.com/TomWright/dasel/releases/download/v2.0.2/dasel_linux_amd64 \ - -L -o /tmp/dasel && chmod +x /tmp/dasel - /tmp/dasel put -f pyproject.toml project.name -v blackjax-nightly - /tmp/dasel put -f pyproject.toml tool.setuptools_scm.version_scheme -v post-release - /tmp/dasel put -f pyproject.toml tool.setuptools_scm.local_scheme -v no-local-version - - name: Build the sdist and wheel - run: | - python -m pip install -U pip - python -m pip install build - python -m build - - name: Check sdist install and imports - run: | - mkdir -p test-sdist - cd test-sdist - python -m venv venv-sdist - venv-sdist/bin/python -m pip install ../dist/blackjax-nightly-*.tar.gz - venv-sdist/bin/python -c "import blackjax" - - name: Check wheel install and imports - run: | - mkdir -p test-wheel - cd test-wheel - python -m venv venv-wheel - venv-wheel/bin/python -m pip install ../dist/blackjax_nightly-*.whl - - name: Publish to PyPi - uses: pypa/gh-action-pypi-publish@v1.4.2 - with: - user: __token__ - password: ${{ secrets.PYPI_NIGHTLY_TOKEN }} diff --git a/README.md b/README.md index d7d78b15f..a8d847cf9 100644 --- a/README.md +++ b/README.md @@ -41,12 +41,6 @@ or via conda-forge: conda install -c conda-forge blackjax ``` -Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`: - -```bash -pip install blackjax-nightly -``` - BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow diff --git a/docs/index.md b/docs/index.md index 0fd84d860..edc02631c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,13 +57,7 @@ If you want to use Blackjax with a model implemented with a PPL, go to the relat ```{code-block} bash pip install blackjax ``` -::: -:::{tab-item} Nightly -```{code-block} bash -pip install blackjax-nightly -``` -::: :::{tab-item} Conda ```{code-block} bash From f8db9aa04d83fed4b3bde0f37ddf6194b229308f Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Mon, 24 Jun 2024 01:07:30 -0400 Subject: [PATCH 03/13] Fix doc mistakes (#701) * Fix equation formatting * Clarify JAX gradient error * Fix punctuation + capitalization * Fix grammar Should not begin sentence with "i.e." in English. * Fix math formatting error * Fix typo Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation. * Add SVGD citation to appear in doc Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation. To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring. * Fix grammar + clarify doc * Fix typo --------- Co-authored-by: Junpeng Lao --- blackjax/adaptation/meads_adaptation.py | 4 ++-- blackjax/base.py | 2 +- blackjax/vi/svgd.py | 2 +- docs/examples/howto_custom_gradients.md | 13 ++++++------- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/blackjax/adaptation/meads_adaptation.py b/blackjax/adaptation/meads_adaptation.py index a431a591d..b383653e8 100644 --- a/blackjax/adaptation/meads_adaptation.py +++ b/blackjax/adaptation/meads_adaptation.py @@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple): alpha Value of the alpha parameter of the generalized HMC algorithm. delta - Value of the alpha parameter of the generalized HMC algorithm. + Value of the delta parameter of the generalized HMC algorithm. """ @@ -60,7 +60,7 @@ def base(): with shape. This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain - adaptation instead of parallel ensample chain adaptation. + adaptation instead of parallel ensemble chain adaptation. Returns ------- diff --git a/blackjax/base.py b/blackjax/base.py index f766e98b5..8ea24cd70 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple): """A pair of functions that represents a MCMC sampling algorithm. Blackjax sampling algorithms are implemented as a pair of pure functions: a - kernel, that takes a new samples starting from the current state, and an + kernel, that generates a new sample from the current state, and an initialization function that creates a kernel state from a chain position. As they represent Markov kernels, the kernel functions are pure functions diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py index 881de77e6..e287b813f 100644 --- a/blackjax/vi/svgd.py +++ b/blackjax/vi/svgd.py @@ -135,7 +135,7 @@ def as_top_level_api( kernel: Callable = rbf_kernel, update_kernel_parameters: Callable = update_median_heuristic, ): - """Implements the (basic) user interface for the svgd algorithm. + """Implements the (basic) user interface for the svgd algorithm :cite:p:`liu2016stein`. Parameters ---------- diff --git a/docs/examples/howto_custom_gradients.md b/docs/examples/howto_custom_gradients.md index 731de0eea..653e6d393 100644 --- a/docs/examples/howto_custom_gradients.md +++ b/docs/examples/howto_custom_gradients.md @@ -29,10 +29,9 @@ Functions can be defined as the minimum of another one, $f(x) = min_{y} g(x,y)$. Our example is taken from the theory of [convex conjugates](https://en.wikipedia.org/wiki/Convex_conjugate), used for example in optimal transport. Let's consider the following function: $$ -\begin{align*} -g(x, y) &= h(y) - \langle x, y\rangle\\ -h(x) &= \frac{1}{p}|x|^p,\qquad p > 1.\\ -\end{align*} +\begin{equation*} +g(x, y) = h(y) - \langle x, y\rangle,\qquad h(x) = \frac{1}{p}|x|^p,\qquad p > 1. +\end{equation*} $$ And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as: @@ -69,7 +68,7 @@ Note the we also return the value of $y$ where the minimum of $g$ is achieved (t ### Trying to differentate the function with `jax.grad` -The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error: +The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops used in BFGS, and trying to compute it directly raises an error: ```{code-cell} ipython3 # We only want the gradient with respect to `x` @@ -97,7 +96,7 @@ The first order optimality criterion \end{equation*} ``` -Ensures that: +ensures that ```{math} \begin{equation*} @@ -105,7 +104,7 @@ Ensures that: \end{equation*} ``` -i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved. +In other words, the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved. ### Telling JAX to use a custom gradient From 441412a09e39f514189be84813f812d95709365c Mon Sep 17 00:00:00 2001 From: johannahaffner <38662446+johannahaffner@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:32:54 +0200 Subject: [PATCH 04/13] Update index.md (#711) The jitted step remained unused, leading to the example running with an uncompiled nuts.step. Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed. --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index edc02631c..fca4787c4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,7 +41,7 @@ rng_key = jax.random.key(0) step = jax.jit(nuts.step) for i in range(1_000): nuts_key = jax.random.fold_in(rng_key, i) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` :::{note} From 27dfc9e30dd5b8c8f0771f5f6f3cbf7ec3f4c7ec Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Wed, 7 Aug 2024 06:22:43 -0700 Subject: [PATCH 05/13] Enable progress bar under pmap (#712) * enable pmap progbar * fix bar creation * add locking * fix formatting * switch to using chain state --- blackjax/adaptation/window_adaptation.py | 8 +++- blackjax/progress_bar.py | 55 ++++++++++++++---------- blackjax/util.py | 14 ++++-- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 63c54bad0..cb02eb2c4 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -334,16 +334,22 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): if progress_bar: print("Running window adaptation") one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) + start_state = ((init_state, init_adaptation_state), -1) else: one_step_ = jax.jit(one_step) + start_state = (init_state, init_adaptation_state) keys = jax.random.split(rng_key, num_steps) schedule = build_schedule(num_steps) last_state, info = jax.lax.scan( one_step_, - (init_state, init_adaptation_state), + start_state, (jnp.arange(num_steps), keys, schedule), ) + + if progress_bar: + last_state, _ = last_state + last_chain_state, last_warmup_state, *_ = last_state step_size, inverse_mass_matrix = adapt_final(last_warmup_state) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index ac509b9b6..188ab7dba 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -14,14 +14,19 @@ """Progress bar decorators for use with step functions. Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. """ +from threading import Lock + from fastprogress.fastprogress import progress_bar from jax import lax from jax.experimental import io_callback +from jax.numpy import array def progress_bar_scan(num_samples, print_rate=None): "Progress bar for a JAX scan" progress_bars = {} + idx_counter = 0 + lock = Lock() if print_rate is None: if num_samples > 20: @@ -29,41 +34,44 @@ def progress_bar_scan(num_samples, print_rate=None): else: print_rate = 1 # if you run the sampler for less than 20 iterations - def _define_bar(arg): - del arg - progress_bars[0] = progress_bar(range(num_samples)) - progress_bars[0].update(0) + def _calc_chain_idx(iter_num): + nonlocal idx_counter + with lock: + idx = idx_counter + idx_counter += 1 + return idx + + def _update_bar(arg, chain_id): + chain_id = int(chain_id) + if arg == 0: + chain_id = _calc_chain_idx(arg) + progress_bars[chain_id] = progress_bar(range(num_samples)) + progress_bars[chain_id].update(0) - def _update_bar(arg): - progress_bars[0].update_bar(arg + 1) + progress_bars[chain_id].update_bar(arg + 1) + return chain_id - def _close_bar(arg): - del arg - progress_bars[0].on_iter_end() + def _close_bar(arg, chain_id): + progress_bars[int(chain_id)].on_iter_end() - def _update_progress_bar(iter_num): + def _update_progress_bar(iter_num, chain_id): "Updates progress bar of a JAX scan or loop" - _ = lax.cond( - iter_num == 0, - lambda _: io_callback(_define_bar, None, iter_num), - lambda _: None, - operand=None, - ) - _ = lax.cond( + chain_id = lax.cond( # update every multiple of `print_rate` except at the end (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), - lambda _: io_callback(_update_bar, None, iter_num), - lambda _: None, + lambda _: io_callback(_update_bar, array(0), iter_num, chain_id), + lambda _: chain_id, operand=None, ) _ = lax.cond( iter_num == num_samples - 1, - lambda _: io_callback(_close_bar, None, None), + lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id), lambda _: None, operand=None, ) + return chain_id def _progress_bar_scan(func): """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. @@ -77,8 +85,11 @@ def wrapper_progress_bar(carry, x): iter_num, *_ = x else: iter_num = x - _update_progress_bar(iter_num) - return func(carry, x) + subcarry, chain_id = carry + chain_id = _update_progress_bar(iter_num, chain_id) + subcarry, y = func(subcarry, x) + + return (subcarry, chain_id), y return wrapper_progress_bar diff --git a/blackjax/util.py b/blackjax/util.py index cdb9f4c91..78a7c0633 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -224,13 +224,19 @@ def one_step(average_and_state, xs, return_state): one_step = jax.jit(partial(one_step, return_state=return_state_history)) + xs = (jnp.arange(num_steps), keys) if progress_bar: one_step = progress_bar_scan(num_steps)(one_step) + (((_, average), final_state), _), history = lax.scan( + one_step, + (((0, expectation(transform(initial_state))), initial_state), -1), + xs, + ) - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(transform(initial_state))), initial_state), xs - ) + else: + ((_, average), final_state), history = lax.scan( + one_step, ((0, expectation(transform(initial_state))), initial_state), xs + ) if not return_state_history: return average, transform(final_state) From 148c02880faae6b272565ff3ea4f423268b67010 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 8 Aug 2024 23:36:46 -0700 Subject: [PATCH 06/13] remove labels (#716) --- blackjax/adaptation/window_adaptation.py | 17 +++++------------ blackjax/progress_bar.py | 14 ++++++++++++++ blackjax/util.py | 20 +++++++------------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index cb02eb2c4..69a098325 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -28,7 +28,7 @@ dual_averaging_adaptation, ) from blackjax.base import AdaptationAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size @@ -333,23 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): if progress_bar: print("Running window adaptation") - one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) - start_state = ((init_state, init_adaptation_state), -1) - else: - one_step_ = jax.jit(one_step) - start_state = (init_state, init_adaptation_state) - + scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar) + start_state = (init_state, init_adaptation_state) keys = jax.random.split(rng_key, num_steps) schedule = build_schedule(num_steps) - last_state, info = jax.lax.scan( - one_step_, + last_state, info = scan_fn( + one_step, start_state, (jnp.arange(num_steps), keys, schedule), ) - if progress_bar: - last_state, _ = last_state - last_chain_state, last_warmup_state, *_ = last_state step_size, inverse_mass_matrix = adapt_final(last_warmup_state) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index 188ab7dba..a1425df88 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -94,3 +94,17 @@ def wrapper_progress_bar(carry, x): return wrapper_progress_bar return _progress_bar_scan + + +def gen_scan_fn(num_samples, progress_bar, print_rate=None): + if progress_bar: + + def scan_wrap(f, init, *args, **kwargs): + func = progress_bar_scan(num_samples, print_rate)(f) + carry = (init, -1) + (last_state, _), output = lax.scan(func, carry, *args, **kwargs) + return last_state, output + + return scan_wrap + else: + return lax.scan diff --git a/blackjax/util.py b/blackjax/util.py index 78a7c0633..9f4d6f9c7 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -11,7 +11,7 @@ from jax.tree_util import tree_leaves from blackjax.base import SamplingAlgorithm, VIAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -225,18 +225,12 @@ def one_step(average_and_state, xs, return_state): one_step = jax.jit(partial(one_step, return_state=return_state_history)) xs = (jnp.arange(num_steps), keys) - if progress_bar: - one_step = progress_bar_scan(num_steps)(one_step) - (((_, average), final_state), _), history = lax.scan( - one_step, - (((0, expectation(transform(initial_state))), initial_state), -1), - xs, - ) - - else: - ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(transform(initial_state))), initial_state), xs - ) + scan_fn = gen_scan_fn(num_steps, progress_bar) + ((_, average), final_state), history = scan_fn( + one_step, + ((0, expectation(transform(initial_state))), initial_state), + xs, + ) if not return_state_history: return average, transform(final_state) From 7135fd766e6908f345cafbdeb7a9ae798d69f355 Mon Sep 17 00:00:00 2001 From: Reuben Date: Mon, 12 Aug 2024 09:33:43 -0400 Subject: [PATCH 07/13] Simplify `run_inference_algorithm` (#714) * fix minor type errors * storing only expectation values * fixed memory efficient sampling * clean up * renaming vars * precommit fixes * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * merge main * burn in and fix tests * burn in and fix tests * minor fixes * minor fixes * minor fixes --------- Co-authored-by: jakob.robnik@gmail.com --- blackjax/adaptation/mclmc_adaptation.py | 8 +- blackjax/util.py | 156 +++++++++++++++--------- tests/adaptation/test_adaptation.py | 2 +- tests/mcmc/test_sampling.py | 57 +++++---- tests/test_benchmarks.py | 2 +- tests/test_util.py | 56 +++++---- 6 files changed, 171 insertions(+), 110 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 76a016242..7645a890b 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size, streaming_average_update +from blackjax.util import incremental_value_update, pytree_size class MCLMCAdaptationState(NamedTuple): @@ -199,9 +199,9 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average_update( - current_value=jnp.array([x, jnp.square(x)]), - previous_weight_and_average=streaming_avg, + streaming_avg = incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=streaming_avg, weight=(1 - mask) * success * params.step_size, zero_prevention=mask, ) diff --git a/blackjax/util.py b/blackjax/util.py index 9f4d6f9c7..b6c5367b5 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -3,12 +3,11 @@ from functools import partial from typing import Callable, Union -import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree from jax.random import normal, split -from jax.tree_util import tree_leaves +from jax.tree_util import tree_leaves, tree_map from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn @@ -149,9 +148,7 @@ def run_inference_algorithm( initial_state: ArrayLikeTree = None, initial_position: ArrayLikeTree = None, progress_bar: bool = False, - transform: Callable = lambda x: x, - return_state_history=True, - expectation: Callable = lambda x: x, + transform: Callable = lambda state, info: (state, info), ) -> tuple: """Wrapper to run an inference algorithm. @@ -166,8 +163,7 @@ def run_inference_algorithm( initial_state The initial state of the inference algorithm. initial_position - The initial position of the inference algorithm. This is used when the initial - state is not provided. + The initial position of the inference algorithm. This is used when the initial state is not provided. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -175,26 +171,14 @@ def run_inference_algorithm( progress_bar Whether to display a progress bar. transform - A transformation of the trace of states to be returned. This is useful for + A transformation of the trace of states (and info) to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - expectation - A function that computes the expectation of the state. This is done - incrementally, so doesn't require storing all the states. - return_state_history - if False, `run_inference_algorithm` will only return an expectation of the value - of transform, and return that average instead of the full set of samples. This - is useful when memory is a bottleneck. Returns ------- - If return_state_history is True: 1. The final state. - 2. The trace of the state. - 3. The trace of the info of the inference algorithm for diagnostics. - If return_state_history is False: - 1. This is the expectation of state over the chain. Otherwise the final state. - 2. The final state of the inference algorithm. + 2. The history of states. """ if initial_state is None and initial_position is None: @@ -212,58 +196,116 @@ def run_inference_algorithm( keys = split(rng_key, num_steps) - def one_step(average_and_state, xs, return_state): + def one_step(state, xs): _, rng_key = xs - average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average_update(expectation(transform(state)), average) - if return_state: - return (average, state), (transform(state), info) - else: - return (average, state), None + return state, transform(state, info) - one_step = jax.jit(partial(one_step, return_state=return_state_history)) - - xs = (jnp.arange(num_steps), keys) scan_fn = gen_scan_fn(num_steps, progress_bar) - ((_, average), final_state), history = scan_fn( - one_step, - ((0, expectation(transform(initial_state))), initial_state), - xs, - ) - if not return_state_history: - return average, transform(final_state) - else: - state_history, info_history = history - return transform(final_state), state_history, info_history + xs = jnp.arange(num_steps), keys + final_state, history = scan_fn(one_step, initial_state, xs) + + return final_state, history -def streaming_average_update( - current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0 +def store_only_expectation_values( + sampling_algorithm, + state_transform=lambda x: x, + incremental_value_transform=lambda x: x, + burn_in=0, +): + """Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same + kernel but only stores the streaming expectation values of some observables, not the full states; to save memory. + + It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample. + + Example: + + .. code:: + + init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3) + model = StandardNormal(2) + initial_position = model.sample_init(init_key) + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key + ) + integrator_type = "mclachlan" + L = 1.0 + step_size = 0.1 + num_steps = 4 + + integrator = map_integrator_type_to_integrator['mclmc'][integrator_type] + state_transform = lambda state: state.position + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, + state_transform=state_transform) + + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( + + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, + ) + """ + + def init_fn(state): + averaging_state = (0.0, state_transform(state)) + return (state, averaging_state) + + def update_fn(rng_key, state_and_incremental_val): + state, averaging_state = state_and_incremental_val + state, info = sampling_algorithm.step( + rng_key, state + ) # update the state with the sampling algorithm + averaging_state = incremental_value_update( + state_transform(state), + averaging_state, + weight=( + averaging_state[0] >= burn_in + ), # If we want to eliminate some number of steps as a burn-in + zero_prevention=1e-10 * (burn_in > 0), + ) + # update the expectation value with the running average + return (state, averaging_state), info + + def transform(state_and_incremental_val, info): + (state, (_, incremental_value)) = state_and_incremental_val + return incremental_value_transform(incremental_value), info + + return SamplingAlgorithm(init_fn, update_fn), transform + + +def incremental_value_update( + expectation, incremental_val, weight=1.0, zero_prevention=0.0 ): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- - current_value - the current value of the function that we want to take average of - previous_weight_and_average - tuple of (previous_weight, previous_average) where previous_weight is the - sum of weights and average is the current estimated average + expectation + the value of the expectation at the current timestep + incremental_val + tuple of (total, average) where total is the sum of weights and average is the current average weight weight of the current state zero_prevention small value to prevent division by zero Returns: ---------- - new total weight and streaming average + new streaming average """ - previous_weight, previous_average = previous_weight_and_average - current_weight = previous_weight + weight - current_average = jax.tree.map( - lambda x, avg: (previous_weight * avg + weight * x) - / (current_weight + zero_prevention), - current_value, - previous_average, + + total, average = incremental_val + average = tree_map( + lambda exp, av: (total * av + weight * exp) + / (total + weight + zero_prevention), + expectation, + average, ) - return current_weight, current_average + total += weight + return total, average diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 68751bee8..4b34511be 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters): algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, _, infos = jax.vmap( + _, (_, infos) = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 18a07625b..c399929da 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -135,12 +135,12 @@ def run_mclmc( sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) - _, samples, _ = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) return samples @@ -197,7 +197,7 @@ def check_attrs(attribute, keyset): for i, attribute in enumerate(["state", "info", "adaptation_state"]): check_attrs(attribute, keysets[i]) - _, states, _ = run_inference_algorithm( + _, (states, _) = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, @@ -223,15 +223,16 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=mala, + transform=lambda state, info: state.position, num_steps=10_000, ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -375,15 +376,16 @@ def test_pathfinder_adaptation( ) inference_algorithm = algorithm(logposterior_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, + transform=lambda state, info: state.position, ) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -418,17 +420,18 @@ def test_meads(self): inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -465,17 +468,18 @@ def test_chees(self, jitter_generator): inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -494,15 +498,16 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=barker, + transform=lambda state, info: state.position, num_steps=10_000, ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @@ -679,19 +684,20 @@ def test_latent_gaussian(self): initial_state = inference_algorithm.init(jnp.zeros((1,))) - _, states, _ = self.variant( + _, states = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=self.sampling_steps, ), )(rng_key=self.key, initial_state=initial_state) np.testing.assert_allclose( - np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 + np.var(states[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 ) np.testing.assert_allclose( - np.mean(states.position[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 + np.mean(states[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 ) @@ -724,7 +730,7 @@ def univariate_normal_test_case( **kwargs, ): inference_key, orbit_key = jax.random.split(rng_key) - _, states, _ = self.variant( + _, (states, info) = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, @@ -855,7 +861,7 @@ def postprocess_samples(states, key): 20_000, burnin, postprocess_samples, - transform=lambda x: (x.positions, x.weights), + transform=lambda state, info: ((state.positions, state.weights), info), ) @chex.all_variants(with_pmap=False) @@ -997,14 +1003,15 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=2_000, ) ) - _, states, _ = inference_loop_multiple_chains( + _, states = inference_loop_multiple_chains( rng_key=multi_chain_sample_key, initial_state=initial_states ) - posterior_samples = states.position[:, -1000:] + posterior_samples = states[:, -1000:] posterior_delta = posterior_samples - true_loc posterior_variance = posterior_delta**2.0 posterior_correlation = jnp.prod(posterior_delta, axis=-1, keepdims=True) / ( diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index c2295e7e2..2d108a48d 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -48,7 +48,7 @@ def run_regression(algorithm, **parameters): ) inference_algorithm = algorithm(logdensity_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, (states, _) = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, diff --git a/tests/test_util.py b/tests/test_util.py index 1f03498dd..78198f013 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized import blackjax -from blackjax.util import run_inference_algorithm +from blackjax.util import run_inference_algorithm, store_only_expectation_values class RunInferenceAlgorithmTest(chex.TestCase): @@ -30,7 +30,7 @@ def check_compatible(self, initial_state, progress_bar): inference_algorithm=self.algorithm, num_steps=self.num_steps, progress_bar=progress_bar, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) def test_streaming(self): @@ -41,37 +41,49 @@ def logdensity_fn(x): 10, ) - init_key, run_key = jax.random.split(self.key, 2) - + init_key, state_key, run_key = jax.random.split(self.key, 3) initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + position=initial_position, logdensity_fn=logdensity_fn, rng_key=state_key + ) + L = 1.0 + step_size = 0.1 + num_steps = 4 + + sampling_alg = blackjax.mclmc( + logdensity_fn, + L=L, + step_size=step_size, ) - alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + state_transform = lambda x: x.position - _, states, info = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=True, + inference_algorithm=sampling_alg, + num_steps=num_steps, + transform=lambda state, info: state_transform(state), + progress_bar=True, + ) + + print("average of steps (slow way):", samples.mean(axis=0)) + + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, state_transform=state_transform ) - average, _ = run_inference_algorithm( + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=False, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, ) - assert jnp.allclose(states.mean(axis=0), average) + assert jnp.allclose(trace_at_every_step[0][-1], samples.mean(axis=0)) @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): @@ -81,7 +93,7 @@ def test_compatible_with_initial_pos(self, progress_bar): inference_algorithm=self.algorithm, num_steps=self.num_steps, progress_bar=progress_bar, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) @parameterized.parameters([True, False]) From 834f55d5fa6d5f76c78d31f3ac2c90b3fe2d4e25 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Tue, 13 Aug 2024 02:26:05 -0400 Subject: [PATCH 08/13] Harmonize Quickstart example (#717) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a8d847cf9..9590b4cd6 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,10 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) +step = jax.jit(nuts.step) for step in range(100): nuts_key = jax.random.fold_in(rng_key, step) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc. From 4a11236930a49cf06ebb41e530002a90c6ecec21 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:04:20 -0400 Subject: [PATCH 09/13] Update README.md (#719) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9590b4cd6..06d5b46cf 100644 --- a/README.md +++ b/README.md @@ -76,8 +76,8 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) step = jax.jit(nuts.step) -for step in range(100): - nuts_key = jax.random.fold_in(rng_key, step) +for i in range(100): + nuts_key = jax.random.fold_in(rng_key, i) state, _ = step(nuts_key, state) ``` From 072cc81a67154c3bc4601b75ecec66da25da0899 Mon Sep 17 00:00:00 2001 From: Reuben Date: Sat, 24 Aug 2024 19:21:05 -0400 Subject: [PATCH 10/13] Bug fix (#724) * bug fix; first part * bug fix; first part * further debug * remove print statements --- blackjax/adaptation/mclmc_adaptation.py | 7 ++++--- blackjax/mcmc/integrators.py | 1 + blackjax/util.py | 9 +++++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 7645a890b..3365526b3 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -154,6 +154,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): L=params.L, step_size=params.step_size, ) + # step updating success, state, step_size_max, energy_change = handle_nans( previous_state, @@ -203,7 +204,6 @@ def step(iteration_state, weight_and_key): expectation=jnp.array([x, jnp.square(x)]), incremental_val=streaming_avg, weight=(1 - mask) * success * params.step_size, - zero_prevention=mask, ) return (state, params, adaptive_state, streaming_avg), None @@ -243,7 +243,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L sqrt_diag_cov = params.sqrt_diag_cov - if num_steps2 != 0.0: + if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) @@ -304,7 +304,8 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch reduced_step_size = 0.8 p, unravel_fn = ravel_pytree(next_state.position) - nonans = jnp.all(jnp.isfinite(p)) + q, unravel_fn = ravel_pytree(next_state.momentum) + nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q))) state, step_size, kinetic_change = jax.tree_util.tree_map( lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (next_state, step_size_max, kinetic_change), diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 1d4b95a09..e9d19e3dc 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -435,6 +435,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): ) # one step of the deterministic dynamics state, info = integrator(state, step_size) + # partial refreshment state = state._replace( momentum=partially_refresh_momentum( diff --git a/blackjax/util.py b/blackjax/util.py index b6c5367b5..8cdcd45ee 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -281,6 +281,10 @@ def transform(state_and_incremental_val, info): return SamplingAlgorithm(init_fn, update_fn), transform +def safediv(x, y): + return jnp.where(x == 0.0, 0.0, x / y) + + def incremental_value_update( expectation, incremental_val, weight=1.0, zero_prevention=0.0 ): @@ -302,8 +306,9 @@ def incremental_value_update( total, average = incremental_val average = tree_map( - lambda exp, av: (total * av + weight * exp) - / (total + weight + zero_prevention), + lambda exp, av: safediv( + total * av + weight * exp, (total + weight + zero_prevention) + ), expectation, average, ) From b02b60b16eb9078967586e1a1e613382d4be8565 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 26 Aug 2024 11:26:51 -0300 Subject: [PATCH 11/13] Waste Free SMC available for adaptive tempered and tempered SMC. (#721) * extracting taking last * test passing * layering * example * more * Adding another example * tests in place * rolling back changes * Adding test for num_mcmc_steps * format * better test coverage * linter * Flake8 * black * Update blackjax/smc/waste_free.py Co-authored-by: Junpeng Lao * fixing linter --------- Co-authored-by: Junpeng Lao --- blackjax/smc/adaptive_tempered.py | 4 + blackjax/smc/tempered.py | 60 ++++++++--- blackjax/smc/waste_free.py | 70 +++++++++++++ tests/smc/test_smc.py | 94 ++++++++--------- tests/smc/test_waste_free_smc.py | 163 ++++++++++++++++++++++++++++++ 5 files changed, 323 insertions(+), 68 deletions(-) create mode 100644 blackjax/smc/waste_free.py create mode 100644 tests/smc/test_waste_free_smc.py diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index 10fb194fa..9e773e9b6 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -34,6 +34,7 @@ def build_kernel( resampling_fn: Callable, target_ess: float, root_solver: Callable = solver.dichotomy, + **extra_parameters, ) -> Callable: r"""Build a Tempered SMC step using an adaptive schedule. @@ -88,6 +89,7 @@ def compute_delta(state: tempered.TemperedSMCState) -> float: mcmc_step_fn, mcmc_init_fn, resampling_fn, + **extra_parameters, ) def kernel( @@ -116,6 +118,7 @@ def as_top_level_api( target_ess: float, root_solver: Callable = solver.dichotomy, num_mcmc_steps: int = 10, + **extra_parameters, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. @@ -155,6 +158,7 @@ def as_top_level_api( resampling_fn, target_ess, root_solver, + **extra_parameters, ) def init_fn(position: ArrayLikeTree, rng_key=None): diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 43b83d034..19de8afb7 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, NamedTuple +from typing import Callable, NamedTuple, Optional import jax import jax.numpy as jnp @@ -48,12 +48,42 @@ def init(particles: ArrayLikeTree): return TemperedSMCState(particles, weights, 0.0) +def update_and_take_last( + mcmc_init_fn, + tempered_logposterior_fn, + shared_mcmc_step_fn, + num_mcmc_steps, + n_particles, +): + """ + Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and + returns the last values, waisting the previous num_mcmc_steps-1 + samples per chain. + """ + + def mcmc_kernel(rng_key, position, step_parameters): + state = mcmc_init_fn(position, tempered_logposterior_fn) + + def body_fn(state, rng_key): + new_state, info = shared_mcmc_step_fn( + rng_key, state, tempered_logposterior_fn, **step_parameters + ) + return new_state, info + + keys = jax.random.split(rng_key, num_mcmc_steps) + last_state, info = jax.lax.scan(body_fn, state, keys) + return last_state.position, info + + return jax.vmap(mcmc_kernel), n_particles + + def build_kernel( logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, + update_strategy: Callable = update_and_take_last, ) -> Callable: """Build the base Tempered SMC kernel. @@ -141,26 +171,23 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) - def mcmc_kernel(rng_key, position, step_parameters): - state = mcmc_init_fn(position, tempered_logposterior_fn) - - def body_fn(state, rng_key): - new_state, info = shared_mcmc_step_fn( - rng_key, state, tempered_logposterior_fn, **step_parameters - ) - return new_state, info - - keys = jax.random.split(rng_key, num_mcmc_steps) - last_state, info = jax.lax.scan(body_fn, state, keys) - return last_state.position, info + update_fn, num_resampled = update_strategy( + mcmc_init_fn, + tempered_logposterior_fn, + shared_mcmc_step_fn, + n_particles=state.weights.shape[0], + num_mcmc_steps=num_mcmc_steps, + ) smc_state, info = smc.base.step( rng_key, SMCState(state.particles, state.weights, unshared_mcmc_parameters), - jax.vmap(mcmc_kernel), + update_fn, jax.vmap(log_weights_fn), resampling_fn, + num_resampled, ) + tempered_state = TemperedSMCState( smc_state.particles, smc_state.weights, state.lmbda + delta ) @@ -177,7 +204,8 @@ def as_top_level_api( mcmc_init_fn: Callable, mcmc_parameters: dict, resampling_fn: Callable, - num_mcmc_steps: int = 10, + num_mcmc_steps: Optional[int] = 10, + update_strategy=update_and_take_last, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. @@ -204,12 +232,14 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ + kernel = build_kernel( logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, resampling_fn, + update_strategy, ) def init_fn(position: ArrayLikeTree, rng_key=None): diff --git a/blackjax/smc/waste_free.py b/blackjax/smc/waste_free.py new file mode 100644 index 000000000..629cdd725 --- /dev/null +++ b/blackjax/smc/waste_free.py @@ -0,0 +1,70 @@ +import functools + +import jax +import jax.lax +import jax.numpy as jnp + + +def update_waste_free( + mcmc_init_fn, + logposterior_fn, + mcmc_step_fn, + n_particles: int, + p: int, + num_resampled, + num_mcmc_steps=None, +): + """ + Given M particles, mutates them using p-1 steps. Returns M*P-1 particles, + consistent of the initial plus all the intermediate steps, thus implementing a + waste-free update function + See Algorithm 2: https://arxiv.org/abs/2011.02328 + """ + if num_mcmc_steps is not None: + raise ValueError( + "Can't use waste free SMC with a num_mcmc_steps parameter, set num_mcmc_steps = None" + ) + + num_mcmc_steps = p - 1 + + def mcmc_kernel(rng_key, position, step_parameters): + state = mcmc_init_fn(position, logposterior_fn) + + def body_fn(state, rng_key): + new_state, info = mcmc_step_fn( + rng_key, state, logposterior_fn, **step_parameters + ) + return new_state, (new_state, info) + + _, (states, infos) = jax.lax.scan( + body_fn, state, jax.random.split(rng_key, num_mcmc_steps) + ) + return states, infos + + def update(rng_key, position, step_parameters): + """ + Given the initial particles, runs a chain starting at each. + The combines the initial particles with all the particles generated + at each step of each chain. + """ + states, infos = jax.vmap(mcmc_kernel)(rng_key, position, step_parameters) + + # step particles is num_resmapled, num_mcmc_steps, dimension_of_variable + # want to transformed into num_resampled * num_mcmc_steps, dimension of variable + def reshape_step_particles(x): + _num_resampled, num_mcmc_steps, *dimension_of_variable = x.shape + return x.reshape((_num_resampled * num_mcmc_steps, *dimension_of_variable)) + + step_particles = jax.tree.map(reshape_step_particles, states.position) + new_particles = jax.tree.map( + lambda x, y: jnp.concatenate([x, y]), position, step_particles + ) + return new_particles, infos + + return update, num_resampled + + +def waste_free_smc(n_particles, p): + if not n_particles % p == 0: + raise ValueError("p must be a divider of n_particles ") + return functools.partial(update_waste_free, num_resampled=int(n_particles / p), p=p) diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 6366182a8..b0e86e0b0 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -1,4 +1,6 @@ """Test the generic SMC sampler""" +import functools + import chex import jax import jax.numpy as jnp @@ -9,6 +11,8 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax.smc.base import extend_params, init, step +from blackjax.smc.tempered import update_and_take_last +from blackjax.smc.waste_free import update_waste_free def logdensity_fn(position): @@ -29,82 +33,66 @@ def setUp(self): @chex.variants(with_jit=True) def test_smc(self): num_mcmc_steps = 20 - num_particles = 1000 - - def update_fn(rng_key, position, update_params): - hmc = blackjax.hmc(logdensity_fn, **update_params) - state = hmc.init(position) - - def body_fn(state, rng_key): - new_state, info = hmc.step(rng_key, state) - return new_state, info - - keys = jax.random.split(rng_key, num_mcmc_steps) - last_state, info = jax.lax.scan(body_fn, state, keys) - return last_state.position, info - - init_key, sample_key = jax.random.split(self.key) + num_particles = 5000 - # Initialize the state of the SMC sampler - init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) same_for_all_params = dict( step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 ) + hmc_kernel = functools.partial( + blackjax.hmc.build_kernel(), **same_for_all_params + ) + hmc_init = blackjax.hmc.init - state = init( - init_particles, - same_for_all_params, + update_fn, _ = update_and_take_last( + hmc_init, logdensity_fn, hmc_kernel, num_mcmc_steps, num_particles ) + init_key, sample_key = jax.random.split(self.key) + # Initialize the state of the SMC sampler + init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) + state = init(init_particles, {}) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4))( sample_key, state, - jax.vmap(update_fn, in_axes=(0, 0, None)), + update_fn, jax.vmap(logdensity_fn), resampling.systematic, ) + assert new_state.particles.shape == (num_particles,) mean, std = _weighted_avg_and_std(new_state.particles, state.weights) - np.testing.assert_allclose(0.0, mean, atol=1e-1) - np.testing.assert_allclose(1.0, std, atol=1e-1) + np.testing.assert_allclose(mean, 0.0, atol=1e-1) + np.testing.assert_allclose(std, 1.0, atol=1e-1) @chex.variants(with_jit=True) def test_smc_waste_free(self): - num_mcmc_steps = 10 + p = 500 num_particles = 1000 - num_resampled = num_particles // num_mcmc_steps - - def waste_free_update_fn(keys, particles, update_params): - def one_particle_fn(rng_key, position, particle_update_params): - hmc = blackjax.hmc(logdensity_fn, **particle_update_params) - state = hmc.init(position) - - def body_fn(state, rng_key): - new_state, info = hmc.step(rng_key, state) - return new_state, (state, info) - - keys = jax.random.split(rng_key, num_mcmc_steps) - _, (states, info) = jax.lax.scan(body_fn, state, keys) - return states.position, info - - particles, info = jax.vmap(one_particle_fn, in_axes=(0, 0, None))( - keys, particles, update_params - ) - particles = particles.reshape((num_particles,)) - return particles, info - + num_resampled = num_particles // p init_key, sample_key = jax.random.split(self.key) # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) state = init( init_particles, - dict( - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=100, - ), + {}, + ) + same_for_all_params = dict( + step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 + ) + hmc_kernel = functools.partial( + blackjax.hmc.build_kernel(), **same_for_all_params + ) + hmc_init = blackjax.hmc.init + + waste_free_update_fn, _ = update_waste_free( + hmc_init, + logdensity_fn, + hmc_kernel, + num_particles, + p=p, + num_resampled=num_resampled, ) # Run the SMC sampler once @@ -116,10 +104,10 @@ def body_fn(state, rng_key): resampling.systematic, num_resampled, ) - + assert new_state.particles.shape == (num_particles,) mean, std = _weighted_avg_and_std(new_state.particles, state.weights) - np.testing.assert_allclose(0.0, mean, atol=1e-1) - np.testing.assert_allclose(1.0, std, atol=1e-1) + np.testing.assert_allclose(mean, 0.0, atol=1e-1) + np.testing.assert_allclose(std, 1.0, atol=1e-1) class ExtendParamsTest(chex.TestCase): diff --git a/tests/smc/test_waste_free_smc.py b/tests/smc/test_waste_free_smc.py new file mode 100644 index 000000000..a5eeef135 --- /dev/null +++ b/tests/smc/test_waste_free_smc.py @@ -0,0 +1,163 @@ +"""Test the tempered SMC steps and routine""" + +import functools + +import chex +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from absl.testing import absltest + +import blackjax +import blackjax.smc.resampling as resampling +from blackjax import adaptive_tempered_smc, tempered_smc +from blackjax.smc import extend_params +from blackjax.smc.waste_free import update_waste_free, waste_free_smc +from tests.smc import SMCLinearRegressionTestCase +from tests.smc.test_tempered_smc import inference_loop + + +class WasteFreeSMCTest(SMCLinearRegressionTestCase): + """Test posterior mean estimate.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.variants(with_jit=True) + def test_fixed_schedule_tempered_smc(self): + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + num_tempering_steps = 10 + + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + hmc_init = blackjax.hmc.init + hmc_kernel = blackjax.hmc.build_kernel() + hmc_parameters = extend_params( + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) + + tempering = tempered_smc( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + resampling.systematic, + None, + waste_free_smc(100, 4), + ) + init_state = tempering.init(init_particles) + smc_kernel = self.variant(tempering.step) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda) + return (i + 1, new_state), (new_state, info) + + (_, result), _ = jax.lax.scan(body_fn, (0, init_state), lambda_schedule) + self.assert_linear_regression_test_case(result) + + @chex.variants(with_jit=True) + def test_adaptive_tempered_smc(self): + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + hmc_init = blackjax.hmc.init + hmc_kernel = blackjax.hmc.build_kernel() + hmc_parameters = extend_params( + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) + + tempering = adaptive_tempered_smc( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + resampling.systematic, + 0.5, + update_strategy=waste_free_smc(100, 4), + num_mcmc_steps=None, + ) + init_state = tempering.init(init_particles) + + n_iter, result, log_likelihood = self.variant( + functools.partial(inference_loop, tempering.step) + )(self.key, init_state) + + self.assert_linear_regression_test_case(result) + + +class Update_waste_free_multivariate_particles(chex.TestCase): + @chex.variants(with_jit=True) + def test_update_waste_free_multivariate_particles(self): + """ + Given resampled multivariate particles, + when updating with waste free, they are joined + by the result of iterating the MCMC chain to + get a bigger set of particles. + """ + resampled_particles = np.ones((50, 3)) + n_particles = 100 + + def normal_logdensity(x): + return jnp.log( + jax.scipy.stats.multivariate_normal.pdf( + x, mean=np.zeros(3), cov=np.diag(np.ones(3)) + ) + ) + + def rmh_proposal_distribution(rng_key, position): + return position + jax.random.normal(rng_key, (3,)) * 25.0 + + kernel = functools.partial( + blackjax.rmh.build_kernel(), transition_generator=rmh_proposal_distribution + ) + init = blackjax.rmh.init + update, _ = waste_free_smc(n_particles, 2)( + init, normal_logdensity, kernel, n_particles + ) + + updated_particles, infos = self.variant(update)( + jax.random.split(jax.random.PRNGKey(10), 50), resampled_particles, {} + ) + + assert updated_particles.shape == (n_particles, 3) + + +def test_waste_free_set_num_mcmc_steps(): + with pytest.raises(ValueError) as exc_info: + update_waste_free( + lambda x: x, lambda x: 1, lambda x: 1, 100, 10, 3, num_mcmc_steps=50 + ) + assert str(exc_info.value).startswith( + "Can't use waste free SMC with a num_mcmc_steps parameter" + ) + + +def test_waste_free_p_non_divier(): + with pytest.raises(ValueError) as exc_info: + waste_free_smc(100, 3) + assert str(exc_info.value).startswith("p must be a divider") + + +if __name__ == "__main__": + absltest.main() From 8a9b5466866cb75b8e2a61f3ac8b7d46b8cc8dca Mon Sep 17 00:00:00 2001 From: Reuben Date: Mon, 26 Aug 2024 10:27:09 -0400 Subject: [PATCH 12/13] NaN Handling (#727) * bug fix; first part * bug fix; first part * further debug * remove print statements * handle logdensity nans. mask -> 1 - mask. --- blackjax/adaptation/mclmc_adaptation.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 3365526b3..831586201 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import incremental_value_update, pytree_size +from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size class MCLMCAdaptationState(NamedTuple): @@ -147,6 +147,8 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state + rng_key, nan_key = jax.random.split(rng_key) + # dynamics next_state, info = kernel(params.sqrt_diag_cov)( rng_key=rng_key, @@ -162,6 +164,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): params.step_size, step_size_max, info.energy_change, + nan_key, ) # Warning: var = 0 if there were nans, but we will give it a very small weight @@ -203,7 +206,7 @@ def step(iteration_state, weight_and_key): streaming_avg = incremental_value_update( expectation=jnp.array([x, jnp.square(x)]), incremental_val=streaming_avg, - weight=(1 - mask) * success * params.step_size, + weight=mask * success * params.step_size, ) return (state, params, adaptive_state, streaming_avg), None @@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): ) # we use the last num_steps2 to compute the diagonal preconditioner - mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) # run the steps state, params, _, (_, average) = run_steps( @@ -298,7 +301,9 @@ def step(state, key): return adaptation_L -def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): +def handle_nans( + previous_state, next_state, step_size, step_size_max, kinetic_change, key +): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" @@ -311,4 +316,13 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch (next_state, step_size_max, kinetic_change), (previous_state, step_size * reduced_step_size, 0.0), ) + + state = jax.lax.cond( + jnp.isnan(next_state.logdensity), + lambda: state._replace( + momentum=generate_unit_vector(key, previous_state.position) + ), + lambda: state, + ) + return nonans, state, step_size, kinetic_change From e1d816af3b81384f36359b979e4ecdef855a36da Mon Sep 17 00:00:00 2001 From: Adrien Corenflos Date: Mon, 16 Sep 2024 20:52:27 +0100 Subject: [PATCH 13/13] Implement metric scaling (#733) * Plotting BlackJAX with BlackJAX * Plotting BlackJAX with BlackJAX * Proposed implementation for metric scaling * Add tests and fix some small typing issues raised by pre-commit. * Fix remaining failing tests * pre-commit run * The original implementation was using upper cholesky, I was using lower. * Fixing a bunch of tests * Update blackjax/mcmc/metrics.py Co-authored-by: Junpeng Lao * Update blackjax/mcmc/metrics.py Co-authored-by: Junpeng Lao * Merged comments from Junpeng * Merged comments from Junpeng --------- Co-authored-by: Junpeng Lao --- blackjax/mcmc/ghmc.py | 9 +- blackjax/mcmc/metrics.py | 186 ++++++++++++++++++++---------- blackjax/mcmc/periodic_orbital.py | 2 +- blackjax/types.py | 4 + tests/mcmc/test_metrics.py | 151 ++++++++++++++++++++++-- tests/mcmc/test_trajectory.py | 5 +- tests/mcmc/test_uturn.py | 2 +- 7 files changed, 281 insertions(+), 78 deletions(-) diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index a04ce0641..5f8ab89a7 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +from jax.flatten_util import ravel_pytree import blackjax.mcmc.hmc as hmc import blackjax.mcmc.integrators as integrators @@ -129,8 +130,8 @@ def kernel( """ - flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0] - momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( + flat_inverse_scale = ravel_pytree(momentum_inverse_scale)[0] + momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean( flat_inverse_scale**2 ) @@ -248,6 +249,10 @@ def as_top_level_api( A PyTree of the same structure as the target PyTree (position) with the values used for as a step size for each dimension of the target space in the velocity verlet integrator. + momentum_inverse_scale + Pytree with the same structure as the targeted position variable + specifying the per dimension inverse scaling transformation applied + to the persistent momentum variable prior to the integration step. alpha The value defining the persistence of the momentum variable. delta diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 1368a8441..4e079714b 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -30,13 +30,13 @@ """ from typing import Callable, NamedTuple, Optional, Protocol, Union +import jax import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree -from jax.scipy import stats as sp_stats -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -from blackjax.util import generate_gaussian_noise +from blackjax.types import Array, ArrayLikeTree, ArrayTree, Numeric, PRNGKey +from blackjax.util import generate_gaussian_noise, linear_map __all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"] @@ -44,7 +44,7 @@ class KineticEnergy(Protocol): def __call__( self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> float: + ) -> Numeric: ... @@ -60,10 +60,18 @@ def __call__( ... +class Scale(Protocol): + def __call__( + self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + ) -> ArrayLikeTree: + ... + + class Metric(NamedTuple): sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree] kinetic_energy: KineticEnergy check_turning: CheckTurning + scale: Scale MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] @@ -128,46 +136,19 @@ def gaussian_euclidean( itself given the values of the momentum along the trajectory. """ - ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type] - shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type] - - if ndim == 1: # diagonal mass matrix - mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix)) - matmul = jnp.multiply - - elif ndim == 2: - # inverse mass matrix can be factored into L*L.T. We want the cholesky - # factor (inverse of L.T) of the mass matrix. - L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True) - identity = jnp.identity(shape[0]) - mass_matrix_sqrt = jscipy.linalg.solve_triangular( - L, identity, lower=True, trans=True - ) - # Note that mass_matrix_sqrt is a upper triangular matrix here, with - # jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T) - # == inverse_mass_matrix - # An alternative is to compute directly the cholesky factor of the inverse mass - # matrix - # mass_matrix_sqrt = jscipy.linalg.cholesky( - # jscipy.linalg.inv(inverse_mass_matrix), lower=True) - # which the result would instead be a lower triangular matrix. - matmul = jnp.matmul - - else: - raise ValueError( - "The mass matrix has the wrong number of dimensions:" - f" expected 1 or 2, got {ndim}." - ) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance( + inverse_mass_matrix, is_inv=True + ) def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree: return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) def kinetic_energy( momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> float: + ) -> Numeric: del position momentum, _ = ravel_pytree(momentum) - velocity = matmul(inverse_mass_matrix, momentum) + velocity = linear_map(inverse_mass_matrix, momentum) kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum) return kinetic_energy_val @@ -196,8 +177,8 @@ def is_turning( m_right, _ = ravel_pytree(momentum_right) m_sum, _ = ravel_pytree(momentum_sum) - velocity_left = matmul(inverse_mass_matrix, m_left) - velocity_right = matmul(inverse_mass_matrix, m_right) + velocity_left = linear_map(inverse_mass_matrix, m_left) + velocity_right = linear_map(inverse_mass_matrix, m_right) # rho = m_sum rho = m_sum - (m_right + m_left) / 2 @@ -205,7 +186,37 @@ def is_turning( turning_at_right = jnp.dot(velocity_right, rho) <= 0 return turning_at_left | turning_at_right - return Metric(momentum_generator, kinetic_energy, is_turning) + def scale( + position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + ) -> ArrayLikeTree: + """Scale elements by the mass matrix. + + Parameters + ---------- + position + The current position. Not used in this metric. + elements + Elements to scale + invs + Whether to scale the elements by the inverse mass matrix or the mass matrix. + If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem. + Same pytree structure as `elements`. + + Returns + ------- + scaled_elements + The scaled elements. + """ + + ravelled_element, unravel_fn = ravel_pytree(element) + scaled = jax.lax.cond( + inv, + lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), + lambda: linear_map(mass_matrix_sqrt, ravelled_element), + ) + return unravel_fn(scaled) + + return Metric(momentum_generator, kinetic_energy, is_turning, scale) def gaussian_riemannian( @@ -213,22 +224,13 @@ def gaussian_riemannian( ) -> Metric: def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree: mass_matrix = mass_matrix_fn(position) - ndim = jnp.ndim(mass_matrix) - if ndim == 1: - mass_matrix_sqrt = jnp.sqrt(mass_matrix) - elif ndim == 2: - mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True) - else: - raise ValueError( - "The mass matrix has the wrong number of dimensions:" - f" expected 1 or 2, got {jnp.ndim(mass_matrix)}." - ) + mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, is_inv=False) return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) def kinetic_energy( momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> float: + ) -> Numeric: if position is None: raise ValueError( "A Reinmannian kinetic energy function must be called with the " @@ -238,18 +240,11 @@ def kinetic_energy( momentum, _ = ravel_pytree(momentum) mass_matrix = mass_matrix_fn(position) - ndim = jnp.ndim(mass_matrix) - if ndim == 1: - return -jnp.sum(sp_stats.norm.logpdf(momentum, 0.0, jnp.sqrt(mass_matrix))) - elif ndim == 2: - return -sp_stats.multivariate_normal.logpdf( - momentum, jnp.zeros_like(momentum), mass_matrix - ) - else: - raise ValueError( - "The mass matrix has the wrong number of dimensions:" - f" expected 1 or 2, got {jnp.ndim(mass_matrix)}." - ) + sqrt_mass_matrix, inv_sqrt_mass_matrix, diag = _format_covariance( + mass_matrix, is_inv=False + ) + + return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix.T, diag) def is_turning( momentum_left: ArrayLikeTree, @@ -283,4 +278,69 @@ def is_turning( # turning_at_right = jnp.dot(velocity_right, rho) <= 0 # return turning_at_left | turning_at_right - return Metric(momentum_generator, kinetic_energy, is_turning) + def scale( + position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + ) -> ArrayLikeTree: + """Scale elements by the mass matrix. + + Parameters + ---------- + position + The current position. + + Returns + ------- + scaled_elements + The scaled elements. + """ + mass_matrix = mass_matrix_fn(position) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance( + mass_matrix, is_inv=False + ) + ravelled_element, unravel_fn = ravel_pytree(element) + scaled = jax.lax.cond( + inv, + lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), + lambda: linear_map(mass_matrix_sqrt, ravelled_element), + ) + return unravel_fn(scaled) + + return Metric(momentum_generator, kinetic_energy, is_turning, scale) + + +def _format_covariance(cov: Array, is_inv): + ndim = jnp.ndim(cov) + if ndim == 1: + cov_sqrt = jnp.sqrt(cov) + inv_cov_sqrt = 1 / cov_sqrt + diag = lambda x: x + if is_inv: + inv_cov_sqrt, cov_sqrt = cov_sqrt, inv_cov_sqrt + elif ndim == 2: + identity = jnp.identity(cov.shape[0]) + if is_inv: + inv_cov_sqrt = jscipy.linalg.cholesky(cov, lower=True) + cov_sqrt = jscipy.linalg.solve_triangular( + inv_cov_sqrt, identity, lower=True, trans=True + ) + else: + cov_sqrt = jscipy.linalg.cholesky(cov, lower=False).T + inv_cov_sqrt = jscipy.linalg.solve_triangular( + cov_sqrt, identity, lower=True, trans=True + ) + + diag = lambda x: jnp.diag(x) + + else: + raise ValueError( + "The mass matrix has the wrong number of dimensions:" + f" expected 1 or 2, got {jnp.ndim(cov)}." + ) + return cov_sqrt, inv_cov_sqrt, diag + + +def _energy(x, mean, cov_sqrt, inv_cov_sqrt, diag): + d = x.shape[0] + z = linear_map(inv_cov_sqrt, x - mean) + const = jnp.sum(jnp.log(diag(cov_sqrt))) + d / 2 * jnp.log(2 * jnp.pi) + return 0.5 * jnp.sum(z**2) + const diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index 61625a0b8..b996205e8 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -172,7 +172,7 @@ def kernel( """ - momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( + momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean( inverse_mass_matrix ) bijection_fn = bijection(logdensity_fn, kinetic_energy_fn) diff --git a/blackjax/types.py b/blackjax/types.py index 5a3b59f07..4b23fcd22 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -43,3 +43,7 @@ class WelfordAlgorithmState(NamedTuple): #: JAX PRNGKey PRNGKey = jax.Array + +#: JAX Scalar types +Scalar = Union[float, int] +Numeric = Union[jax.Array, Scalar] diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index f806a375c..0791f3cb1 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -8,6 +8,90 @@ from blackjax.mcmc import metrics +class CovarianceFormattingTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = random.key(0) + self.dtype = "float32" + + @parameterized.named_parameters( + {"testcase_name": "0d", "shape": (), "is_inv": False}, + {"testcase_name": "0d_inv", "shape": (), "is_inv": True}, + {"testcase_name": "3d", "shape": (1, 2, 3), "is_inv": False}, + {"testcase_name": "3d_inv", "shape": (1, 2, 3), "is_inv": True}, + ) + def test_invalid(self, shape, is_inv): + """Test formatting raises error for invalid shapes""" + mass_matrix = jnp.zeros(shape=shape) + with self.assertRaisesRegex( + ValueError, "The mass matrix has the wrong number of dimensions" + ): + metrics._format_covariance(mass_matrix, is_inv) + + @parameterized.named_parameters( + {"testcase_name": "inv", "is_inv": True}, + {"testcase_name": "no_inv", "is_inv": False}, + ) + def test_dim_1(self, is_inv): + """Test formatting for 1D mass matrix""" + mass_matrix = jnp.asarray([1 / 4], dtype=self.dtype) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = metrics._format_covariance( + mass_matrix, is_inv + ) + if is_inv: + chex.assert_trees_all_close(inv_mass_matrix_sqrt, mass_matrix**0.5) + chex.assert_trees_all_close(mass_matrix_sqrt, mass_matrix**-0.5) + else: + chex.assert_trees_all_close(mass_matrix_sqrt, mass_matrix**0.5) + chex.assert_trees_all_close(inv_mass_matrix_sqrt, mass_matrix**-0.5) + + chex.assert_trees_all_close(diag(mass_matrix), mass_matrix) + + @parameterized.named_parameters( + {"testcase_name": "inv", "is_inv": True}, + {"testcase_name": "no_inv", "is_inv": False}, + ) + def test_dim_2(self, is_inv): + """Test formatting for 2D mass matrix""" + mass_matrix = jnp.asarray([[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = metrics._format_covariance( + mass_matrix, is_inv + ) + if is_inv: + chex.assert_trees_all_close( + mass_matrix_sqrt @ mass_matrix_sqrt.T, linalg.inv(mass_matrix) + ) + chex.assert_trees_all_close( + inv_mass_matrix_sqrt @ inv_mass_matrix_sqrt.T, mass_matrix + ) + + else: + chex.assert_trees_all_close( + mass_matrix_sqrt @ mass_matrix_sqrt.T, mass_matrix + ) + chex.assert_trees_all_close( + inv_mass_matrix_sqrt @ inv_mass_matrix_sqrt.T, linalg.inv(mass_matrix) + ) + + def test_dim2_inv_and_not_inv_agree(self): + mass_matrix = jnp.asarray([[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype) + mass_matrix_sqrt, inv_mass_matrix_sqrt, _ = metrics._format_covariance( + mass_matrix, False + ) + mass_matrix_sqrt_inv, inv_mass_matrix_sqrt_inv, _ = metrics._format_covariance( + linalg.inv(mass_matrix), True + ) + + chex.assert_trees_all_close( + mass_matrix_sqrt @ mass_matrix_sqrt.T, + mass_matrix_sqrt_inv @ mass_matrix_sqrt_inv.T, + ) + chex.assert_trees_all_close( + inv_mass_matrix_sqrt @ inv_mass_matrix_sqrt.T, + inv_mass_matrix_sqrt_inv @ inv_mass_matrix_sqrt_inv.T, + ) + + class GaussianEuclideanMetricsTest(chex.TestCase): def setUp(self): super().setUp() @@ -30,7 +114,9 @@ def test_gaussian_euclidean_ndim_invalid(self, shape): def test_gaussian_euclidean_dim_1(self): """Test Gaussian Euclidean Function with ndim 1""" inverse_mass_matrix = jnp.asarray([1 / 4], dtype=self.dtype) - momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix) + momentum, kinetic_energy, _, scale = metrics.gaussian_euclidean( + inverse_mass_matrix + ) arbitrary_position = jnp.asarray([12345], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) @@ -45,18 +131,30 @@ def test_gaussian_euclidean_dim_1(self): assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + + expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) + expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) + @chex.all_variants(with_pmap=False) def test_gaussian_euclidean_dim_2(self): """Test Gaussian Euclidean Function with ndim 2""" inverse_mass_matrix = jnp.asarray( - [[1 / 9, 0.5], [0.5, 1 / 4]], dtype=self.dtype + [[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype + ) + momentum, kinetic_energy, _, scale = metrics.gaussian_euclidean( + inverse_mass_matrix ) - momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix) arbitrary_position = jnp.asarray([12345, 23456], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) - L_inv = linalg.cholesky(linalg.inv(inverse_mass_matrix), lower=True) + L_inv = linalg.inv(linalg.cholesky(inverse_mass_matrix, lower=False)) + expected_momentum_val = L_inv @ random.normal(self.key, shape=(2,)) kinetic_energy_val = self.variant(kinetic_energy)(momentum_val) @@ -66,6 +164,15 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + + expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val + expected_scaled_momentum = L_inv @ momentum_val + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) + class GaussianRiemannianMetricsTest(chex.TestCase): def setUp(self): @@ -99,7 +206,9 @@ def test_gaussian_riemannian_value_errors(self, shape): def test_gaussian_riemannian_dim_1(self): inverse_mass_matrix = jnp.asarray([1 / 4], dtype=self.dtype) mass_matrix = jnp.asarray([4.0], dtype=self.dtype) - momentum, kinetic_energy, _ = metrics.gaussian_riemannian(lambda _: mass_matrix) + momentum, kinetic_energy, _, scale = metrics.gaussian_riemannian( + lambda _: mass_matrix + ) arbitrary_position = jnp.asarray([12345], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) @@ -114,16 +223,26 @@ def test_gaussian_riemannian_dim_1(self): expected_kinetic_energy_val = 0.5 * velocity * momentum_val expected_kinetic_energy_val += 0.5 * jnp.sum(jnp.log(2 * jnp.pi * mass_matrix)) - assert momentum_val == expected_momentum_val - assert kinetic_energy_val == expected_kinetic_energy_val + np.testing.assert_allclose(expected_momentum_val, momentum_val) + np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) + + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) + expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) @chex.all_variants(with_pmap=False) - def test_gaussian_euclidean_dim_2(self): + def test_gaussian_riemannian_dim_2(self): inverse_mass_matrix = jnp.asarray( - [[1 / 9, 0.5], [0.5, 1 / 4]], dtype=self.dtype + [[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype ) mass_matrix = jnp.linalg.inv(inverse_mass_matrix) - momentum, kinetic_energy, _ = metrics.gaussian_riemannian(lambda _: mass_matrix) + momentum, kinetic_energy, _, scale = metrics.gaussian_riemannian( + lambda _: mass_matrix + ) arbitrary_position = jnp.asarray([12345, 23456], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) @@ -131,6 +250,10 @@ def test_gaussian_euclidean_dim_2(self): L_inv = linalg.cholesky(linalg.inv(inverse_mass_matrix), lower=True) expected_momentum_val = L_inv @ random.normal(self.key, shape=(2,)) + sqrt_mass_matrix, inv_sqrt_mass_matrix, _ = metrics._format_covariance( + inverse_mass_matrix, True + ) + kinetic_energy_val = self.variant(kinetic_energy)( momentum_val, position=arbitrary_position ) @@ -142,6 +265,14 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val + expected_scaled_momentum = L_inv @ momentum_val + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) + if __name__ == "__main__": absltest.main() diff --git a/tests/mcmc/test_trajectory.py b/tests/mcmc/test_trajectory.py index c8a5aa908..e93280400 100644 --- a/tests/mcmc/test_trajectory.py +++ b/tests/mcmc/test_trajectory.py @@ -32,6 +32,7 @@ def test_dynamic_progressive_integration_divergence( momentum_generator, kinetic_energy_fn, uturn_check_fn, + _, ) = metrics.gaussian_euclidean(inverse_mass_matrix) integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn) @@ -83,6 +84,7 @@ def logdensity_fn(x): momentum_generator, kinetic_energy_fn, uturn_check_fn, + _, ) = metrics.gaussian_euclidean(inverse_mass_matrix) integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn) @@ -211,6 +213,7 @@ def logdensity_fn(x): momentum_generator, kinetic_energy_fn, uturn_check_fn, + _, ) = metrics.gaussian_euclidean(inverse_mass_matrix) integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn) @@ -266,7 +269,7 @@ def test_static_integration_variable_num_steps(self): ( momentum_generator, kinetic_energy_fn, - _, + *_, ) = metrics.gaussian_euclidean(inverse_mass_matrix) initial_state = integrators.new_integrator_state( logdensity_fn, position, momentum_generator(rng_key, position) diff --git a/tests/mcmc/test_uturn.py b/tests/mcmc/test_uturn.py index 3dc730565..7f9f597d6 100644 --- a/tests/mcmc/test_uturn.py +++ b/tests/mcmc/test_uturn.py @@ -20,7 +20,7 @@ class UTurnTest(chex.TestCase): ) def test_is_iterative_turning(self, checkpoint_idxs, expected_turning): inverse_mass_matrix = jnp.ones(1) - _, _, is_turning = gaussian_euclidean(inverse_mass_matrix) + _, _, is_turning, _ = gaussian_euclidean(inverse_mass_matrix) _, _, is_iterative_turning = iterative_uturn_numpyro(is_turning) momentum = 1.0