Skip to content

Commit

Permalink
Adding test for num_mcmc_steps
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Aug 15, 2024
1 parent 3576690 commit 5c34177
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion tests/smc/test_waste_free_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from absl.testing import absltest

import blackjax
import blackjax.smc.resampling as resampling
from blackjax import adaptive_tempered_smc, tempered_smc
from blackjax.smc import extend_params
from blackjax.smc.waste_free import waste_free_smc
from blackjax.smc.waste_free import waste_free_smc, update_waste_free
from tests.smc import SMCLinearRegressionTestCase
from tests.smc.test_tempered_smc import inference_loop

Expand Down Expand Up @@ -104,5 +105,16 @@ def test_adaptive_tempered_smc(self):
self.assert_linear_regression_test_case(result)


def test_waste_free_set_num_mcmc_steps():
with pytest.raises(ValueError) as exc_info:
update_waste_free(lambda x:x,
lambda x:1,
lambda x:1,
100,
10,
3,
num_mcmc_steps=50)
assert str(exc_info.value).startswith("Can't use waste free SMC with a num_mcmc_steps parameter")

if __name__ == "__main__":
absltest.main()

0 comments on commit 5c34177

Please sign in to comment.