diff --git a/examples/dict_comprehensions.py b/examples/dict_comprehensions.py new file mode 100644 index 00000000..fbb7e971 --- /dev/null +++ b/examples/dict_comprehensions.py @@ -0,0 +1,12 @@ +#!opshin +from opshin.prelude import * + + +def validator(n: int, even: bool) -> Dict[int, int]: + if even: + # generate even squares + res = {k: k * k for k in range(n) if k % 2 == 0} + else: + # generate all squares + res = {k: k * k for k in range(n)} + return res diff --git a/opshin/compiler.py b/opshin/compiler.py index 89756629..6ff02d6c 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -977,6 +977,57 @@ def visit_ListComp(self, node: TypedListComp) -> plt.AST: empty_list_con, ) + def visit_DictComp(self, node: TypedDictComp) -> plt.AST: + assert len(node.generators) == 1, "Currently only one generator supported" + gen = node.generators[0] + assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators" + assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators" + assert isinstance( + gen.target, Name + ), "Can only assign value to singleton element" + lst = self.visit(gen.iter) + ifs = None + for ifexpr in gen.ifs: + if ifs is None: + ifs = self.visit(ifexpr) + else: + ifs = plt.And(ifs, self.visit(ifexpr)) + map_fun = OLambda( + ["x"], + plt.Let( + [(gen.target.id, plt.Delay(OVar("x")))], + plt.MkPairData( + transform_output_map(node.key.typ)( + self.visit(node.key), + ), + transform_output_map(node.value.typ)( + self.visit(node.value), + ), + ), + ), + ) + empty_list_con = plt.EmptyDataPairList() + if ifs is not None: + filter_fun = OLambda( + ["x"], + plt.Let( + [(gen.target.id, plt.Delay(OVar("x")))], + ifs, + ), + ) + return plt.MapFilterList( + lst, + filter_fun, + map_fun, + empty_list_con, + ) + else: + return plt.MapList( + lst, + map_fun, + empty_list_con, + ) + def visit_FormattedValue(self, node: TypedFormattedValue) -> plt.AST: return plt.Apply( node.value.typ.stringify(), diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index 55a02371..1b79f816 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -658,6 +658,30 @@ def test_list_comprehension_all(self): "List comprehension incorrectly evaluated", ) + def test_dict_comprehension_even(self): + input_file = "examples/dict_comprehensions.py" + with open(input_file) as fp: + source_code = fp.read() + ret = eval_uplc_value(source_code, 8, 1) + ret = {x.value: y.value for x, y in ret.items()} + self.assertEqual( + ret, + {x: x * x for x in range(8) if x % 2 == 0}, + "Dict comprehension incorrectly evaluated", + ) + + def test_dict_comprehension_all(self): + input_file = "examples/dict_comprehensions.py" + with open(input_file) as fp: + source_code = fp.read() + ret = eval_uplc_value(source_code, 8, 0) + ret = {x.value: y.value for x, y in ret.items()} + self.assertEqual( + ret, + {x: x * x for x in range(8)}, + "Dict comprehension incorrectly evaluated", + ) + @hypothesis.given(some_output) def test_union_type_attr_access_all_records(self, x): source_code = """ diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 451442b0..4a27a9f4 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -904,6 +904,21 @@ def visit_ListComp(self, node: ListComp) -> TypedListComp: typed_listcomp.typ = InstanceType(ListType(typed_listcomp.elt.typ)) return typed_listcomp + def visit_DictComp(self, node: DictComp) -> TypedDictComp: + typed_dictcomp = copy(node) + # inside the comprehension is a seperate scope + self.enter_scope() + # first evaluate generators for assigned variables + typed_dictcomp.generators = [self.visit(s) for s in node.generators] + # then evaluate elements + typed_dictcomp.key = self.visit(node.key) + typed_dictcomp.value = self.visit(node.value) + self.exit_scope() + typed_dictcomp.typ = InstanceType( + DictType(typed_dictcomp.key.typ, typed_dictcomp.value.typ) + ) + return typed_dictcomp + def visit_FormattedValue(self, node: FormattedValue) -> TypedFormattedValue: typed_node = copy(node) typed_node.value = self.visit(node.value) diff --git a/opshin/typed_ast.py b/opshin/typed_ast.py index 2f8b14b3..134b349c 100644 --- a/opshin/typed_ast.py +++ b/opshin/typed_ast.py @@ -118,8 +118,14 @@ class typedcomprehension(typedexpr, comprehension): class TypedListComp(typedexpr, ListComp): - generators: typing.List[typedcomprehension] elt: typedexpr + generators: typing.List[typedcomprehension] + + +class TypedDictComp(typedexpr, DictComp): + key: typedexpr + value: typedexpr + generators: typing.List[typedcomprehension] class TypedFormattedValue(typedexpr, FormattedValue):