Skip to content

Commit

Permalink
Merge branch 'annotations_refactor' of github.com:alcides/aeon into a…
Browse files Browse the repository at this point in the history
…nnotations_refactor
  • Loading branch information
alcides committed Feb 7, 2024
2 parents 8a03a05 + 984332c commit fd80d9c
Show file tree
Hide file tree
Showing 38 changed files with 760 additions and 299 deletions.
55 changes: 30 additions & 25 deletions aeon/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import argparse
import sys

from aeon.backend.evaluator import eval
from aeon.backend.evaluator import EvaluationContext
from aeon.backend.evaluator import eval
from aeon.core.types import top
from aeon.decorators import apply_decorators
from aeon.frontend.anf_converter import ensure_anf
Expand All @@ -17,35 +17,30 @@
from aeon.sugar.parser import parse_program
from aeon.sugar.program import Program
from aeon.synthesis_grammar.identification import incomplete_functions_and_holes
from aeon.synthesis_grammar.synthesizer import synthesize
from aeon.synthesis_grammar.synthesizer import synthesize, parse_config
from aeon.typechecking.typeinfer import check_type_errors
from aeon.utils.ctx_helpers import build_context


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("filename",
help="name of the aeon files to be synthesized")
parser.add_argument("--core",
action="store_true",
help="synthesize a aeon core file")
parser.add_argument("filename", help="name of the aeon files to be synthesized")
parser.add_argument("--core", action="store_true", help="synthesize a aeon core file")
parser.add_argument(
"-l",
"--log",
nargs="+",
default="",
help="set log level: \nTRACE \nDEBUG \nINFO \nTYPECHECKER \nCONSTRAINT "
"\nWARNINGS \nERROR \nCRITICAL",
help="""set log level: \nTRACE \nDEBUG \nINFO \nWARNINGS \nTYPECHECKER \nSYNTH_TYPE \nCONSTRAINT \nSYNTHESIZER
\nERROR \nCRITICAL""",
)
parser.add_argument("-f",
"--logfile",
action="store_true",
help="export log file")

parser.add_argument("-csv",
"--csv-synth",
action="store_true",
help="export synthesis csv file")
parser.add_argument("-f", "--logfile", action="store_true", help="export log file")

parser.add_argument("-csv", "--csv-synth", action="store_true", help="export synthesis csv file")

parser.add_argument("-gp", "--gp-config", help="path to the GP configuration file")

parser.add_argument("-csec", "--config-section", help="section name in the GP configuration file")
return parser.parse_args()


Expand Down Expand Up @@ -100,15 +95,25 @@ def log_type_errors(errors: list[Exception | str]):
log_type_errors(type_errors)
sys.exit(1)

incomplete_functions: list[tuple[
str,
list[str]]] = incomplete_functions_and_holes(typing_ctx, core_ast_anf)
incomplete_functions: list[tuple[str, list[str]]] = incomplete_functions_and_holes(typing_ctx, core_ast_anf)

if incomplete_functions:
file_name = args.filename if args.csv_synth else None
synthesis_result = synthesize(typing_ctx, evaluation_ctx, core_ast_anf,
incomplete_functions, file_name)
print(f"Best solution: {synthesis_result}")
filename = args.filename if args.csv_synth else None
synth_config = (
parse_config(args.gp_config, args.config_section) if args.gp_config and args.config_section else None
)

synthesis_result = synthesize(
typing_ctx,
evaluation_ctx,
core_ast_anf,
incomplete_functions,
filename,
synth_config,
)
print(f"Best solution:{synthesis_result}")
# print()
# pretty_print_term(ensure_anf(synthesis_result, 200))
sys.exit(1)

eval(core_ast, evaluation_ctx)
98 changes: 97 additions & 1 deletion aeon/core/pprint.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
from __future__ import annotations

from aeon.core.liquid import LiquidLiteralBool
from aeon.core.terms import (
Abstraction,
Annotation,
Application,
Hole,
If,
Let,
Literal,
Rec,
Term,
TypeAbstraction,
TypeApplication,
Var,
)
from aeon.core.types import AbstractionType
from aeon.core.types import BaseType
from aeon.core.types import RefinedType
from aeon.core.types import Type
from aeon.core.types import type_free_term_vars
from aeon.core.types import TypeVar
from aeon.core.types import type_free_term_vars
from aeon.synthesis_grammar.grammar import aeon_prelude_ops_to_text


def pretty_print(t: Type) -> str:
Expand All @@ -28,3 +43,84 @@ def pretty_print(t: Type) -> str:
else:
return f"{{{t.name}:{it} | {t.refinement}}}"
assert False


def pretty_print_term(term: Term):
term_str: str = custom_preludes_ops_representation(term)[0]
print(term_str)


ops_to_abstraction: dict[str, str] = {
"%": "Int) -> Int",
"/": "Int) -> Int",
"*": "Int) -> Int",
"-": "Int) -> Int",
"+": "Int) -> Int",
"%.": "Float) -> Float",
"/.": "Float) -> Float",
"*.": "Float) -> Float",
"-.": "Float) -> Float",
"+.": "Float) -> Float",
">=": "Int) -> Bool",
">": "Int) -> Bool",
"<=": "Int) -> Bool",
"<": "Int) -> Bool",
"!=": "Int) -> Bool",
"==": "Int) -> Bool",
}


def custom_preludes_ops_representation(term: Term, counter: int = 0) -> tuple[str, int]:
prelude_operations: dict[str, str] = aeon_prelude_ops_to_text
match term:
case Application(fun=Var(name=var_name), arg=arg) if var_name in prelude_operations.keys():
op = var_name
arg_str, counter = custom_preludes_ops_representation(arg, counter)
counter += 1
new_var_name = f"__{prelude_operations[op]}_{counter}__"
abstraction_type_str = f"({new_var_name}:{ops_to_abstraction[op]}"
personalized_op = f": {abstraction_type_str} = (\\{new_var_name} -> {arg_str} {op} {new_var_name})"
return personalized_op, counter

case Application(fun=fun, arg=arg):
fun_str, counter = custom_preludes_ops_representation(fun, counter)
arg_str, counter = custom_preludes_ops_representation(arg, counter)
return f"= ({fun_str} {arg_str})", counter

case Annotation(expr=expr, type=type):
expr_str, counter = custom_preludes_ops_representation(expr, counter)
return f"({expr_str} : {type})", counter

case Abstraction(var_name=var_name, body=body):
body_str, counter = custom_preludes_ops_representation(body, counter)
return f"(\\{var_name} -> {body_str})", counter

case Let(var_name=var_name, var_value=var_value, body=body):
var_value_prefix = "= " if not isinstance(var_value, Application) else ""
var_value_str, counter = custom_preludes_ops_representation(var_value, counter)
body_str, counter = custom_preludes_ops_representation(body, counter)
return f"(let {var_name} {var_value_prefix}{var_value_str} in\n {body_str})", counter

case Rec(var_name=var_name, var_type=var_type, var_value=var_value, body=body):
var_value_str, counter = custom_preludes_ops_representation(var_value, counter)
body_str, counter = custom_preludes_ops_representation(body, counter)
return f"(let {var_name} : {var_type} = {var_value_str} in\n {body_str})", counter

case If(cond=cond, then=then, otherwise=otherwise):
cond_str, counter = custom_preludes_ops_representation(cond, counter)
then_str, counter = custom_preludes_ops_representation(then, counter)
otherwise_str, counter = custom_preludes_ops_representation(otherwise, counter)
return f"(if {cond_str} then {then_str} else {otherwise_str})", counter

case TypeAbstraction(name=name, kind=kind, body=body):
body_str, counter = custom_preludes_ops_representation(body, counter)
return f"ƛ{name}:{kind}.({body_str})", counter

case TypeApplication(body=body, type=type):
body_str, counter = custom_preludes_ops_representation(body, counter)
return f"({body_str})[{type}]", counter

case Literal(_, _) | Var(_) | Hole(_):
return str(term), counter

return str(term), counter
48 changes: 24 additions & 24 deletions aeon/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
from abc import ABC
from dataclasses import dataclass

from aeon.core.liquid import liquid_free_vars
from aeon.core.liquid import LiquidHole
from aeon.core.liquid import LiquidLiteralBool
from aeon.core.liquid import LiquidTerm
from aeon.core.liquid import liquid_free_vars


class Kind(ABC):

def __repr__(self):
return str(self)


class BaseKind(Kind):

def __eq__(self, o):
return self.__class__ == o.__class__

Expand All @@ -25,7 +23,6 @@ def __str__(self):


class StarKind(Kind):

def __eq__(self, o):
return self.__class__ == o.__class__

Expand Down Expand Up @@ -70,7 +67,6 @@ def __hash__(self) -> int:


class Top(Type):

def __repr__(self):
return "⊤"

Expand All @@ -85,7 +81,6 @@ def __hash__(self) -> int:


class Bottom(Type):

def __repr__(self):
return "⊥"

Expand Down Expand Up @@ -123,10 +118,12 @@ def __repr__(self):
return f"({self.var_name}:{self.var_type}) -> {self.type}"

def __eq__(self, other):
return (isinstance(other, AbstractionType)
and self.var_name == other.var_name
and self.var_type == other.var_type
and self.type == other.type)
return (
isinstance(other, AbstractionType)
and self.var_name == other.var_name
and self.var_type == other.var_type
and self.type == other.type
)

def __hash__(self) -> int:
return hash(self.var_name) + hash(self.var_type) + hash(self.type)
Expand All @@ -137,8 +134,7 @@ class RefinedType(Type):
type: BaseType | TypeVar
refinement: LiquidTerm

def __init__(self, name: str, ty: BaseType | TypeVar,
refinement: LiquidTerm):
def __init__(self, name: str, ty: BaseType | TypeVar, refinement: LiquidTerm):
self.name = name
self.type = ty
self.refinement = refinement
Expand All @@ -147,9 +143,12 @@ def __repr__(self):
return f"{{ {self.name}:{self.type} | {self.refinement} }}"

def __eq__(self, other):
return (isinstance(other, RefinedType) and self.name == other.name
and self.type == other.type
and self.refinement == other.refinement)
return (
isinstance(other, RefinedType)
and self.name == other.name
and self.type == other.type
and self.refinement == other.refinement
)

def __hash__(self) -> int:
return hash(self.name) + hash(self.type) + hash(self.refinement)
Expand All @@ -162,9 +161,10 @@ class TypePolymorphism(Type):
body: Type


def extract_parts(t: Type, ) -> tuple[str, BaseType | TypeVar, LiquidTerm]:
assert isinstance(t, BaseType) or isinstance(t, RefinedType) or isinstance(
t, TypeVar)
def extract_parts(
t: Type,
) -> tuple[str, BaseType | TypeVar, LiquidTerm]:
assert isinstance(t, BaseType) or isinstance(t, RefinedType) or isinstance(t, TypeVar)
if isinstance(t, RefinedType):
return (t.name, t.type, t.refinement)
else:
Expand All @@ -177,10 +177,8 @@ def extract_parts(t: Type, ) -> tuple[str, BaseType | TypeVar, LiquidTerm]:

def is_bare(ty: Type) -> bool:
"""Returns whether a type is bare or not."""
bare_base = isinstance(ty, RefinedType) and isinstance(
ty.refinement, LiquidHole)
dependent_function = isinstance(ty, AbstractionType) and is_bare(
ty.var_type) and is_bare(ty.type)
bare_base = isinstance(ty, RefinedType) and isinstance(ty.refinement, LiquidHole)
dependent_function = isinstance(ty, AbstractionType) and is_bare(ty.var_type) and is_bare(ty.type)
type_polymorphism = isinstance(ty, TypePolymorphism) and is_bare(ty.body)
return bare_base or dependent_function or type_polymorphism

Expand All @@ -192,18 +190,20 @@ def base(ty: Type) -> Type:


def type_free_term_vars(t: Type) -> list[str]:
from aeon.prelude.prelude import ALL_OPS

if isinstance(t, BaseType):
return []
elif isinstance(t, TypeVar):
return []
elif isinstance(t, AbstractionType):
afv = type_free_term_vars(t.var_type)
rfv = type_free_term_vars(t.type)
return [x for x in afv + rfv if x != t.var_name]
return [x for x in afv + rfv if x != t.var_name and x not in ALL_OPS]
elif isinstance(t, RefinedType):
ifv = type_free_term_vars(t.type)
rfv = liquid_free_vars(t.refinement)
return [x for x in ifv + rfv if x != t.name]
return [x for x in ifv + rfv if x != t.name and x not in ALL_OPS]
return []


Expand Down
14 changes: 7 additions & 7 deletions aeon/decorators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ def fun(...) { ... }
"""

from typing import Callable

from aeon.core.terms import Term
from aeon.sugar.program import Definition
from aeon.synthesis_grammar.decorators import minimize_int
from aeon.synthesis_grammar.decorators import minimize_int, minimize_float, multi_minimize_float

DecoratorType = Callable[[list[Term], Definition], tuple[Definition,
list[Definition]]]
DecoratorType = Callable[[list[Term], Definition], tuple[Definition, list[Definition]]]

decorators_environment: dict[str, DecoratorType] = {
"minimize_int": minimize_int
"minimize_int": minimize_int,
"minimize_float": minimize_float,
"multi_minimize_float": multi_minimize_float,
}


Expand All @@ -27,9 +29,7 @@ def apply_decorators(fun: Definition) -> tuple[Definition, list[Definition]]:
total_extra = []
for decorator in fun.decorators:
if decorator.name not in decorators_environment:
raise Exception(
f"Unknown decorator named {decorator.name}, in function {fun.name}."
)
raise Exception(f"Unknown decorator named {decorator.name}, in function {fun.name}.")
decorator_processor = decorators_environment[decorator.name]
(fun, extra) = decorator_processor(decorator.macro_args, fun)
total_extra.extend(extra)
Expand Down
Loading

0 comments on commit fd80d9c

Please sign in to comment.