diff --git a/src/nagini_contracts/contracts.py b/src/nagini_contracts/contracts.py index ac27a13d..68aadb8b 100644 --- a/src/nagini_contracts/contracts.py +++ b/src/nagini_contracts/contracts.py @@ -441,6 +441,13 @@ def Predicate(func: T) -> T: return func +def Inline(func: T) -> T: + """ + Decorator to mark functions that should be inlined during verification, and not treated modularly. It's a no-op. + """ + return func + + def Ghost(func: T) -> T: """ Decorator for ghost functions. It's a no-op. @@ -555,6 +562,7 @@ def isNaN(f: float) -> bool: 'Unfolding', 'Pure', 'Predicate', + 'Inline', 'Ghost', 'ContractOnly', 'GhostReturns', diff --git a/src/nagini_translation/analyzer.py b/src/nagini_translation/analyzer.py index 23a822c6..bdef38e8 100644 --- a/src/nagini_translation/analyzer.py +++ b/src/nagini_translation/analyzer.py @@ -674,6 +674,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: func.method_type = MethodType.class_method self.current_class._has_classmethod = True func.predicate = self.is_predicate(node) + if self.is_inline_method(node): + if name == '__init__': + raise UnsupportedException(node, 'Inlining constructors is currently not supported.') + decorators = {d.id for d in node.decorator_list if isinstance(d, ast.Name)} + if len(decorators) > 1: + raise UnsupportedException(node, 'Unsupported decorator for inline function.') + func.inline = True if func.predicate: func.contract_only = self.is_declared_contract_only(node) if self.is_all_low(node): @@ -909,6 +916,11 @@ def track_access(self, node: ast.AST, var: Union[PythonVar, PythonField]) -> Non elif isinstance(node.ctx, ast.Store): var.writes.append(node) + def _check_not_in_inline_method(self, node: ast.Call) -> None: + if isinstance(self.stmt_container, PythonMethod) and self.stmt_container.inline: + raise InvalidProgramException(node, 'contract.in.inline.method', + 'Inline methods must not have specifications.') + def visit_Call(self, node: ast.Call) -> None: """ Collects preconditions, postconditions, raised exceptions and @@ -916,6 +928,9 @@ def visit_Call(self, node: ast.Call) -> None: """ if (isinstance(node.func, ast.Name) and node.func.id in CONTRACT_WRAPPER_FUNCS): + if node.func.id != 'Invariant': + self._check_not_in_inline_method(node) + if node.func.id == 'Requires': self.stmt_container.precondition.append( (node.args[0], self._aliases.copy())) @@ -1469,6 +1484,8 @@ def visit_Try(self, node: ast.Try) -> None: def _incompatible_decorators(self, decorators) -> bool: return ((('Predicate' in decorators) and ('Pure' in decorators)) or + (('Predicate' in decorators) and ('Inline' in decorators)) or + (('Inline' in decorators) and ('Pure' in decorators)) or (('IOOperation' in decorators) and (len(decorators) != 1)) or (('property' in decorators) and (len(decorators) != 1)) or (('AllLow' in decorators) and ('PreservesLow' in decorators)) or @@ -1521,6 +1538,9 @@ def is_pure(self, func: ast.FunctionDef) -> bool: def is_predicate(self, func: ast.FunctionDef) -> bool: return self.has_decorator(func, 'Predicate') + def is_inline_method(self, func: ast.FunctionDef) -> bool: + return self.has_decorator(func, 'Inline') + def is_static_method(self, func: ast.FunctionDef) -> bool: return self.has_decorator(func, 'staticmethod') diff --git a/src/nagini_translation/lib/context.py b/src/nagini_translation/lib/context.py index ecdd99d6..21f859e4 100644 --- a/src/nagini_translation/lib/context.py +++ b/src/nagini_translation/lib/context.py @@ -172,7 +172,7 @@ def set_alias(self, name: str, var: PythonVar, self.old_aliases[name] = [] self.old_aliases[name].append(self.var_aliases[name]) if replaces: - if replaces.alt_types: + if hasattr(replaces, 'alt_types') and replaces.alt_types: assert not var.alt_types var.alt_types = replaces.alt_types self.var_aliases[name] = var diff --git a/src/nagini_translation/lib/program_nodes.py b/src/nagini_translation/lib/program_nodes.py index c9eec4c2..e9633784 100644 --- a/src/nagini_translation/lib/program_nodes.py +++ b/src/nagini_translation/lib/program_nodes.py @@ -1018,6 +1018,7 @@ def __init__(self, name: str, node: ast.AST, cls: PythonClass, self.declared_exceptions = OrderedDict() # direct self.pure = pure self.predicate = False + self.inline = False self.all_low = False self.preserves_low = False self.contract_only = contract_only @@ -1080,6 +1081,13 @@ def process(self, sil_name: str, translator: 'Translator') -> None: # Could be overridden by anything, so we have to check if there's # anything with the same name. self.overrides = self.cls.superclass.get_contents(False)[self.name] + if self.overrides is not None: + if self.overrides.inline: + raise InvalidProgramException(self.node, 'overriding.inline.method', + 'Functions marked to be inlined cannot be overridden.') + if self.inline: + raise InvalidProgramException(self.node, 'overriding.inline.method', + 'Functions marked to be inlined cannot override other methods.') except KeyError: pass for local in self.locals: diff --git a/src/nagini_translation/sif/translators/method.py b/src/nagini_translation/sif/translators/method.py index 8bfa1612..88a2f5df 100644 --- a/src/nagini_translation/sif/translators/method.py +++ b/src/nagini_translation/sif/translators/method.py @@ -29,7 +29,7 @@ def _method_body_postamble(self, method: PythonMethod, ctx: Context) -> List[Stm def _create_method_epilog(self, method: PythonMethod, ctx: Context) -> List[Stmt]: # With the extended AST we don't need a label at the end of the method. # Check that no undeclared exceptions are raised. (but not for main method) - if not method.declared_exceptions: + if not (method.declared_exceptions or method.inline): no_info = self.no_info(ctx) error_string = '"method raises no exceptions"' error_pos = self.to_position(method.node, ctx, error_string) diff --git a/src/nagini_translation/sif/translators/statement.py b/src/nagini_translation/sif/translators/statement.py index 3e618539..3a1691c7 100644 --- a/src/nagini_translation/sif/translators/statement.py +++ b/src/nagini_translation/sif/translators/statement.py @@ -170,7 +170,7 @@ def _create_With_exception_handler(self, try_block: PythonTryBlock, err_type_arg = self.type_factory.typeof(error_var, ctx) tb_class = ctx.module.global_module.classes['traceback'] - traceback_var = ctx.actual_function.create_variable('tb', tb_class, + traceback_var = ctx.current_function.create_variable('tb', tb_class, self.translator) tb_type = self.type_check(traceback_var.ref(), tb_class, no_pos, ctx) inhale_types = self.viper.Inhale(tb_type, no_pos, no_info) diff --git a/src/nagini_translation/translators/call.py b/src/nagini_translation/translators/call.py index c6268bc0..51715695 100644 --- a/src/nagini_translation/translators/call.py +++ b/src/nagini_translation/translators/call.py @@ -347,7 +347,7 @@ def _translate_list(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: if contents: sil_ref_seq = self.viper.SeqType(self.viper.Ref) ref_seq = SilverType(sil_ref_seq, ctx.module) - havoc_var = ctx.actual_function.create_variable('havoc_seq', ref_seq, + havoc_var = ctx.current_function.create_variable('havoc_seq', ref_seq, self.translator) seq_field = self.viper.Field('list_acc', sil_ref_seq, position, info) content_field = self.viper.FieldAccess(result_var, seq_field, position, info) @@ -391,7 +391,7 @@ def _translate_set(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: if contents: sil_ref_set = self.viper.SetType(self.viper.Ref) ref_set = SilverType(sil_ref_set, ctx.module) - havoc_var = ctx.actual_function.create_variable('havoc_set', ref_set, + havoc_var = ctx.current_function.create_variable('havoc_set', ref_set, self.translator) set_field = self.viper.Field('set_acc', sil_ref_set, position, info) content_field = self.viper.FieldAccess(result_var, set_field, position, info) @@ -455,7 +455,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: arg_type = self.get_type(node.args[0], ctx) arg_stmt, arg = self.translate_expr(node.args[0], ctx) arg_contents = self.get_sequence(arg_type, arg, None, node.args[0], ctx) - new_list = ctx.actual_function.create_variable('enumerate_res', result_type, + new_list = ctx.current_function.create_variable('enumerate_res', result_type, self.translator) sil_ref_seq = self.viper.SeqType(self.viper.Ref) seq_field = self.viper.Field('list_acc', sil_ref_seq, pos, info) @@ -470,7 +470,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: type_inhale = self.viper.Inhale(self.viper.And(list_type_info, list_len_info, pos, info), pos, info) - i_var = ctx.actual_function.create_variable('i', prim_int_type, self.translator, + i_var = ctx.current_function.create_variable('i', prim_int_type, self.translator, False) orig_seq_i = self.viper.SeqIndex(arg_contents, i_var.ref(), pos, info) content_type = result_type.type_args[0].type_args[1] @@ -921,9 +921,12 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar], # Create local var aliases locals_to_copy = method.locals.copy() for local_name, local in locals_to_copy.items(): + if type(local).__name__ == 'SilverVar': + continue local_var = ctx.current_function.create_variable(local_name, local.type, self.translator) + local_var.writes = local.writes ctx.set_alias(local_name, local_var, local) # Create label aliases @@ -939,6 +942,11 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar], start, end = get_body_indices(method.node.body) stmts = [] + if error_var: + pos = self.no_position(ctx) + info = self.no_info(ctx) + stmts.append(self.viper.LocalVarAssign(error_var.ref(), self.viper.NullLit(pos, info), pos, info)) + for stmt in method.node.body[start:end]: stmts += self.translate_stmt(stmt, ctx) @@ -986,7 +994,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool, self.translator) optional_error_var = None error_var = self.get_error_var(node, ctx) - if method.declared_exceptions: + if method.declared_exceptions or method.inline: var = PythonVar(ERROR_NAME, None, ctx.module.global_module.classes['Exception']) var._ref = error_var @@ -999,7 +1007,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool, stmts += inline_stmts if end_lbl: stmts.append(end_lbl) - if method.declared_exceptions: + if method.declared_exceptions or method.inline: stmts += self.create_exception_catchers(error_var, ctx.actual_function.try_blocks, node, ctx) # Return result @@ -1201,6 +1209,9 @@ def translate_normal_call(self, target: PythonMethod, arg_stmts: List[Stmt], return self._translate_function_call(target, args, formal_args, arg_stmts, position, node, ctx) else: + if target.inline: + return self._inline_call(target, node, False, 'inlined call', + ctx) return self._translate_method_call(target, args, arg_stmts, position, node, ctx) @@ -1355,7 +1366,7 @@ def _translate_thread_creation(self, node: ast.Call, # Create thread object thread_class = ctx.module.global_module.classes['Thread'] - thread_var = ctx.actual_function.create_variable('threadingVar', thread_class, + thread_var = ctx.current_function.create_variable('threadingVar', thread_class, self.translator) thread = thread_var.ref(node, ctx) newstmt = self.viper.NewStmt(thread, [], pos, info) @@ -1501,7 +1512,7 @@ def _translate_thread_join(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: ctx.perm_factor = post_perm object_class = ctx.module.global_module.classes[OBJECT_TYPE] - res_var = ctx.actual_function.create_variable('join_result', object_class, + res_var = ctx.current_function.create_variable('join_result', object_class, self.translator) # Resolve list of possible thread target methods. @@ -1544,7 +1555,7 @@ def _inhale_possible_thread_post(self, method: PythonMethod, thread: Expr, # Set arg aliases with types for index, arg in enumerate(method._args.values()): - arg_var = ctx.actual_function.create_variable('thread_arg', arg.type, + arg_var = ctx.current_function.create_variable('thread_arg', arg.type, self.translator) ctx.set_alias(arg.name, arg_var) id = self.viper.IntLit(index, pos, info) diff --git a/src/nagini_translation/translators/common.py b/src/nagini_translation/translators/common.py index 5fe442b8..fb9ce510 100644 --- a/src/nagini_translation/translators/common.py +++ b/src/nagini_translation/translators/common.py @@ -544,7 +544,7 @@ def get_func_or_method_call(self, receiver: PythonType, func_name: str, method = receiver.get_method(func_name) if method: assert method.type - target_var = ctx.actual_function.create_variable('target', method.type, + target_var = ctx.current_function.create_variable('target', method.type, self.translator) val = target_var.ref(node, ctx) call = self.get_method_call(receiver, func_name, args, arg_types, [val], node, @@ -815,7 +815,7 @@ def get_error_var(self, stmt: ast.AST, if err_var.sil_name in ctx.var_aliases: err_var = ctx.var_aliases[err_var.sil_name] return err_var.ref() - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: return ctx.error_var.ref() else: new_var = ctx.current_function.create_variable('error', diff --git a/src/nagini_translation/translators/expression.py b/src/nagini_translation/translators/expression.py index c9214860..cc1a4f4c 100644 --- a/src/nagini_translation/translators/expression.py +++ b/src/nagini_translation/translators/expression.py @@ -177,7 +177,7 @@ def translate_ListComp(self, node: ast.ListComp, ctx: Context) -> StmtsAndExpr: if body_stmt: raise InvalidProgramException(node, 'impure.list.comprehension.body') ctx.remove_alias(target.id) - result_var = ctx.actual_function.create_variable('listcomp', list_type, + result_var = ctx.current_function.create_variable('listcomp', list_type, self.translator) stmt = self._create_list_comp_inhale(result_var, list_type, element_var, body, node, ctx) @@ -213,7 +213,7 @@ def _create_list_comp_inhale(self, result_var: PythonVar, list_type: PythonType, len_equal = self.viper.EqCmp(self.to_int(result_len, ctx), seq_len, position, info) int_class = ctx.module.global_module.classes[PRIMITIVE_INT_TYPE] - index_var = ctx.actual_function.create_variable('i', int_class, self.translator, + index_var = ctx.current_function.create_variable('i', int_class, self.translator, False) index_positive = self.viper.GeCmp(index_var.ref(), self.viper.IntLit(0, position, info), @@ -491,7 +491,7 @@ def _translate_slice_subscript(self, node: ast.Subscript, target: Expr, stmt = target_stmt + start_stmt + stop_stmt getitem = target_type.get_func_or_method('__getitem_slice__') if not getitem.pure: - result_var = ctx.actual_function.create_variable( + result_var = ctx.current_function.create_variable( 'slice_res', target_type, self.translator) call = self.get_method_call(target_type, '__getitem_slice__', args, [None, None], @@ -523,7 +523,7 @@ def create_exception_catchers(self, var: PythonVar, else: end_label = ctx.get_label_name(END_LABEL) goto_end = self.viper.Goto(end_label, position, self.no_info(ctx)) - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: assignerror = self.viper.LocalVarAssign(err_var, var, position, self.no_info(ctx)) uncaught_option = self.translate_block([assignerror, goto_end], @@ -665,7 +665,7 @@ def translate_Name(self, node: ast.Name, ctx: Context) -> StmtsAndExpr: self.to_position(node, ctx), ctx) if node.id == '_': object_type = ctx.module.global_module.classes[OBJECT_TYPE] - temp_var = ctx.actual_function.create_variable('wildcard', object_type, self.translator) + temp_var = ctx.current_function.create_variable('wildcard', object_type, self.translator) return [], temp_var.ref(node, ctx) if node.id in ctx.var_aliases: var = ctx.var_aliases[node.id] diff --git a/src/nagini_translation/translators/method.py b/src/nagini_translation/translators/method.py index 4ad73490..87bec7a5 100644 --- a/src/nagini_translation/translators/method.py +++ b/src/nagini_translation/translators/method.py @@ -751,7 +751,7 @@ def translate_finally(self, block: PythonTryBlock, goto_continue = self.viper.Goto(loop.end_label, pos, info) break_block = [goto_break] continue_block = [goto_continue] - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: # Assign error to error output var error_var = ctx.error_var.ref() block_error_var = block.get_error_var(self.translator) diff --git a/src/nagini_translation/translators/obligation/fork.py b/src/nagini_translation/translators/obligation/fork.py index 25d8a793..1ab46829 100644 --- a/src/nagini_translation/translators/obligation/fork.py +++ b/src/nagini_translation/translators/obligation/fork.py @@ -120,7 +120,7 @@ def _set_parameter_aliases(self, method: PythonMethod) -> List[Stmt]: arg_vars = [] stmts = [] for index, arg in enumerate(method._args.values()): - arg_var = self._ctx.actual_function.create_variable( + arg_var = self._ctx.current_function.create_variable( 'thread_arg', arg.type, self._translator.translator) arg_vars.append(arg_var) index_lit = self.viper.IntLit(index, self._position, self._info) @@ -219,7 +219,7 @@ def _add_waitlevel(self) -> None: def _create_level_below( self, expr: sil.PermExpression, ctx: Context) -> sil.BoolExpression: - residue_level_var = sil.PermVar(ctx.actual_function.obligation_info.residue_level) + residue_level_var = sil.PermVar(ctx.current_function.obligation_info.residue_level) obligation = self._obligation_manager.must_release_obligation fields = obligation.create_fields_untranslated() var = ctx.current_function.create_variable( diff --git a/src/nagini_translation/translators/obligation/inexhale.py b/src/nagini_translation/translators/obligation/inexhale.py index f767ab9d..3ad7e30f 100644 --- a/src/nagini_translation/translators/obligation/inexhale.py +++ b/src/nagini_translation/translators/obligation/inexhale.py @@ -179,7 +179,7 @@ def get_use_method( self, ctx: Context) -> List[Tuple[sil.Expression, Rules]]: """Default implementation for obligation use in method contract.""" inexhale = self._get_inexhale(ctx) - obligation_info = ctx.actual_function.obligation_info + obligation_info = ctx.current_function.obligation_info if self.is_fresh(): return [(inexhale.construct_use_method_unbounded(), None)] else: diff --git a/src/nagini_translation/translators/obligation/interface.py b/src/nagini_translation/translators/obligation/interface.py index b3b97fff..de488e22 100644 --- a/src/nagini_translation/translators/obligation/interface.py +++ b/src/nagini_translation/translators/obligation/interface.py @@ -129,7 +129,7 @@ def translate_must_invoke_ctoken( return self._loop_translator.translate_may_invoke( node, ctx) elif ctx.obligation_context.is_translating_posts: - if ctx.actual_function.name != 'Gap': + if ctx.current_function.name != 'Gap': raise InvalidProgramException( node, 'invalid.postcondition.ctoken_not_allowed') else: diff --git a/src/nagini_translation/translators/obligation/loop.py b/src/nagini_translation/translators/obligation/loop.py index b70541f6..8d6ca6a1 100644 --- a/src/nagini_translation/translators/obligation/loop.py +++ b/src/nagini_translation/translators/obligation/loop.py @@ -46,9 +46,9 @@ def enter_loop_translation( err_var: PythonVar = None) -> None: """Update context with info needed to translate loop.""" info = PythonLoopObligationInfo( - self._obligation_manager, node, self, ctx.actual_function, + self._obligation_manager, node, self, ctx.current_function, err_var) - info.traverse_invariants() + info.traverse_invariants(ctx.actual_function) ctx.obligation_context.push_loop_info(info) def leave_loop_translation(self, ctx: Context) -> None: diff --git a/src/nagini_translation/translators/obligation/loop_node.py b/src/nagini_translation/translators/obligation/loop_node.py index cdfbe288..f5290458 100644 --- a/src/nagini_translation/translators/obligation/loop_node.py +++ b/src/nagini_translation/translators/obligation/loop_node.py @@ -126,7 +126,7 @@ def _add_additional_invariants(self) -> None: def _add_leak_check(self) -> None: """Add leak checks to invariant.""" - reference_name = self._ctx.actual_function.get_fresh_name('_r') + reference_name = self._ctx.current_function.get_fresh_name('_r') leak_check = self._obligation_manager.create_leak_check(reference_name) loop_check_before = sil.BoolVar( self._loop_obligation_info.loop_check_before_var) diff --git a/src/nagini_translation/translators/obligation/obligation_info.py b/src/nagini_translation/translators/obligation/obligation_info.py index 69d0d6ce..d7f28547 100644 --- a/src/nagini_translation/translators/obligation/obligation_info.py +++ b/src/nagini_translation/translators/obligation/obligation_info.py @@ -428,11 +428,11 @@ def current_thread_var(self) -> PythonVar: """Return the variable that denotes current thread in method.""" return self._method.obligation_info.current_thread_var - def traverse_invariants(self) -> None: + def traverse_invariants(self, actual_method: PythonMethod) -> None: """Collect all needed information about obligations.""" assert self._current_instance_map is None self._current_instance_map = self._instances - for invariant, _ in self._method.loop_invariants[self.node]: + for invariant, _ in actual_method.loop_invariants[self.node]: if isinstance(invariant, ast.Expr): self.traverse(invariant.value.args[0]) else: diff --git a/src/nagini_translation/translators/obligation/waitlevel.py b/src/nagini_translation/translators/obligation/waitlevel.py index 5c98deb7..196a5067 100644 --- a/src/nagini_translation/translators/obligation/waitlevel.py +++ b/src/nagini_translation/translators/obligation/waitlevel.py @@ -190,8 +190,9 @@ def _translate_waitlevel_method( position = self.to_position(node, ctx) info = self.no_info(ctx) - obligation_info = ctx.actual_function.obligation_info - guard = obligation_info.get_wait_level_guard(node.left) + obligation_info = ctx.current_function.obligation_info + actual_obligation_info = ctx.actual_function.obligation_info + guard = actual_obligation_info.get_wait_level_guard(node.left) exhale = self._create_level_below_inex( guard, expr, obligation_info.residue_level, ctx) translated_exhale = exhale.translate(self, ctx, position, info) diff --git a/src/nagini_translation/translators/program.py b/src/nagini_translation/translators/program.py index 1834cabf..e02dfbae 100644 --- a/src/nagini_translation/translators/program.py +++ b/src/nagini_translation/translators/program.py @@ -257,7 +257,7 @@ def create_inherit_check(self, method: PythonMethod, cls: PythonClass, pres = [not_null, new_type] + pres stmts, end_lbl = self.inline_method(method, args, method.result, - error_var, ctx) + optional_error_var, ctx) self._create_inherit_check_postamble(stmts, end_lbl, ctx) @@ -1314,7 +1314,7 @@ def translate_program(self, modules: List[PythonModule], sil_progs: Program, for method in module.methods.values(): id_constant = self.translate_method_id_to_constant(method, ctx) threading_ids_constants.append(id_constant) - if method.interface: + if method.interface or method.inline: continue self.track_dependencies(selected_names, selected, method, ctx) methods.append(self.translate_method(method, ctx)) @@ -1351,7 +1351,7 @@ def translate_program(self, modules: List[PythonModule], sil_progs: Program, method = cls.methods[method_name] threading_ids_constants.append( self.translate_method_id_to_constant(method, ctx)) - if method.interface: + if method.interface or method.inline: continue self.track_dependencies(selected_names, selected, method, ctx) methods.append(self.translate_method(method, ctx)) diff --git a/src/nagini_translation/translators/statement.py b/src/nagini_translation/translators/statement.py index f629d7be..e58c7aba 100644 --- a/src/nagini_translation/translators/statement.py +++ b/src/nagini_translation/translators/statement.py @@ -601,7 +601,7 @@ def _get_iterator(self, iterable: Expr, iterable_type: PythonType, node: ast.AST, ctx: Context) -> Tuple[PythonVar, List[Stmt]]: iter_class = ctx.module.global_module.classes['Iterator'] - iter_var = ctx.actual_function.create_variable('iter', iter_class, + iter_var = ctx.current_function.create_variable('iter', iter_class, self.translator) assert not node in ctx.loop_iterators ctx.loop_iterators[node] = iter_var @@ -616,7 +616,7 @@ def _get_next_call(self, iter_var: PythonVar, target_var: PythonVar, node: ast.For, ctx: Context) -> Tuple[PythonVar, List[Stmt]]: exc_class = ctx.module.global_module.classes['Exception'] - err_var = ctx.actual_function.create_variable('iter_err', exc_class, + err_var = ctx.current_function.create_variable('iter_err', exc_class, self.translator) iter_class = ctx.module.global_module.classes['Iterator'] args = [iter_var.ref()] @@ -701,7 +701,7 @@ def translate_stmt_For(self, node: ast.For, ctx: Context) -> List[Stmt]: node.end_label = end_label iterable_type = self.get_type(node.iter, ctx) iterable_stmt, iterable = self.translate_expr(node.iter, ctx) - iterable_var = ctx.actual_function.create_variable('iterable', iterable_type, + iterable_var = ctx.current_function.create_variable('iterable', iterable_type, self.translator, True) iterable_assign = self.viper.LocalVarAssign(iterable_var.ref(), iterable, position, info) @@ -718,7 +718,7 @@ def translate_stmt_For(self, node: ast.For, ctx: Context) -> List[Stmt]: raise UnsupportedException(node, 'unknown.iterable') # Create artificial new variable to store current iteration content. - target_var = ctx.actual_function.create_variable('loop_target', + target_var = ctx.current_function.create_variable('loop_target', target_type, self.translator) @@ -1430,14 +1430,14 @@ def _set_result_none(self, ctx: Context) -> List[Stmt]: null = self.viper.NullLit(self.no_position(ctx), self.no_info(ctx)) if ctx.actual_function.type: result_none = self.viper.LocalVarAssign( - ctx.actual_function.result.ref(), + ctx.result_var.ref(), null, self.no_position(ctx), self.no_info(ctx)) result.append(result_none) # Do the same for the error variable - if ctx.actual_function.declared_exceptions: + if ctx.actual_function.declared_exceptions or ctx.actual_function.inline: error_none = self.viper.LocalVarAssign( - ctx.actual_function.error_var.ref(), + ctx.error_var.ref(), null, self.no_position(ctx), self.no_info(ctx)) result.append(error_none) return result diff --git a/tests/functional/translation/test_inline_1.py b/tests/functional/translation/test_inline_1.py new file mode 100644 index 00000000..d6ec41f2 --- /dev/null +++ b/tests/functional/translation/test_inline_1.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Predicate +@Inline #:: ExpectedOutput(invalid.program:decorators.incompatible) +def test1() -> int: + return 5 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_2.py b/tests/functional/translation/test_inline_2.py new file mode 100644 index 00000000..683bc3d9 --- /dev/null +++ b/tests/functional/translation/test_inline_2.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Pure +@Inline #:: ExpectedOutput(invalid.program:decorators.incompatible) +def test1() -> int: + return 5 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_3.py b/tests/functional/translation/test_inline_3.py new file mode 100644 index 00000000..36681e28 --- /dev/null +++ b/tests/functional/translation/test_inline_3.py @@ -0,0 +1,16 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +class A: + + def test1(self) -> int: + return 5 + +class B(A): + + @Inline #:: ExpectedOutput(invalid.program:overriding.inline.method) + def test1(self) -> int: + return 6 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_4.py b/tests/functional/translation/test_inline_4.py new file mode 100644 index 00000000..e53d89c5 --- /dev/null +++ b/tests/functional/translation/test_inline_4.py @@ -0,0 +1,17 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +class A: + + @Inline + def test1(self) -> int: + return 5 + +class B(A): + + #:: ExpectedOutput(invalid.program:overriding.inline.method) + def test1(self) -> int: + return 6 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_5.py b/tests/functional/translation/test_inline_5.py new file mode 100644 index 00000000..fe999649 --- /dev/null +++ b/tests/functional/translation/test_inline_5.py @@ -0,0 +1,11 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Inline +def test1() -> int: + #:: ExpectedOutput(invalid.program:contract.in.inline.method) + Ensures(Result() > 0) + return 6 \ No newline at end of file diff --git a/tests/functional/translation/test_inline_6.py b/tests/functional/translation/test_inline_6.py new file mode 100644 index 00000000..b6b70368 --- /dev/null +++ b/tests/functional/translation/test_inline_6.py @@ -0,0 +1,11 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +@Inline +def test1(i: int) -> int: + #:: ExpectedOutput(invalid.program:contract.in.inline.method) + Requires(i > 0) + return 6 \ No newline at end of file diff --git a/tests/functional/verification/test_inline.py b/tests/functional/verification/test_inline.py new file mode 100644 index 00000000..e3ddab8c --- /dev/null +++ b/tests/functional/verification/test_inline.py @@ -0,0 +1,158 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.obligations import MustTerminate +from nagini_contracts.contracts import * + +@Inline +def power3_A(input: int) -> int: + out_a = input + try: + + for i in range(2): + Invariant(out_a == input * (input ** len(Previous(i)))) + try: + out_a = out_a * input + out_a -= 5 + finally: + out_a = out_a + 5 + finally: + out_a = out_a - 5 + return out_a + 5 + + +@Inline +def power3_B(input: int) -> int: + i = 0 + out_b = input + while i < 2: + Invariant(0 <= i and i <= 2) + Invariant(out_b == input * (input ** i)) + out_b = out_b * input + i += 1 + return out_b + +@Inline +def not_always_correct(i: int, j: int) -> int: + if j < 0: + assert False # would lead to an error but is not reached by any call + + if i == 0: + res = 8 + #:: ExpectedOutput(assert.failed:assertion.false,L1) + assert False # would lead to an error + elif i > 0: + res = 14 + else: + res = 9 + return res + +def partly_correct_caller() -> None: + tst = not_always_correct(5, 8) + Assert(tst == 14) + #:: Label(L1) + tst = not_always_correct(0, 8) + +@Inline +def may_not_terminate(i: int) -> None: + will_terminate = i >= 0 + #:: ExpectedOutput(leak_check.failed:loop_context.has_unsatisfied_obligations,L2) + while i != 0: + Invariant(Implies(will_terminate, MustTerminate(i))) + i -= 1 + +def should_terminate() -> None: + Requires(MustTerminate(5)) + + may_not_terminate(8) + + #:: Label(L2) + may_not_terminate(-6) + + +class A: + @Inline + def foo(self) -> int: + return 1 + + def bar(self) -> int: + return 1 + +def test_calls(input: int, a: A) -> None: + Requires(input > 0) + Assert(a.foo() == 1) + + Assert(power3_A(input) == power3_B(input)) + + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(a.bar() == 1) + +@Inline +def plus_two(i: int) -> int: + a = i + 2 + return a + +@Inline +def plus_seven(i: int) -> int: + a = plus_two(i) + b = plus_two(a) + return b + 3 + +def nested_caller() -> None: + Assert(plus_seven(9) == 16) + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(plus_seven(5) == 13) + +@Inline +def raises(a: int) -> None: + if a > 0: + raise Exception + +@Inline +def raises_and_catches(b: int) -> int: + r = 9 + try: + if b == 9: + raise Exception + r = 7 + except: + r = 12 + return r + +def calls_raises_1() -> int: + Ensures(Result() == 8888) + #:: ExpectedOutput(postcondition.violated:assertion.false) + Ensures(False) + r = 43 + + try: + r = 12 + raises(-3) + r = 456 + raises(4) + r = 2 + except: + Assert(r == 456) + r = 8888 + return r + +#:: ExpectedOutput(exhale.failed:assertion.false,L3) +def calls_raises_2() -> None: + #:: Label(L3) + raises(8) # error bc raises exception + +def calls_raises_3(i: int) -> None: + Ensures(i <= 0) + Exsures(Exception, i > 0) + #:: ExpectedOutput(postcondition.violated:assertion.false) + Exsures(Exception, False) + raises(i) + +def calls_raises_and_catches_1() -> None: + raises_and_catches(12) + + raises_and_catches(9) + + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(False) + diff --git a/tests/functional/verification/test_super.py b/tests/functional/verification/test_super.py index 094f04b7..419a4349 100644 --- a/tests/functional/verification/test_super.py +++ b/tests/functional/verification/test_super.py @@ -7,6 +7,14 @@ class Super: def some_method(self) -> int: Ensures(Result() >= 14) + + # some code that is just here to make sure loops can be properly inlined + input = 2 + out_a = input + for i in range(2): + Invariant(out_a == input * (input ** len(Previous(i)))) + out_a = out_a * input + return 14