-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pathfinder #157
Comments
Hi Michele! |
Hi! @miclegr Yes as Junpeng said you can try to fit the @junpenglao Stochastic gradient MCMC is getting a slightly different API as well, so that wouldn't be a problem. I need to read the paper first to have an opinion. |
Pathfinder is a variational method. Is it intended for BlackJAX to contain VI implementations at some point? |
We had put off general VI implementation because it requires a bit more design decision. In general we need: target_density: callable = pytree -> scalar tensor
approx_density: callable = pytree -> scalar tensor
approx_sample: callable = PRNGKey -> pytree
KL_divergence: callable = pytree -> scalar tensor which are much easier with a PPL. I think for "standard" approximation that are multivariate Gaussian (meanfield and full rank ADVI, Pathfinder, Laplace approximation), we can start off with stand alone implementation and do abstraction later on. |
Hi guys, I've opened the pull request now, an introduction to the method and its implementation in blackjax here. Some design decision/open points:
But this wont work when running the full suite of tests. Happy to spend some time to get the multi float mode tests in place, if that's the solution.
Happy to discuss feedbacks. Michele |
Hi,
I built a JAX implementation of pathfinder here; since it's in the list of #154 I can try to merge it.
Any requirement for the API? Shall it fit the
pattern, even if it's not a MCMC kernel?
Michele
The text was updated successfully, but these errors were encountered: