Skip to content

Commit

Permalink
SMC Pretuning (#765)
Browse files Browse the repository at this point in the history
* extracting taking last

* test passing

* layering

* example

* more

* Adding another example

* tests in place

* rolling back changes

* Adding test for num_mcmc_steps

* format

* better test coverage

* linter

* Flake8

* black

* implementation[

* partial posteriors implementation

* rolling back some changes

* linter

* fixing test

* adding reference

* typo

* exposing in top level api

* reruning precommit

* up to now

* one step working

* fixes

* tests passing

* checkpoint tests passing

* more

* tests passing, implementation in place

* tests passing

* rounding

* adding to init

* rollbacks

* rollback

* rollback

* docs

* precommit

* removing extra parameter

* code review updates
  • Loading branch information
ciguaran authored Jan 16, 2025
1 parent df87345 commit fc539ca
Show file tree
Hide file tree
Showing 7 changed files with 615 additions and 14 deletions.
3 changes: 2 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import partial_posteriors_path as _partial_posteriors_smc
from .smc import pretuning as _pretuning
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
Expand Down Expand Up @@ -124,7 +125,7 @@ def generate_top_level_api_from(module):
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)

pretuning = generate_top_level_api_from(_pretuning)
smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc]
"Step_fn returning state has a .particles attribute"

Expand Down
29 changes: 20 additions & 9 deletions blackjax/smc/from_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,23 @@
from blackjax.types import PRNGKey


def unshared_parameters_and_step_fn(mcmc_parameters, mcmc_step_fn):
"""Splits MCMC parameters into two dictionaries. The shared dictionary
represents the parameters common to all chains, and the unshared are
different per chain.
Binds the step fn using the shared parameters.
"""
shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v
shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)
return unshared_mcmc_parameters, shared_mcmc_step_fn


def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
Expand All @@ -34,15 +51,9 @@ def step(
logposterior_fn: Callable,
log_weights_fn: Callable,
) -> tuple[smc.base.SMCState, smc.base.SMCInfo]:
shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)
unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn(
mcmc_parameters, mcmc_step_fn
)

update_fn, num_resampled = update_strategy(
mcmc_init_fn,
Expand Down
Loading

0 comments on commit fc539ca

Please sign in to comment.