diff --git a/src/nagini_translation/analyzer.py b/src/nagini_translation/analyzer.py index 5dd313f3..bdef38e8 100644 --- a/src/nagini_translation/analyzer.py +++ b/src/nagini_translation/analyzer.py @@ -675,6 +675,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: self.current_class._has_classmethod = True func.predicate = self.is_predicate(node) if self.is_inline_method(node): + if name == '__init__': + raise UnsupportedException(node, 'Inlining constructors is currently not supported.') + decorators = {d.id for d in node.decorator_list if isinstance(d, ast.Name)} + if len(decorators) > 1: + raise UnsupportedException(node, 'Unsupported decorator for inline function.') func.inline = True if func.predicate: func.contract_only = self.is_declared_contract_only(node) @@ -911,6 +916,11 @@ def track_access(self, node: ast.AST, var: Union[PythonVar, PythonField]) -> Non elif isinstance(node.ctx, ast.Store): var.writes.append(node) + def _check_not_in_inline_method(self, node: ast.Call) -> None: + if isinstance(self.stmt_container, PythonMethod) and self.stmt_container.inline: + raise InvalidProgramException(node, 'contract.in.inline.method', + 'Inline methods must not have specifications.') + def visit_Call(self, node: ast.Call) -> None: """ Collects preconditions, postconditions, raised exceptions and @@ -918,6 +928,9 @@ def visit_Call(self, node: ast.Call) -> None: """ if (isinstance(node.func, ast.Name) and node.func.id in CONTRACT_WRAPPER_FUNCS): + if node.func.id != 'Invariant': + self._check_not_in_inline_method(node) + if node.func.id == 'Requires': self.stmt_container.precondition.append( (node.args[0], self._aliases.copy())) diff --git a/src/nagini_translation/lib/program_nodes.py b/src/nagini_translation/lib/program_nodes.py index a1a2d770..e9633784 100644 --- a/src/nagini_translation/lib/program_nodes.py +++ b/src/nagini_translation/lib/program_nodes.py @@ -1085,6 +1085,9 @@ def process(self, sil_name: str, translator: 'Translator') -> None: if self.overrides.inline: raise InvalidProgramException(self.node, 'overriding.inline.method', 'Functions marked to be inlined cannot be overridden.') + if self.inline: + raise InvalidProgramException(self.node, 'overriding.inline.method', + 'Functions marked to be inlined cannot override other methods.') except KeyError: pass for local in self.locals: @@ -1582,8 +1585,6 @@ def process(self, sil_name: str, translator: 'Translator') -> None: this Python variable. """ super().process(sil_name, translator) - if sil_name == "iterable_1": - print("++") self._translator = translator module = self.type.module self.decl = translator.translate_pythonvar_decl(self, module) diff --git a/src/nagini_translation/sif/translators/method.py b/src/nagini_translation/sif/translators/method.py index 8bfa1612..88a2f5df 100644 --- a/src/nagini_translation/sif/translators/method.py +++ b/src/nagini_translation/sif/translators/method.py @@ -29,7 +29,7 @@ def _method_body_postamble(self, method: PythonMethod, ctx: Context) -> List[Stm def _create_method_epilog(self, method: PythonMethod, ctx: Context) -> List[Stmt]: # With the extended AST we don't need a label at the end of the method. # Check that no undeclared exceptions are raised. (but not for main method) - if not method.declared_exceptions: + if not (method.declared_exceptions or method.inline): no_info = self.no_info(ctx) error_string = '"method raises no exceptions"' error_pos = self.to_position(method.node, ctx, error_string) diff --git a/src/nagini_translation/translators/call.py b/src/nagini_translation/translators/call.py index c165eb85..51715695 100644 --- a/src/nagini_translation/translators/call.py +++ b/src/nagini_translation/translators/call.py @@ -942,6 +942,11 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar], start, end = get_body_indices(method.node.body) stmts = [] + if error_var: + pos = self.no_position(ctx) + info = self.no_info(ctx) + stmts.append(self.viper.LocalVarAssign(error_var.ref(), self.viper.NullLit(pos, info), pos, info)) + for stmt in method.node.body[start:end]: stmts += self.translate_stmt(stmt, ctx) @@ -989,7 +994,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool, self.translator) optional_error_var = None error_var = self.get_error_var(node, ctx) - if method.declared_exceptions: + if method.declared_exceptions or method.inline: var = PythonVar(ERROR_NAME, None, ctx.module.global_module.classes['Exception']) var._ref = error_var @@ -1002,7 +1007,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool, stmts += inline_stmts if end_lbl: stmts.append(end_lbl) - if method.declared_exceptions: + if method.declared_exceptions or method.inline: stmts += self.create_exception_catchers(error_var, ctx.actual_function.try_blocks, node, ctx) # Return result diff --git a/src/nagini_translation/translators/common.py b/src/nagini_translation/translators/common.py index 8a9057bf..fb9ce510 100644 --- a/src/nagini_translation/translators/common.py +++ b/src/nagini_translation/translators/common.py @@ -815,7 +815,7 @@ def get_error_var(self, stmt: ast.AST, if err_var.sil_name in ctx.var_aliases: err_var = ctx.var_aliases[err_var.sil_name] return err_var.ref() - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: return ctx.error_var.ref() else: new_var = ctx.current_function.create_variable('error', diff --git a/src/nagini_translation/translators/expression.py b/src/nagini_translation/translators/expression.py index 8867a244..cc1a4f4c 100644 --- a/src/nagini_translation/translators/expression.py +++ b/src/nagini_translation/translators/expression.py @@ -523,7 +523,7 @@ def create_exception_catchers(self, var: PythonVar, else: end_label = ctx.get_label_name(END_LABEL) goto_end = self.viper.Goto(end_label, position, self.no_info(ctx)) - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: assignerror = self.viper.LocalVarAssign(err_var, var, position, self.no_info(ctx)) uncaught_option = self.translate_block([assignerror, goto_end], diff --git a/src/nagini_translation/translators/method.py b/src/nagini_translation/translators/method.py index 4ad73490..87bec7a5 100644 --- a/src/nagini_translation/translators/method.py +++ b/src/nagini_translation/translators/method.py @@ -751,7 +751,7 @@ def translate_finally(self, block: PythonTryBlock, goto_continue = self.viper.Goto(loop.end_label, pos, info) break_block = [goto_break] continue_block = [goto_continue] - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: # Assign error to error output var error_var = ctx.error_var.ref() block_error_var = block.get_error_var(self.translator) diff --git a/src/nagini_translation/translators/obligation/method.py b/src/nagini_translation/translators/obligation/method.py index 5cb2c8d5..e130498f 100644 --- a/src/nagini_translation/translators/obligation/method.py +++ b/src/nagini_translation/translators/obligation/method.py @@ -55,7 +55,7 @@ class MethodObligationTranslator(CommonObligationTranslator): """Class for translating obligations in methods.""" def _get_obligation_info(self, ctx: Context) -> BaseObligationInfo: - return ctx.current_function.obligation_info + return ctx.actual_function.obligation_info def _create_obligation_instance_use( self, obligation_instance: ObligationInstance, diff --git a/src/nagini_translation/translators/obligation/waitlevel.py b/src/nagini_translation/translators/obligation/waitlevel.py index ef99282f..196a5067 100644 --- a/src/nagini_translation/translators/obligation/waitlevel.py +++ b/src/nagini_translation/translators/obligation/waitlevel.py @@ -191,7 +191,8 @@ def _translate_waitlevel_method( info = self.no_info(ctx) obligation_info = ctx.current_function.obligation_info - guard = obligation_info.get_wait_level_guard(node.left) + actual_obligation_info = ctx.actual_function.obligation_info + guard = actual_obligation_info.get_wait_level_guard(node.left) exhale = self._create_level_below_inex( guard, expr, obligation_info.residue_level, ctx) translated_exhale = exhale.translate(self, ctx, position, info) diff --git a/src/nagini_translation/translators/statement.py b/src/nagini_translation/translators/statement.py index 4467ecc4..e58c7aba 100644 --- a/src/nagini_translation/translators/statement.py +++ b/src/nagini_translation/translators/statement.py @@ -1430,14 +1430,14 @@ def _set_result_none(self, ctx: Context) -> List[Stmt]: null = self.viper.NullLit(self.no_position(ctx), self.no_info(ctx)) if ctx.actual_function.type: result_none = self.viper.LocalVarAssign( - ctx.actual_function.result.ref(), + ctx.result_var.ref(), null, self.no_position(ctx), self.no_info(ctx)) result.append(result_none) # Do the same for the error variable - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: error_none = self.viper.LocalVarAssign( - ctx.actual_function.error_var.ref(), + ctx.error_var.ref(), null, self.no_position(ctx), self.no_info(ctx)) result.append(error_none) return result diff --git a/tests/functional/translation/test_inline_1.py b/tests/functional/translation/test_inline_1.py new file mode 100644 index 00000000..738e35a1 --- /dev/null +++ b/tests/functional/translation/test_inline_1.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Predicate +@Inline +def test1() -> int: #:: ExpectedOutput(invalid.program:decorators.incompatible) + return 5 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_2.py b/tests/functional/translation/test_inline_2.py new file mode 100644 index 00000000..076578f6 --- /dev/null +++ b/tests/functional/translation/test_inline_2.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Pure +@Inline +def test1() -> int: #:: ExpectedOutput(invalid.program:decorators.incompatible) + return 5 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_3.py b/tests/functional/translation/test_inline_3.py new file mode 100644 index 00000000..8df02768 --- /dev/null +++ b/tests/functional/translation/test_inline_3.py @@ -0,0 +1,16 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +class A: + + def test1(self) -> int: + return 5 + +class B(A): + + @Inline + def test1(self) -> int: #:: ExpectedOutput(invalid.program:overriding.inline.method) + return 6 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_4.py b/tests/functional/translation/test_inline_4.py new file mode 100644 index 00000000..242011bc --- /dev/null +++ b/tests/functional/translation/test_inline_4.py @@ -0,0 +1,16 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +class A: + + @Inline + def test1(self) -> int: + return 5 + +class B(A): + + def test1(self) -> int: #:: ExpectedOutput(invalid.program:overriding.inline.method) + return 6 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_5.py b/tests/functional/translation/test_inline_5.py new file mode 100644 index 00000000..c3ee2af6 --- /dev/null +++ b/tests/functional/translation/test_inline_5.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Inline +def test1() -> int: + Ensures(Result() > 0) #:: ExpectedOutput(invalid.program:contract.in.inline.method) + return 6 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_6.py b/tests/functional/translation/test_inline_6.py new file mode 100644 index 00000000..d01981ff --- /dev/null +++ b/tests/functional/translation/test_inline_6.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Inline +def test1(i: int) -> int: + Requires(i > 0) #:: ExpectedOutput(invalid.program:contract.in.inline.method) + return 6 \ No newline at end of file diff --git a/tests/functional/verification/test_inline.py b/tests/functional/verification/test_inline.py new file mode 100644 index 00000000..e3ddab8c --- /dev/null +++ b/tests/functional/verification/test_inline.py @@ -0,0 +1,158 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.obligations import MustTerminate +from nagini_contracts.contracts import * + +@Inline +def power3_A(input: int) -> int: + out_a = input + try: + + for i in range(2): + Invariant(out_a == input * (input ** len(Previous(i)))) + try: + out_a = out_a * input + out_a -= 5 + finally: + out_a = out_a + 5 + finally: + out_a = out_a - 5 + return out_a + 5 + + +@Inline +def power3_B(input: int) -> int: + i = 0 + out_b = input + while i < 2: + Invariant(0 <= i and i <= 2) + Invariant(out_b == input * (input ** i)) + out_b = out_b * input + i += 1 + return out_b + +@Inline +def not_always_correct(i: int, j: int) -> int: + if j < 0: + assert False # would lead to an error but is not reached by any call + + if i == 0: + res = 8 + #:: ExpectedOutput(assert.failed:assertion.false,L1) + assert False # would lead to an error + elif i > 0: + res = 14 + else: + res = 9 + return res + +def partly_correct_caller() -> None: + tst = not_always_correct(5, 8) + Assert(tst == 14) + #:: Label(L1) + tst = not_always_correct(0, 8) + +@Inline +def may_not_terminate(i: int) -> None: + will_terminate = i >= 0 + #:: ExpectedOutput(leak_check.failed:loop_context.has_unsatisfied_obligations,L2) + while i != 0: + Invariant(Implies(will_terminate, MustTerminate(i))) + i -= 1 + +def should_terminate() -> None: + Requires(MustTerminate(5)) + + may_not_terminate(8) + + #:: Label(L2) + may_not_terminate(-6) + + +class A: + @Inline + def foo(self) -> int: + return 1 + + def bar(self) -> int: + return 1 + +def test_calls(input: int, a: A) -> None: + Requires(input > 0) + Assert(a.foo() == 1) + + Assert(power3_A(input) == power3_B(input)) + + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(a.bar() == 1) + +@Inline +def plus_two(i: int) -> int: + a = i + 2 + return a + +@Inline +def plus_seven(i: int) -> int: + a = plus_two(i) + b = plus_two(a) + return b + 3 + +def nested_caller() -> None: + Assert(plus_seven(9) == 16) + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(plus_seven(5) == 13) + +@Inline +def raises(a: int) -> None: + if a > 0: + raise Exception + +@Inline +def raises_and_catches(b: int) -> int: + r = 9 + try: + if b == 9: + raise Exception + r = 7 + except: + r = 12 + return r + +def calls_raises_1() -> int: + Ensures(Result() == 8888) + #:: ExpectedOutput(postcondition.violated:assertion.false) + Ensures(False) + r = 43 + + try: + r = 12 + raises(-3) + r = 456 + raises(4) + r = 2 + except: + Assert(r == 456) + r = 8888 + return r + +#:: ExpectedOutput(exhale.failed:assertion.false,L3) +def calls_raises_2() -> None: + #:: Label(L3) + raises(8) # error bc raises exception + +def calls_raises_3(i: int) -> None: + Ensures(i <= 0) + Exsures(Exception, i > 0) + #:: ExpectedOutput(postcondition.violated:assertion.false) + Exsures(Exception, False) + raises(i) + +def calls_raises_and_catches_1() -> None: + raises_and_catches(12) + + raises_and_catches(9) + + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(False) + diff --git a/tests/functional/verification/test_super.py b/tests/functional/verification/test_super.py index 094f04b7..419a4349 100644 --- a/tests/functional/verification/test_super.py +++ b/tests/functional/verification/test_super.py @@ -7,6 +7,14 @@ class Super: def some_method(self) -> int: Ensures(Result() >= 14) + + # some code that is just here to make sure loops can be properly inlined + input = 2 + out_a = input + for i in range(2): + Invariant(out_a == input * (input ** len(Previous(i)))) + out_a = out_a * input + return 14