Skip to content

Commit

Permalink
Move everything back into rewriter, refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nielstron committed Mar 20, 2023
1 parent 63bff19 commit 8956f40
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
3 changes: 0 additions & 3 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ Arguments are fully evaluated, they do not require another application of the st
Note that this means the function has access to all variables defined in the surrounding code _at the time of the function being called_.
This is consistent with the way it is done in python.

Also note that functions that take 0 arguments are expected to actually take a single unit argument.
This is due to the fact that UPLC does not feature 0-ary functions, which is emulated by this behavior.
Calls with 0 arguments are transformed by the compiler into calls that pass a single unit argument to the callee.

The python atomic types map to the UPLC builtin equivalents.
They are cast from and to plutus equivalents when passed into the validator and returned from it.
Expand Down
4 changes: 0 additions & 4 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,6 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
# if the function expects input of generic type data, wrap data before passing it inside
a_int = transform_output_map(a.typ)(a_int)
args.append(a_int)
if node.func.typ.typ.argtyps == [UnitInstanceType] and node.args == []:
# this would not pass the type check normally, only possible due to the zero-arg rewrite
# 0-ary functions expect another parameter
args.append(plt.Unit())
return plt.Lambda(
[STATEMONAD],
plt.Apply(
Expand Down
17 changes: 14 additions & 3 deletions opshin/rewrite/rewrite_zero_ary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
FunctionType,
NoneInstanceType,
TypedConstant,
RawPlutoExpr,
InstanceType,
TypedCall,
UnitInstanceType,
)

"""
Expand All @@ -21,7 +21,7 @@


class RewriteZeroAry(CompilingNodeTransformer):
step = "Rewriting augmenting assignments"
step = "Rewriting zero-ary functions"

def visit_FunctionDef(self, node: TypedFunctionDef) -> TypedFunctionDef:
if len(node.args.args) == 0:
Expand All @@ -30,3 +30,14 @@ def visit_FunctionDef(self, node: TypedFunctionDef) -> TypedFunctionDef:
node.typ.typ.argtyps.append(NoneInstanceType)
self.generic_visit(node)
return node

def visit_Call(self, node: TypedCall) -> TypedCall:
if isinstance(node.func, Name) and node.func.id == "dataclass":
# special case for the dataclass function
return node
if node.func.typ.typ.argtyps == [UnitInstanceType] and node.args == []:
# this would not pass the type check normally, only possible due to the zero-arg rewrite
# 0-ary functions expect another parameter
node.args.append(TypedConstant(None, typ=UnitInstanceType))
self.generic_visit(node)
return node
20 changes: 14 additions & 6 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def validator(x: None) -> None:
code = compiler.compile(ast).compile()
res = uplc_eval(uplc.Apply(code, uplc.PlutusInteger(0)))

@unittest.expectedFailure
def test_zero_ary_exec(self):
source_code = """
def a() -> None:
Expand All @@ -771,19 +772,26 @@ def validator(x: None) -> None:
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast).compile()
try:
res = uplc_eval(uplc.Apply(code, uplc.PlutusInteger(0)))
failed = False
except RuntimeError:
failed = True
self.assertTrue(failed, "Machine validated contract")
res = uplc_eval(uplc.Apply(code, uplc.PlutusInteger(0)))

def test_zero_ary_method(self):
source_code = """
def validator(x: None) -> None:
b = b"\\xFF".decode
if False:
b()
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast).compile()
res = uplc_eval(uplc.Apply(code, uplc.PlutusInteger(0)))

@unittest.expectedFailure
def test_zero_ary_method_exec(self):
source_code = """
def validator(x: None) -> None:
b = b"\\xFF".decode
if True:
b()
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast).compile()
Expand Down

0 comments on commit 8956f40

Please sign in to comment.