Skip to content

Commit

Permalink
Merge pull request #309 from OpShin/feat/dictcomp
Browse files Browse the repository at this point in the history
Add support for dictionary comprehensions
  • Loading branch information
nielstron authored Jan 17, 2024
2 parents 833ce51 + aa9daaa commit aa75f24
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 1 deletion.
12 changes: 12 additions & 0 deletions examples/dict_comprehensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!opshin
from opshin.prelude import *


def validator(n: int, even: bool) -> Dict[int, int]:
if even:
# generate even squares
res = {k: k * k for k in range(n) if k % 2 == 0}
else:
# generate all squares
res = {k: k * k for k in range(n)}
return res
51 changes: 51 additions & 0 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,57 @@ def visit_ListComp(self, node: TypedListComp) -> plt.AST:
empty_list_con,
)

def visit_DictComp(self, node: TypedDictComp) -> plt.AST:
assert len(node.generators) == 1, "Currently only one generator supported"
gen = node.generators[0]
assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
assert isinstance(
gen.target, Name
), "Can only assign value to singleton element"
lst = self.visit(gen.iter)
ifs = None
for ifexpr in gen.ifs:
if ifs is None:
ifs = self.visit(ifexpr)
else:
ifs = plt.And(ifs, self.visit(ifexpr))
map_fun = OLambda(
["x"],
plt.Let(
[(gen.target.id, plt.Delay(OVar("x")))],
plt.MkPairData(
transform_output_map(node.key.typ)(
self.visit(node.key),
),
transform_output_map(node.value.typ)(
self.visit(node.value),
),
),
),
)
empty_list_con = plt.EmptyDataPairList()
if ifs is not None:
filter_fun = OLambda(
["x"],
plt.Let(
[(gen.target.id, plt.Delay(OVar("x")))],
ifs,
),
)
return plt.MapFilterList(
lst,
filter_fun,
map_fun,
empty_list_con,
)
else:
return plt.MapList(
lst,
map_fun,
empty_list_con,
)

def visit_FormattedValue(self, node: TypedFormattedValue) -> plt.AST:
return plt.Apply(
node.value.typ.stringify(),
Expand Down
24 changes: 24 additions & 0 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,30 @@ def test_list_comprehension_all(self):
"List comprehension incorrectly evaluated",
)

def test_dict_comprehension_even(self):
input_file = "examples/dict_comprehensions.py"
with open(input_file) as fp:
source_code = fp.read()
ret = eval_uplc_value(source_code, 8, 1)
ret = {x.value: y.value for x, y in ret.items()}
self.assertEqual(
ret,
{x: x * x for x in range(8) if x % 2 == 0},
"Dict comprehension incorrectly evaluated",
)

def test_dict_comprehension_all(self):
input_file = "examples/dict_comprehensions.py"
with open(input_file) as fp:
source_code = fp.read()
ret = eval_uplc_value(source_code, 8, 0)
ret = {x.value: y.value for x, y in ret.items()}
self.assertEqual(
ret,
{x: x * x for x in range(8)},
"Dict comprehension incorrectly evaluated",
)

@hypothesis.given(some_output)
def test_union_type_attr_access_all_records(self, x):
source_code = """
Expand Down
15 changes: 15 additions & 0 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,21 @@ def visit_ListComp(self, node: ListComp) -> TypedListComp:
typed_listcomp.typ = InstanceType(ListType(typed_listcomp.elt.typ))
return typed_listcomp

def visit_DictComp(self, node: DictComp) -> TypedDictComp:
typed_dictcomp = copy(node)
# inside the comprehension is a seperate scope
self.enter_scope()
# first evaluate generators for assigned variables
typed_dictcomp.generators = [self.visit(s) for s in node.generators]
# then evaluate elements
typed_dictcomp.key = self.visit(node.key)
typed_dictcomp.value = self.visit(node.value)
self.exit_scope()
typed_dictcomp.typ = InstanceType(
DictType(typed_dictcomp.key.typ, typed_dictcomp.value.typ)
)
return typed_dictcomp

def visit_FormattedValue(self, node: FormattedValue) -> TypedFormattedValue:
typed_node = copy(node)
typed_node.value = self.visit(node.value)
Expand Down
8 changes: 7 additions & 1 deletion opshin/typed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,14 @@ class typedcomprehension(typedexpr, comprehension):


class TypedListComp(typedexpr, ListComp):
generators: typing.List[typedcomprehension]
elt: typedexpr
generators: typing.List[typedcomprehension]


class TypedDictComp(typedexpr, DictComp):
key: typedexpr
value: typedexpr
generators: typing.List[typedcomprehension]


class TypedFormattedValue(typedexpr, FormattedValue):
Expand Down

0 comments on commit aa75f24

Please sign in to comment.