Skip to content

Commit

Permalink
Merge branch 'annotations_refactor' of github.com:alcides/aeon into a…
Browse files Browse the repository at this point in the history
…nnotations_refactor
  • Loading branch information
alcides committed Nov 15, 2023
2 parents 575df59 + 783950d commit 8a03a05
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 237 deletions.
3 changes: 2 additions & 1 deletion aeon/sugar/aeon_sugar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type_decls : type_decl* -> list

defs : def* -> list

type_decl : "type" ID ";" -> type_decl
type_decl : "type" TYPENAME ";" -> type_decl

def : "def" ID ":" type "=" expression ";" -> def_cons
| (soft_constraint)* "def" ID "(" args ")" ":" type "{" expression "}" -> def_fun
Expand Down Expand Up @@ -99,6 +99,7 @@ INTLIT : /[0-9][0-9]*/
FLOATLIT : SIGNED_FLOAT
STRINGLIT : ESCAPED_STRING

TYPENAME : /[a-zA-Z0-9]+/
ID.0 : CNAME | /\([\+=\>\<!\*\-&\|]{1,3}\)/
PATH : (".." | ID )* "/" ID

Expand Down
98 changes: 49 additions & 49 deletions aeon/synthesis_grammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,12 @@
text_to_aeon_prelude_ops = {v: k for k, v in aeon_prelude_ops_to_text.items()}

grammar_base_types = ["t_Float", "t_Int", "t_String", "t_Bool"]
aeon_to_python_types = {
"Int": int,
"Bool": bool,
"String": str,
"Float": float
}

aeon_to_python_types = {"Int": int, "Bool": bool, "String": str, "Float": float}


# Protocol for classes that can have a get_core method
class HasGetCore(Protocol):

def get_core(self):
...

Expand All @@ -65,7 +60,6 @@ def get_core(self):


def mk_method_core(cls: classType) -> classType:

def get_core(self):
class_name = self.__class__.__name__
# the prefix is either "var_" or "app_"
Expand Down Expand Up @@ -103,7 +97,6 @@ def get_core(self):


def mk_method_core_literal(cls: classType) -> classType:

def get_core(self):
class_name = self.__class__.__name__
class_name_without_prefix = class_name[8:]
Expand All @@ -129,8 +122,7 @@ def get_core(self):
return cls


def find_class_by_name(class_name: str,
grammar_nodes: list[type]) -> tuple[list[type], type]:
def find_class_by_name(class_name: str, grammar_nodes: list[type]) -> tuple[list[type], type]:
"""This function iterates over the provided list of grammar nodes and
returns the node whose name matches the provided name. If no match is found
it creates a new abstract class and a new data class, adds them to the
Expand All @@ -147,40 +139,37 @@ def find_class_by_name(class_name: str,
if cls.__name__ in [class_name, "t_" + class_name]:
return grammar_nodes, cls
if class_name in list(aeon_to_python_types.keys()):
new_abs_class = make_dataclass("t_" + class_name, [], bases=(ABC, ))
new_abs_class = make_dataclass("t_" + class_name, [], bases=(ABC,))
# new_abs_class = type("t_"+class_name, (), {})
# new_abs_class = abstract(new_abs_class)
grammar_nodes.append(new_abs_class)
new_class = make_dataclass(
"literal_" + class_name,
[("value", aeon_to_python_types[class_name])],
bases=(new_abs_class, ),
bases=(new_abs_class,),
)

new_class = mk_method_core_literal(new_class)

grammar_nodes.append(new_class)

else:
class_name = class_name if class_name.startswith("t_") else (
"t_" + class_name)
new_abs_class = make_dataclass(class_name, [], bases=(ABC, ))
class_name = class_name if class_name.startswith("t_") else ("t_" + class_name)
new_abs_class = make_dataclass(class_name, [], bases=(ABC,))
grammar_nodes.append(new_abs_class)
return grammar_nodes, new_abs_class


def is_valid_class_name(class_name: str) -> bool:
return class_name not in prelude_ops and not class_name.startswith(
("_anf_", "target"))
return class_name not in prelude_ops and not class_name.startswith(("_anf_", "target"))


def get_attribute_type_name(attribute_type, parent_name=None):
parent_name = parent_name or ""
while isinstance(attribute_type, AbstractionType):
attribute_type = refined_to_unrefined_type(attribute_type.type)
parent_name += f"t_{get_attribute_type_name(attribute_type, parent_name)}_"
return parent_name + attribute_type.name if isinstance(
attribute_type, BaseType) else parent_name
return parent_name + attribute_type.name if isinstance(attribute_type, BaseType) else parent_name


def generate_class_components(
Expand All @@ -201,15 +190,15 @@ def generate_class_components(
fields = []
parent_name = ""
while isinstance(class_type, AbstractionType):
attribute_name = class_type.var_name.value if isinstance(
class_type.var_name, Token) else class_type.var_name
attribute_type = (refined_to_unrefined_type(class_type.var_type)
if isinstance(class_type.var_type, RefinedType) else
class_type.var_type)
attribute_name = class_type.var_name.value if isinstance(class_type.var_name, Token) else class_type.var_name
attribute_type = (
refined_to_unrefined_type(class_type.var_type)
if isinstance(class_type.var_type, RefinedType)
else class_type.var_type
)
attribute_type_name = get_attribute_type_name(attribute_type)

grammar_nodes, cls = find_class_by_name(attribute_type_name,
grammar_nodes)
grammar_nodes, cls = find_class_by_name(attribute_type_name, grammar_nodes)
fields.append((attribute_name, cls))

parent_name += f"t_{attribute_type_name}_"
Expand All @@ -236,14 +225,13 @@ def create_new_class(class_name: str, parent_class: type, fields=None) -> type:
"""Creates a new class with the given name, parent class, and fields."""
if fields is None:
fields = []
new_class = make_dataclass(class_name, fields, bases=(parent_class, ))
new_class = make_dataclass(class_name, fields, bases=(parent_class,))
new_class = mk_method_core(new_class)

return new_class


def create_class_from_ctx_var(var: tuple,
grammar_nodes: list[type]) -> list[type]:
def create_class_from_ctx_var(var: tuple, grammar_nodes: list[type]) -> list[type]:
"""Creates a new class based on a context variable and adds it to the list
of grammar nodes.
Expand Down Expand Up @@ -282,31 +270,26 @@ def create_class_from_ctx_var(var: tuple,
parent_class_name = parent_type.name
else:
raise Exception(f"parent class name not definied: {(parent_type)}")
grammar_nodes, parent_class = find_class_by_name(parent_class_name,
grammar_nodes)
grammar_nodes, parent_class = find_class_by_name(parent_class_name, grammar_nodes)

new_class_app = create_new_class(f"app_{class_name}", parent_class, fields)
grammar_nodes.append(new_class_app)

# class var_function_name
if isinstance(class_type, AbstractionType):
grammar_nodes, parent_class = find_class_by_name(
abstraction_type_class_name, grammar_nodes)
grammar_nodes, parent_class = find_class_by_name(abstraction_type_class_name, grammar_nodes)

new_class_var = create_new_class(f"var_{class_name}", parent_class)
grammar_nodes.append(new_class_var)

return grammar_nodes


def create_if_class(class_name: str, parent_class_name: str,
grammar_nodes: list[type]) -> list[type]:
def create_if_class(class_name: str, parent_class_name: str, grammar_nodes: list[type]) -> list[type]:
grammar_nodes, cond_class = find_class_by_name("Bool", grammar_nodes)
grammar_nodes, parent_class = find_class_by_name(parent_class_name,
grammar_nodes)
grammar_nodes, parent_class = find_class_by_name(parent_class_name, grammar_nodes)

fields = [("cond", cond_class), ("then", parent_class),
("otherwise", parent_class)]
fields = [("cond", cond_class), ("then", parent_class), ("otherwise", parent_class)]

if_class = create_new_class(class_name, parent_class, fields)
grammar_nodes.append(if_class)
Expand All @@ -315,17 +298,17 @@ def create_if_class(class_name: str, parent_class_name: str,


def build_control_flow_grammar_nodes(grammar_nodes: list[type]) -> list[type]:
grammar_nodes_names_set = {cls.__name__ for cls in grammar_nodes}
for base_type in grammar_base_types:
if base_type in grammar_nodes_names_set:
grammar_nodes = create_if_class(f"If_{base_type}", base_type,
grammar_nodes)
types_names_set = {
cls.__name__
for cls in grammar_nodes
if cls.__base__ is ABC and not any(issubclass(cls, other) and cls is not other for other in grammar_nodes)
}
for ty_name in types_names_set:
grammar_nodes = create_if_class(f"If_{ty_name}", ty_name, grammar_nodes)
return grammar_nodes


def gen_grammar_nodes(ctx: TypingContext,
synth_func_name: str,
grammar_nodes: list[type] | None = None) -> list[type]:
def gen_grammar_nodes(ctx: TypingContext, synth_func_name: str, grammar_nodes: list[type] | None = None) -> list[type]:
"""Generate grammar nodes from the variables in the given TypingContext.
This function iterates over the variables in the provided TypingContext. For each variable,
Expand All @@ -346,6 +329,8 @@ def gen_grammar_nodes(ctx: TypingContext,
if var[0] != synth_func_name:
grammar_nodes = create_class_from_ctx_var(var, grammar_nodes)
grammar_nodes = build_control_flow_grammar_nodes(grammar_nodes)

##print_grammar_nodes(grammar_nodes)
return grammar_nodes


Expand All @@ -360,7 +345,9 @@ def get_grammar_node(node_name: str, nodes: list[type]) -> type | None:
Returns:
type: The node with the matching name
"""
return next((n for n in nodes if n.__name__ == node_name), )
return next(
(n for n in nodes if n.__name__ == node_name),
)


def convert_to_term(inp):
Expand All @@ -373,3 +360,16 @@ def convert_to_term(inp):
elif isinstance(inp, float):
return Literal(inp, type=t_float)
raise Exception(f"unable to converto to term : {type(inp)}")


def print_grammar_nodes(grammar_nodes: list[type]):
for cls in grammar_nodes:
parents = [base.__name__ for base in cls.__bases__]
print(f"class {cls.__name__} ({', '.join(parents)}):")
class_vars = cls.__annotations__
if class_vars:
for var_name, var_type in class_vars.items():
print(f"\t {var_name}: {var_type.__name__}")
else:
print("\t pass")
print("---------------------------------------------------")
50 changes: 28 additions & 22 deletions aeon/synthesis_grammar/identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_holes_info(
ctx: TypingContext,
t: Term,
ty: Type,
targets: list[tuple[str, list[str]]],
) -> dict[str, tuple[Type, TypingContext]]:
"""Retrieve the Types of "holes" in a given Term and TypingContext.
Expand All @@ -36,9 +37,7 @@ def get_holes_info(
ctx (TypingContext): The current TypingContext.
t (Term): The term to analyze.
ty (Type): The current type.
holes (dict[str, tuple[Type, TypingContext, str]]: The current dictionary of hole types. Defaults to None.
Returns:
dict[str, tuple[Type, TypingContext, str]]: The updated dictionary of hole Types and their TypingContexts.
targets (list(tuple(str, list(str)))): List of tuples functions names that contains holes and the name holes
"""
match t:
case Annotation(expr=Hole(name=hname), type=ty):
Expand All @@ -52,46 +51,53 @@ def get_holes_info(
case Var(_):
return {}
case Annotation(expr=expr, type=ty):
return get_holes_info(ctx, expr, ty)
return get_holes_info(ctx, expr, ty, targets)
case Application(fun=fun, arg=arg):
hs1 = get_holes_info(ctx, fun, ty)
hs2 = get_holes_info(ctx, arg, ty)
hs1 = get_holes_info(ctx, fun, ty, targets)
hs2 = get_holes_info(ctx, arg, ty, targets)
return hs1 | hs2
case If(cond=cond, then=then, otherwise=otherwise):
hs1 = get_holes_info(ctx, cond, ty)
hs2 = get_holes_info(ctx, then, ty)
hs3 = get_holes_info(ctx, otherwise, ty)
hs1 = get_holes_info(ctx, cond, ty, targets)
hs2 = get_holes_info(ctx, then, ty, targets)
hs3 = get_holes_info(ctx, otherwise, ty, targets)
return hs1 | hs2 | hs3
case Abstraction(var_name=vname, body=body):
if isinstance(ty, AbstractionType):
ret = substitution_in_type(ty.type, Var(vname), ty.var_name)
ctx = ctx.with_var(vname, ty.var_type)
return get_holes_info(ctx, body, ret)
return get_holes_info(ctx, body, ret, targets)
else:
assert False, f"Synthesis cannot infer the type of {t}"
case Let(var_name=vname, var_value=value, body=body):
_, t1 = synth(ctx, value)
ctx = ctx.with_var(vname, t1)
hs1 = get_holes_info(ctx, t.var_value, ty)
hs2 = get_holes_info(ctx, t.body, ty)
if not isinstance(value, Hole) and not (isinstance(value, Annotation) and isinstance(value.expr, Hole)):
ctx = ctx.with_var(vname, t1)
hs1 = get_holes_info(ctx, t.var_value, ty, targets)
hs2 = get_holes_info(ctx, t.body, ty, targets)
else:
hs1 = get_holes_info(ctx, t.var_value, ty, targets)
ctx = ctx.with_var(vname, t1)
hs2 = get_holes_info(ctx, t.body, ty, targets)
return hs1 | hs2
case Rec(var_name=vname, var_type=vtype, var_value=value, body=body):
ctx = ctx.with_var(vname, vtype)
hs1 = get_holes_info(
ctx,
value,
vtype,
)
hs2 = get_holes_info(ctx, body, ty)
if any(tup[0] == vname for tup in targets):
hs1 = get_holes_info(ctx, value, vtype, targets)
ctx = ctx.with_var(vname, vtype)
hs2 = get_holes_info(ctx, body, ty, targets)
else:
ctx = ctx.with_var(vname, vtype)
hs1 = get_holes_info(ctx, value, vtype, targets)
hs2 = get_holes_info(ctx, body, ty, targets)

return hs1 | hs2
case TypeApplication(body=body, type=argty):
if isinstance(ty, TypePolymorphism):
ntype = substitute_vartype(ty.body, argty, ty.name)
return get_holes_info(ctx, body, ntype)
return get_holes_info(ctx, body, ntype, targets)
else:
assert False, f"Synthesis cannot infer the type of {t}"
case TypeAbstraction(name=n, kind=k, body=body):
return get_holes_info(ctx.with_typevar(n, k), body, ty)
return get_holes_info(ctx.with_typevar(n, k), body, ty, targets)
case _:
assert False, f"Could not infer the type of {t} for synthesis."

Expand Down
Loading

0 comments on commit 8a03a05

Please sign in to comment.