From 783950d1f22784e0e3f0b48b20a93f2d5ac7a24e Mon Sep 17 00:00:00 2001 From: eduardomadeira98 Date: Tue, 14 Nov 2023 22:20:45 +0000 Subject: [PATCH] removing hole's function from the ctx of that hole --- aeon/synthesis_grammar/identification.py | 50 +++++++++++++----------- aeon/synthesis_grammar/synthesizer.py | 3 +- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/aeon/synthesis_grammar/identification.py b/aeon/synthesis_grammar/identification.py index c334cacd..766c0c79 100644 --- a/aeon/synthesis_grammar/identification.py +++ b/aeon/synthesis_grammar/identification.py @@ -26,6 +26,7 @@ def get_holes_info( ctx: TypingContext, t: Term, ty: Type, + targets: list[tuple[str, list[str]]], ) -> dict[str, tuple[Type, TypingContext]]: """Retrieve the Types of "holes" in a given Term and TypingContext. @@ -36,9 +37,7 @@ def get_holes_info( ctx (TypingContext): The current TypingContext. t (Term): The term to analyze. ty (Type): The current type. - holes (dict[str, tuple[Type, TypingContext, str]]: The current dictionary of hole types. Defaults to None. - Returns: - dict[str, tuple[Type, TypingContext, str]]: The updated dictionary of hole Types and their TypingContexts. + targets (list(tuple(str, list(str)))): List of tuples functions names that contains holes and the name holes """ match t: case Annotation(expr=Hole(name=hname), type=ty): @@ -52,46 +51,53 @@ def get_holes_info( case Var(_): return {} case Annotation(expr=expr, type=ty): - return get_holes_info(ctx, expr, ty) + return get_holes_info(ctx, expr, ty, targets) case Application(fun=fun, arg=arg): - hs1 = get_holes_info(ctx, fun, ty) - hs2 = get_holes_info(ctx, arg, ty) + hs1 = get_holes_info(ctx, fun, ty, targets) + hs2 = get_holes_info(ctx, arg, ty, targets) return hs1 | hs2 case If(cond=cond, then=then, otherwise=otherwise): - hs1 = get_holes_info(ctx, cond, ty) - hs2 = get_holes_info(ctx, then, ty) - hs3 = get_holes_info(ctx, otherwise, ty) + hs1 = get_holes_info(ctx, cond, ty, targets) + hs2 = get_holes_info(ctx, then, ty, targets) + hs3 = get_holes_info(ctx, otherwise, ty, targets) return hs1 | hs2 | hs3 case Abstraction(var_name=vname, body=body): if isinstance(ty, AbstractionType): ret = substitution_in_type(ty.type, Var(vname), ty.var_name) ctx = ctx.with_var(vname, ty.var_type) - return get_holes_info(ctx, body, ret) + return get_holes_info(ctx, body, ret, targets) else: assert False, f"Synthesis cannot infer the type of {t}" case Let(var_name=vname, var_value=value, body=body): _, t1 = synth(ctx, value) - ctx = ctx.with_var(vname, t1) - hs1 = get_holes_info(ctx, t.var_value, ty) - hs2 = get_holes_info(ctx, t.body, ty) + if not isinstance(value, Hole) and not (isinstance(value, Annotation) and isinstance(value.expr, Hole)): + ctx = ctx.with_var(vname, t1) + hs1 = get_holes_info(ctx, t.var_value, ty, targets) + hs2 = get_holes_info(ctx, t.body, ty, targets) + else: + hs1 = get_holes_info(ctx, t.var_value, ty, targets) + ctx = ctx.with_var(vname, t1) + hs2 = get_holes_info(ctx, t.body, ty, targets) return hs1 | hs2 case Rec(var_name=vname, var_type=vtype, var_value=value, body=body): - ctx = ctx.with_var(vname, vtype) - hs1 = get_holes_info( - ctx, - value, - vtype, - ) - hs2 = get_holes_info(ctx, body, ty) + if any(tup[0] == vname for tup in targets): + hs1 = get_holes_info(ctx, value, vtype, targets) + ctx = ctx.with_var(vname, vtype) + hs2 = get_holes_info(ctx, body, ty, targets) + else: + ctx = ctx.with_var(vname, vtype) + hs1 = get_holes_info(ctx, value, vtype, targets) + hs2 = get_holes_info(ctx, body, ty, targets) + return hs1 | hs2 case TypeApplication(body=body, type=argty): if isinstance(ty, TypePolymorphism): ntype = substitute_vartype(ty.body, argty, ty.name) - return get_holes_info(ctx, body, ntype) + return get_holes_info(ctx, body, ntype, targets) else: assert False, f"Synthesis cannot infer the type of {t}" case TypeAbstraction(name=n, kind=k, body=body): - return get_holes_info(ctx.with_typevar(n, k), body, ty) + return get_holes_info(ctx.with_typevar(n, k), body, ty, targets) case _: assert False, f"Could not infer the type of {t} for synthesis." diff --git a/aeon/synthesis_grammar/synthesizer.py b/aeon/synthesis_grammar/synthesizer.py index 8052fd0e..7a9643c1 100644 --- a/aeon/synthesis_grammar/synthesizer.py +++ b/aeon/synthesis_grammar/synthesizer.py @@ -210,7 +210,7 @@ def synthesize( ctx: TypingContext, ectx: EvaluationContext, term: Term, - targets=list[tuple[str, list[str]]], + targets: list[tuple[str, list[str]]], filename: str | None = None, ) -> Term: """Synthesizes code for multiple functions, each with multiple holes.""" @@ -219,6 +219,7 @@ def synthesize( ctx, term, top, + targets, ) for name, holes_names in targets: