Skip to content

Commit

Permalink
removing hole's function from the ctx of that hole
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardo-imadeira committed Nov 14, 2023
1 parent 90a86cb commit 783950d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
50 changes: 28 additions & 22 deletions aeon/synthesis_grammar/identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_holes_info(
ctx: TypingContext,
t: Term,
ty: Type,
targets: list[tuple[str, list[str]]],
) -> dict[str, tuple[Type, TypingContext]]:
"""Retrieve the Types of "holes" in a given Term and TypingContext.
Expand All @@ -36,9 +37,7 @@ def get_holes_info(
ctx (TypingContext): The current TypingContext.
t (Term): The term to analyze.
ty (Type): The current type.
holes (dict[str, tuple[Type, TypingContext, str]]: The current dictionary of hole types. Defaults to None.
Returns:
dict[str, tuple[Type, TypingContext, str]]: The updated dictionary of hole Types and their TypingContexts.
targets (list(tuple(str, list(str)))): List of tuples functions names that contains holes and the name holes
"""
match t:
case Annotation(expr=Hole(name=hname), type=ty):
Expand All @@ -52,46 +51,53 @@ def get_holes_info(
case Var(_):
return {}
case Annotation(expr=expr, type=ty):
return get_holes_info(ctx, expr, ty)
return get_holes_info(ctx, expr, ty, targets)
case Application(fun=fun, arg=arg):
hs1 = get_holes_info(ctx, fun, ty)
hs2 = get_holes_info(ctx, arg, ty)
hs1 = get_holes_info(ctx, fun, ty, targets)
hs2 = get_holes_info(ctx, arg, ty, targets)
return hs1 | hs2
case If(cond=cond, then=then, otherwise=otherwise):
hs1 = get_holes_info(ctx, cond, ty)
hs2 = get_holes_info(ctx, then, ty)
hs3 = get_holes_info(ctx, otherwise, ty)
hs1 = get_holes_info(ctx, cond, ty, targets)
hs2 = get_holes_info(ctx, then, ty, targets)
hs3 = get_holes_info(ctx, otherwise, ty, targets)
return hs1 | hs2 | hs3
case Abstraction(var_name=vname, body=body):
if isinstance(ty, AbstractionType):
ret = substitution_in_type(ty.type, Var(vname), ty.var_name)
ctx = ctx.with_var(vname, ty.var_type)
return get_holes_info(ctx, body, ret)
return get_holes_info(ctx, body, ret, targets)
else:
assert False, f"Synthesis cannot infer the type of {t}"
case Let(var_name=vname, var_value=value, body=body):
_, t1 = synth(ctx, value)
ctx = ctx.with_var(vname, t1)
hs1 = get_holes_info(ctx, t.var_value, ty)
hs2 = get_holes_info(ctx, t.body, ty)
if not isinstance(value, Hole) and not (isinstance(value, Annotation) and isinstance(value.expr, Hole)):
ctx = ctx.with_var(vname, t1)
hs1 = get_holes_info(ctx, t.var_value, ty, targets)
hs2 = get_holes_info(ctx, t.body, ty, targets)
else:
hs1 = get_holes_info(ctx, t.var_value, ty, targets)
ctx = ctx.with_var(vname, t1)
hs2 = get_holes_info(ctx, t.body, ty, targets)
return hs1 | hs2
case Rec(var_name=vname, var_type=vtype, var_value=value, body=body):
ctx = ctx.with_var(vname, vtype)
hs1 = get_holes_info(
ctx,
value,
vtype,
)
hs2 = get_holes_info(ctx, body, ty)
if any(tup[0] == vname for tup in targets):
hs1 = get_holes_info(ctx, value, vtype, targets)
ctx = ctx.with_var(vname, vtype)
hs2 = get_holes_info(ctx, body, ty, targets)
else:
ctx = ctx.with_var(vname, vtype)
hs1 = get_holes_info(ctx, value, vtype, targets)
hs2 = get_holes_info(ctx, body, ty, targets)

return hs1 | hs2
case TypeApplication(body=body, type=argty):
if isinstance(ty, TypePolymorphism):
ntype = substitute_vartype(ty.body, argty, ty.name)
return get_holes_info(ctx, body, ntype)
return get_holes_info(ctx, body, ntype, targets)
else:
assert False, f"Synthesis cannot infer the type of {t}"
case TypeAbstraction(name=n, kind=k, body=body):
return get_holes_info(ctx.with_typevar(n, k), body, ty)
return get_holes_info(ctx.with_typevar(n, k), body, ty, targets)
case _:
assert False, f"Could not infer the type of {t} for synthesis."

Expand Down
3 changes: 2 additions & 1 deletion aeon/synthesis_grammar/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def synthesize(
ctx: TypingContext,
ectx: EvaluationContext,
term: Term,
targets=list[tuple[str, list[str]]],
targets: list[tuple[str, list[str]]],
filename: str | None = None,
) -> Term:
"""Synthesizes code for multiple functions, each with multiple holes."""
Expand All @@ -219,6 +219,7 @@ def synthesize(
ctx,
term,
top,
targets,
)

for name, holes_names in targets:
Expand Down

0 comments on commit 783950d

Please sign in to comment.