diff --git a/src/nagini_contracts/contracts.py b/src/nagini_contracts/contracts.py index ac27a13d..68aadb8b 100644 --- a/src/nagini_contracts/contracts.py +++ b/src/nagini_contracts/contracts.py @@ -441,6 +441,13 @@ def Predicate(func: T) -> T: return func +def Inline(func: T) -> T: + """ + Decorator to mark functions that should be inlined during verification, and not treated modularly. It's a no-op. + """ + return func + + def Ghost(func: T) -> T: """ Decorator for ghost functions. It's a no-op. @@ -555,6 +562,7 @@ def isNaN(f: float) -> bool: 'Unfolding', 'Pure', 'Predicate', + 'Inline', 'Ghost', 'ContractOnly', 'GhostReturns', diff --git a/src/nagini_translation/analyzer.py b/src/nagini_translation/analyzer.py index 23a822c6..5dd313f3 100644 --- a/src/nagini_translation/analyzer.py +++ b/src/nagini_translation/analyzer.py @@ -674,6 +674,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: func.method_type = MethodType.class_method self.current_class._has_classmethod = True func.predicate = self.is_predicate(node) + if self.is_inline_method(node): + func.inline = True if func.predicate: func.contract_only = self.is_declared_contract_only(node) if self.is_all_low(node): @@ -1469,6 +1471,8 @@ def visit_Try(self, node: ast.Try) -> None: def _incompatible_decorators(self, decorators) -> bool: return ((('Predicate' in decorators) and ('Pure' in decorators)) or + (('Predicate' in decorators) and ('Inline' in decorators)) or + (('Inline' in decorators) and ('Pure' in decorators)) or (('IOOperation' in decorators) and (len(decorators) != 1)) or (('property' in decorators) and (len(decorators) != 1)) or (('AllLow' in decorators) and ('PreservesLow' in decorators)) or @@ -1521,6 +1525,9 @@ def is_pure(self, func: ast.FunctionDef) -> bool: def is_predicate(self, func: ast.FunctionDef) -> bool: return self.has_decorator(func, 'Predicate') + def is_inline_method(self, func: ast.FunctionDef) -> bool: + return self.has_decorator(func, 'Inline') + def is_static_method(self, func: ast.FunctionDef) -> bool: return self.has_decorator(func, 'staticmethod') diff --git a/src/nagini_translation/lib/program_nodes.py b/src/nagini_translation/lib/program_nodes.py index c9eec4c2..a1a2d770 100644 --- a/src/nagini_translation/lib/program_nodes.py +++ b/src/nagini_translation/lib/program_nodes.py @@ -1018,6 +1018,7 @@ def __init__(self, name: str, node: ast.AST, cls: PythonClass, self.declared_exceptions = OrderedDict() # direct self.pure = pure self.predicate = False + self.inline = False self.all_low = False self.preserves_low = False self.contract_only = contract_only @@ -1080,6 +1081,10 @@ def process(self, sil_name: str, translator: 'Translator') -> None: # Could be overridden by anything, so we have to check if there's # anything with the same name. self.overrides = self.cls.superclass.get_contents(False)[self.name] + if self.overrides is not None: + if self.overrides.inline: + raise InvalidProgramException(self.node, 'overriding.inline.method', + 'Functions marked to be inlined cannot be overridden.') except KeyError: pass for local in self.locals: @@ -1577,6 +1582,8 @@ def process(self, sil_name: str, translator: 'Translator') -> None: this Python variable. """ super().process(sil_name, translator) + if sil_name == "iterable_1": + print("++") self._translator = translator module = self.type.module self.decl = translator.translate_pythonvar_decl(self, module) diff --git a/src/nagini_translation/sif/translators/statement.py b/src/nagini_translation/sif/translators/statement.py index 3e618539..3a1691c7 100644 --- a/src/nagini_translation/sif/translators/statement.py +++ b/src/nagini_translation/sif/translators/statement.py @@ -170,7 +170,7 @@ def _create_With_exception_handler(self, try_block: PythonTryBlock, err_type_arg = self.type_factory.typeof(error_var, ctx) tb_class = ctx.module.global_module.classes['traceback'] - traceback_var = ctx.actual_function.create_variable('tb', tb_class, + traceback_var = ctx.current_function.create_variable('tb', tb_class, self.translator) tb_type = self.type_check(traceback_var.ref(), tb_class, no_pos, ctx) inhale_types = self.viper.Inhale(tb_type, no_pos, no_info) diff --git a/src/nagini_translation/translators/call.py b/src/nagini_translation/translators/call.py index c6268bc0..2286be1e 100644 --- a/src/nagini_translation/translators/call.py +++ b/src/nagini_translation/translators/call.py @@ -347,7 +347,7 @@ def _translate_list(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: if contents: sil_ref_seq = self.viper.SeqType(self.viper.Ref) ref_seq = SilverType(sil_ref_seq, ctx.module) - havoc_var = ctx.actual_function.create_variable('havoc_seq', ref_seq, + havoc_var = ctx.current_function.create_variable('havoc_seq', ref_seq, self.translator) seq_field = self.viper.Field('list_acc', sil_ref_seq, position, info) content_field = self.viper.FieldAccess(result_var, seq_field, position, info) @@ -391,7 +391,7 @@ def _translate_set(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: if contents: sil_ref_set = self.viper.SetType(self.viper.Ref) ref_set = SilverType(sil_ref_set, ctx.module) - havoc_var = ctx.actual_function.create_variable('havoc_set', ref_set, + havoc_var = ctx.current_function.create_variable('havoc_set', ref_set, self.translator) set_field = self.viper.Field('set_acc', sil_ref_set, position, info) content_field = self.viper.FieldAccess(result_var, set_field, position, info) @@ -455,7 +455,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: arg_type = self.get_type(node.args[0], ctx) arg_stmt, arg = self.translate_expr(node.args[0], ctx) arg_contents = self.get_sequence(arg_type, arg, None, node.args[0], ctx) - new_list = ctx.actual_function.create_variable('enumerate_res', result_type, + new_list = ctx.current_function.create_variable('enumerate_res', result_type, self.translator) sil_ref_seq = self.viper.SeqType(self.viper.Ref) seq_field = self.viper.Field('list_acc', sil_ref_seq, pos, info) @@ -470,7 +470,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: type_inhale = self.viper.Inhale(self.viper.And(list_type_info, list_len_info, pos, info), pos, info) - i_var = ctx.actual_function.create_variable('i', prim_int_type, self.translator, + i_var = ctx.current_function.create_variable('i', prim_int_type, self.translator, False) orig_seq_i = self.viper.SeqIndex(arg_contents, i_var.ref(), pos, info) content_type = result_type.type_args[0].type_args[1] @@ -921,6 +921,10 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar], # Create local var aliases locals_to_copy = method.locals.copy() for local_name, local in locals_to_copy.items(): + if type(local).__name__ == 'SilverVar': + newName = ctx.current_function.get_fresh_name(local_name) + print(local_name) + continue local_var = ctx.current_function.create_variable(local_name, local.type, self.translator) @@ -1201,6 +1205,9 @@ def translate_normal_call(self, target: PythonMethod, arg_stmts: List[Stmt], return self._translate_function_call(target, args, formal_args, arg_stmts, position, node, ctx) else: + if target.inline: + return self._inline_call(target, node, False, 'inlined call', + ctx) return self._translate_method_call(target, args, arg_stmts, position, node, ctx) @@ -1355,7 +1362,7 @@ def _translate_thread_creation(self, node: ast.Call, # Create thread object thread_class = ctx.module.global_module.classes['Thread'] - thread_var = ctx.actual_function.create_variable('threadingVar', thread_class, + thread_var = ctx.current_function.create_variable('threadingVar', thread_class, self.translator) thread = thread_var.ref(node, ctx) newstmt = self.viper.NewStmt(thread, [], pos, info) @@ -1501,7 +1508,7 @@ def _translate_thread_join(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: ctx.perm_factor = post_perm object_class = ctx.module.global_module.classes[OBJECT_TYPE] - res_var = ctx.actual_function.create_variable('join_result', object_class, + res_var = ctx.current_function.create_variable('join_result', object_class, self.translator) # Resolve list of possible thread target methods. @@ -1544,7 +1551,7 @@ def _inhale_possible_thread_post(self, method: PythonMethod, thread: Expr, # Set arg aliases with types for index, arg in enumerate(method._args.values()): - arg_var = ctx.actual_function.create_variable('thread_arg', arg.type, + arg_var = ctx.current_function.create_variable('thread_arg', arg.type, self.translator) ctx.set_alias(arg.name, arg_var) id = self.viper.IntLit(index, pos, info) diff --git a/src/nagini_translation/translators/common.py b/src/nagini_translation/translators/common.py index 5fe442b8..8a9057bf 100644 --- a/src/nagini_translation/translators/common.py +++ b/src/nagini_translation/translators/common.py @@ -544,7 +544,7 @@ def get_func_or_method_call(self, receiver: PythonType, func_name: str, method = receiver.get_method(func_name) if method: assert method.type - target_var = ctx.actual_function.create_variable('target', method.type, + target_var = ctx.current_function.create_variable('target', method.type, self.translator) val = target_var.ref(node, ctx) call = self.get_method_call(receiver, func_name, args, arg_types, [val], node, diff --git a/src/nagini_translation/translators/expression.py b/src/nagini_translation/translators/expression.py index c9214860..8867a244 100644 --- a/src/nagini_translation/translators/expression.py +++ b/src/nagini_translation/translators/expression.py @@ -177,7 +177,7 @@ def translate_ListComp(self, node: ast.ListComp, ctx: Context) -> StmtsAndExpr: if body_stmt: raise InvalidProgramException(node, 'impure.list.comprehension.body') ctx.remove_alias(target.id) - result_var = ctx.actual_function.create_variable('listcomp', list_type, + result_var = ctx.current_function.create_variable('listcomp', list_type, self.translator) stmt = self._create_list_comp_inhale(result_var, list_type, element_var, body, node, ctx) @@ -213,7 +213,7 @@ def _create_list_comp_inhale(self, result_var: PythonVar, list_type: PythonType, len_equal = self.viper.EqCmp(self.to_int(result_len, ctx), seq_len, position, info) int_class = ctx.module.global_module.classes[PRIMITIVE_INT_TYPE] - index_var = ctx.actual_function.create_variable('i', int_class, self.translator, + index_var = ctx.current_function.create_variable('i', int_class, self.translator, False) index_positive = self.viper.GeCmp(index_var.ref(), self.viper.IntLit(0, position, info), @@ -491,7 +491,7 @@ def _translate_slice_subscript(self, node: ast.Subscript, target: Expr, stmt = target_stmt + start_stmt + stop_stmt getitem = target_type.get_func_or_method('__getitem_slice__') if not getitem.pure: - result_var = ctx.actual_function.create_variable( + result_var = ctx.current_function.create_variable( 'slice_res', target_type, self.translator) call = self.get_method_call(target_type, '__getitem_slice__', args, [None, None], @@ -665,7 +665,7 @@ def translate_Name(self, node: ast.Name, ctx: Context) -> StmtsAndExpr: self.to_position(node, ctx), ctx) if node.id == '_': object_type = ctx.module.global_module.classes[OBJECT_TYPE] - temp_var = ctx.actual_function.create_variable('wildcard', object_type, self.translator) + temp_var = ctx.current_function.create_variable('wildcard', object_type, self.translator) return [], temp_var.ref(node, ctx) if node.id in ctx.var_aliases: var = ctx.var_aliases[node.id] diff --git a/src/nagini_translation/translators/obligation/fork.py b/src/nagini_translation/translators/obligation/fork.py index 25d8a793..47cf6196 100644 --- a/src/nagini_translation/translators/obligation/fork.py +++ b/src/nagini_translation/translators/obligation/fork.py @@ -120,7 +120,7 @@ def _set_parameter_aliases(self, method: PythonMethod) -> List[Stmt]: arg_vars = [] stmts = [] for index, arg in enumerate(method._args.values()): - arg_var = self._ctx.actual_function.create_variable( + arg_var = self._ctx.current_function.create_variable( 'thread_arg', arg.type, self._translator.translator) arg_vars.append(arg_var) index_lit = self.viper.IntLit(index, self._position, self._info) diff --git a/src/nagini_translation/translators/statement.py b/src/nagini_translation/translators/statement.py index f629d7be..4467ecc4 100644 --- a/src/nagini_translation/translators/statement.py +++ b/src/nagini_translation/translators/statement.py @@ -601,7 +601,7 @@ def _get_iterator(self, iterable: Expr, iterable_type: PythonType, node: ast.AST, ctx: Context) -> Tuple[PythonVar, List[Stmt]]: iter_class = ctx.module.global_module.classes['Iterator'] - iter_var = ctx.actual_function.create_variable('iter', iter_class, + iter_var = ctx.current_function.create_variable('iter', iter_class, self.translator) assert not node in ctx.loop_iterators ctx.loop_iterators[node] = iter_var @@ -616,7 +616,7 @@ def _get_next_call(self, iter_var: PythonVar, target_var: PythonVar, node: ast.For, ctx: Context) -> Tuple[PythonVar, List[Stmt]]: exc_class = ctx.module.global_module.classes['Exception'] - err_var = ctx.actual_function.create_variable('iter_err', exc_class, + err_var = ctx.current_function.create_variable('iter_err', exc_class, self.translator) iter_class = ctx.module.global_module.classes['Iterator'] args = [iter_var.ref()] @@ -701,7 +701,7 @@ def translate_stmt_For(self, node: ast.For, ctx: Context) -> List[Stmt]: node.end_label = end_label iterable_type = self.get_type(node.iter, ctx) iterable_stmt, iterable = self.translate_expr(node.iter, ctx) - iterable_var = ctx.actual_function.create_variable('iterable', iterable_type, + iterable_var = ctx.current_function.create_variable('iterable', iterable_type, self.translator, True) iterable_assign = self.viper.LocalVarAssign(iterable_var.ref(), iterable, position, info) @@ -718,7 +718,7 @@ def translate_stmt_For(self, node: ast.For, ctx: Context) -> List[Stmt]: raise UnsupportedException(node, 'unknown.iterable') # Create artificial new variable to store current iteration content. - target_var = ctx.actual_function.create_variable('loop_target', + target_var = ctx.current_function.create_variable('loop_target', target_type, self.translator)