Skip to content

Commit

Permalink
Fixed a lot of issues with inlining in general, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoeilers committed Nov 30, 2024
1 parent 00a5a40 commit 681ea6a
Show file tree
Hide file tree
Showing 18 changed files with 271 additions and 13 deletions.
13 changes: 13 additions & 0 deletions src/nagini_translation/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.current_class._has_classmethod = True
func.predicate = self.is_predicate(node)
if self.is_inline_method(node):
if name == '__init__':
raise UnsupportedException(node, 'Inlining constructors is currently not supported.')
decorators = {d.id for d in node.decorator_list if isinstance(d, ast.Name)}
if len(decorators) > 1:
raise UnsupportedException(node, 'Unsupported decorator for inline function.')
func.inline = True
if func.predicate:
func.contract_only = self.is_declared_contract_only(node)
Expand Down Expand Up @@ -911,13 +916,21 @@ def track_access(self, node: ast.AST, var: Union[PythonVar, PythonField]) -> Non
elif isinstance(node.ctx, ast.Store):
var.writes.append(node)

def _check_not_in_inline_method(self, node: ast.Call) -> None:
if isinstance(self.stmt_container, PythonMethod) and self.stmt_container.inline:
raise InvalidProgramException(node, 'contract.in.inline.method',
'Inline methods must not have specifications.')

def visit_Call(self, node: ast.Call) -> None:
"""
Collects preconditions, postconditions, raised exceptions and
invariants.
"""
if (isinstance(node.func, ast.Name) and
node.func.id in CONTRACT_WRAPPER_FUNCS):
if node.func.id != 'Invariant':
self._check_not_in_inline_method(node)

if node.func.id == 'Requires':
self.stmt_container.precondition.append(
(node.args[0], self._aliases.copy()))
Expand Down
5 changes: 3 additions & 2 deletions src/nagini_translation/lib/program_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,9 @@ def process(self, sil_name: str, translator: 'Translator') -> None:
if self.overrides.inline:
raise InvalidProgramException(self.node, 'overriding.inline.method',
'Functions marked to be inlined cannot be overridden.')
if self.inline:
raise InvalidProgramException(self.node, 'overriding.inline.method',
'Functions marked to be inlined cannot override other methods.')
except KeyError:
pass
for local in self.locals:
Expand Down Expand Up @@ -1582,8 +1585,6 @@ def process(self, sil_name: str, translator: 'Translator') -> None:
this Python variable.
"""
super().process(sil_name, translator)
if sil_name == "iterable_1":
print("++")
self._translator = translator
module = self.type.module
self.decl = translator.translate_pythonvar_decl(self, module)
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/sif/translators/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _method_body_postamble(self, method: PythonMethod, ctx: Context) -> List[Stm
def _create_method_epilog(self, method: PythonMethod, ctx: Context) -> List[Stmt]:
# With the extended AST we don't need a label at the end of the method.
# Check that no undeclared exceptions are raised. (but not for main method)
if not method.declared_exceptions:
if not (method.declared_exceptions or method.inline):
no_info = self.no_info(ctx)
error_string = '"method raises no exceptions"'
error_pos = self.to_position(method.node, ctx, error_string)
Expand Down
9 changes: 7 additions & 2 deletions src/nagini_translation/translators/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,11 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar],
start, end = get_body_indices(method.node.body)
stmts = []

if error_var:
pos = self.no_position(ctx)
info = self.no_info(ctx)
stmts.append(self.viper.LocalVarAssign(error_var.ref(), self.viper.NullLit(pos, info), pos, info))

for stmt in method.node.body[start:end]:
stmts += self.translate_stmt(stmt, ctx)

Expand Down Expand Up @@ -989,7 +994,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool,
self.translator)
optional_error_var = None
error_var = self.get_error_var(node, ctx)
if method.declared_exceptions:
if method.declared_exceptions or method.inline:
var = PythonVar(ERROR_NAME, None,
ctx.module.global_module.classes['Exception'])
var._ref = error_var
Expand All @@ -1002,7 +1007,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool,
stmts += inline_stmts
if end_lbl:
stmts.append(end_lbl)
if method.declared_exceptions:
if method.declared_exceptions or method.inline:
stmts += self.create_exception_catchers(error_var,
ctx.actual_function.try_blocks, node, ctx)
# Return result
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def get_error_var(self, stmt: ast.AST,
if err_var.sil_name in ctx.var_aliases:
err_var = ctx.var_aliases[err_var.sil_name]
return err_var.ref()
if ctx.actual_function.declared_exceptions:
if ctx.actual_function.declared_exceptions or ctx.actual_function.inline:
return ctx.error_var.ref()
else:
new_var = ctx.current_function.create_variable('error',
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def create_exception_catchers(self, var: PythonVar,
else:
end_label = ctx.get_label_name(END_LABEL)
goto_end = self.viper.Goto(end_label, position, self.no_info(ctx))
if ctx.actual_function.declared_exceptions:
if ctx.actual_function.declared_exceptions or ctx.actual_function.inline:
assignerror = self.viper.LocalVarAssign(err_var, var, position,
self.no_info(ctx))
uncaught_option = self.translate_block([assignerror, goto_end],
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def translate_finally(self, block: PythonTryBlock,
goto_continue = self.viper.Goto(loop.end_label, pos, info)
break_block = [goto_break]
continue_block = [goto_continue]
if ctx.actual_function.declared_exceptions:
if ctx.actual_function.declared_exceptions or ctx.actual_function.inline:
# Assign error to error output var
error_var = ctx.error_var.ref()
block_error_var = block.get_error_var(self.translator)
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/obligation/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class MethodObligationTranslator(CommonObligationTranslator):
"""Class for translating obligations in methods."""

def _get_obligation_info(self, ctx: Context) -> BaseObligationInfo:
return ctx.current_function.obligation_info
return ctx.actual_function.obligation_info

def _create_obligation_instance_use(
self, obligation_instance: ObligationInstance,
Expand Down
3 changes: 2 additions & 1 deletion src/nagini_translation/translators/obligation/waitlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def _translate_waitlevel_method(
info = self.no_info(ctx)

obligation_info = ctx.current_function.obligation_info
guard = obligation_info.get_wait_level_guard(node.left)
actual_obligation_info = ctx.actual_function.obligation_info
guard = actual_obligation_info.get_wait_level_guard(node.left)
exhale = self._create_level_below_inex(
guard, expr, obligation_info.residue_level, ctx)
translated_exhale = exhale.translate(self, ctx, position, info)
Expand Down
6 changes: 3 additions & 3 deletions src/nagini_translation/translators/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,14 +1430,14 @@ def _set_result_none(self, ctx: Context) -> List[Stmt]:
null = self.viper.NullLit(self.no_position(ctx), self.no_info(ctx))
if ctx.actual_function.type:
result_none = self.viper.LocalVarAssign(
ctx.actual_function.result.ref(),
ctx.result_var.ref(),
null, self.no_position(ctx),
self.no_info(ctx))
result.append(result_none)
# Do the same for the error variable
if ctx.actual_function.declared_exceptions:
if ctx.actual_function.declared_exceptions or ctx.actual_function.inline:
error_none = self.viper.LocalVarAssign(
ctx.actual_function.error_var.ref(),
ctx.error_var.ref(),
null, self.no_position(ctx), self.no_info(ctx))
result.append(error_none)
return result
Expand Down
10 changes: 10 additions & 0 deletions tests/functional/translation/test_inline_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Any copyright is dedicated to the Public Domain.
# http://creativecommons.org/publicdomain/zero/1.0/

from nagini_contracts.contracts import *


@Predicate
@Inline
def test1() -> int: #:: ExpectedOutput(invalid.program:decorators.incompatible)
return 5
10 changes: 10 additions & 0 deletions tests/functional/translation/test_inline_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Any copyright is dedicated to the Public Domain.
# http://creativecommons.org/publicdomain/zero/1.0/

from nagini_contracts.contracts import *


@Pure
@Inline
def test1() -> int: #:: ExpectedOutput(invalid.program:decorators.incompatible)
return 5
16 changes: 16 additions & 0 deletions tests/functional/translation/test_inline_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Any copyright is dedicated to the Public Domain.
# http://creativecommons.org/publicdomain/zero/1.0/

from nagini_contracts.contracts import *


class A:

def test1(self) -> int:
return 5

class B(A):

@Inline
def test1(self) -> int: #:: ExpectedOutput(invalid.program:overriding.inline.method)
return 6
16 changes: 16 additions & 0 deletions tests/functional/translation/test_inline_4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Any copyright is dedicated to the Public Domain.
# http://creativecommons.org/publicdomain/zero/1.0/

from nagini_contracts.contracts import *


class A:

@Inline
def test1(self) -> int:
return 5

class B(A):

def test1(self) -> int: #:: ExpectedOutput(invalid.program:overriding.inline.method)
return 6
10 changes: 10 additions & 0 deletions tests/functional/translation/test_inline_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Any copyright is dedicated to the Public Domain.
# http://creativecommons.org/publicdomain/zero/1.0/

from nagini_contracts.contracts import *


@Inline
def test1() -> int:
Ensures(Result() > 0) #:: ExpectedOutput(invalid.program:contract.in.inline.method)
return 6
10 changes: 10 additions & 0 deletions tests/functional/translation/test_inline_6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Any copyright is dedicated to the Public Domain.
# http://creativecommons.org/publicdomain/zero/1.0/

from nagini_contracts.contracts import *


@Inline
def test1(i: int) -> int:
Requires(i > 0) #:: ExpectedOutput(invalid.program:contract.in.inline.method)
return 6
158 changes: 158 additions & 0 deletions tests/functional/verification/test_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Any copyright is dedicated to the Public Domain.
# http://creativecommons.org/publicdomain/zero/1.0/

from nagini_contracts.obligations import MustTerminate
from nagini_contracts.contracts import *

@Inline
def power3_A(input: int) -> int:
out_a = input
try:

for i in range(2):
Invariant(out_a == input * (input ** len(Previous(i))))
try:
out_a = out_a * input
out_a -= 5
finally:
out_a = out_a + 5
finally:
out_a = out_a - 5
return out_a + 5


@Inline
def power3_B(input: int) -> int:
i = 0
out_b = input
while i < 2:
Invariant(0 <= i and i <= 2)
Invariant(out_b == input * (input ** i))
out_b = out_b * input
i += 1
return out_b

@Inline
def not_always_correct(i: int, j: int) -> int:
if j < 0:
assert False # would lead to an error but is not reached by any call

if i == 0:
res = 8
#:: ExpectedOutput(assert.failed:assertion.false,L1)
assert False # would lead to an error
elif i > 0:
res = 14
else:
res = 9
return res

def partly_correct_caller() -> None:
tst = not_always_correct(5, 8)
Assert(tst == 14)
#:: Label(L1)
tst = not_always_correct(0, 8)

@Inline
def may_not_terminate(i: int) -> None:
will_terminate = i >= 0
#:: ExpectedOutput(leak_check.failed:loop_context.has_unsatisfied_obligations,L2)
while i != 0:
Invariant(Implies(will_terminate, MustTerminate(i)))
i -= 1

def should_terminate() -> None:
Requires(MustTerminate(5))

may_not_terminate(8)

#:: Label(L2)
may_not_terminate(-6)


class A:
@Inline
def foo(self) -> int:
return 1

def bar(self) -> int:
return 1

def test_calls(input: int, a: A) -> None:
Requires(input > 0)
Assert(a.foo() == 1)

Assert(power3_A(input) == power3_B(input))

#:: ExpectedOutput(assert.failed:assertion.false)
Assert(a.bar() == 1)

@Inline
def plus_two(i: int) -> int:
a = i + 2
return a

@Inline
def plus_seven(i: int) -> int:
a = plus_two(i)
b = plus_two(a)
return b + 3

def nested_caller() -> None:
Assert(plus_seven(9) == 16)
#:: ExpectedOutput(assert.failed:assertion.false)
Assert(plus_seven(5) == 13)

@Inline
def raises(a: int) -> None:
if a > 0:
raise Exception

@Inline
def raises_and_catches(b: int) -> int:
r = 9
try:
if b == 9:
raise Exception
r = 7
except:
r = 12
return r

def calls_raises_1() -> int:
Ensures(Result() == 8888)
#:: ExpectedOutput(postcondition.violated:assertion.false)
Ensures(False)
r = 43

try:
r = 12
raises(-3)
r = 456
raises(4)
r = 2
except:
Assert(r == 456)
r = 8888
return r

#:: ExpectedOutput(exhale.failed:assertion.false,L3)
def calls_raises_2() -> None:
#:: Label(L3)
raises(8) # error bc raises exception

def calls_raises_3(i: int) -> None:
Ensures(i <= 0)
Exsures(Exception, i > 0)
#:: ExpectedOutput(postcondition.violated:assertion.false)
Exsures(Exception, False)
raises(i)

def calls_raises_and_catches_1() -> None:
raises_and_catches(12)

raises_and_catches(9)

#:: ExpectedOutput(assert.failed:assertion.false)
Assert(False)

Loading

0 comments on commit 681ea6a

Please sign in to comment.