We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Blackjax already has an example where we use SGLD to sample from a 3 layer MLP with a very decent accuracy when using the uncertainties to discard ambiguous predictions. We can use the CNN architecture in the Flax documentation:
from flax import linen as nn class CNN(nn.Module): """A simple CNN model.""" @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x
And the logprob function as (not tested):
from jax.tree_utils import flatten_pytree import distrax def logpdf(params, images, categories, model): logits = model.apply(params, images).ravel() flat_params, _ = ravel_pytree(params) log_prior = distrax.Normal(0.0, 1.0).log_prob(flat_params).sum() log_likelihood = distrax.Bernoulli(logits=logits).log_prob(categories).sum() return log_prior + log_likelihood
We should look at:
The text was updated successfully, but these errors were encountered:
Hey @rlouf, love the example! Inside the logpdf function the y variable doesn't exist, I am guessing it should be categories instead?
logpdf
y
categories
Sorry, something went wrong.
Yes, made the change, thank you! I have no guarantee that this will work though
Hi @rlouf, I’ll work on this issue!
Hey @gerdm do you still intend on working on this?
Hey @rlouf. Yes, still planning to work on it. Expect updates in September.
gerdm
No branches or pull requests
Blackjax already has an example where we use SGLD to sample from a 3 layer MLP with a very decent accuracy when using the uncertainties to discard ambiguous predictions. We can use the CNN architecture in the Flax documentation:
And the logprob function as (not tested):
We should look at:
The text was updated successfully, but these errors were encountered: