Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adaptation on HMC trajectory length to increase the expected Change in the Estimator of the Expected Square (ChEES-HMC) #421

Merged
merged 3 commits into from
Nov 3, 2023

Conversation

albcab
Copy link
Member

@albcab albcab commented Dec 11, 2022

Closes #382

@codecov
Copy link

codecov bot commented Dec 11, 2022

Codecov Report

Merging #421 (56eb275) into main (d056670) will decrease coverage by 0.01%.
The diff coverage is 99.04%.

@@            Coverage Diff             @@
##             main     #421      +/-   ##
==========================================
- Coverage   99.14%   99.14%   -0.01%     
==========================================
  Files          49       50       +1     
  Lines        2117     2218     +101     
==========================================
+ Hits         2099     2199     +100     
- Misses         18       19       +1     
Files Coverage Δ
blackjax/__init__.py 100.00% <100.00%> (ø)
blackjax/adaptation/__init__.py 100.00% <100.00%> (ø)
blackjax/adaptation/chees_adaptation.py 100.00% <100.00%> (ø)
blackjax/mcmc/hmc.py 98.92% <90.90%> (-1.08%) ⬇️

@rlouf
Copy link
Member

rlouf commented Dec 11, 2022

Isn't it called ChEES? 😁

@albcab
Copy link
Member Author

albcab commented Dec 11, 2022

Oh no 🤦

@albcab albcab changed the title Add adaptation on HMC trajectory length to increase the expected Change in the Estimator of the Expected Square (ChESS-HMC) Add adaptation on HMC trajectory length to increase the expected Change in the Estimator of the Expected Square (ChEES-HMC) Dec 11, 2022
@rlouf rlouf mentioned this pull request Dec 20, 2022
12 tasks
@junpenglao
Copy link
Member

@albcab time to pick this up again (no pressure lol)?

@twiecki
Copy link

twiecki commented Oct 23, 2023

What's missing here? Would be great to get this sampler in!

@albcab
Copy link
Member Author

albcab commented Oct 23, 2023

Need to rebase and resolve conflicts for all the remaining, restructuring, etc. that have been done during the past year. Other than that, the algorithm is done.

@twiecki
Copy link

twiecki commented Oct 30, 2023

@albcab Great, let us know if you need any help!

@twiecki
Copy link

twiecki commented Oct 30, 2023

Nice, seems like conflicts are solved, just some mypy errors:

blackjax/adaptation/chees_adaptation.py:257: error: Incompatible default for argument "jitter_seed" (default has type "None", argument has type "int")  [assignment]
blackjax/adaptation/chees_adaptation.py:257: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
blackjax/adaptation/chees_adaptation.py:257: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
tests/smc/test_tempered_smc.py:69: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sending out first round of review, I will need a bit more time on chees_adaptation.py

blackjax/mcmc/hmc.py Outdated Show resolved Hide resolved
blackjax/mcmc/trajectory.py Outdated Show resolved Hide resolved
blackjax/mcmc/hmc.py Outdated Show resolved Hide resolved
@junpenglao
Copy link
Member

blackjax/adaptation/chees_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/chees_adaptation.py Show resolved Hide resolved
blackjax/adaptation/chees_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/chees_adaptation.py Show resolved Hide resolved
blackjax/adaptation/chees_adaptation.py Outdated Show resolved Hide resolved
@albcab
Copy link
Member Author

albcab commented Nov 1, 2023

Could you add a test to check that it learns a good step size and trajectory_length_moving_average? Similar to https://github.com/tensorflow/probability/blob/bc2e0e0d0969d2b5e7cb0a8fbaaf93ba1e0adfb2/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py#L334

I don't get it; how do we know that an avg step size of 1.5 and an avg number of steps of 15 are the "correct" amounts? Also, they target an acceptance rate of .75 while we do .651, which would change the target step size and number of steps.

@albcab albcab requested a review from junpenglao November 1, 2023 12:11
@junpenglao
Copy link
Member

I don't get it; how do we know that an avg step size of 1.5 and an avg number of steps of 15 are the "correct" amounts? Also, they target an acceptance rate of .75 while we do .651, which would change the target step size and number of steps.

I think those values comes from empirical, we should get similar number with the similar log_density set up. Could you turn .651 into an kwarg?

@albcab
Copy link
Member Author

albcab commented Nov 2, 2023

I think those values comes from empirical, we should get similar number with the similar log_density set up. Could you turn .651 into an kwarg?

I've added the test, although the number of integration steps depends on the optimizer/learning rate used.

junpenglao
junpenglao previously approved these changes Nov 2, 2023
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job, thanks for working on this!

@junpenglao
Copy link
Member

I've added the test, although the number of integration steps depends on the optimizer/learning rate used.

Interesting. Makes sense.

@albcab albcab enabled auto-merge (squash) November 2, 2023 20:24
@albcab albcab requested review from junpenglao and removed request for twiecki November 2, 2023 20:25
log_step_size_moving_average
Running moving average of the log step_size parameter.
trajectory_length
Value of the num_integration_steps / step_size parameter of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be

Suggested change
Value of the num_integration_steps / step_size parameter of
Value of the num_integration_steps * step_size parameter of

right?

@albcab albcab merged commit dbdcfd2 into blackjax-devs:main Nov 3, 2023
5 of 7 checks passed
@twiecki
Copy link

twiecki commented Nov 3, 2023

Amazing! 🥳

@albcab albcab deleted the chess branch November 3, 2023 10:32
junpenglao added a commit that referenced this pull request Mar 12, 2024
…ge in the Estimator of the Expected Square (ChEES-HMC) (#421)

* ChEES-HMC

* geometric weighted moving average and tests for convergence of step size and number of steps

* Update blackjax/adaptation/chees_adaptation.py

Co-authored-by: Junpeng Lao <[email protected]>

---------

Co-authored-by: Junpeng Lao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement ChEES-HMC
4 participants