From 1445b140699c7abeb1a2de28160ad8209df3e6e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Basile=20Cl=C3=A9ment?= <129742207+bclement-ocp@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:20:54 +0200 Subject: [PATCH] chore(BV, CP): Refactor propagation mechanism (#1185) This patch simplifies the propagation mechanism (currently used in the bit-vector relations only) in order to accomodate different types of propagators more easily. In particular, there is now a single generic (and configurable) loop that runs the propagators instead of a spaghetti of different loops for each kind of propagators. --- src/lib/reasoners/bitv_rel.ml | 532 ++++++++++++++++++++------------- src/lib/reasoners/rel_utils.ml | 2 +- src/lib/util/compat.ml | 26 ++ src/lib/util/compat.mli | 56 ++++ 4 files changed, 411 insertions(+), 205 deletions(-) diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index e3ee1c5f1..9537eb49c 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -646,6 +646,23 @@ struct end) end +module ExplainedComparable(V : Domains_intf.ComparableType) : + Domains_intf.ComparableType with type t = V.t explained = +struct + include ExplainedOrdered(V) + + let hash { value = v ; _ } = V.hash v + + let equal { value = v1 ; _ } { value = v2 ; _ } = V.equal v1 v2 + + module Table = Hashtbl.Make(struct + type nonrec t = t + + let equal = equal + let hash = hash + end) +end + module Constraint : sig type binop = (* Bitwise operations *) @@ -927,11 +944,12 @@ end = struct end -module EC = ExplainedOrdered(struct +module EC = ExplainedComparable(struct include Constraint module Set = Set.Make(Constraint) module Map = Map.Make(Constraint) + module Table = Hashtbl.Make(Constraint) end) module Interval_domains = @@ -1575,43 +1593,82 @@ let add_eqs = fun acc x x_sz bl -> aux x x_sz acc x_sz bl -module Any_constraint = struct - type t = - | Constraint of Constraint.t explained - | Structural of X.r - (** Structural constraint associated with [X.r]. See - {!Rel_utils.Bitlist_domains.structural_propagation}. *) - - let equal a b = - match a, b with - | Constraint ca, Constraint cb -> Constraint.equal ca.value cb.value - | Constraint _, Structural _ | Structural _, Constraint _ -> false - | Structural xa, Structural xb -> X.equal xa xb - - let hash = function - | Constraint c -> 2 * Constraint.hash c.value - | Structural r -> 2 * X.hash r + 1 - - let propagate constraint_propagate structural_propagation c = - Steps.incr CP; - match c with - | Constraint { value; explanation = ex } -> - constraint_propagate ~ex value - | Structural r -> - structural_propagation r -end +(* Propagation state + + TODO(bclement): Generalize this to arbitrary domains. *) +type state = + { intervals_uf : Interval_domains.Ephemeral.Canon.t + ; bitlists_uf : Bitlist_domains.Ephemeral.Canon.t + ; intervals : Interval_domains.Ephemeral.t + ; bitlists : Bitlist_domains.Ephemeral.t + ; interval_variables : SX.t MX.t + ; bitlist_variables : SX.t MX.t + ; interval_changed : unit HX.t + ; bitlist_changed : unit HX.t + } + +type 'a hashable = (module Hashtbl.HashedType with type t = 'a) + +module Any_propagator : sig + type 'a t + + val make : 'a hashable -> (state -> 'a -> unit) -> 'a t + + module Queue : sig + type 'a propagator = 'a t + + type t -module QC = Uqueue.Make(Any_constraint) + val create : int -> t -let propagate_queue queue constraint_propagate structural_propagation = - try - while true do - Any_constraint.propagate - constraint_propagate - structural_propagation - (QC.pop queue) - done - with QC.Empty -> () + val push : t -> 'a propagator -> 'a -> unit + + val run1 : state -> t -> bool + end +end = struct + type 'a t = + { id : 'a Compat.Type.Id.t + ; hashable : 'a hashable + ; run : state -> 'a -> unit } + + let make hashable run = + { id = Compat.Type.Id.make () ; hashable ; run } + + type binding = B : 'k t * 'k -> binding + + module Binding = struct + type t = binding + + let equal (B (k1, v1)) (B (k2, v2)) = + match Compat.Type.Id.provably_equal k1.id k2.id with + | Some Equal -> + let module T = (val k1.hashable) in + T.equal v1 v2 + | None -> false + + let hash (B (k, v)) = + let module T = (val k.hashable) in + Hashtbl.hash (Compat.Type.Id.uid k.id, T.hash v) + end + + module Q = Uqueue.Make(Binding) + + module Queue = struct + type 'a propagator = 'a t + + type t = Q.t + + let create = Q.create + + let push q k p = Q.push q (B (k, p)) + + (* Returns [false] if no propagation was performed (empty queue). *) + let run1 st q = + match Q.pop q with + | B ({ run ; _ }, p) -> run st p; true + | exception Q.Empty -> false + end +end let finite_lower_bound = function | Intervals_intf.Unbounded -> Z.zero @@ -1725,103 +1782,121 @@ let iter_parents a f t = | cs -> Rel_utils.XComparable.Set.iter f cs | exception Not_found -> () -let propagate_bitlist queue vars dom = - let structural_propagation r = - let open Bitlist_domains.Ephemeral.Canon in - let get r = !!(entry dom r) in - let update r d = update ~ex:Explanation.empty (entry dom r) d in - if X.is_a_leaf r then - iter_parents r (fun p -> - if X.is_a_leaf p then - assert (X.equal r p) - else - update p (Bitlist_domain.map_domain get p) - ) vars - else - let iter_signed sz f { Bitv.value; negated } bl = - let bl = if negated then Bitlist.(extract (lognot bl)) 0 sz else bl in - f value bl - in - ignore @@ List.fold_left (fun (bl, w) { Bitv.bv; sz } -> - (* Extract the bitlist associated with the current component *) - let mid = w - sz in - let bl_tail = Bitlist.extract bl 0 mid in - let bl = Bitlist.extract bl mid (w - mid) in - - match bv with - | Bitv.Cte z -> - assert (Z.numbits z <= sz); - (* Nothing to update, but still check for consistency! *) - ignore @@ Bitlist.intersect bl (Bitlist.exact z Ex.empty); - bl_tail, mid - | Other r -> - iter_signed sz update r bl; - (bl_tail, mid) - | Ext (r, r_size, i, j) -> - (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + w - mid = r_size); - let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in - let lo = Bitlist.(extract unknown 0 i) in - let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in - iter_signed r_size update r bl_hd; - (bl_tail, mid) - ) ((get r), (bitwidth r)) (Shostak.Bitv.embed r) - in - propagate_queue - queue - (Propagator.propagate_bitlist dom) - structural_propagation +let bitlist_structural_propagator = + Any_propagator.make (module Rel_utils.XComparable) @@ + fun ({ bitlists_uf = dom ; bitlist_variables = vars ; _ } as st) r -> + HX.replace st.bitlist_changed r (); + let open Bitlist_domains.Ephemeral.Canon in + let get r = !!(entry dom r) in + let update r d = update ~ex:Explanation.empty (entry dom r) d in + if X.is_a_leaf r then + iter_parents r (fun p -> + if X.is_a_leaf p then + assert (X.equal r p) + else + update p (Bitlist_domain.map_domain get p) + ) vars + else + let iter_signed sz f { Bitv.value; negated } bl = + let bl = if negated then Bitlist.(extract (lognot bl)) 0 sz else bl in + f value bl + in + ignore @@ List.fold_left (fun (bl, w) { Bitv.bv; sz } -> + (* Extract the bitlist associated with the current component *) + let mid = w - sz in + let bl_tail = Bitlist.extract bl 0 mid in + let bl = Bitlist.extract bl mid (w - mid) in -let propagate_intervals queue vars dom = - let structural_propagation r = - let open Interval_domains.Ephemeral.Canon in - let get r = !!(entry dom r) in - let update r d = update ~ex:Explanation.empty (entry dom r) d in - if X.is_a_leaf r then - iter_parents r (fun p -> - if X.is_a_leaf p then - assert (X.equal r p) - else - update p (Interval_domain.map_domain get p) - ) vars - else - let iter_signed f { Bitv.value; negated } sz int = - f value (if negated then Interval_domain.lognot sz int else int) - in - let int = get r in - let width = bitwidth r in - let j = - List.fold_left (fun j { Bitv.bv; sz } -> - (* sz = j - i + 1 => i = j - sz + 1 *) - let int = Intervals.Int.extract int ~ofs:(j - sz + 1) ~len:sz in - - begin match bv with - | Bitv.Cte z -> - (* Nothing to update, but still check for consistency *) - ignore @@ - Interval_domain.intersect int (Interval_domain.constant z) - | Other s -> iter_signed update s sz int - | Ext (r, r_size, i, j) -> - (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + sz = r_size); - let lo = Interval_domain.unknown (Tbitv i) in - let int = Intervals.Int.scale Z.(~$1 lsl i) int in - let hi = Interval_domain.unknown (Tbitv (r_size - j - 1)) in - let hi = - Intervals.Int.scale Z.(~$1 lsl Stdlib.(i + sz)) hi - in - iter_signed update r r_size Intervals.Int.(add hi (add int lo)) - end; - - (j - sz) - ) (width - 1) (Shostak.Bitv.embed r) - in - assert (j = -1) - in - propagate_queue - queue - (Propagator.propagate_interval dom) - structural_propagation + match bv with + | Bitv.Cte z -> + assert (Z.numbits z <= sz); + (* Nothing to update, but still check for consistency! *) + ignore @@ Bitlist.intersect bl (Bitlist.exact z Ex.empty); + bl_tail, mid + | Other r -> + iter_signed sz update r bl; + (bl_tail, mid) + | Ext (r, r_size, i, j) -> + (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) + assert (i + r_size - j - 1 + w - mid = r_size); + let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in + let lo = Bitlist.(extract unknown 0 i) in + let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in + iter_signed r_size update r bl_hd; + (bl_tail, mid) + ) ((get r), (bitwidth r)) (Shostak.Bitv.embed r) + +let bitlist_constraint_propagator = + Any_propagator.make (module EC) @@ + fun { bitlists_uf = dom ; _ } { explanation = ex ; value = c } -> + Propagator.propagate_bitlist ~ex dom c + +let interval_structural_propagator = + Any_propagator.make (module Rel_utils.XComparable) @@ + fun ({ intervals_uf = dom ; interval_variables = vars ; _ } as st) r -> + HX.replace st.interval_changed r (); + let open Interval_domains.Ephemeral.Canon in + let get r = !!(entry dom r) in + let update r d = update ~ex:Explanation.empty (entry dom r) d in + if X.is_a_leaf r then + iter_parents r (fun p -> + if X.is_a_leaf p then + assert (X.equal r p) + else + update p (Interval_domain.map_domain get p) + ) vars + else + let iter_signed f { Bitv.value; negated } sz int = + f value (if negated then Interval_domain.lognot sz int else int) + in + let int = get r in + let width = bitwidth r in + let j = + List.fold_left (fun j { Bitv.bv; sz } -> + (* sz = j - i + 1 => i = j - sz + 1 *) + let int = Intervals.Int.extract int ~ofs:(j - sz + 1) ~len:sz in + + begin match bv with + | Bitv.Cte z -> + (* Nothing to update, but still check for consistency *) + ignore @@ + Interval_domain.intersect int (Interval_domain.constant z) + | Other s -> iter_signed update s sz int + | Ext (r, r_size, i, j) -> + (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) + assert (i + r_size - j - 1 + sz = r_size); + let lo = Interval_domain.unknown (Tbitv i) in + let int = Intervals.Int.scale Z.(~$1 lsl i) int in + let hi = Interval_domain.unknown (Tbitv (r_size - j - 1)) in + let hi = + Intervals.Int.scale Z.(~$1 lsl Stdlib.(i + sz)) hi + in + iter_signed update r r_size Intervals.Int.(add hi (add int lo)) + end; + + (j - sz) + ) (width - 1) (Shostak.Bitv.embed r) + in + assert (j = -1) + +let interval_constraint_propagator = + Any_propagator.make (module EC) @@ + fun { intervals_uf = dom ; _ } { explanation = ex ; value = c } -> + Propagator.propagate_interval ~ex dom c + +let interval_from_bitlist_propagator = + Any_propagator.make (module Rel_utils.XComparable) @@ + fun { intervals = idom ; bitlists = bdom ; _ } r -> + constrain_interval_from_bitlist ~size:(bitwidth r) + Interval_domains.Ephemeral.(entry idom r) + Bitlist_domains.Ephemeral.(Entry.domain (entry bdom r)) + +let bitlist_from_interval_propagator = + Any_propagator.make (module Rel_utils.XComparable) @@ + fun { intervals = idom ; bitlists = bdom ; _ } r -> + constrain_bitlist_from_interval ~size:(bitwidth r) + Bitlist_domains.Ephemeral.(entry bdom r) + Interval_domains.Ephemeral.(Entry.domain (entry idom r)) module HC = Hashtbl.Make(Constraint) @@ -1860,94 +1935,143 @@ let simplify_all uf eqs touched (dom, idom) = ) c.value (dom, idom) ) to_add (dom, idom) +(* Holds the queues for propagation. + + Currently we use two separate queues for single-domain propagators (i.e. + propagators that only use one domain) and for multi-domain propagators + (i.e. propagators that propagate information across different domains, + e.g. between bitlists and intervals). + + TODO(bclement): generalize to an arbitrary queue combination (and possibly + infer the queue configuration from the schedule?). *) +type queues = + { to_simplify : Ex.t HC.t + ; propagation_queue : Any_propagator.Queue.t + ; consistency_queue : Any_propagator.Queue.t } + +let state uf idom bdom = + let to_simplify = HC.create 17 in + let propagation_queue = Any_propagator.Queue.create 17 in + let consistency_queue = Any_propagator.Queue.create 17 in + let events + constraint_propagator + structural_propagator + consistency_propagator + = + let evt_atomic_change r = + Any_propagator.Queue.push propagation_queue structural_propagator r; + Any_propagator.Queue.push consistency_queue consistency_propagator r + and evt_composite_change r = + Any_propagator.Queue.push propagation_queue structural_propagator r; + Any_propagator.Queue.push consistency_queue consistency_propagator r + and evt_watch_trigger c = + HC.replace to_simplify c.value c.explanation; + Any_propagator.Queue.push propagation_queue constraint_propagator c + in + { Domains_intf.evt_atomic_change + ; evt_composite_change + ; evt_watch_trigger + } + in + + let bitlist_events = + events + bitlist_constraint_propagator + bitlist_structural_propagator + interval_from_bitlist_propagator + and interval_events = + events + interval_constraint_propagator + interval_structural_propagator + bitlist_from_interval_propagator + in + let bvars = Bitlist_domains.parents bdom in + let ivars = Interval_domains.parents idom in + + let bdom = Bitlist_domains.edit ~events:bitlist_events bdom in + let idom = Interval_domains.edit ~events:interval_events idom in + + let uf_bdom = Bitlist_domains.Ephemeral.canon uf bdom in + let uf_idom = Interval_domains.Ephemeral.canon uf idom in + { to_simplify ; propagation_queue ; consistency_queue }, + { intervals = idom + ; intervals_uf = uf_idom + ; interval_variables = ivars + ; interval_changed = HX.create 17 + ; bitlists = bdom + ; bitlists_uf = uf_bdom + ; bitlist_variables = bvars + ; bitlist_changed = HX.create 17 } + +module Schedule = struct + (* Schedule specifications *) + type t = + | Single of Any_propagator.Queue.t + (** Propagate a single constraint from the given queue. *) + | Sequence of t array + (** Performs propagations from the provided schedules in sequence. *) + | Repeat of t + (** Repeat the given schedule to completion. *) + + (* Returns [false] if no propagation was performed. *) + let rec run state schedule = + match schedule with + | Single queue -> + Any_propagator.Queue.run1 state queue + | Sequence schedules -> + Array.fold_left (fun did_run schedule -> + run state schedule || did_run + ) false schedules + | Repeat schedule -> + run_repeatedly state schedule false + + and run_repeatedly state schedule did_run = + if run state schedule then + run_repeatedly state schedule true + else + did_run + + let queue q = Single q + + let sequence scheds = Sequence scheds + + let repeat s = Repeat s +end + let rec propagate_all uf eqs bdom idom = (* Optimization to avoid unnecessary allocations *) let do_bitlist = Bitlist_domains.needs_propagation bdom in let do_intervals = Interval_domains.needs_propagation idom in let do_any = do_bitlist || do_intervals in if do_any then - let shostak_candidates = HX.create 17 in - let seen_constraints = HC.create 17 in - let bitlist_queue = QC.create 17 in - let interval_queue = QC.create 17 in - let touch_c queue c = - HC.replace seen_constraints c.value c.explanation; - QC.push queue (Constraint c) - in - let touch tbl queue r = - HX.replace tbl r (); - QC.push queue (Structural r) - in - let bitlist_changed = HX.create 17 in - let interval_changed = HX.create 17 in - let bitlist_events = - { Domains_intf.evt_atomic_change = touch bitlist_changed bitlist_queue - ; evt_composite_change = touch bitlist_changed bitlist_queue - ; evt_watch_trigger = touch_c bitlist_queue } - and interval_events = - { Domains_intf.evt_atomic_change = touch interval_changed interval_queue - ; evt_composite_change = touch interval_changed interval_queue - ; evt_watch_trigger = touch_c interval_queue } - in - let bvars = Bitlist_domains.parents bdom in - let ivars = Interval_domains.parents idom in - - let bdom = Bitlist_domains.edit ~events:bitlist_events bdom in - let idom = Interval_domains.edit ~events:interval_events idom in - - let uf_bdom = Bitlist_domains.Ephemeral.canon uf bdom in - let uf_idom = Interval_domains.Ephemeral.canon uf idom in - - (* First, we propagate the pending constraints to both domains. Changes in - the bitlist domain are used to shrink the interval domains. *) - propagate_bitlist bitlist_queue bvars uf_bdom; - assert (QC.is_empty bitlist_queue); - - HX.iter (fun r () -> - HX.replace shostak_candidates r (); - constrain_interval_from_bitlist ~size:(bitwidth r) - Interval_domains.Ephemeral.(entry idom r) - Bitlist_domains.Ephemeral.(!!(entry bdom r)) - ) bitlist_changed; - HX.clear bitlist_changed; - propagate_intervals interval_queue ivars uf_idom; - assert (QC.is_empty interval_queue); - - (* Now the interval domain is stable, but the new intervals may have an - impact on the bitlist domains, so we must shrink them again when - applicable. We repeat until a fixpoint is reached. *) - while HX.length interval_changed > 0 do - HX.iter (fun r () -> - constrain_bitlist_from_interval ~size:(bitwidth r) - Bitlist_domains.Ephemeral.(entry bdom r) - Interval_domains.Ephemeral.(!!(entry idom r)) - ) interval_changed; - HX.clear interval_changed; - propagate_bitlist bitlist_queue bvars uf_bdom; - assert (QC.is_empty bitlist_queue); - - HX.iter (fun r () -> - HX.replace shostak_candidates r (); - constrain_interval_from_bitlist ~size:(bitwidth r) - Interval_domains.Ephemeral.(entry idom r) - Bitlist_domains.Ephemeral.(!!(entry bdom r)) - ) bitlist_changed; - HX.clear bitlist_changed; - propagate_intervals interval_queue ivars uf_idom; - assert (QC.is_empty interval_queue); - done; + let queues, state = state uf idom bdom in + + (* We run all propagations over a single domain to completion, then run + the consistency propagators to perform cross-domain propagations. *) + ignore Schedule.( + run state @@ repeat @@ + sequence + [| repeat @@ queue queues.propagation_queue + ; repeat @@ queue queues.consistency_queue + |] + ); + + let bdom = state.bitlists in + let idom = state.intervals in let eqs = HX.fold (fun r () acc -> let d = Bitlist_domains.Ephemeral.(!!(entry bdom r)) in add_eqs acc (Shostak.Bitv.embed r) (bitwidth r) d - ) shostak_candidates eqs + ) state.bitlist_changed eqs in let bdom, idom = Bitlist_domains.snapshot bdom, Interval_domains.snapshot idom in - let eqs, (bdom, idom) = simplify_all uf eqs seen_constraints (bdom, idom) in + let eqs, (bdom, idom) = + simplify_all uf eqs queues.to_simplify (bdom, idom) + in (* Propagate again in case constraints were simplified. *) propagate_all uf eqs bdom idom diff --git a/src/lib/reasoners/rel_utils.ml b/src/lib/reasoners/rel_utils.ml index 9998db475..d55b7201e 100644 --- a/src/lib/reasoners/rel_utils.ml +++ b/src/lib/reasoners/rel_utils.ml @@ -212,7 +212,7 @@ end = struct end (** Implementation of the [ComparableType] interface for semantic values. *) -module XComparable : Domains_intf.ComparableType with type t = X.r = struct +module XComparable = struct type t = X.r let pp = X.print diff --git a/src/lib/util/compat.ml b/src/lib/util/compat.ml index 97c50f0ca..b01f79e84 100644 --- a/src/lib/util/compat.ml +++ b/src/lib/util/compat.ml @@ -127,3 +127,29 @@ module Seq = struct include Seq end + +module Type = struct + type (_, _) eq = Equal : ('a, 'a) eq + + module Id = struct + type _ id = .. + + module type ID = sig + type t + type _ id += Id : t id + end + + type 'a t = (module ID with type t = 'a) + + let make (type a) () : a t = + (module struct type t = a type _ id += Id : t id end) + + let[@inline] uid (type a) ((module A) : a t) = + Obj.Extension_constructor.id (Obj.Extension_constructor.of_val A.Id) + + let provably_equal + (type a b) ((module A) : a t) ((module B) : b t) : (a, b) eq option + = + match A.Id with B.Id -> Some Equal | _ -> None + end +end diff --git a/src/lib/util/compat.mli b/src/lib/util/compat.mli index cc0b1c187..feeac4b06 100644 --- a/src/lib/util/compat.mli +++ b/src/lib/util/compat.mli @@ -112,3 +112,59 @@ module Seq : sig @since OCaml 4.14 *) end + +module Type : sig + (** Type introspection. + + @since OCaml 5.1 *) + + type (_, _) eq = Equal: ('a, 'a) eq (** *) + (** The purpose of [eq] is to represent type equalities that may not + otherwise + be known by the type checker (e.g. because they may depend on dynamic + data). + + A value of type [(a, b) eq] represents the fact that types [a] and [b] + are equal. + + If one has a value [eq : (a, b) eq] that proves types [a] and [b] are + equal, one can use it to convert a value of type [a] to a value of + type [b] by pattern matching on [Equal]: + {[ + let cast (type a) (type b) (Equal : (a, b) Type.eq) (a : a) : b = a + ]} + + At runtime, this function simply returns its second argument + unchanged. + *) + + (** {1:identifiers Type identifiers} *) + + (** Type identifiers. + + A type identifier is a value that denotes a type. Given two type + identifiers, they can be tested for {{!Id.provably_equal}equality} to + prove they denote the same type. Note that: + + - Unequal identifiers do not imply unequal types: a given type can be + denoted by more than one identifier. + - Type identifiers can be marshalled, but they get a new, distinct, + identity on unmarshalling, so the equalities are lost. *) + module Id : sig + + (** {1:ids Type identifiers} *) + + type 'a t + (** The type for identifiers for type ['a]. *) + + val make : unit -> 'a t + (** [make ()] is a new type identifier. *) + + val uid : 'a t -> int + (** [uid id] is a runtime unique identifier for [id]. *) + + val provably_equal : 'a t -> 'b t -> ('a, 'b) eq option + (** [provably_equal i0 i1] is [Some Equal] if identifier [i0] is equal + to [i1] and [None] otherwise. *) + end +end