Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep track of normalized vs un-normalized entries in the monomorphiza… #520

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,7 @@ let flags_of_decl = function
flags

let tuple_lid = [ "K" ], ""

let fst3 (x, _, _) = x
let snd3 (_, y, _) = y
let thd3 (_, _, z) = z
6 changes: 3 additions & 3 deletions lib/Inlining.ml
Original file line number Diff line number Diff line change
Expand Up @@ -401,22 +401,22 @@ let cross_call_analysis files =

method! visit_TApp () name ts =
match Hashtbl.find_opt MonomorphizationState.state (name, ts, []) with
| Some (_, lid) ->
| Some (_, lid, _) ->
self#visit_TQualified () lid
| None ->
self#visit_TQualified () name;
List.iter (self#visit_typ ()) ts

method! visit_TTuple () ts =
match Hashtbl.find_opt MonomorphizationState.state (tuple_lid, ts, []) with
| Some (_, lid) ->
| Some (_, lid, _) ->
self#visit_TQualified () lid
| None ->
super#visit_TTuple () ts

method! visit_TCgApp () name ts =
match Hashtbl.find_opt MonomorphizationState.state (flatten_tapp (TCgApp (name, ts))) with
| Some (_, lid) ->
| Some (_, lid, _) ->
self#visit_TQualified () lid
| None ->
super#visit_TCgApp () name ts
Expand Down
28 changes: 14 additions & 14 deletions lib/Monomorphization.ml
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,15 @@ let monomorphize_data_types map = object(self)
| exception Not_found ->
let args = List.map (self#visit_typ false) args in
TTuple args
| _, chosen_lid ->
| _, chosen_lid, _ ->
TQualified chosen_lid

method! visit_TApp () lid args =
match Hashtbl.find state (lid, args, []) with
| exception Not_found ->
let args = List.map (self#visit_typ false) args in
TApp (lid, args)
| _, chosen_lid ->
| _, chosen_lid, _ ->
TQualified chosen_lid
end)#visit_typ () t

Expand All @@ -307,7 +307,7 @@ let monomorphize_data_types map = object(self)
themselves been renormalized. *)
let ts' = List.map resolve_deep ts in
if not (Hashtbl.mem state (n, ts', cgs)) then
Hashtbl.add state (n, ts', cgs) (Black, chosen_lid)
Hashtbl.add state (n, ts', cgs) (Black, chosen_lid, true)

(* Compute the name of a given node in the graph. *)
method private lid_of (n: node) =
Expand Down Expand Up @@ -346,16 +346,16 @@ let monomorphize_data_types map = object(self)
KPrint.bprintf "visiting %a: Not_found\n" ptyp (fold_tapp n);
let chosen_lid, flag = self#lid_of n in
if lid = tuple_lid then begin
Hashtbl.add state n (Gray, chosen_lid);
Hashtbl.add state n (Gray, chosen_lid, false);
let args = List.map (self#visit_typ under_ref) args in
(* For tuples, we immediately know how to generate a definition. *)
let fields = List.mapi (fun i arg -> Some (self#field_at i), (arg, false)) args in
self#record (DType (chosen_lid, [ Common.Private ] @ flag, 0, 0, Flat fields));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid, false)
end else begin
(* This specific node has not been visited yet. *)
Hashtbl.add state n (Gray, chosen_lid);
Hashtbl.add state n (Gray, chosen_lid, false);

let subst fields = List.map (fun (field, (t, m)) ->
field, (DeBruijn.subst_ctn' cgs (DeBruijn.subst_tn args t), m)
Expand All @@ -365,7 +365,7 @@ let monomorphize_data_types map = object(self)
| exception Not_found ->
(* Unknown, external non-polymorphic lid, e.g. Prims.int *)
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid, false)
| flags, ((Variant _ | Flat _ | Union _) as def) when under_ref && not (Hashtbl.mem seen_declarations lid) ->
(* Because this looks up a definition in the global map, the
definitions are reordered according to the traversal order, which
Expand Down Expand Up @@ -396,12 +396,12 @@ let monomorphize_data_types map = object(self)
let branches = self#visit_branches_t under_ref branches in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Variant branches));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid, false)
| flags, Flat fields ->
let fields = self#visit_fields_t_opt under_ref (subst fields) in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Flat fields));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid, false)
| flags, Union fields ->
let fields = List.map (fun (f, t) ->
let t = DeBruijn.subst_tn args t in
Expand All @@ -410,20 +410,20 @@ let monomorphize_data_types map = object(self)
) fields in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Union fields));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid, false)
| flags, Abbrev t ->
let t = DeBruijn.subst_tn args t in
let t = self#visit_typ under_ref t in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Abbrev t));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid, false)
| _ ->
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid, false)
end
end;
chosen_lid
| Gray, chosen_lid ->
| Gray, chosen_lid, _ ->
if Options.debug "data-types-traversal" then
KPrint.bprintf "visiting %a: Gray\n" ptyp (fold_tapp n);
begin match Hashtbl.find map lid with
Expand All @@ -435,7 +435,7 @@ let monomorphize_data_types map = object(self)
self#record (DType (chosen_lid, flags, 0, 0, Forward FStruct))
end;
chosen_lid
| Black, chosen_lid ->
| Black, chosen_lid, _ ->
if Options.debug "data-types-traversal" then
KPrint.bprintf "visiting %a: Black\n" ptyp (fold_tapp n);
chosen_lid
Expand Down
14 changes: 8 additions & 6 deletions lib/MonomorphizationState.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ open PrintAst.Ops
type node = lident * typ list * cg list
type color = Gray | Black

(* Is this a normalized entry or not? *)
type normalized = bool

(* Each polymorphic type `lid` applied to types `ts` and const generics `ts`
appears in `state`, and maps to `monomorphized_lid`, the name of its
monomorphized instance. *)
let state: (node, color * lident) Hashtbl.t = Hashtbl.create 41
let state: (node, color * lident * normalized) Hashtbl.t = Hashtbl.create 41

(* Because of polymorphic externals, one still encounters,
post-monomorphizations, application nodes in types (e.g. after instantiating
Expand All @@ -21,9 +24,9 @@ let state: (node, color * lident) Hashtbl.t = Hashtbl.create 41
let resolve t: typ =
match t with
| TApp _ | TCgApp _ when Hashtbl.mem state (flatten_tapp t) ->
TQualified (snd (Hashtbl.find state (flatten_tapp t)))
TQualified (snd3 (Hashtbl.find state (flatten_tapp t)))
| TTuple ts when Hashtbl.mem state (tuple_lid, ts, []) ->
TQualified (snd (Hashtbl.find state (tuple_lid, ts, [])))
TQualified (snd3 (Hashtbl.find state (tuple_lid, ts, [])))
| _ ->
t

Expand All @@ -48,9 +51,8 @@ type reverse_mapping = (lident * expr list * typ list, lident) Hashtbl.t
let generated_lids: reverse_mapping = Hashtbl.create 41

let debug () =
Hashtbl.iter (fun (lid, ts, cgs) (_, monomorphized_lid) ->
KPrint.bprintf "%a <%a> <%a> ~~> %a\n" plid lid pcgs cgs ptyps ts plid
monomorphized_lid
Hashtbl.iter (fun (lid, ts, cgs) (_, monomorphized_lid, normalized) ->
KPrint.bprintf "%a <%a> <%a> %b ~~> %a\n" plid lid pcgs cgs ptyps ts normalized plid monomorphized_lid
) state;
Hashtbl.iter (fun (lid, es, ts) monomorphized_lid ->
KPrint.bprintf "%a <%a> <%a> ~~> %a\n" plid lid pexprs es ptyps ts plid
Expand Down