diff --git a/aeon/synthesis_grammar/grammar.py b/aeon/synthesis_grammar/grammar.py index ff44e313..293dcbc9 100644 --- a/aeon/synthesis_grammar/grammar.py +++ b/aeon/synthesis_grammar/grammar.py @@ -336,7 +336,11 @@ def gen_grammar_nodes( Returns: list[type]: The list of generated grammar nodes. """ - vars_to_ignore = metadata[synth_func_name]["syn_ignore"] + vars_to_ignore = ( + metadata[synth_func_name]["syn_ignore"] + if synth_func_name in metadata and "syn_ignore" in metadata[synth_func_name].keys() + else [] + ) if grammar_nodes is None: grammar_nodes = [] for var in ctx.vars(): diff --git a/tests/hole_test.py b/tests/hole_test.py index e5179426..85971591 100644 --- a/tests/hole_test.py +++ b/tests/hole_test.py @@ -9,7 +9,7 @@ def extract_target_functions(source): prog = parse_program(source) prog = apply_decorators_in_program(prog) - core, ctx, _ = desugar(prog) + core, ctx, _, _ = desugar(prog) core_anf = ensure_anf(core) check_type_errors(ctx, core_anf, top) return incomplete_functions_and_holes(ctx, core_anf) diff --git a/tests/native_test.py b/tests/native_test.py index c0dfa038..c5ca5467 100644 --- a/tests/native_test.py +++ b/tests/native_test.py @@ -8,7 +8,7 @@ def check_compile(source, ty): - p, ctx, ectx = desugar(parse_program(source)) + p, ctx, ectx, _ = desugar(parse_program(source)) assert check_type(ctx, p, ty) assert eval(p, ectx) == 2 diff --git a/tests/optimization_decorators_test.py b/tests/optimization_decorators_test.py index 3d0ef0f9..cdce3169 100644 --- a/tests/optimization_decorators_test.py +++ b/tests/optimization_decorators_test.py @@ -10,7 +10,7 @@ def extract_core(source: str) -> Term: prog = parse_program(source) - core, ctx, _ = desugar(prog) + core, ctx, _, _ = desugar(prog) core_anf = ensure_anf(core) check_type_errors(ctx, core_anf, top) return core_anf @@ -41,6 +41,7 @@ def main(args:Int) : Unit { core_ast, typing_ctx, evaluation_ctx, + metadata, ) = desugar(prog) core_ast_anf = ensure_anf(core_ast) diff --git a/tests/pow_test.py b/tests/pow_test.py index a967f59e..3732c657 100644 --- a/tests/pow_test.py +++ b/tests/pow_test.py @@ -7,7 +7,7 @@ def check_compile(source, ty): - p, ctx, _ = desugar(parse_program(source)) + p, ctx, _, _ = desugar(parse_program(source)) assert check_type(ctx, p, ty) diff --git a/tests/recursion_test.py b/tests/recursion_test.py index fbc395b1..88553952 100644 --- a/tests/recursion_test.py +++ b/tests/recursion_test.py @@ -7,7 +7,7 @@ def check_compile(source, ty, res): - p, ctx, ectx = desugar(parse_program(source)) + p, ctx, ectx, _ = desugar(parse_program(source)) assert check_type(ctx, p, ty) # assert eval(p, ectx) == res diff --git a/tests/smt_test.py b/tests/smt_test.py index 6fd671c4..05006ea7 100644 --- a/tests/smt_test.py +++ b/tests/smt_test.py @@ -20,7 +20,7 @@ def extract_core(source: str) -> Term: prog = parse_program(source) - core, ctx, _ = desugar(prog) + core, ctx, _, _ = desugar(prog) core_anf = ensure_anf(core) check_type_errors(ctx, core_anf, top) return core_anf @@ -30,8 +30,7 @@ def extract_core(source: str) -> Term: "x", t_int, LiquidApp("==", [LiquidVar("x"), LiquidLiteralInt(3)]), - LiquidConstraint(LiquidApp( - "==", [LiquidVar("x"), LiquidLiteralInt(3)])), + LiquidConstraint(LiquidApp("==", [LiquidVar("x"), LiquidLiteralInt(3)])), ) @@ -47,8 +46,7 @@ def test_smt_example3(): "y", BaseType("a"), LiquidApp("==", [LiquidVar("x"), LiquidVar("y")]), - LiquidConstraint(LiquidApp( - "==", [LiquidVar("x"), LiquidVar("y")])), + LiquidConstraint(LiquidApp("==", [LiquidVar("x"), LiquidVar("y")])), ), ) @@ -80,6 +78,7 @@ def main (x:Int) : Unit { core_ast, typing_ctx, evaluation_ctx, + metadata, ) = desugar(prog) core_ast_anf = ensure_anf(core_ast) @@ -104,6 +103,7 @@ def main (x:Int) : Unit { core_ast, typing_ctx, evaluation_ctx, + metadata, ) = desugar(prog) core_ast_anf = ensure_anf(core_ast) diff --git a/tests/synth_fitness_test.py b/tests/synth_fitness_test.py index 6c136d50..6bf1bbbd 100644 --- a/tests/synth_fitness_test.py +++ b/tests/synth_fitness_test.py @@ -36,10 +36,10 @@ def synth (i: Int): Int { (?hole: Int) * i} def __internal__fitness_function_synth : Int = year - synth(7); """ prog = parse_program(code) - p, ctx, ectx = desugar(prog) + p, ctx, ectx, _ = desugar(prog) p = ensure_anf(p) check_type_errors(ctx, p, top) - term = synthesize(ctx, ectx, p, [("synth", ["hole"])]) + term = synthesize(ctx, ectx, p, [("synth", ["hole"])], {}) assert isinstance(term, Term) @@ -50,9 +50,9 @@ def test_fitness2(): def synth (i:Int) : Int {(?hole: Int) * i} """ prog = parse_program(code) - p, ctx, ectx = desugar(prog) + p, ctx, ectx, _ = desugar(prog) p = ensure_anf(p) check_type_errors(ctx, p, top) - term = synthesize(ctx, ectx, p, [("synth", ["hole"])]) + term = synthesize(ctx, ectx, p, [("synth", ["hole"])], {}) assert isinstance(term, Term)