Skip to content

Commit

Permalink
Enable shared mcmc parameters with tempered smc (#694)
Browse files Browse the repository at this point in the history
* add parameter filtering

* fix parameter split + docstring

* change extend_paramss
  • Loading branch information
andrewdipper authored Jun 15, 2024
1 parent dd9ba03 commit 3353209
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 45 deletions.
3 changes: 2 additions & 1 deletion blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
target_ess
Expand Down
7 changes: 2 additions & 5 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,9 @@ def step(
)


def extend_params(n_particles, params):
def extend_params(params):
"""Given a dictionary of params, repeats them for every single particle. The expected
usage is in cases where the aim is to repeat the same parameters for all chains within SMC.
"""

def extend(param):
return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0)

return jax.tree.map(extend, params)
return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params)
21 changes: 18 additions & 3 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple

import jax
Expand Down Expand Up @@ -108,6 +109,9 @@ def kernel(
Current state of the tempered SMC algorithm
lmbda
Current value of the tempering parameter
mcmc_parameters
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
Returns
-------
Expand All @@ -119,6 +123,14 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)

Expand All @@ -127,11 +139,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info
Expand All @@ -142,7 +156,7 @@ def body_fn(state, rng_key):

smc_state, info = smc.base.step(
rng_key,
SMCState(state.particles, state.weights, mcmc_parameters),
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
jax.vmap(mcmc_kernel),
jax.vmap(log_weights_fn),
resampling_fn,
Expand Down Expand Up @@ -178,7 +192,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
num_mcmc_steps
Expand Down
10 changes: 3 additions & 7 deletions tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def smc_inner_kernel_tuning_test_case(
proposal_factory.return_value = 100

def mcmc_parameter_update_fn(state, info):
return extend_params(1000, {"mean": 100})
return extend_params({"mean": 100})

prior = lambda x: stats.norm.logpdf(x)

Expand All @@ -114,7 +114,7 @@ def wrapped_kernel(rng_key, state, logdensity, mean):
resampling_fn=resampling.systematic,
smc_algorithm=smc_algorithm,
mcmc_parameter_update_fn=mcmc_parameter_update_fn,
initial_parameter_value=extend_params(1000, {"mean": 1.0}),
initial_parameter_value=extend_params({"mean": 1.0}),
**smc_parameters,
)

Expand Down Expand Up @@ -281,7 +281,6 @@ def test_with_adaptive_tempered(self):

def parameter_update(state, info):
return extend_params(
100,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -298,7 +297,6 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down Expand Up @@ -326,7 +324,7 @@ def body(carry):

_, state = inference_loop(smc_kernel, self.key, init_state)

assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2)
assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2)
self.assert_linear_regression_test_case(state.sampler_state)

@chex.all_variants(with_pmap=False)
Expand All @@ -340,7 +338,6 @@ def test_with_tempered_smc(self):

def parameter_update(state, info):
return extend_params(
100,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -357,7 +354,6 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down
10 changes: 4 additions & 6 deletions tests/smc/test_kernel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def kernel(rng_key, state, logdensity_fn, proposal_mean):
self.check_compatible(
kernel,
blackjax.additive_step_random_walk.init,
extend_params(self.n_particles, {"proposal_mean": 1.0}),
extend_params({"proposal_mean": 1.0}),
)

def test_compatible_with_rmh(self):
Expand All @@ -70,15 +70,14 @@ def kernel(
self.check_compatible(
kernel,
blackjax.rmh.init,
extend_params(self.n_particles, {"proposal_mean": 1.0}),
extend_params({"proposal_mean": 1.0}),
)

def test_compatible_with_hmc(self):
self.check_compatible(
blackjax.hmc.build_kernel(),
blackjax.hmc.init,
extend_params(
self.n_particles,
{
"step_size": 0.3,
"inverse_mass_matrix": jnp.array([1.0]),
Expand All @@ -100,15 +99,14 @@ def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None):
self.check_compatible(
kernel,
blackjax.irmh.init,
extend_params(self.n_particles, {"mean": jnp.array([1.0, 1.0])}),
extend_params({"mean": jnp.array([1.0, 1.0])}),
)

def test_compatible_with_nuts(self):
self.check_compatible(
blackjax.nuts.build_kernel(),
blackjax.nuts.init,
extend_params(
self.n_particles,
{"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)},
),
)
Expand All @@ -117,7 +115,7 @@ def test_compatible_with_mala(self):
self.check_compatible(
blackjax.mala.build_kernel(),
blackjax.mala.init,
extend_params(self.n_particles, {"step_size": 1e-10}),
extend_params({"step_size": 1e-10}),
)

@staticmethod
Expand Down
31 changes: 14 additions & 17 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ def body_fn(state, rng_key):
same_for_all_params = dict(
step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50
)

state = init(
init_particles,
extend_params(num_particles, same_for_all_params),
same_for_all_params,
)

# Run the SMC sampler once
new_state, info = self.variant(step, static_argnums=(2, 3, 4))(
sample_key,
state,
jax.vmap(update_fn),
jax.vmap(update_fn, in_axes=(0, 0, None)),
jax.vmap(logdensity_fn),
resampling.systematic,
)
Expand Down Expand Up @@ -87,7 +88,9 @@ def body_fn(state, rng_key):
_, (states, info) = jax.lax.scan(body_fn, state, keys)
return states.position, info

particles, info = jax.vmap(one_particle_fn)(keys, particles, update_params)
particles, info = jax.vmap(one_particle_fn, in_axes=(0, 0, None))(
keys, particles, update_params
)
particles = particles.reshape((num_particles,))
return particles, info

Expand All @@ -97,13 +100,10 @@ def body_fn(state, rng_key):
init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,))
state = init(
init_particles,
extend_params(
num_resampled,
dict(
step_size=1e-2,
inverse_mass_matrix=jnp.eye(1),
num_integration_steps=100,
),
dict(
step_size=1e-2,
inverse_mass_matrix=jnp.eye(1),
num_integration_steps=100,
),
)

Expand All @@ -125,22 +125,19 @@ def body_fn(state, rng_key):
class ExtendParamsTest(chex.TestCase):
def test_extend_params(self):
extended = extend_params(
3,
{
"a": 50,
"b": np.array([50]),
"c": np.array([50, 60]),
"d": np.array([[1, 2], [3, 4]]),
},
)
np.testing.assert_allclose(extended["a"], np.ones((3,)) * 50)
np.testing.assert_allclose(extended["b"], np.array([[50], [50], [50]]))
np.testing.assert_allclose(
extended["c"], np.array([[50, 60], [50, 60], [50, 60]])
)
np.testing.assert_allclose(extended["a"], np.ones((1,)) * 50)
np.testing.assert_allclose(extended["b"], np.array([[50]]))
np.testing.assert_allclose(extended["c"], np.array([[50, 60]]))
np.testing.assert_allclose(
extended["d"],
np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]),
np.array([[[1, 2], [3, 4]]]),
)


Expand Down
22 changes: 16 additions & 6 deletions tests/smc/test_tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,28 @@ def logprior_fn(x):

hmc_kernel = blackjax.hmc.build_kernel()
hmc_init = blackjax.hmc.init
hmc_parameters = extend_params(
num_particles,

base_params = extend_params(
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
"num_integration_steps": 50,
},
}
)

for target_ess in [0.5, 0.75]:
# verify results are equivalent with all shared, all unshared, and mixed params
hmc_parameters_list = [
base_params,
jax.tree.map(lambda x: jnp.repeat(x, num_particles, axis=0), base_params),
jax.tree_util.tree_map_with_path(
lambda path, x: jnp.repeat(x, num_particles, axis=0)
if path[0].key == "step_size"
else x,
base_params,
),
]

for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list):
tempering = adaptive_tempered_smc(
logprior_fn,
loglikelihood_fn,
Expand Down Expand Up @@ -115,7 +127,6 @@ def test_fixed_schedule_tempered_smc(self):
hmc_init = blackjax.hmc.init
hmc_kernel = blackjax.hmc.build_kernel()
hmc_parameters = extend_params(
100,
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
Expand Down Expand Up @@ -182,7 +193,6 @@ def test_normalizing_constant(self):
hmc_init = blackjax.hmc.init
hmc_kernel = blackjax.hmc.build_kernel()
hmc_parameters = extend_params(
num_particles,
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(num_dim),
Expand Down

0 comments on commit 3353209

Please sign in to comment.