-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for user-supplied RNG state in all interfaces #520
base: modular-rng
Are you sure you want to change the base?
Changes from 4 commits
9da5a4d
db61b28
e39459b
1e9d81d
8ad71dd
cde7265
5720acb
0132f5f
0539d93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
|
||
module Gen | ||
|
||
using Random: AbstractRNG, default_rng | ||
|
||
""" | ||
load_generated_functions(__module__=Main) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,10 +47,13 @@ accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad | |
|
||
mutable struct GFUntracedState | ||
params::Dict{Symbol,Any} | ||
rng::AbstractRNG | ||
end | ||
|
||
function (gen_fn::DynamicDSLFunction)(args...) | ||
state = GFUntracedState(gen_fn.params) | ||
(gen_fn::DynamicDSLFunction)(args...) = gen_fn(default_rng(), args...) | ||
|
||
function (gen_fn::DynamicDSLFunction)(rng::AbstractRNG, args...) | ||
state = GFUntracedState(gen_fn.params, rng) | ||
gen_fn.julia_function(state, args...) | ||
end | ||
|
||
|
@@ -85,7 +88,7 @@ end | |
gen_fn(args...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
@inline traceat(state::GFUntracedState, dist::Distribution, args, key) = | ||
random(dist, args...) | ||
random(state.rng, dist, args...) | ||
|
||
@inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) = | ||
gen_fn(args...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,12 @@ mutable struct GFGenerateState | |
weight::Float64 | ||
visitor::AddressVisitor | ||
params::Dict{Symbol,Any} | ||
rng::AbstractRNG | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
function GFGenerateState(gen_fn, args, constraints, params) | ||
function GFGenerateState(gen_fn, args, constraints, params, rng::AbstractRNG) | ||
trace = DynamicDSLTrace(gen_fn, args) | ||
GFGenerateState(trace, constraints, 0., AddressVisitor(), params) | ||
GFGenerateState(trace, constraints, 0., AddressVisitor(), params, rng) | ||
end | ||
|
||
function traceat(state::GFGenerateState, dist::Distribution{T}, | ||
|
@@ -26,7 +27,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, | |
if constrained | ||
retval = get_value(state.constraints, key) | ||
else | ||
retval = random(dist, args...) | ||
retval = random(state.rng, dist, args...) | ||
end | ||
|
||
# compute logpdf | ||
|
@@ -78,9 +79,12 @@ function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction, | |
retval | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On line 59, the recursive call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
function generate(gen_fn::DynamicDSLFunction, args::Tuple, | ||
constraints::ChoiceMap) | ||
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params) | ||
generate(gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap) = | ||
generate(default_rng(), gen_fn, args, constraints) | ||
|
||
function generate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple, | ||
constraints::ChoiceMap) | ||
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params, rng) | ||
retval = exec(gen_fn, state, args) | ||
set_retval!(state.trace, retval) | ||
(state.trace, state.weight) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,11 @@ mutable struct GFProposeState | |
weight::Float64 | ||
visitor::AddressVisitor | ||
params::Dict{Symbol,Any} | ||
rng::AbstractRNG | ||
end | ||
|
||
function GFProposeState(params::Dict{Symbol,Any}) | ||
GFProposeState(choicemap(), 0., AddressVisitor(), params) | ||
function GFProposeState(params::Dict{Symbol,Any}, rng::AbstractRNG) | ||
GFProposeState(choicemap(), 0., AddressVisitor(), params, rng) | ||
end | ||
|
||
function traceat(state::GFProposeState, dist::Distribution{T}, | ||
|
@@ -17,7 +18,7 @@ function traceat(state::GFProposeState, dist::Distribution{T}, | |
visit!(state.visitor, key) | ||
|
||
# sample return value | ||
retval = random(dist, args...) | ||
retval = random(state.rng, dist, args...) | ||
|
||
# update assignment | ||
set_value!(state.choices, key, retval) | ||
|
@@ -55,8 +56,8 @@ function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple) | |
retval | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On line 40, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
function propose(gen_fn::DynamicDSLFunction, args::Tuple) | ||
state = GFProposeState(gen_fn.params) | ||
function propose(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple) | ||
state = GFProposeState(gen_fn.params, rng) | ||
retval = exec(gen_fn, state, args) | ||
(state.choices, state.weight, retval) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,14 @@ mutable struct GFRegenerateState | |
weight::Float64 | ||
visitor::AddressVisitor | ||
params::Dict{Symbol,Any} | ||
rng::AbstractRNG | ||
end | ||
|
||
function GFRegenerateState(gen_fn, args, prev_trace, | ||
selection, params) | ||
selection, params, rng::AbstractRNG) | ||
visitor = AddressVisitor() | ||
GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection, | ||
0., visitor, params) | ||
0., visitor, params, rng) | ||
end | ||
|
||
function traceat(state::GFRegenerateState, dist::Distribution{T}, | ||
|
@@ -35,11 +36,11 @@ function traceat(state::GFRegenerateState, dist::Distribution{T}, | |
|
||
# get return value | ||
if has_previous && in_selection | ||
retval = random(dist, args...) | ||
retval = random(state.rng, dist, args...) | ||
elseif has_previous | ||
retval = prev_retval | ||
else | ||
retval = random(dist, args...) | ||
retval = random(state.rng, dist, args...) | ||
end | ||
|
||
# compute logpdf | ||
|
@@ -130,10 +131,10 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, | |
noise | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On lines 78 and 81, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple, | ||
selection::Selection) | ||
function regenerate(rng::AbstractRNG, trace::DynamicDSLTrace, args::Tuple, | ||
argdiffs::Tuple, selection::Selection) | ||
gen_fn = trace.gen_fn | ||
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params) | ||
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params, rng) | ||
retval = exec(gen_fn, state, args) | ||
set_retval!(state.trace, retval) | ||
visited = state.visitor.visited | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,11 +2,12 @@ mutable struct GFSimulateState | |
trace::DynamicDSLTrace | ||
visitor::AddressVisitor | ||
params::Dict{Symbol,Any} | ||
rng::AbstractRNG | ||
end | ||
|
||
function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) | ||
function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params, rng::AbstractRNG) | ||
trace = DynamicDSLTrace(gen_fn, args) | ||
GFSimulateState(trace, AddressVisitor(), params) | ||
GFSimulateState(trace, AddressVisitor(), params, rng) | ||
end | ||
|
||
function traceat(state::GFSimulateState, dist::Distribution{T}, | ||
|
@@ -16,7 +17,7 @@ function traceat(state::GFSimulateState, dist::Distribution{T}, | |
# check that key was not already visited, and mark it as visited | ||
visit!(state.visitor, key) | ||
|
||
retval = random(dist, args...) | ||
retval = random(state.rng, dist, args...) | ||
|
||
# compute logpdf | ||
score = logpdf(dist, retval, args...) | ||
|
@@ -56,8 +57,8 @@ function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction, | |
retval | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On line 40, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
function simulate(gen_fn::DynamicDSLFunction, args::Tuple) | ||
state = GFSimulateState(gen_fn, args, gen_fn.params) | ||
function simulate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple) | ||
state = GFSimulateState(gen_fn, args, gen_fn.params, rng) | ||
retval = exec(gen_fn, state, args) | ||
set_retval!(state.trace, retval) | ||
state.trace | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,14 +6,15 @@ mutable struct GFUpdateState | |
visitor::AddressVisitor | ||
params::Dict{Symbol,Any} | ||
discard::DynamicChoiceMap | ||
rng::AbstractRNG | ||
end | ||
|
||
function GFUpdateState(gen_fn, args, prev_trace, constraints, params) | ||
function GFUpdateState(gen_fn, args, prev_trace, constraints, params, rng::AbstractRNG) | ||
visitor = AddressVisitor() | ||
discard = choicemap() | ||
trace = DynamicDSLTrace(gen_fn, args) | ||
GFUpdateState(prev_trace, trace, constraints, | ||
0., visitor, params, discard) | ||
0., visitor, params, discard, rng) | ||
end | ||
|
||
function traceat(state::GFUpdateState, dist::Distribution{T}, | ||
|
@@ -48,7 +49,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T}, | |
elseif has_previous | ||
retval = prev_retval | ||
else | ||
retval = random(dist, args...) | ||
retval = random(state.rng, dist, args...) | ||
end | ||
|
||
# compute logpdf | ||
|
@@ -184,10 +185,10 @@ function add_unvisited_to_discard!(discard::DynamicChoiceMap, | |
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On lines 91 and 94, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple, | ||
function update(rng::AbstractRNG, trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple, | ||
constraints::ChoiceMap) | ||
gen_fn = trace.gen_fn | ||
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params) | ||
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params, rng) | ||
retval = exec(gen_fn, state, arg_values) | ||
set_retval!(state.trace, retval) | ||
visited = get_visited(state.visitor) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,7 +135,7 @@ Return an iterable over the trainable parameters of the generative function. | |
get_params(::GenerativeFunction) = () | ||
|
||
""" | ||
trace = simulate(gen_fn, args) | ||
trace = simulate([rng::AbstractRNG], gen_fn, args) | ||
|
||
Execute the generative function and return the trace. | ||
|
||
|
@@ -145,16 +145,18 @@ If `gen_fn` has optional trailing arguments (i.e., default values are provided), | |
the optional arguments can be omitted from the `args` tuple. The generated trace | ||
will have default values filled in. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either the docstrings or the docs should be updated to explain that the user can specify a custom There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
""" | ||
function simulate(::GenerativeFunction, ::Tuple) | ||
function simulate(::AbstractRNG, ::GenerativeFunction, ::Tuple) | ||
error("Not implemented") | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, a fair number of projects that rely on Gen.jl assume that the signature of To prevent this breaking change, and propagate this change in the Generative Function Interface more gracefully to the rest of the Gen ecosystem, I would replace the above code with a fallback to function simulate(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple)
@warn "Missing concrete implementation of `simulate(::AbstractRNG, ::$(typeof(gen_fn)), ::Tuple), `" *
"falling back to `simulate(::$(typeof(gen_fn)), ::Tuple)`."
return simulate(gen_fn, args)
end These warnings and fallbacks should also be implemented for all the other GFI functions that require an This change creates the possibility of infinite recursion / stack overflows if someone forgets to define both versions of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
simulate(gen_fn::GenerativeFunction, args::Tuple) = simulate(default_rng(), gen_fn, args) | ||
|
||
""" | ||
(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple) | ||
(trace::U, weight) = generate([rng::AbstractRNG], gen_fn::GenerativeFunction{T,U}, args::Tuple) | ||
|
||
Return a trace of a generative function. | ||
|
||
(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple, | ||
(trace::U, weight) = generate(rng, gen_fn::GenerativeFunction{T,U}, args::Tuple, | ||
constraints::ChoiceMap) | ||
|
||
Return a trace of a generative function that is consistent with the given | ||
|
@@ -181,14 +183,18 @@ Example with constraint that address `:z` takes value `true`. | |
(trace, weight) = generate(foo, (2, 4), choicemap((:z, true)) | ||
``` | ||
""" | ||
function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap) | ||
function generate(::AbstractRNG, ::GenerativeFunction, ::Tuple, ::ChoiceMap) | ||
error("Not implemented") | ||
end | ||
|
||
function generate(gen_fn::GenerativeFunction, args::Tuple) | ||
generate(gen_fn, args, EmptyChoiceMap()) | ||
generate(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) = generate(default_rng(), gen_fn, args, choices) | ||
|
||
function generate(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple) | ||
generate(rng, gen_fn, args, EmptyChoiceMap()) | ||
end | ||
|
||
generate(gen_fn::GenerativeFunction, args::Tuple) = generate(default_rng(), gen_fn, args) | ||
|
||
""" | ||
weight = project(trace::U, selection::Selection) | ||
|
||
|
@@ -207,7 +213,7 @@ function project(trace, selection::Selection) | |
end | ||
|
||
""" | ||
(choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple) | ||
(choices, weight, retval) = propose([rng::AbstractRNG], gen_fn::GenerativeFunction, args::Tuple) | ||
|
||
Sample an assignment and compute the probability of proposing that assignment. | ||
|
||
|
@@ -218,12 +224,14 @@ t)\$, and return \$t\$ | |
\\log \\frac{p(r, t; x)}{q(r; x, t)} | ||
``` | ||
""" | ||
function propose(gen_fn::GenerativeFunction, args::Tuple) | ||
trace = simulate(gen_fn, args) | ||
function propose(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple) | ||
trace = simulate(rng, gen_fn, args) | ||
weight = get_score(trace) | ||
(get_choices(trace), weight, get_retval(trace)) | ||
end | ||
|
||
propose(gen_fn::GenerativeFunction, args::Tuple) = propose(default_rng(), gen_fn, args) | ||
|
||
""" | ||
(weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) | ||
|
||
|
@@ -243,8 +251,8 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) | |
end | ||
|
||
""" | ||
(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple, | ||
constraints::ChoiceMap) | ||
(new_trace, weight, retdiff, discard) = update([rng::AbstractRNG], trace, args::Tuple, | ||
argdiffs::Tuple, constraints::ChoiceMap) | ||
|
||
Update a trace by changing the arguments and/or providing new values for some | ||
existing random choice(s) and values for some newly introduced random choice(s). | ||
|
@@ -272,25 +280,30 @@ that if the original `trace` was generated using non-default argument values, | |
then for each optional argument that is omitted, the old value will be | ||
over-written by the default argument value in the updated trace. | ||
""" | ||
function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap) | ||
function update(::AbstractRNG, trace, ::Tuple, ::Tuple, ::ChoiceMap) | ||
error("Not implemented") | ||
end | ||
|
||
update(trace, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) = | ||
update(default_rng(), trace, args, argdiffs, choices) | ||
|
||
""" | ||
(new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap) | ||
(new_trace, weight, retdiff, discard) = update([rng::AbstractRNG], trace, constraints::ChoiceMap) | ||
|
||
Shorthand variant of | ||
[`update`](@ref update(::Any, ::Tuple, ::Tuple, ::ChoiceMap)) | ||
which assumes the arguments are unchanged. | ||
""" | ||
function update(trace, constraints::ChoiceMap) | ||
function update(rng::AbstractRNG, trace, constraints::ChoiceMap) | ||
args = get_args(trace) | ||
argdiffs = Tuple(NoChange() for _ in args) | ||
return update(trace, args, argdiffs, constraints) | ||
return update(rng, trace, args, argdiffs, constraints) | ||
end | ||
|
||
update(trace, constraints::ChoiceMap) = update(default_rng(), trace, constraints) | ||
|
||
""" | ||
(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, | ||
(new_trace, weight, retdiff) = regenerate([rng::AbstractRNG], trace, args::Tuple, argdiffs::Tuple, | ||
selection::Selection) | ||
|
||
Update a trace by changing the arguments and/or randomly sampling new values | ||
|
@@ -317,23 +330,28 @@ that if the original `trace` was generated using non-default argument values, | |
then for each optional argument that is omitted, the old value will be | ||
over-written by the default argument value in the regenerated trace. | ||
""" | ||
function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) | ||
function regenerate(::AbstractRNG, trace, ::Tuple, ::Tuple, ::Selection) | ||
error("Not implemented") | ||
end | ||
|
||
regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) = | ||
regenerate(default_rng(), trace, args, argdiffs, selection) | ||
|
||
""" | ||
(new_trace, weight, retdiff) = regenerate(trace, selection::Selection) | ||
(new_trace, weight, retdiff) = regenerate([rng::AbstractRNG], trace, selection::Selection) | ||
|
||
Shorthand variant of | ||
[`regenerate`](@ref regenerate(::Any, ::Tuple, ::Tuple, ::Selection)) | ||
which assumes the arguments are unchanged. | ||
""" | ||
function regenerate(trace, selection::Selection) | ||
function regenerate(rng::AbstractRNG, trace, selection::Selection) | ||
args = get_args(trace) | ||
argdiffs = Tuple(NoChange() for _ in args) | ||
return regenerate(trace, args, argdiffs, selection) | ||
return regenerate(rng, trace, args, argdiffs, selection) | ||
end | ||
|
||
regenerate(trace, selection::Selection) = regenerate(default_rng(), trace, selection) | ||
|
||
""" | ||
arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should change all these
GF...State
structs to be parametric in the type of the RNG, to avoid potential performance regressions due to type instability. To be specific, I would replace this with:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in Add RNG type parameter to GF state types