Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Syntax: Function arguments #45

Merged
merged 10 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,20 @@ jobs:
run: |
files=$(find examples/core -type f -name '*_main.ae')
for file in $files; do
echo "Running $file"
python aeon $file --core --log INFO ERROR TYPECHECKER CONSTRAINT
done
- name: Run Aeon Sugar Files
run: |
files=$(find examples/sugar -type f -name '*_main.ae')
for file in $files; do
echo "Running $file"
python aeon $file --log INFO ERROR TYPECHECKER CONSTRAINT
done
- name: Run PSB2 Solved Problems
run: |
files=$(find examples/PSB2/solved -type f -name '*.ae')
for file in $files; do
echo "Running $file"
python aeon $file --log INFO ERROR TYPECHECKER CONSTRAINT
done
2 changes: 1 addition & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def abs (i:Int) : Int {
if i > 0 then i else -i
}

def midpoint(a:Int, b:Int) : Int {
def midpoint (a:In) (b:Int) : Int {
(a + b) / 2
}

Expand Down
20 changes: 12 additions & 8 deletions aeon/sugar/aeon_sugar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ defs : def* -> list
type_decl : "type" TYPENAME ";" -> type_decl

def : (soft_constraint)* "def" ID ":" type "=" expression ";" -> def_cons
| (soft_constraint)* "def" ID "(" args ")" ":" type "{" expression "}" -> def_fun
| (soft_constraint)* "def" ID args ":" type "{" expression "}" -> def_fun

soft_constraint : "@" ID "(" macro_args ")" -> macro

Expand All @@ -21,9 +21,18 @@ macro_args : -> empty_list


args : -> empty_list
| arg ("," arg)* -> args
| arg (arg)* -> args

arg : ID ":" type -> arg
arg : "(" ID ":" type ")" -> arg
| "(" ID ":" type _PIPE expression ")" -> refined_arg


type : "{" ID ":" type _PIPE expression "}" -> refined_t // TODO delete later
| "(" ID ":" type ")" "->" type -> abstraction_t
| "(" ID ":" type _PIPE expression ")" "->" type -> abstraction_refined_t
| "forall" ID ":" kind "," type -> polymorphism_t // TODO: desugar
| ID -> simple_t
| "(" type ")" -> same


expression : "-" expression_i -> minus
Expand Down Expand Up @@ -84,11 +93,6 @@ expression_simple : "(" expression ")" -> same
| ID -> var
| "?" ID -> hole

type : "{" ID ":" type _PIPE expression "}" -> refined_t
| "(" ID ":" type ")" "->" type -> abstraction_t
| "forall" ID ":" kind "," type -> polymorphism_t // TODO: desugar
| ID -> simple_t
| "(" type ")" -> same

kind : "B" -> base_kind
| "*" -> star_kind
Expand Down
11 changes: 10 additions & 1 deletion aeon/sugar/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from lark import Lark
from lark import Tree

from aeon.core.substitutions import liquefy
from aeon.core.terms import Abstraction
from aeon.core.terms import Annotation
from aeon.core.types import AbstractionType
from aeon.core.types import AbstractionType, RefinedType
from aeon.core.types import TypeVar
from aeon.frontend.parser import TreeToCore
from aeon.sugar.program import Decorator
Expand All @@ -19,6 +20,7 @@


class TreeToSugar(TreeToCore):

def list(self, args):
return args

Expand Down Expand Up @@ -67,6 +69,13 @@ def args(self, args):
def arg(self, args):
return (args[0], args[1])

def refined_arg(self, args):
return args[0], RefinedType(args[0], args[1], liquefy(args[2]))

def abstraction_refined_t(self, args):
type = RefinedType(args[0], args[1], liquefy(args[2]))
return AbstractionType(str(args[0]), type, args[3])

def abstraction_et(self, args):
return Annotation(
Abstraction(args[0], args[2]),
Expand Down
2 changes: 1 addition & 1 deletion aeon/sugar/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __repr__(self):
return f"def {self.name} : {self.type} = {self.body};"
else:
args = ", ".join([f"{n}:{t}" for (n, t) in self.args])
return f"def {self.name} ({args}) -> {self.type} {{\n {self.body} \n}}"
return f"def {self.name} {args} -> {self.type} {{\n {self.body} \n}}"


@dataclass
Expand Down
80 changes: 64 additions & 16 deletions aeon/typechecking/typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from loguru import logger

from aeon.core.instantiation import type_substitution
from aeon.core.liquid import LiquidApp
from aeon.core.liquid import LiquidApp, LiquidHole
from aeon.core.liquid import LiquidLiteralBool
from aeon.core.liquid import LiquidLiteralFloat
from aeon.core.liquid import LiquidLiteralInt
Expand All @@ -28,16 +28,16 @@
from aeon.core.types import BaseKind
from aeon.core.types import BaseType
from aeon.core.types import RefinedType
from aeon.core.types import t_bool
from aeon.core.types import t_float
from aeon.core.types import t_int
from aeon.core.types import t_unit
from aeon.core.types import Type
from aeon.core.types import TypePolymorphism
from aeon.core.types import TypeVar
from aeon.core.types import args_size_of_type
from aeon.core.types import bottom
from aeon.core.types import extract_parts
from aeon.core.types import t_bool
from aeon.core.types import t_float
from aeon.core.types import t_int
from aeon.core.types import t_unit
from aeon.core.types import type_free_term_vars
from aeon.prelude.prelude import (
INTEGER_ARITHMETIC_OPERATORS,
Expand Down Expand Up @@ -77,10 +77,14 @@ def __str__(self):


def argument_is_typevar(ty: Type):
return (isinstance(ty, TypeVar) or isinstance(
ty,
RefinedType,
) and isinstance(ty.type, TypeVar))
return (
isinstance(ty, TypeVar)
or isinstance(
ty,
RefinedType,
)
and isinstance(ty.type, TypeVar)
)


def prim_litbool(t: bool) -> RefinedType:
Expand Down Expand Up @@ -154,6 +158,46 @@ def prim_op(t: str) -> Type:
)


def rename_liquid_term(refinement, old_name, new_name):
if isinstance(refinement, LiquidVar):
if refinement.name == old_name:
return LiquidVar(new_name)
else:
return refinement
elif isinstance(refinement, LiquidLiteralBool):
return refinement
elif isinstance(refinement, LiquidLiteralInt):
return refinement
elif isinstance(refinement, LiquidLiteralFloat):
return refinement
elif isinstance(refinement, LiquidApp):
return LiquidApp(
refinement.fun,
[rename_liquid_term(x, old_name, new_name) for x in refinement.args],
)
elif isinstance(refinement, LiquidHole):
if refinement.name == old_name:
return LiquidHole(
new_name,
[(rename_liquid_term(x, old_name, new_name), t) for (x, t) in refinement.argtypes],
)
else:
return LiquidHole(
refinement.name,
[(rename_liquid_term(x, old_name, new_name), t) for (x, t) in refinement.argtypes],
)
else:
assert False


def renamed_refined_type(ty: RefinedType) -> RefinedType:
old_name = ty.name
new_name = "_inner_" + old_name

refinement = rename_liquid_term(ty.refinement, old_name, new_name)
return RefinedType(new_name, ty.type, refinement)


# patterm matching term
def synth(ctx: TypingContext, t: Term) -> tuple[Constraint, Type]:
if isinstance(t, Literal) and t.type == t_unit:
Expand All @@ -178,8 +222,9 @@ def synth(ctx: TypingContext, t: Term) -> tuple[Constraint, Type]:
ty = ctx.type_of(t.name)
if isinstance(ty, BaseType) or isinstance(ty, RefinedType):
ty = ensure_refined(ty)
assert ty.name != t.name
# TODO if the names are equal , we must replace it for another variable
# assert ty.name != t.name
if ty.name == t.name:
ty = renamed_refined_type(ty)
# Self
ty = RefinedType(
ty.name,
Expand All @@ -200,7 +245,8 @@ def synth(ctx: TypingContext, t: Term) -> tuple[Constraint, Type]:
)
if not ty:
raise CouldNotGenerateConstraintException(
f"Variable {t.name} not in context", )
f"Variable {t.name} not in context",
)
return (ctrue, ty)
elif isinstance(t, Application):
(c, ty) = synth(ctx, t.fun)
Expand All @@ -222,7 +268,8 @@ def synth(ctx: TypingContext, t: Term) -> tuple[Constraint, Type]:
return (c0, t_subs)
else:
raise CouldNotGenerateConstraintException(
f"Application {t} is not a function.", )
f"Application {t} is not a function.",
)
elif isinstance(t, Let):
(c1, t1) = synth(ctx, t.var_value)
nctx: TypingContext = ctx.with_var(t.var_name, t1)
Expand Down Expand Up @@ -298,8 +345,8 @@ def check_(ctx: TypingContext, t: Term, ty: Type) -> Constraint:
@wrap_checks # DEMO1
def check(ctx: TypingContext, t: Term, ty: Type) -> Constraint:
if isinstance(t, Abstraction) and isinstance(
ty,
AbstractionType,
ty,
AbstractionType,
): # ??? (\__equal_1__ -> (let _anf_1 = (== _anf_1) in(_anf_1 __equal_1__))) , basetype INT
ret = substitution_in_type(ty.type, Var(t.var_name), ty.var_name)
c = check(ctx.with_var(t.var_name, ty.var_type), t.body, ret)
Expand All @@ -325,7 +372,8 @@ def check(ctx: TypingContext, t: Term, ty: Type) -> Constraint:
assert liq_cond is not None
if not check_type(ctx, t.cond, t_bool):
raise CouldNotGenerateConstraintException(
"If condition not boolean", )
"If condition not boolean",
)
c0 = check(ctx, t.cond, t_bool)
c1 = implication_constraint(
y,
Expand Down
6 changes: 3 additions & 3 deletions examples/PSB2/annotations/multi_objective/gcd.ae
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# output : integer
import PSB2;

def gcd ( n:Int, z:Int) : Int {
if z == 0 then n else (gcd(z)(n % z))
def gcd (n:Int) (z:Int) : Int {
if z == 0 then n else gcd z (n % z)
}

def train: TrainData = extract_train_data (load_dataset "gcd" 200 200);
def input_list : List = get_input_list (unpack_train_data train);
def expected_values : List = get_output_list (unpack_train_data train);

@multi_minimize_float( calculate_list_errors (get_gcd_synth_values input_list synth) (expected_values))
def synth ( n:Int, z:Int) : Int {
def synth (n:Int) (z:Int) : Int {
(?hole:Int)
}
# def largest_common_divisor ( n :{x:Int | 1 <= x && x <= 1000000}, m :{y:Int | 1 <= y && y <= 1000000}) : Int { gcd n m }
4 changes: 2 additions & 2 deletions examples/PSB2/annotations/single_objective/gcd.ae
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# output : integer
import PSB2;

def gcd ( n:Int, z:Int) : Int {
def gcd (n:Int) (z:Int) : Int {
if z == 0 then n else (gcd(z)(n % z))
}

Expand All @@ -13,7 +13,7 @@ def input_list : List = get_input_list ( unpack_train_data train);
def expected_values : List = get_output_list ( unpack_train_data train);

@minimize_float( mean_absolute_error ( get_gcd_synth_values input_list synth) ( expected_values))
def synth ( n:Int, z:Int) : Int {
def synth (n:Int) (z:Int) : Int {
(?hole:Int)
}
# def largest_common_divisor ( n :{x:Int | 1 <= x && x <= 1000000}, m :{y:Int | 1 <= y && y <= 1000000}) : Int { gcd n m }
6 changes: 3 additions & 3 deletions examples/PSB2/non_working/gcd.ae
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# --not- working - recursion problem--
# RecursionError: maximum recursion depth exceeded in comparison

def gcd ( n:Int, z:Int) : Int {
if z == 0 then n else (gcd(z)(n % z))
def gcd (n:Int) (z:Int) : Int {
if z == 0 then n else gcd z (n % z)
}

def synth_gcd ( n:Int, z:Int) : Int {
def synth_gcd (n:Int) (z:Int) : Int {
(?hole:Int)
}
# def largest_common_divisor ( n :{x:Int | 1 <= x && x <= 1000000}, m :{y:Int | 1 <= y && y <= 1000000}) : Int { gcd n m }
Expand Down
4 changes: 2 additions & 2 deletions examples/PSB2/solved/bouncing_balls.ae
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

def isZero : (n: Int) -> Bool = \n -> n == 0;

def calculate_distance_helper (b_index: Float, h: Float, n: Int, distance: Float) : Float {
def calculate_distance_helper (b_index: Float) (h: Float) (n: Int) (distance: Float) : Float {
if isZero n then
distance
else
Expand All @@ -11,7 +11,7 @@ else
calculate_distance_helper b_index a n1 d1
}

def bouncing_balls (s_height: Float, b_height:Float , b:Int) : Float{
def bouncing_balls (s_height: Float) (b_height:Float) (b:Int) : Float{
bounciness_index : Float = s_height /. b_height;
calculate_distance_helper bounciness_index s_height b 0.0
}
Expand Down
2 changes: 1 addition & 1 deletion examples/PSB2/solved/bowling.ae
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def String_eq : (s:String) -> (s2:String) -> Bool = native "lambda s: lambda s2:

def List_range_step : (start:Int) -> (end:Int) -> (step:Int) -> List = native "lambda s: lambda e: lambda st: list(range(s, e, st))";

def create_mapper (scores:String, i:Int) : Int {
def create_mapper (scores:String) (i:Int) : Int {
current : String = String_get scores i;
next : String = String_get scores (i+1);
if String_eq current "X" then
Expand Down
4 changes: 2 additions & 2 deletions examples/PSB2/solved/dice_game.ae
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ def math : Unit = native_import "math";
def Math_toFloat : (i:Int) -> Float = native "lambda i: float(i)" ;
def Math_max : (i:Int) -> (j:Int) -> Int = native "lambda i: lambda j: max(i,j)" ;

def peter_wins (n:Int, m:Int) : Int {
def peter_wins (n:Int) (m:Int) : Int {
if m == 0 then 0 else Math_max 0 (n - m) + peter_wins n (m - 1)
}

def dice_game (n:{x:Int | 1 <= x && x <= 10000}, m:{y:Int | 1 <= y && y <= 10000}) : Float {
def dice_game (n:Int | 1 <= n && n <= 10000) (m:Int | 1 <= m && m <= 10000) : Float {
a : Float = Math_toFloat (peter_wins n m);
b : Float = Math_toFloat (n * m);
c : Float = a /. b;
Expand Down
2 changes: 1 addition & 1 deletion examples/PSB2/solved/find_pair.ae
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def List_new : {x:List | List_size x == 0} = native "[]" ;
def List_append : (l:List) -> (i: Int) -> {l2:List | List_size l2 == List_size l + 1} = native "lambda xs: lambda x: xs + [x]";

#assuming that the list always has at least 2 elements that sum to the target
def find_pair (numbers: List, target: Int) : List {
def find_pair (numbers: List) (target: Int) : List {
pairs : List = combinations numbers 2;
pred : (s:List) -> Bool = \pair -> List_sum pair == target;
matching_pairs : List = List_filter pred pairs;
Expand Down
2 changes: 1 addition & 1 deletion examples/PSB2/solved/gcd.ae
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def gcd (n:Int, z:Int) : Int {
def gcd (n:Int) (z:Int) : Int {
if z == 0 then n else gcd z (n % z)
}

Expand Down
2 changes: 1 addition & 1 deletion examples/PSB2/solved/indices_of_substring.ae
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def String_equal : (i:String) -> (j:String) -> Bool = native "lambda i: lambda j

def String_slice : (i:String) -> (j:Int) -> (l:Int)-> String = native "lambda i: lambda j: lambda l: i[j:l]" ;

def indices_of_substring ( text :String, target: String) : List {
def indices_of_substring ( text :String) (target: String) : List {
start: Int = 0;
end: Int = (String_len text) - (String_len target) + 1 ;
step: Int = 1;
Expand Down
2 changes: 1 addition & 1 deletion examples/PSB2/solved/luhn.ae
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def Range_new : (a:Int) -> (b:Int) -> Range = native "lambda a: lambda b: range(

def even : (n:Int) -> Bool = \n -> n % 2 == 0;

def map_digit (x:Int, i:Int) : Int {
def map_digit (x:Int) (i:Int) : Int {
if i % 2 != 0 then ( if x * 2 > 9 then (x*2 - 9) else (x*2) : Int) else x
}

Expand Down
2 changes: 1 addition & 1 deletion examples/PSB2/solved/mastermind.ae
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def Counter_intersection : (c1:Counter) -> (c2:Counter) -> Counter = native "lam
def Counter_values : (c:Counter) -> List = native "lambda c: list(c.values())";
def Tuple : (a:Int) -> (b:Int) -> List = native "lambda a: lambda b: [a,b]";

def mastermind ( code : String, guess: String ) : List {
def mastermind ( code : String) (guess: String ) : List {
items_eq : (x:List) -> Int = \x -> if String_eq (List_get_fst x) (List_get_snd x) then 1 else 0;
black_map : List = Map_bool items_eq (String_Zip code guess);
black_pegs : Int = List_sum black_map;
Expand Down
Loading