Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

20200904 marcoct thesischanges #305

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/gen_fn_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ When there is no non-addressed randomness, this simplifies to the log probabilit
"""
function get_score end

"""
logpdf(trace)

Synonym for [`get_score`](@ref).
"""
logpdf(trace::Trace) = get_score(trace)

"""
gen_fn::GenerativeFunction = get_gen_fn(trace)

Expand Down Expand Up @@ -394,3 +401,4 @@ export update
export regenerate
export accumulate_param_gradients!
export choice_gradients
export logpdf
2 changes: 2 additions & 0 deletions src/modeling_library/modeling_library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ include("map/map.jl")
include("unfold/unfold.jl")
include("recurse/recurse.jl")

include("override_internal_proposal.jl")

#############################################################
# abstractions for constructing custom generative functions #
#############################################################
Expand Down
76 changes: 76 additions & 0 deletions src/modeling_library/override_internal_proposal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

# Generative function combiantor that overrides internal proposal with another
# generative function

# Not yet implemented:
# - update
# - project
# - choice_gradients
# - accumulate_param_gradients!

struct ReplaceProposalGFTrace{U} <: Trace
model_trace::U
gen_fn::GenerativeFunction
end

get_args(tr::ReplaceProposalGFTrace) = get_args(tr.model_trace)
get_retval(tr::ReplaceProposalGFTrace) = get_retval(tr.model_trace)
get_choices(tr::ReplaceProposalGFTrace) = get_choices(tr.model_trace)
get_score(tr::ReplaceProposalGFTrace) = get_score(tr.model_trace)

struct ReplaceProposalGF{T,U} <: GenerativeFunction{T,ReplaceProposalGFTrace{U}}
model::GenerativeFunction{T,U}
proposal::GenerativeFunction
end

get_gen_fn(tr::ReplaceProposalGFTrace) = tr.gen_fn

# gradient ops not implemented yet
has_argument_grads(f::ReplaceProposalGF) = map(_->false,has_argument_grads(f.model))
accepts_output_grad(f::ReplaceProposalGF) = false

function project(tr::ReplaceProposalGFTrace, ::EmptySelection)
return project(tr.model_trace, EmptySelection())
end

function simulate(gen_fn::ReplaceProposalGF, args::Tuple)
tr = simulate(gen_fn.model, args)
return ReplaceProposalGFTrace(tr, gen_fn)
end

function generate(gen_fn::ReplaceProposalGF, args::Tuple, constraints::ChoiceMap)
(proposed_choices, proposal_weight, _) = propose(gen_fn.proposal, (constraints, args...))
all_constraints = merge(proposed_choices, constraints)
new_tr, model_weight = generate(gen_fn.model, args, all_constraints)
@assert isapprox(model_weight, get_score(new_tr))
weight = model_weight - proposal_weight
return (ReplaceProposalGFTrace(new_tr, gen_fn), weight)
end

function regenerate(trace::ReplaceProposalGFTrace, args::Tuple, argdiffs::Tuple, selection::Selection)
gen_fn = get_gen_fn(trace)
prev_args = get_args(trace)

# u <- create choice map u containing addresses from trace, except for those in selection
u = get_selected(get_choices(trace), complement(selection))

# then, run generate with that u to obtain new-trace t', and weight w = p(t'; x') / q(t; x, u')
(new_trace, p_weight) = generate(gen_fn, args, u)

# then, create choice map u' containing addresses from new-trace, except for those in selection
u_backward = get_selected(get_choices(new_trace), complement(selection))

# then, run generate on custom_q to obtain q(t; x, u')
(_, q_weight) = generate(gen_fn.proposal, (u_backward, prev_args...), get_choices(trace)) # NOTE there will be extra choices

# then, use get_score(trace) and subtracct it from the weight
weight = p_weight + q_weight - get_score(trace)

return (new_trace, weight, UnknownChange())
end

function override_internal_proposal(p, q)
return ReplaceProposalGF(p, q)
end

export override_internal_proposal