Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user-supplied RNG state in all interfaces #520

Open
wants to merge 9 commits into
base: modular-rng
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/Gen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

module Gen

using Random: AbstractRNG, default_rng

"""
load_generated_functions(__module__=Main)

Expand Down
9 changes: 6 additions & 3 deletions src/dynamic/dynamic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,13 @@ accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad

mutable struct GFUntracedState
params::Dict{Symbol,Any}
rng::AbstractRNG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change all these GF...State structs to be parametric in the type of the RNG, to avoid potential performance regressions due to type instability. To be specific, I would replace this with:

mutable struct GFUntracedState{R <: AbstractRNG}
    params::Dict{Symbol, Any}
    rng::R
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function (gen_fn::DynamicDSLFunction)(args...)
state = GFUntracedState(gen_fn.params)
(gen_fn::DynamicDSLFunction)(args...) = gen_fn(default_rng(), args...)

function (gen_fn::DynamicDSLFunction)(rng::AbstractRNG, args...)
state = GFUntracedState(gen_fn.params, rng)
gen_fn.julia_function(state, args...)
end

Expand Down Expand Up @@ -85,7 +88,7 @@ end
gen_fn(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should pass state.rng as the first argument to gen_fn.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@inline traceat(state::GFUntracedState, dist::Distribution, args, key) =
random(dist, args...)
random(state.rng, dist, args...)

@inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) =
gen_fn(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splice should also pass state.rng as the first argument to gen_fn.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand Down
16 changes: 10 additions & 6 deletions src/dynamic/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ mutable struct GFGenerateState
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::AbstractRNG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with GFUntracedState, let's replace AbstractRNG with a type parameter. Same for all of the others.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function GFGenerateState(gen_fn, args, constraints, params)
function GFGenerateState(gen_fn, args, constraints, params, rng::AbstractRNG)
trace = DynamicDSLTrace(gen_fn, args)
GFGenerateState(trace, constraints, 0., AddressVisitor(), params)
GFGenerateState(trace, constraints, 0., AddressVisitor(), params, rng)
end

function traceat(state::GFGenerateState, dist::Distribution{T},
Expand All @@ -26,7 +27,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T},
if constrained
retval = get_value(state.constraints, key)
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
Expand Down Expand Up @@ -78,9 +79,12 @@ function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction,
retval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 59, the recursive call to generate needs to pass state.rng to the callee function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function generate(gen_fn::DynamicDSLFunction, args::Tuple,
constraints::ChoiceMap)
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params)
generate(gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap) =
generate(default_rng(), gen_fn, args, constraints)

function generate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple,
constraints::ChoiceMap)
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
(state.trace, state.weight)
Expand Down
11 changes: 6 additions & 5 deletions src/dynamic/propose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ mutable struct GFProposeState
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::AbstractRNG
end

function GFProposeState(params::Dict{Symbol,Any})
GFProposeState(choicemap(), 0., AddressVisitor(), params)
function GFProposeState(params::Dict{Symbol,Any}, rng::AbstractRNG)
GFProposeState(choicemap(), 0., AddressVisitor(), params, rng)
end

function traceat(state::GFProposeState, dist::Distribution{T},
Expand All @@ -17,7 +18,7 @@ function traceat(state::GFProposeState, dist::Distribution{T},
visit!(state.visitor, key)

# sample return value
retval = random(dist, args...)
retval = random(state.rng, dist, args...)

# update assignment
set_value!(state.choices, key, retval)
Expand Down Expand Up @@ -55,8 +56,8 @@ function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple)
retval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 40, state.rng needs to be passed to the recursive call to propose.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function propose(gen_fn::DynamicDSLFunction, args::Tuple)
state = GFProposeState(gen_fn.params)
function propose(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple)
state = GFProposeState(gen_fn.params, rng)
retval = exec(gen_fn, state, args)
(state.choices, state.weight, retval)
end
15 changes: 8 additions & 7 deletions src/dynamic/regenerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ mutable struct GFRegenerateState
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::AbstractRNG
end

function GFRegenerateState(gen_fn, args, prev_trace,
selection, params)
selection, params, rng::AbstractRNG)
visitor = AddressVisitor()
GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection,
0., visitor, params)
0., visitor, params, rng)
end

function traceat(state::GFRegenerateState, dist::Distribution{T},
Expand All @@ -35,11 +36,11 @@ function traceat(state::GFRegenerateState, dist::Distribution{T},

# get return value
if has_previous && in_selection
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
elseif has_previous
retval = prev_retval
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
Expand Down Expand Up @@ -130,10 +131,10 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord},
noise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On lines 78 and 81, state.rng needs to be passed to the calls to regenerate and generate respectively.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple,
selection::Selection)
function regenerate(rng::AbstractRNG, trace::DynamicDSLTrace, args::Tuple,
argdiffs::Tuple, selection::Selection)
gen_fn = trace.gen_fn
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params)
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
visited = state.visitor.visited
Expand Down
11 changes: 6 additions & 5 deletions src/dynamic/simulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ mutable struct GFSimulateState
trace::DynamicDSLTrace
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::AbstractRNG
end

function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params)
function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params, rng::AbstractRNG)
trace = DynamicDSLTrace(gen_fn, args)
GFSimulateState(trace, AddressVisitor(), params)
GFSimulateState(trace, AddressVisitor(), params, rng)
end

function traceat(state::GFSimulateState, dist::Distribution{T},
Expand All @@ -16,7 +17,7 @@ function traceat(state::GFSimulateState, dist::Distribution{T},
# check that key was not already visited, and mark it as visited
visit!(state.visitor, key)

retval = random(dist, args...)
retval = random(state.rng, dist, args...)

# compute logpdf
score = logpdf(dist, retval, args...)
Expand Down Expand Up @@ -56,8 +57,8 @@ function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction,
retval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 40, state.rng needs to be passed to the call to simulate.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function simulate(gen_fn::DynamicDSLFunction, args::Tuple)
state = GFSimulateState(gen_fn, args, gen_fn.params)
function simulate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple)
state = GFSimulateState(gen_fn, args, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
state.trace
Expand Down
11 changes: 6 additions & 5 deletions src/dynamic/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ mutable struct GFUpdateState
visitor::AddressVisitor
params::Dict{Symbol,Any}
discard::DynamicChoiceMap
rng::AbstractRNG
end

function GFUpdateState(gen_fn, args, prev_trace, constraints, params)
function GFUpdateState(gen_fn, args, prev_trace, constraints, params, rng::AbstractRNG)
visitor = AddressVisitor()
discard = choicemap()
trace = DynamicDSLTrace(gen_fn, args)
GFUpdateState(prev_trace, trace, constraints,
0., visitor, params, discard)
0., visitor, params, discard, rng)
end

function traceat(state::GFUpdateState, dist::Distribution{T},
Expand Down Expand Up @@ -48,7 +49,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T},
elseif has_previous
retval = prev_retval
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
Expand Down Expand Up @@ -184,10 +185,10 @@ function add_unvisited_to_discard!(discard::DynamicChoiceMap,
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On lines 91 and 94, state.rng needs to be passed to the recursive calls to update and generate respectively.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple,
function update(rng::AbstractRNG, trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple,
constraints::ChoiceMap)
gen_fn = trace.gen_fn
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params)
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params, rng)
retval = exec(gen_fn, state, arg_values)
set_retval!(state.trace, retval)
visited = get_visited(state.visitor)
Expand Down
60 changes: 39 additions & 21 deletions src/gen_fn_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ Return an iterable over the trainable parameters of the generative function.
get_params(::GenerativeFunction) = ()

"""
trace = simulate(gen_fn, args)
trace = simulate([rng::AbstractRNG], gen_fn, args)

Execute the generative function and return the trace.

Expand All @@ -145,16 +145,18 @@ If `gen_fn` has optional trailing arguments (i.e., default values are provided),
the optional arguments can be omitted from the `args` tuple. The generated trace
will have default values filled in.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either the docstrings or the docs should be updated to explain that the user can specify a custom RNG, and that Random.default_rng() will be used by default.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
function simulate(::GenerativeFunction, ::Tuple)
function simulate(::AbstractRNG, ::GenerativeFunction, ::Tuple)
error("Not implemented")
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, a fair number of projects that rely on Gen.jl assume that the signature of simulate etc. does not have an AbstractRNG as the first argument. However, as I've noted in my above comments, we should be recursively calling simulate etc. in by passing the RNG down to nested generative functions. This will break those other projects, some of which define custom generative functions which do not yet support specification of the RNG.

To prevent this breaking change, and propagate this change in the Generative Function Interface more gracefully to the rest of the Gen ecosystem, I would replace the above code with a fallback to simulate without the AbstractRNG argument, along with a warning that reminds library developers that they should add such an implementation to their library:

function simulate(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple)
    @warn "Missing concrete implementation of `simulate(::AbstractRNG, ::$(typeof(gen_fn)), ::Tuple), `" *
                "falling back to `simulate(::$(typeof(gen_fn)), ::Tuple)`."
    return simulate(gen_fn, args)
end

These warnings and fallbacks should also be implemented for all the other GFI functions that require an AbstractRNG.

This change creates the possibility of infinite recursion / stack overflows if someone forgets to define both versions of simulate etc., but I think it is the price we'll have to pay to maintain backwards compatibility. Once the GFI changes are propagated across enough of the Gen.jl ecosystem, then I think we can remove this fallback in the next breaking release.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


simulate(gen_fn::GenerativeFunction, args::Tuple) = simulate(default_rng(), gen_fn, args)

"""
(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple)
(trace::U, weight) = generate([rng::AbstractRNG], gen_fn::GenerativeFunction{T,U}, args::Tuple)

Return a trace of a generative function.

(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple,
(trace::U, weight) = generate(rng, gen_fn::GenerativeFunction{T,U}, args::Tuple,
constraints::ChoiceMap)

Return a trace of a generative function that is consistent with the given
Expand All @@ -181,14 +183,18 @@ Example with constraint that address `:z` takes value `true`.
(trace, weight) = generate(foo, (2, 4), choicemap((:z, true))
```
"""
function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap)
function generate(::AbstractRNG, ::GenerativeFunction, ::Tuple, ::ChoiceMap)
error("Not implemented")
end

function generate(gen_fn::GenerativeFunction, args::Tuple)
generate(gen_fn, args, EmptyChoiceMap())
generate(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) = generate(default_rng(), gen_fn, args, choices)

function generate(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple)
generate(rng, gen_fn, args, EmptyChoiceMap())
end

generate(gen_fn::GenerativeFunction, args::Tuple) = generate(default_rng(), gen_fn, args)

"""
weight = project(trace::U, selection::Selection)

Expand All @@ -207,7 +213,7 @@ function project(trace, selection::Selection)
end

"""
(choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple)
(choices, weight, retval) = propose([rng::AbstractRNG], gen_fn::GenerativeFunction, args::Tuple)

Sample an assignment and compute the probability of proposing that assignment.

Expand All @@ -218,12 +224,14 @@ t)\$, and return \$t\$
\\log \\frac{p(r, t; x)}{q(r; x, t)}
```
"""
function propose(gen_fn::GenerativeFunction, args::Tuple)
trace = simulate(gen_fn, args)
function propose(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple)
trace = simulate(rng, gen_fn, args)
weight = get_score(trace)
(get_choices(trace), weight, get_retval(trace))
end

propose(gen_fn::GenerativeFunction, args::Tuple) = propose(default_rng(), gen_fn, args)

"""
(weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)

Expand All @@ -243,8 +251,8 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
end

"""
(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple,
constraints::ChoiceMap)
(new_trace, weight, retdiff, discard) = update([rng::AbstractRNG], trace, args::Tuple,
argdiffs::Tuple, constraints::ChoiceMap)

Update a trace by changing the arguments and/or providing new values for some
existing random choice(s) and values for some newly introduced random choice(s).
Expand Down Expand Up @@ -272,25 +280,30 @@ that if the original `trace` was generated using non-default argument values,
then for each optional argument that is omitted, the old value will be
over-written by the default argument value in the updated trace.
"""
function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap)
function update(::AbstractRNG, trace, ::Tuple, ::Tuple, ::ChoiceMap)
error("Not implemented")
end

update(trace, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) =
update(default_rng(), trace, args, argdiffs, choices)

"""
(new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap)
(new_trace, weight, retdiff, discard) = update([rng::AbstractRNG], trace, constraints::ChoiceMap)

Shorthand variant of
[`update`](@ref update(::Any, ::Tuple, ::Tuple, ::ChoiceMap))
which assumes the arguments are unchanged.
"""
function update(trace, constraints::ChoiceMap)
function update(rng::AbstractRNG, trace, constraints::ChoiceMap)
args = get_args(trace)
argdiffs = Tuple(NoChange() for _ in args)
return update(trace, args, argdiffs, constraints)
return update(rng, trace, args, argdiffs, constraints)
end

update(trace, constraints::ChoiceMap) = update(default_rng(), trace, constraints)

"""
(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple,
(new_trace, weight, retdiff) = regenerate([rng::AbstractRNG], trace, args::Tuple, argdiffs::Tuple,
selection::Selection)

Update a trace by changing the arguments and/or randomly sampling new values
Expand All @@ -317,23 +330,28 @@ that if the original `trace` was generated using non-default argument values,
then for each optional argument that is omitted, the old value will be
over-written by the default argument value in the regenerated trace.
"""
function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection)
function regenerate(::AbstractRNG, trace, ::Tuple, ::Tuple, ::Selection)
error("Not implemented")
end

regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) =
regenerate(default_rng(), trace, args, argdiffs, selection)

"""
(new_trace, weight, retdiff) = regenerate(trace, selection::Selection)
(new_trace, weight, retdiff) = regenerate([rng::AbstractRNG], trace, selection::Selection)

Shorthand variant of
[`regenerate`](@ref regenerate(::Any, ::Tuple, ::Tuple, ::Selection))
which assumes the arguments are unchanged.
"""
function regenerate(trace, selection::Selection)
function regenerate(rng::AbstractRNG, trace, selection::Selection)
args = get_args(trace)
argdiffs = Tuple(NoChange() for _ in args)
return regenerate(trace, args, argdiffs, selection)
return regenerate(rng, trace, args, argdiffs, selection)
end

regenerate(trace, selection::Selection) = regenerate(default_rng(), trace, selection)

"""
arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.)

Expand Down
Loading
Loading