Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathfinder #157

Closed
miclegr opened this issue Jan 17, 2022 · 6 comments · Fixed by #194
Closed

Pathfinder #157

miclegr opened this issue Jan 17, 2022 · 6 comments · Fixed by #194
Labels
sampler Issue related to samplers

Comments

@miclegr
Copy link
Contributor

miclegr commented Jan 17, 2022

Hi,
I built a JAX implementation of pathfinder here; since it's in the list of #154 I can try to merge it.
Any requirement for the API? Shall it fit the

new_state, info =  kernel(rng_key, state)

pattern, even if it's not a MCMC kernel?

Michele

@junpenglao
Copy link
Member

Hi Michele!
I think in the long term it is better to create a new API approx, which need a bit more discussion about the high level api. For now, if you could fit it into kernel call pattern, feel free to go ahead and code it as a standalone function (some examples of using it for initializing HMC for example will be great)

@junpenglao
Copy link
Member

junpenglao commented Jan 17, 2022

@rlouf we talked about approximation with API similar to optimjax, but in this case the approximation is a bit different (simiar to jax.scipy.optimize.minimize that terminate when converge).

@rlouf rlouf added the sampler Issue related to samplers label Jan 17, 2022
@rlouf
Copy link
Member

rlouf commented Jan 17, 2022

Hi!

@miclegr Yes as Junpeng said you can try to fit the kernel pattern as much as you can. If you have to deviate from it you should aim to make the API the least surprising for someone who already knows blackjax. It would be great if you could have a first look at the algorithm and start a discussion on the user-facing design here before writing too much code.

@junpenglao Stochastic gradient MCMC is getting a slightly different API as well, so that wouldn't be a problem. I need to read the paper first to have an opinion.

@sethaxen
Copy link

Pathfinder is a variational method. Is it intended for BlackJAX to contain VI implementations at some point?

@junpenglao
Copy link
Member

We had put off general VI implementation because it requires a bit more design decision.

In general we need:

target_density: callable = pytree -> scalar tensor
approx_density: callable = pytree -> scalar tensor
approx_sample: callable = PRNGKey -> pytree
KL_divergence: callable = pytree -> scalar tensor

which are much easier with a PPL.

I think for "standard" approximation that are multivariate Gaussian (meanfield and full rank ADVI, Pathfinder, Laplace approximation), we can start off with stand alone implementation and do abstraction later on.
Also, even for "standard" approximation listed above, Pathfinder is still a bit different as it use a minimizer, which means user dont need to worry about number of approximation step etc.

@miclegr miclegr mentioned this issue Apr 3, 2022
@miclegr
Copy link
Contributor Author

miclegr commented Apr 3, 2022

Hi guys,
sorry it took this long 😄 got busy with interviewing then took some time off.

I've opened the pull request now, an introduction to the method and its implementation in blackjax here.

Some design decision/open points:

  • basically pathfinder works by running L-BGFS optimization and storing the steps of the optimization path. Then for each of those steps the "implied" ELBO to the target distribution is estimated. Finally samples are drawn from the multi-normal distribution with the highest ELBO. In order to fit this schema into the init/step interface I have put the L-BFGS run and the estimation of the "implied" ELBOs in the init phase, and the sampling from the multi-normal distribution with highest ELBO into the step phase. This unnaturally makes all the computational heavy processing in the init, while the step function is just a draw of a single sample from a multi-normal distribution. The init function is jittable, and an helper function for making multiple samples from the multi-normal is available as well.

  • Unfortunately the L-BFGS optimizer is not working well when optimizing model's (negative) log-lilekihood while working in jax's default float32 mode, usually convergence fails. I've noticed that by simply dividing model's likelihood by number of observations (hence optimizing model's average likelihood) optimization converges. Unfortunately given blackjax design it's quite unnatural to ask for average likelihood. So in the end it's recommended to turn on double precision mode

  • A consequence of this is that, for running pathfinder tests, double precision mode is needed. Since double precision mode needs to be set at jax initialization time (see here), test suite should support some test in float32 mode and some in float64 mode (e.g. by running them in separate processes). It's feasible (see here for example) but not implemented at the moment. Right know i've just set the float64 mode for that tests:

https://github.com/miclegr/blackjax/blob/128ce00bd2b28e06f79d126c9d3c097b6378ccc8/tests/test_pathfinder.py#L2-L3

But this wont work when running the full suite of tests. Happy to spend some time to get the multi float mode tests in place, if that's the solution.

  • Here I've introduced an adaptation scheme where the inverse mass matrix estimation it taken from pathfinder and then the dual averaging adapatation is run for estimating step size. This is NOT discussed in the pathfinder paper (here, section IV), but it makes sense to me. So it's a bit experimental and for sure needs some more validation.

Happy to discuss feedbacks.

Michele

@rlouf rlouf linked a pull request Apr 28, 2022 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
sampler Issue related to samplers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants