Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

202009130 marcoct ais #309

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions src/inference/ais.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
(lml_est, trace, weights) = ais(
model::GenerativeFunction, constraints::ChoiceMap,
args_seq::Vector{Tuple}, argdiffs::Tuple,
mcmc_kernel::Function)

Run annealed importance sampling, returning the log marginal likelihood estimate (`lml_est`).

The mcmc_kernel must satisfy detailed balance with respect to each step in the chain.
"""
function ais(
model::GenerativeFunction, constraints::ChoiceMap,
args_seq::Vector{<:Tuple}, argdiffs::Tuple, mcmc_kernel::Function)
init_trace, init_weight = generate(model, args_seq[1], constraints)
_ais(init_trace, init_weight, args_seq, argdiffs, mcmc_kernel)
end

function ais(
trace::Trace, selection::Selection,
args_seq::Vector{<:Tuple}, argdiffs::Tuple, mcmc_kernel::Function)
init_trace, = update(init_trace, args_seq[1], argdiffs, EmptyChoiceMap())
init_weight = project(trace, ComplementSelection(selection))
_ais(init_trace, init_weight, args_seq, argdiffs, mcmc_kernel)
end

function _ais(
trace::Trace, init_weight::Float64, args_seq::Vector{<:Tuple},
argdiffs::Tuple, mcmc_kernel::Function)
@assert get_args(trace) == args_seq[1]

# run forward AIS
weights = Float64[]
lml_est = init_weight
push!(weights, init_weight)
for intermediate_args in args_seq[2:end-1]
trace = mcmc_kernel(trace)
(trace, weight, _, discard) = update(trace, intermediate_args, argdiffs, EmptyChoiceMap())
if !isempty(discard)
error("Change to arguments cannot cause random choices to be removed from trace")
end
lml_est += weight
push!(weights, weight)
end
trace = mcmc_kernel(trace)
(trace, weight, _, discard) = update(
trace, args_seq[end], argdiffs, EmptyChoiceMap())
if !isempty(discard)
error("Change to arguments cannot cause random choices to be removed from trace")
end
lml_est += weight
push!(weights, weight)

# do MCMC at the very end
trace = mcmc_kernel(trace)

return (lml_est, trace, weights)
end

"""
(lml_est, weights) = reverse_ais(
model::GenerativeFunction, constraints::ChoiceMap,
args_seq::Vector{Tuple}, argdiffs::Tuple,
mcmc_kernel::Function)

Run reverse annealed importance sampling, returning the log marginal likelihood estimate (`lml_est`).

`constraints` must be a choice map that uniquely determines a trace of the model for the final arguments in the argument sequence.
The mcmc_kernel must satisfy detailed balance with respect to each step in the chain.
"""
function reverse_ais(
model::GenerativeFunction, constraints::ChoiceMap,
args_seq::Vector, argdiffs::Tuple,
mh_rev::Function, output_addrs::Selection; safe=true)

# construct final model trace from the inferred choices and all the fixed choices
(trace, should_be_score) = generate(model, args_seq[end], constraints)
init_score = get_score(trace)
if safe && !isapprox(should_be_score, init_score) # check it's deterministic
error("Some random choices may have been unconstrained")
end

# do mh at the very beginning
trace = mh_rev(trace)

# run backward AIS
lml_est = 0.
weights = Float64[]
for model_args in reverse(args_seq[1:end-1])
(trace, weight, _, _) = update(trace, model_args, argdiffs, EmptyChoiceMap())
safe && isnan(weight) && error("NaN weight")
lml_est -= weight
push!(weights, -weight)
trace = mh_rev(trace)
end

# get pi_1(z_0) / q(z_0) -- the weight that would be returned by the initial 'generate' call
# select the addresses that would be constrained by the call to generate inside to AIS.simulate()
@assert get_args(trace) == args_seq[1]
#score_from_project = project(trace, ComplementSelection(output_addrs))
score_from_project = project(trace, output_addrs)
lml_est += score_from_project
push!(weights, score_from_project)
if isnan(score_from_project)
error("NaN score_from_project")
end

return (lml_est, reverse(weights))
end

export ais, reverse_ais
24 changes: 23 additions & 1 deletion src/inference/importance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,26 @@ function importance_resampling(model::GenerativeFunction{T,U}, model_args::Tuple
return (model_trace::U, log_ml_estimate::Float64)
end

export importance_sampling, importance_resampling
"""
log_ml_estimate = conditional_is_estimator(
trace::Trace, observed::Selection, num_samples::Int)

Given a trace sampled from the conditional distribution given observed choices,
return an estimate of the log marginal likelihood of the observed choices that is a
stochastic upper bound on the true log marginal likelihood.
"""
function conditional_is_estimator(trace::Trace, observed::Selection, num_samples::Int)
model = get_gen_fn(trace)
model_args = get_args(trace)
observations = get_selected(get_choices(trace), observed)
log_weights = Vector{Float64}(undef, num_samples)
log_weights[1] = project(trace, observed)
for i=2:num_samples
(_, log_weights[i]) = generate(model, model_args, observations)
end
log_total_weight = logsumexp(log_weights)
log_ml_estimate = log_total_weight - log(num_samples)
return log_ml_estimate
end

export importance_sampling, importance_resampling, conditional_is_estimator
1 change: 1 addition & 0 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ include("particle_filter.jl")
include("map_optimize.jl")
include("train.jl")
include("variational.jl")
include("ais.jl")