diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6a0de3809..81f8ebd2e 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -38,6 +38,7 @@ from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning from .smc import partial_posteriors_path as _partial_posteriors_smc +from .smc import pretuning as _pretuning from .smc import tempered from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder @@ -124,7 +125,7 @@ def generate_top_level_api_from(module): tempered_smc = generate_top_level_api_from(tempered) inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc) - +pretuning = generate_top_level_api_from(_pretuning) smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc] "Step_fn returning state has a .particles attribute" diff --git a/blackjax/smc/from_mcmc.py b/blackjax/smc/from_mcmc.py index 0e60b5968..75e5c34a6 100644 --- a/blackjax/smc/from_mcmc.py +++ b/blackjax/smc/from_mcmc.py @@ -8,6 +8,23 @@ from blackjax.types import PRNGKey +def unshared_parameters_and_step_fn(mcmc_parameters, mcmc_step_fn): + """Splits MCMC parameters into two dictionaries. The shared dictionary + represents the parameters common to all chains, and the unshared are + different per chain. + Binds the step fn using the shared parameters. + """ + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + return unshared_mcmc_parameters, shared_mcmc_step_fn + + def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, @@ -34,15 +51,9 @@ def step( logposterior_fn: Callable, log_weights_fn: Callable, ) -> tuple[smc.base.SMCState, smc.base.SMCInfo]: - shared_mcmc_parameters = {} - unshared_mcmc_parameters = {} - for k, v in mcmc_parameters.items(): - if v.shape[0] == 1: - shared_mcmc_parameters[k] = v[0, ...] - else: - unshared_mcmc_parameters[k] = v - - shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn( + mcmc_parameters, mcmc_step_fn + ) update_fn, num_resampled = update_strategy( mcmc_init_fn, diff --git a/blackjax/smc/pretuning.py b/blackjax/smc/pretuning.py new file mode 100644 index 000000000..f489a0dc2 --- /dev/null +++ b/blackjax/smc/pretuning.py @@ -0,0 +1,346 @@ +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp +import jax.random +from jax._src.flatten_util import ravel_pytree + +from blackjax import SamplingAlgorithm, smc +from blackjax.smc.base import SMCInfo, update_and_take_last +from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc +from blackjax.smc.from_mcmc import unshared_parameters_and_step_fn +from blackjax.smc.inner_kernel_tuning import StateWithParameterOverride +from blackjax.smc.resampling import stratified +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_gaussian_noise + + +class SMCInfoWithParameterDistribution(NamedTuple): + """Stores both the sampling status and also a dictionary + with parameter names as keys and (n_particles, *) arrays as values. + The latter represents a parameter per chain for the next mutation step. + """ + + smc_info: SMCInfo + parameter_override: Dict[str, ArrayTree] + + +def esjd(m): + """Implements ESJD (expected squared jumping distance). Inner Mahalanobis distance + is computed using the Cholesky decomposition of M=LLt, and then inverting L. + Whenever M is symmetrical definite positive then it must exist a Cholesky Decomposition. + For example, if M is the Covariance Matrix of Metropolis-Hastings or + the Inverse Mass Matrix of Hamiltonian Monte Carlo. + """ + L = jnp.linalg.cholesky(m) + + def measure(previous_position, next_position, acceptance_probability): + difference = ravel_pytree(previous_position)[0] - ravel_pytree(next_position)[0] + difference_by_matrix = jnp.matmul(L, difference) + norm = jnp.linalg.norm(difference_by_matrix, 2) + return acceptance_probability * jnp.power(norm, 2) + + return jax.vmap(measure) + + +def update_parameter_distribution( + key: PRNGKey, + previous_param_samples: ArrayLikeTree, + previous_particles: ArrayLikeTree, + latest_particles: ArrayLikeTree, + measure_of_chain_mixing: Callable, + alpha: float, + sigma_parameters: ArrayLikeTree, + acceptance_probability: Array, +): + """Given an existing parameter distribution that was used to mutate previous_particles + into latest_particles, updates that parameter distribution by resampling from previous_param_samples after adding + noise to those samples. The weights used are a linear function of the measure of chain mixing. + Only works with float parameters, not integers. + See Equation 4 in https://arxiv.org/pdf/1005.1193.pdf + + Parameters + ---------- + previous_param_samples: + samples of the parameters of SMC inner MCMC chains. To be updated. + previous_particles: + particles from which the kernel step started + latest_particles: + particles after the step was performed + measure_of_chain_mixing: Callable + a callable that can compute a performance measure per chain + alpha: + a scalar to add to the weighting. See paper for details + sigma_parameters: + noise to add to the population of parameters to mutate them. must have the same shape of + previous_param_samples. + acceptance_probability: + the energy difference for each of the chains when taking a step from previous_particles + into latest_particles. + """ + noise_key, resampling_key = jax.random.split(key, 2) + + noises = jax.tree.map( + lambda x, s: generate_gaussian_noise(noise_key, x.astype("float32"), sigma=s), + previous_param_samples, + sigma_parameters, + ) + new_samples = jax.tree.map(lambda x, y: x + y, noises, previous_param_samples) + + chain_mixing_measurement = measure_of_chain_mixing( + previous_particles, latest_particles, acceptance_probability + ) + weights = alpha + chain_mixing_measurement + weights = weights / jnp.sum(weights) + resampling_idx = stratified(resampling_key, weights, len(chain_mixing_measurement)) + return ( + jax.tree.map(lambda x: x[resampling_idx], new_samples), + chain_mixing_measurement, + ) + + +def build_pretune( + mcmc_init_fn: Callable, + mcmc_step_fn: Callable, + alpha: float, + sigma_parameters: ArrayLikeTree, + n_particles: int, + performance_of_chain_measure_factory: Callable = lambda state: esjd( + state.parameter_override["inverse_mass_matrix"] + ), + natural_parameters: Optional[List[str]] = None, + positive_parameters: Optional[List[str]] = None, +): + """Implements Buchholz et al https://arxiv.org/pdf/1808.07730 pretuning procedure. + The goal is to maintain a probability distribution of parameters, in order + to assign different values to each inner MCMC chain. + To have performant parameters for the distribution at step t, it takes a single step, measures + the chain mixing, and reweights the probability distribution of parameters accordingly. + Note that although similar, this strategy is different than inner_kernel_tuning. The latter updates + the parameters based on the particles and transition information after the SMC step is executed. This + implementation runs a single MCMC step which gets discarded, to then proceed with the SMC step execution. + """ + if natural_parameters is None: + round_to_integer_fn = lambda x: x + else: + + def round_to_integer_fn(x): + for k in natural_parameters: + x[k] = jax.tree.map(lambda a: jnp.abs(jnp.round(a).astype(int)), x[k]) + return x + + if positive_parameters is None: + make_positive_fn = lambda x: x + else: + + def make_positive_fn(x): + for k in positive_parameters: + x[k] = jax.tree.map(jnp.abs, x[k]) + return x + + def pretune(key, state, logposterior): + unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn( + state.parameter_override, mcmc_step_fn + ) + + one_step_fn, _ = update_and_take_last( + mcmc_init_fn, logposterior, shared_mcmc_step_fn, 1, n_particles + ) + + new_state, info = one_step_fn( + jax.random.split(key, n_particles), + state.sampler_state.particles, + unshared_mcmc_parameters, + ) + + performance_of_chain_measure = performance_of_chain_measure_factory(state) + + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + key, + previous_param_samples={ + key: state.parameter_override[key] for key in sigma_parameters + }, + previous_particles=state.sampler_state.particles, + latest_particles=new_state, + measure_of_chain_mixing=performance_of_chain_measure, + alpha=alpha, + sigma_parameters=sigma_parameters, + acceptance_probability=info.acceptance_rate, + ) + + return ( + make_positive_fn(round_to_integer_fn(new_parameter_distribution)), + chain_mixing_measurement, + ) + + def pretune_and_update(key, state: StateWithParameterOverride, logposterior): + """ + Updates the parameters that need to be pretuned and returns the rest. + """ + new_parameter_distribution, chain_mixing_measurement = pretune( + key, state, logposterior + ) + old_parameter_distribution = state.parameter_override + updated_parameter_distribution = old_parameter_distribution + for k in new_parameter_distribution: + updated_parameter_distribution[k] = new_parameter_distribution[k] + + return updated_parameter_distribution + + return pretune_and_update + + +def build_kernel( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + pretune_fn: Callable, + num_mcmc_steps: int = 10, + update_strategy=update_and_take_last, + **extra_parameters, +) -> Callable: + """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner + MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, + based on particles. The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) + mcmc_init_fn + A callable that initializes the inner kernel + pretune_fn: + A callable that can update the probability distribution of parameters. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + """ + delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) + + def pretuned_step( + rng_key: PRNGKey, + state, + num_mcmc_steps: int, + mcmc_parameters: dict, + logposterior_fn: Callable, + log_weights_fn: Callable, + ) -> tuple[smc.base.SMCState, SMCInfoWithParameterDistribution]: + """Wraps the output of smc.from_mcmc.build_kernel into a pretuning + step method. + This one should be a subtype of the former, in the sense that a usage of the former + can be replaced with an instance of this one. + """ + + pretune_key, step_key = jax.random.split(rng_key, 2) + pretuned_parameters = pretune_fn( + pretune_key, + StateWithParameterOverride(state, mcmc_parameters), + logposterior_fn, + ) + state, info = delegate( + rng_key, + state, + num_mcmc_steps, + pretuned_parameters, + logposterior_fn, + log_weights_fn, + ) + return state, SMCInfoWithParameterDistribution(info, pretuned_parameters) + + def kernel( + rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + extra_parameters["update_particles_fn"] = pretuned_step + step_fn = smc_algorithm( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + mcmc_step_fn=mcmc_step_fn, + mcmc_init_fn=mcmc_init_fn, + mcmc_parameters=state.parameter_override, + resampling_fn=resampling_fn, + num_mcmc_steps=num_mcmc_steps, + **extra_parameters, + ).step + new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) + return ( + StateWithParameterOverride(new_state, info.parameter_override), + info.smc_info, + ) + + return kernel + + +def init(alg_init_fn, position, initial_parameter_value): + return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value) + + +def as_top_level_api( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + num_mcmc_steps: int, + initial_parameter_value: ArrayLikeTree, + pretune_fn: Callable, + **extra_parameters, +): + """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner + MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, + based on particles. The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) + mcmc_init_fn + A callable that initializes the inner kernel + pretune_fn: + A callable that can update the probability distribution of parameters. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + """ + + kernel = build_kernel( + smc_algorithm, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + pretune_fn, + num_mcmc_steps, + **extra_parameters, + ) + + def init_fn(position, rng_key=None): + del rng_key + return init(smc_algorithm.init, position, initial_parameter_value) + + def step_fn( + rng_key: PRNGKey, state, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + return kernel(rng_key, state, **extra_step_parameters) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 88539deaa..350037f9c 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -55,6 +55,7 @@ def build_kernel( mcmc_init_fn: Callable, resampling_fn: Callable, update_strategy: Callable = update_and_take_last, + update_particles_fn: Optional[Callable] = None, ) -> Callable: """Build the base Tempered SMC kernel. @@ -92,8 +93,12 @@ def build_kernel( information about the transition. """ - delegate = smc_from_mcmc.build_kernel( - mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + update_particles = ( + smc_from_mcmc.build_kernel( + mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + ) + if update_particles_fn is None + else update_particles_fn ) def kernel( @@ -135,7 +140,7 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - smc_state, info = delegate( + smc_state, info = update_particles( rng_key, state, num_mcmc_steps, @@ -162,6 +167,7 @@ def as_top_level_api( resampling_fn: Callable, num_mcmc_steps: Optional[int] = 10, update_strategy=update_and_take_last, + update_particles_fn=None, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. @@ -196,6 +202,7 @@ def as_top_level_api( mcmc_init_fn, resampling_fn, update_strategy, + update_particles_fn, ) def init_fn(position: ArrayLikeTree, rng_key=None): diff --git a/blackjax/smc/tuning/from_kernel_info.py b/blackjax/smc/tuning/from_kernel_info.py index a039e66c1..fa2c7054c 100644 --- a/blackjax/smc/tuning/from_kernel_info.py +++ b/blackjax/smc/tuning/from_kernel_info.py @@ -1,4 +1,5 @@ """ +static (all particles get the same value) strategies to tune the parameters of mcmc kernels used within smc, based on MCMC states """ diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 4c8ca98da..279a718cb 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -1,5 +1,5 @@ """ -strategies to tune the parameters of mcmc kernels +static (all particles get the same value) strategies to tune the parameters of mcmc kernels used within SMC, based on particles. """ import jax diff --git a/tests/smc/test_pretuning.py b/tests/smc/test_pretuning.py new file mode 100644 index 000000000..a677c99ae --- /dev/null +++ b/tests/smc/test_pretuning.py @@ -0,0 +1,235 @@ +import unittest + +import chex +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest + +import blackjax +from blackjax.smc import extend_params, resampling +from blackjax.smc.pretuning import ( + build_pretune, + esjd, + init, + update_parameter_distribution, +) +from tests.smc import SMCLinearRegressionTestCase + + +class TestMeasureOfChainMixing(unittest.TestCase): + previous_position = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + + next_position = np.array([jnp.array([20.0, 30.0]), jnp.array([9.0, 12.0])]) + + def test_measure_of_chain_mixing_identity(self): + """ + Given identity matrix and 1. acceptance probability + then the mixing is the square of norm 2. + """ + m = np.eye(2) + + acceptance_probabilities = np.array([1.0, 1.0]) + chain_mixing = esjd(m)( + self.previous_position, self.next_position, acceptance_probabilities + ) + np.testing.assert_allclose(chain_mixing[0], 325) + np.testing.assert_allclose(chain_mixing[1], 100) + + def test_measure_of_chain_mixing_with_non_1_acceptance_rate(self): + """ + Given identity matrix + then the mixing is the square of norm 2. multiplied by the acceptance rate + """ + m = np.eye(2) + + acceptance_probabilities = np.array([0.5, 0.2]) + chain_mixing = esjd(m)( + self.previous_position, self.next_position, acceptance_probabilities + ) + np.testing.assert_allclose(chain_mixing[0], 162.5) + np.testing.assert_allclose(chain_mixing[1], 20) + + def test_measure_of_chain_mixing(self): + m = np.array([[3, 0], [0, 5]]) + + previous_position = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + + next_position = np.array([jnp.array([20.0, 30.0]), jnp.array([9.0, 12.0])]) + + acceptance_probabilities = np.array([1.0, 1.0]) + + chain_mixing = esjd(m)( + previous_position, next_position, acceptance_probabilities + ) + + assert chain_mixing.shape == (2,) + np.testing.assert_allclose(chain_mixing[0], 10 * 10 * 3 + 15 * 15 * 5) + np.testing.assert_allclose(chain_mixing[1], 6 * 6 * 3 + 8 * 8 * 5) + + +class TestUpdateParameterDistribution(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + self.previous_position = np.array( + [jnp.array([10.0, 15.0]), jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])] + ) + self.next_position = np.array( + [jnp.array([20.0, 30.0]), jnp.array([10.0, 15.0]), jnp.array([9.0, 12.0])] + ) + + def test_update_param_distribution(self): + """ + Given an extremely good mixing on one chain, + and that the alpha parameter is 0, then the parameters + of that chain with a slight mutation due to noise are reused. + """ + + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + self.key, + jnp.array([1.0, 2.0, 3.0]), + self.previous_position, + self.next_position, + measure_of_chain_mixing=lambda x, y, z: jnp.array([1.0, 0.0, 0.0]), + alpha=0, + sigma_parameters=0.0001, + acceptance_probability=None, + ) + + np.testing.assert_allclose( + new_parameter_distribution, + np.array([1, 1, 1], dtype="float32"), + rtol=1e-3, + ) + np.testing.assert_allclose( + chain_mixing_measurement, + np.array([1, 0, 0], dtype="float32"), + rtol=1e-6, + ) + + def test_update_multi_sigmas(self): + """ + When we have multiple parameters, the performance is attached to its combination + so sampling must work accordingly. + """ + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + self.key, + { + "param_a": jnp.array([1.0, 2.0, 3.0]), + "param_b": jnp.array([[5.0, 6.0], [6.0, 7.0], [4.0, 5.0]]), + }, + self.previous_position, + self.next_position, + measure_of_chain_mixing=lambda x, y, z: jnp.array([1.0, 0.0, 0.0]), + alpha=0, + sigma_parameters={"param_a": 0.0001, "param_b": 0.00001}, + acceptance_probability=None, + ) + print(chain_mixing_measurement) + np.testing.assert_allclose(chain_mixing_measurement, np.array([1.0, 0, 0])) + + np.testing.assert_allclose( + new_parameter_distribution["param_a"], jnp.array([1.0, 1.0, 1.0]), atol=0.1 + ) + np.testing.assert_allclose( + new_parameter_distribution["param_b"], + jnp.array([[5.0, 6.0], [5.0, 6.0], [5.0, 6.0]]), + atol=0.1, + ) + + +class PretuningSMCTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.variants(with_jit=True) + def test_linear_regression(self): + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + num_particles = 100 + sampling_key, step_size_key, integration_steps_key = jax.random.split( + self.key, 3 + ) + integration_steps_distribution = jnp.round( + jax.random.uniform( + integration_steps_key, (num_particles,), minval=1, maxval=100 + ) + ).astype(int) + + step_sizes_distribution = jax.random.uniform( + step_size_key, (num_particles,), minval=0, maxval=0.1 + ) + + # Fixes inverse_mass_matrix and distribution for the other two parameters. + initial_parameters = dict( + inverse_mass_matrix=extend_params(jnp.eye(2)), + step_size=step_sizes_distribution, + num_integration_steps=integration_steps_distribution, + ) + assert initial_parameters["step_size"].shape == (num_particles,) + assert initial_parameters["num_integration_steps"].shape == (num_particles,) + + pretune = build_pretune( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + alpha=1, + n_particles=num_particles, + sigma_parameters={"step_size": 0.01, "num_integration_steps": 2}, + natural_parameters=["num_integration_steps"], + positive_parameters=["step_size"], + ) + + step = blackjax.smc.pretuning.build_kernel( + blackjax.tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + ) + + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + smc_kernel = self.variant(step) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (i + 1, new_state), (new_state, info) + + num_tempering_steps = 10 + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + self.assert_linear_regression_test_case(result.sampler_state) + assert set(result.parameter_override.keys()) == { + "step_size", + "num_integration_steps", + "inverse_mass_matrix", + } + assert result.parameter_override["step_size"].shape == (num_particles,) + assert result.parameter_override["num_integration_steps"].shape == ( + num_particles, + ) + assert all(result.parameter_override["step_size"] > 0) + assert all(result.parameter_override["num_integration_steps"] > 0) + + +if __name__ == "__main__": + absltest.main()