-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(Ready for review): Switch combinator #334
(Ready for review): Switch combinator #334
Conversation
My hope is that a few people can take a look at the trace, and the sketch up for
My proposal will follow the outline of the Switch combinator expressed in Trace types and denotational semantics: Note that the following restrictions apply:
|
state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr} | ||
|
||
# create flip distribution | ||
flip_d = bernoulli(branch_p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@femtomc This samples from a Bernoulli distribution with probability branch_p and returns a Bool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies - that comment needs to be removed. I figured that bit out and corrected it elsewhere, but the comments indicate otherwise.
@femtomc This is exciting! A few notes: I think @femtomc I thought more about the different combinators in this space: I think the combinator that takes a vector of generative functions and returns a generative function that accepts an integer argument that switches between them (I'll call this Of course, there are many possibilities:
From the POPL paper, the main benefit of In general, I think the number of combinators can grow, and we it will take some experience using them to understand which combinators are most useful, and whether some tweaks would be useful. Some specific notes about the combinator you've started to implement:
By trace type you mean the set of possible choice maps it samples, right? I think this is overly restrictive. We want to be able to handle cases where the two branches make random choices at different addresses.
I think that's reasonable. Also, I think it makes sense also to require that the return types of the two generative functions are the same. Together those two restrictions will make things like AD more straightforward. This combinator manifests some interesting design choices relating to the probabilistic semantics of generative functions. There are some design choices to be made for the addressing scheme. For example, the Bernoulli random choice needs to have an address of its own. Also, do addresses sampled in the two branches live in the same namespace or different namespaces? If they live in the same namespace, then the semantics of update and regenerate mean that these operations will need to automatically copy data from one branch's trace to another. The version where the branches have different namespaces for the two branches will be easier to implement, and I think that probably makes sense for a first version. |
No, I actually mean the subtype of Sorry, trace type is now overloaded. I should use inheritor of Specifically, all I'm saying is that you have to use the same generative function type in each branch. Do you think that is too conservative? Edit: generative function type. |
Interesting -- why do we need a combinator to branch between applications of the same generative function? Can't we do
in that case? |
@alex-lew Oops - generative function type. corrected. |
@femtomc Yes, I think it would be better to allow the subtraces to have any Julia type that is a subtype of |
Presumably it's all right if this is technically a Union type :-) Or at least, we may need to handle the case where both GFs return union types, and it doesn't seem to be more challenging to handle the case where they return different things? But I might not be thinking through the subtleties with respect to AD. |
@marcoct I've started thinking about the difference between (1 and 2) and (3 and 4) a bit. In practice, is there any difference when extracting the internal random choice made by the combinators described by (1, 3) into a choice preceding (2, 4) ? Neither set of implementations appears to restrict the modeller. I also couldn't think of anything from an argdiff perspective. It seems mostly a matter of convenience (because we can isolate the "which branch" choice inside the namespace of the combinator, whereas the user might have to populate their toplevel model namespace with multiple branch choices). Right now, I'm planning on implementing 1 and 4. Also just a comment - I don't think there's a large increase in complexity for deriving the trace type for (2, 3) ? Because e.g. I would just type these as:
etc. But I might be missing something. |
@marcoct In terms of namespaces, the semantics as you mentioned above are really interesting. If we take the sharing route, what happens if the address spaces for either branch have non-empty set difference? When I switch - do I update with constraints on addresses in the intersection, even if I can't visit all the constraints in the switch? Conversely, do I partially constrain and then sample from the prior? I expect that both of these define a valid update if we implement the correct reverse move (e.g. re-score the constraints and subtract the previous branch weight) but this is different than the branch switching behavior in the dynamic DSL (iirc). But it's neat! |
…onfirm with Marco/Alex.
SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) | ||
end | ||
|
||
function process!(gen_fn::Switch{C, N, K, T}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marcoct This seems correct for the "namespace merging" semantics. In update
, if I switch branches, I generate
with a merge of the previous traces choices and the choice map. I'm worried this will throw an error (however) if not all constraints are visited (e.g. if the previous trace has non-empty set difference with the new namespace).
I'm also dispatching on process!
using the diff types. I'm not sure if this current implementation is correct - mostly because I'm not supported Diffed
dispatch yet (and I suspect I have to).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also dispatching on process! using the diff types. I'm not sure if this current implementation is correct - mostly because I'm not supported Diffed dispatch yet (and I suspect I have to).
@femtomc The Diffed
types are only used for boxing Julia values for diff propagation of Julia code via operator overloading. GFI implementers only need to worry about Diff
values. So what you have is right.
Given an addressing scheme, the semantics for update are well-defined -- the addresses that are shared get copied over, unless they are themselves constrained. New addresses that are introduced but not constrained get sampled from an internal proposal distribution. The discard contains all addresses that were removed, and any shared addresses that were constrained. The weight is the ratio of densities of the two traces, with an extra factor in the denominator for the density on the newly introduced choices that were not constrained. This behavior is described rather tersely in the docs here. An easier-to-read version of the semantics of update is given in page 71 of my thesis. (The definition in my thesis defines the behavior when there are newly introduced but unconstrained choices as an error, and the definition in the Gen documentation extends that definition to allow those choices get sampled from the internal from the internal proposal.). I recognize that the version in the Gen docs at the moment is rather terse, so here is what I think the behavior should be: In the version where the two branches use different address namespaces, when 'update' receives a change in the branch, it would call 'generate' -- with no constraints -- on the new branch to fill in the choices. In the version where the two branches use the same address namespace, you would extract the choices from the previous branch's trace and use them as constraints that you pass to a call to 'generate' on the new branch's generative function. I think in both cases the log weight returned by 'update' would be I think it would be good to implement the version where they use different namespaces first -- the more complicated version can always be a separate PR. Later on, I could also imagine a flag that you pass to the combinator that specifies which version gets generated -- whether they use shared namespace or not -- and also for the different namespace version, what the addresses to use for each branch would be). |
@marcoct I think I have the shared address version implemented. I tested to make sure the (score - weight) after update is the same as the score before (to verify that the implementation was consistent). I suspect when I formalize my The last thing I need to check off is the discard behavior. |
@femtomc Nice! Just read your comment above, and now I better understand your first question. Yes, I think we need to relax the requirement on generate so that it does not fail if there are unvisited choices. More specifically, there should be a flag that can be used to turn on or off this check. In my thesis I documented 'generate' as having this behavior, to support things like what we're trying to do here, and I think the Gen implementation should match this. (I think it already does, but it's not well documented.) I just made an issue for this: #335 |
@marcoct I had to construct some custom methods to correctly perform the merging and discard computation. One question I had: function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
prev_choice_submap_iterator = get_submaps_shallow(prev_choices)
prev_choice_value_iterator = get_values_shallow(prev_choices)
choice_submap_iterator = get_submaps_shallow(choices)
choice_value_iterator = get_values_shallow(choices)
choices = DynamicChoiceMap()
for (key, value) in prev_choice_value_iterator
key in keys(choice_value_iterator) && continue
set_value!(choices, key, value)
end
for (key, node1) in prev_choice_submap_iterator
if key in keys(choice_submap_iterator)
node2 = get_submap(choices, key)
node = update_recurse_merge(node1, node2)
set_submap!(choices, key, node)
else
set_submap!(choices, key, node1)
end
end
for (key, value) in choice_value_iterator
set_value!(choices, key, value)
end
for (key, node) in filter((k, _) -> !(k in keys(prev_choice_submap_iterator)), choice_submap_iterator)
set_submap!(choices, key, node)
end
return choices
end Can I expect that every iterator which is returned from This method above performs a merging which won't throw an error (unlike the method implemented for And a question for function update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace)
discard = choicemap()
prev_choices = get_choices(prev_trace)
for (k, v) in get_submaps_shallow(prev_choices)
get_submap(get_choices(new_trace), k) isa EmptyChoiceMap && continue
get_submap(choices, k) isa EmptyChoiceMap && continue
set_submap!(discard, k, v)
end
for (k, v) in get_values_shallow(prev_choices)
has_value(get_choices(new_trace), k) || continue
has_value(choices, k) || continue
set_value!(discard, k, v)
end
discard
end Here I have to check using |
@marcoct I'm extracting the |
…propagating correctly now.
…1116_mrb_switch_combinator
…ous bug in semantics for regenerate - when switching branches, should generate with choice map constraints except those addresses which are in selection.
…ous bug in semantics for regenerate - when switching branches, should generate with choice map constraints except those addresses which are in selection.
@femtomc @marcoct this looks great! To clarify--it sounds like you have implemented the instance where it is possible for branches to share addresses, right? If so, what performance impact does this have in the case where they cannot share addresses? In the case where eg. they are static functions we can be sure don't share addresses, do we still have to do full choicemap scans when switching branches? |
@georgematheos I think switching branches for now will always need to involve a full execution of the new branch generative function. I could imagine an optimization pass that is applied to the SML + combinators IR code that replaces the branch combinator call with a call to a generative function implementation that is specially compiled to know about both branches and switch more efficiently, but I think that would be a separate piece of code from this combinator implementation. Edit: @georgematheos Sorry, after thinking more I think I understand your question now -- you're right that there will be performance cost to doing the choice lookups. I think the answer for this is probably to (in a separate PR) add flag(s) that instruct the combinator to use separate address spaces, or informs the combinator that the address spaces are disjoint. |
@georgematheos it's also possible to optimize this in a more fine-gained automatic way with These two functions are responsible for traversing and identifying which addresses need to be constrained. But if we have access to static address information, it should be possible to specialize these on static information. Is this what you might have been thinking about? Do SML functions carry their support as type information (or otherwise allow me to maybe call a function which has cached the support information as dispatch on the unique model or trace type)? I think the answer is yes, but I haven't looked into the internals over there. I was actually going to reach out to you, because some things I noticed about the implementation of For this combinator, the notion of |
@georgematheos also, this combinator superficially reminds me of the |
As I thought about this more, I'm a little uncomfortable with providing a specialization which depends upon the modelling language. This doesn't seem to match the philosophy of combinators providing language-agnostic patterns. It becomes a bit tricky to think about where you would put this functionality. This is not something which you would implement to extend the existing It reminds me of some of the new interfaces I'm designing for my modelling language These interfaces AFAIK are not a part of the normal methods you implement now for inheritors of So I've been rambling - but maybe my proposal is to introduce this function as part of the Is there an easier way? |
@femtomc yes--I think that we could implement Dispatch using Switch. And I have a fork of Gen with lots of changes, including merging |
I think the idea of utilizing type information of Generative functions to specialize is a pretty promising idea. I have not fully thought through it but it seems like for static generative functions, it would be able to help. For things like I like the idea you are bringing up about having more static information about the generative function in the trace types; I haven't developed many concrete proposals/ideas beyond what you showed me last time we discussed this, but as I said I think it could be useful for performance, static analysis, etc. |
In terms of a potential "easier way" to implement This way, there can never be shared addresses between branches, so when switching branches, we just return the old choicemap as the discard, and generate a new choicemap (from scratch/the I could imagine implementing this version of The idea here is that this gives users a performance vs readability tradeoff. They can use I could imagine (potentially in a future PR) making it so that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@femtomc This looks mostly really good! I have a few cosmetic comments, but I found what I think is a bug in the computation of the discard in update.
selection::Selection, | ||
state::SwitchRegenerateState{T}) where {C, N, K, T} | ||
branch_fn = getfield(gen_fn.branches, index) | ||
merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be possible to use get_selected
, which is already implemented, instead. It has the same behavior as regenerate_recurse_merge
:
get_selected(get_choices(state.prev_trace), complement(selection))
|
||
function process!(gen_fn::Switch{C, N, K, T}, | ||
index::Int, | ||
index_argdiff::UnknownChange, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to use Diff
instead of UnknownChange
here. There might be other Diff
types that could be passed that are intermediate between UnknownChange
and NoChange
(e.g. there is an IntDiff
already, which tracks the arithmetic difference between two integers). This method should apply to anything that's not a NoChange
.
sel, _ = zip(prev_choice_submap_iterator...) | ||
comp = complement(select(sel...)) | ||
for (key, node) in get_submaps_shallow(get_selected(choices, comp)) | ||
set_submap!(new_choices, key, node) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be simpler and more efficient to do:
for (key, node2) in choice_submap_iterator
if isempty(get_submap(new_choices, key))
set_submap!(new_choices, key, node2)
end
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marcoct This one is slightly interesting: the intent here is to directly identify what submap addresses in choices
are not in prev_choices
. The submaps associated with these addresses can be directly copied over. Addresses which do conflict are treated by the recursive call above these lines.
This suggestion is certainly simpler - is there an argument for efficiency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re efficiency, my thinking is that there might be some overhead in the select and complement operations that the second version doesn't have.
SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) | ||
end | ||
|
||
function process!(gen_fn::Switch{C, N, K, T}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also dispatching on process! using the diff types. I'm not sure if this current implementation is correct - mostly because I'm not supported Diffed dispatch yet (and I suspect I have to).
@femtomc The Diffed
types are only used for boxing Julia values for diff propagation of Julia code via operator overloading. GFI implementers only need to worry about Diff
values. So what you have is right.
state::SwitchUpdateState{T}) where {C, N, K, T, DV} | ||
|
||
# Generate new trace. | ||
merged = update_recurse_merge(get_choices(state.prev_trace), choices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment here or a docstring for update_recurse_merge
that describes the operation this does would be helpful. For example, maybe "Returns choices that are in constraints, merged with all choices in the previous trace that do not have the same address as some choice in the constraints.".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marcoct Added a @doc
- is this appropriate for formatting documentation in the project? Or should I use a docstring directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to use @doc
.
return new_choices | ||
end | ||
|
||
function update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring or comment would be helpful here, e.g. "Returns choices from previous trace that (i) have an address that does not appear in the new trace, or (ii) have an address that does appear in the constraints (choices
)."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above - added an @doc
.
has_value(get_choices(new_trace), k) || continue | ||
has_value(choices, k) || continue | ||
set_value!(discard, k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found this hard to read. I think that writing out the logic with explicit if-else would be clearer here.
if has_value(get_choices(new_trace), k) && has_value(choices, k)
set_value!(discard, k, v)
end
But also, I think this should be changed to:
if (!has_value(get_choices(new_trace), k)) || has_value(choices, k)
set_value!(discard, k, v)
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - this is a lot simpler and easier to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was actually an error here that you caught.
You should add a choice to the discard if:
- The choice is in the old trace but not in the new trace.
- The choice is in the old trace and (the choice is in the new trace and in the constraints).
Previously, I was checking for 2 but I wasn't covering 1.
isempty(get_submap(get_choices(new_trace), k)) && continue | ||
isempty(get_submap(choices, k)) && continue | ||
set_submap!(discard, k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this requires a recursive call.
Even if the new trace has some choice under this key, and the constraints (choices
) also has some choice under this key, there could also be choices in the new trace that were copied directly from the previous trace and not overwritten by the constraints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right - this was a mistake in my logic.
Excellent - will address these tomorrow. |
…d update_discard. Added test to test the discard functionality in a hierarchical model example.
@marcoct Hi Marco - I wasn't sure what the protocol was for addressing review comments (e.g. if I should show code directly in comments to resolve conversations). I've pushed the necessary changes - so you might examine the changes and determine if they have resolved what you've suggested. I've also added a test case for |
Beginning work on Switch combinator. Initial version of
propose
andgenerate
. Beginning testing.