Skip to content

Commit

Permalink
Inline decorator, WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoeilers committed Nov 29, 2024
1 parent 5818baf commit 19a5fbb
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 18 deletions.
8 changes: 8 additions & 0 deletions src/nagini_contracts/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,13 @@ def Predicate(func: T) -> T:
return func


def Inline(func: T) -> T:
"""
Decorator to mark functions that should be inlined during verification, and not treated modularly. It's a no-op.
"""
return func


def Ghost(func: T) -> T:
"""
Decorator for ghost functions. It's a no-op.
Expand Down Expand Up @@ -555,6 +562,7 @@ def isNaN(f: float) -> bool:
'Unfolding',
'Pure',
'Predicate',
'Inline',
'Ghost',
'ContractOnly',
'GhostReturns',
Expand Down
7 changes: 7 additions & 0 deletions src/nagini_translation/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
func.method_type = MethodType.class_method
self.current_class._has_classmethod = True
func.predicate = self.is_predicate(node)
if self.is_inline_method(node):
func.inline = True
if func.predicate:
func.contract_only = self.is_declared_contract_only(node)
if self.is_all_low(node):
Expand Down Expand Up @@ -1469,6 +1471,8 @@ def visit_Try(self, node: ast.Try) -> None:

def _incompatible_decorators(self, decorators) -> bool:
return ((('Predicate' in decorators) and ('Pure' in decorators)) or
(('Predicate' in decorators) and ('Inline' in decorators)) or
(('Inline' in decorators) and ('Pure' in decorators)) or
(('IOOperation' in decorators) and (len(decorators) != 1)) or
(('property' in decorators) and (len(decorators) != 1)) or
(('AllLow' in decorators) and ('PreservesLow' in decorators)) or
Expand Down Expand Up @@ -1521,6 +1525,9 @@ def is_pure(self, func: ast.FunctionDef) -> bool:
def is_predicate(self, func: ast.FunctionDef) -> bool:
return self.has_decorator(func, 'Predicate')

def is_inline_method(self, func: ast.FunctionDef) -> bool:
return self.has_decorator(func, 'Inline')

def is_static_method(self, func: ast.FunctionDef) -> bool:
return self.has_decorator(func, 'staticmethod')

Expand Down
7 changes: 7 additions & 0 deletions src/nagini_translation/lib/program_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def __init__(self, name: str, node: ast.AST, cls: PythonClass,
self.declared_exceptions = OrderedDict() # direct
self.pure = pure
self.predicate = False
self.inline = False
self.all_low = False
self.preserves_low = False
self.contract_only = contract_only
Expand Down Expand Up @@ -1080,6 +1081,10 @@ def process(self, sil_name: str, translator: 'Translator') -> None:
# Could be overridden by anything, so we have to check if there's
# anything with the same name.
self.overrides = self.cls.superclass.get_contents(False)[self.name]
if self.overrides is not None:
if self.overrides.inline:
raise InvalidProgramException(self.node, 'overriding.inline.method',
'Functions marked to be inlined cannot be overridden.')
except KeyError:
pass
for local in self.locals:
Expand Down Expand Up @@ -1577,6 +1582,8 @@ def process(self, sil_name: str, translator: 'Translator') -> None:
this Python variable.
"""
super().process(sil_name, translator)
if sil_name == "iterable_1":
print("++")
self._translator = translator
module = self.type.module
self.decl = translator.translate_pythonvar_decl(self, module)
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/sif/translators/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _create_With_exception_handler(self, try_block: PythonTryBlock,
err_type_arg = self.type_factory.typeof(error_var, ctx)

tb_class = ctx.module.global_module.classes['traceback']
traceback_var = ctx.actual_function.create_variable('tb', tb_class,
traceback_var = ctx.current_function.create_variable('tb', tb_class,
self.translator)
tb_type = self.type_check(traceback_var.ref(), tb_class, no_pos, ctx)
inhale_types = self.viper.Inhale(tb_type, no_pos, no_info)
Expand Down
21 changes: 14 additions & 7 deletions src/nagini_translation/translators/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _translate_list(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
if contents:
sil_ref_seq = self.viper.SeqType(self.viper.Ref)
ref_seq = SilverType(sil_ref_seq, ctx.module)
havoc_var = ctx.actual_function.create_variable('havoc_seq', ref_seq,
havoc_var = ctx.current_function.create_variable('havoc_seq', ref_seq,
self.translator)
seq_field = self.viper.Field('list_acc', sil_ref_seq, position, info)
content_field = self.viper.FieldAccess(result_var, seq_field, position, info)
Expand Down Expand Up @@ -391,7 +391,7 @@ def _translate_set(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
if contents:
sil_ref_set = self.viper.SetType(self.viper.Ref)
ref_set = SilverType(sil_ref_set, ctx.module)
havoc_var = ctx.actual_function.create_variable('havoc_set', ref_set,
havoc_var = ctx.current_function.create_variable('havoc_set', ref_set,
self.translator)
set_field = self.viper.Field('set_acc', sil_ref_set, position, info)
content_field = self.viper.FieldAccess(result_var, set_field, position, info)
Expand Down Expand Up @@ -455,7 +455,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
arg_type = self.get_type(node.args[0], ctx)
arg_stmt, arg = self.translate_expr(node.args[0], ctx)
arg_contents = self.get_sequence(arg_type, arg, None, node.args[0], ctx)
new_list = ctx.actual_function.create_variable('enumerate_res', result_type,
new_list = ctx.current_function.create_variable('enumerate_res', result_type,
self.translator)
sil_ref_seq = self.viper.SeqType(self.viper.Ref)
seq_field = self.viper.Field('list_acc', sil_ref_seq, pos, info)
Expand All @@ -470,7 +470,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
type_inhale = self.viper.Inhale(self.viper.And(list_type_info, list_len_info,
pos, info),
pos, info)
i_var = ctx.actual_function.create_variable('i', prim_int_type, self.translator,
i_var = ctx.current_function.create_variable('i', prim_int_type, self.translator,
False)
orig_seq_i = self.viper.SeqIndex(arg_contents, i_var.ref(), pos, info)
content_type = result_type.type_args[0].type_args[1]
Expand Down Expand Up @@ -921,6 +921,10 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar],
# Create local var aliases
locals_to_copy = method.locals.copy()
for local_name, local in locals_to_copy.items():
if type(local).__name__ == 'SilverVar':
newName = ctx.current_function.get_fresh_name(local_name)
print(local_name)
continue
local_var = ctx.current_function.create_variable(local_name,
local.type,
self.translator)
Expand Down Expand Up @@ -1201,6 +1205,9 @@ def translate_normal_call(self, target: PythonMethod, arg_stmts: List[Stmt],
return self._translate_function_call(target, args, formal_args,
arg_stmts, position, node, ctx)
else:
if target.inline:
return self._inline_call(target, node, False, 'inlined call',
ctx)
return self._translate_method_call(target, args, arg_stmts,
position, node, ctx)

Expand Down Expand Up @@ -1355,7 +1362,7 @@ def _translate_thread_creation(self, node: ast.Call,

# Create thread object
thread_class = ctx.module.global_module.classes['Thread']
thread_var = ctx.actual_function.create_variable('threadingVar', thread_class,
thread_var = ctx.current_function.create_variable('threadingVar', thread_class,
self.translator)
thread = thread_var.ref(node, ctx)
newstmt = self.viper.NewStmt(thread, [], pos, info)
Expand Down Expand Up @@ -1501,7 +1508,7 @@ def _translate_thread_join(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
ctx.perm_factor = post_perm

object_class = ctx.module.global_module.classes[OBJECT_TYPE]
res_var = ctx.actual_function.create_variable('join_result', object_class,
res_var = ctx.current_function.create_variable('join_result', object_class,
self.translator)

# Resolve list of possible thread target methods.
Expand Down Expand Up @@ -1544,7 +1551,7 @@ def _inhale_possible_thread_post(self, method: PythonMethod, thread: Expr,

# Set arg aliases with types
for index, arg in enumerate(method._args.values()):
arg_var = ctx.actual_function.create_variable('thread_arg', arg.type,
arg_var = ctx.current_function.create_variable('thread_arg', arg.type,
self.translator)
ctx.set_alias(arg.name, arg_var)
id = self.viper.IntLit(index, pos, info)
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def get_func_or_method_call(self, receiver: PythonType, func_name: str,
method = receiver.get_method(func_name)
if method:
assert method.type
target_var = ctx.actual_function.create_variable('target', method.type,
target_var = ctx.current_function.create_variable('target', method.type,
self.translator)
val = target_var.ref(node, ctx)
call = self.get_method_call(receiver, func_name, args, arg_types, [val], node,
Expand Down
8 changes: 4 additions & 4 deletions src/nagini_translation/translators/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def translate_ListComp(self, node: ast.ListComp, ctx: Context) -> StmtsAndExpr:
if body_stmt:
raise InvalidProgramException(node, 'impure.list.comprehension.body')
ctx.remove_alias(target.id)
result_var = ctx.actual_function.create_variable('listcomp', list_type,
result_var = ctx.current_function.create_variable('listcomp', list_type,
self.translator)
stmt = self._create_list_comp_inhale(result_var, list_type, element_var,
body, node, ctx)
Expand Down Expand Up @@ -213,7 +213,7 @@ def _create_list_comp_inhale(self, result_var: PythonVar, list_type: PythonType,
len_equal = self.viper.EqCmp(self.to_int(result_len, ctx), seq_len, position,
info)
int_class = ctx.module.global_module.classes[PRIMITIVE_INT_TYPE]
index_var = ctx.actual_function.create_variable('i', int_class, self.translator,
index_var = ctx.current_function.create_variable('i', int_class, self.translator,
False)
index_positive = self.viper.GeCmp(index_var.ref(),
self.viper.IntLit(0, position, info),
Expand Down Expand Up @@ -491,7 +491,7 @@ def _translate_slice_subscript(self, node: ast.Subscript, target: Expr,
stmt = target_stmt + start_stmt + stop_stmt
getitem = target_type.get_func_or_method('__getitem_slice__')
if not getitem.pure:
result_var = ctx.actual_function.create_variable(
result_var = ctx.current_function.create_variable(
'slice_res', target_type, self.translator)
call = self.get_method_call(target_type, '__getitem_slice__',
args, [None, None],
Expand Down Expand Up @@ -665,7 +665,7 @@ def translate_Name(self, node: ast.Name, ctx: Context) -> StmtsAndExpr:
self.to_position(node, ctx), ctx)
if node.id == '_':
object_type = ctx.module.global_module.classes[OBJECT_TYPE]
temp_var = ctx.actual_function.create_variable('wildcard', object_type, self.translator)
temp_var = ctx.current_function.create_variable('wildcard', object_type, self.translator)
return [], temp_var.ref(node, ctx)
if node.id in ctx.var_aliases:
var = ctx.var_aliases[node.id]
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/obligation/fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _set_parameter_aliases(self, method: PythonMethod) -> List[Stmt]:
arg_vars = []
stmts = []
for index, arg in enumerate(method._args.values()):
arg_var = self._ctx.actual_function.create_variable(
arg_var = self._ctx.current_function.create_variable(
'thread_arg', arg.type, self._translator.translator)
arg_vars.append(arg_var)
index_lit = self.viper.IntLit(index, self._position, self._info)
Expand Down
8 changes: 4 additions & 4 deletions src/nagini_translation/translators/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def _get_iterator(self, iterable: Expr, iterable_type: PythonType,
node: ast.AST, ctx: Context) -> Tuple[PythonVar,
List[Stmt]]:
iter_class = ctx.module.global_module.classes['Iterator']
iter_var = ctx.actual_function.create_variable('iter', iter_class,
iter_var = ctx.current_function.create_variable('iter', iter_class,
self.translator)
assert not node in ctx.loop_iterators
ctx.loop_iterators[node] = iter_var
Expand All @@ -616,7 +616,7 @@ def _get_next_call(self, iter_var: PythonVar, target_var: PythonVar,
node: ast.For,
ctx: Context) -> Tuple[PythonVar, List[Stmt]]:
exc_class = ctx.module.global_module.classes['Exception']
err_var = ctx.actual_function.create_variable('iter_err', exc_class,
err_var = ctx.current_function.create_variable('iter_err', exc_class,
self.translator)
iter_class = ctx.module.global_module.classes['Iterator']
args = [iter_var.ref()]
Expand Down Expand Up @@ -701,7 +701,7 @@ def translate_stmt_For(self, node: ast.For, ctx: Context) -> List[Stmt]:
node.end_label = end_label
iterable_type = self.get_type(node.iter, ctx)
iterable_stmt, iterable = self.translate_expr(node.iter, ctx)
iterable_var = ctx.actual_function.create_variable('iterable', iterable_type,
iterable_var = ctx.current_function.create_variable('iterable', iterable_type,
self.translator, True)
iterable_assign = self.viper.LocalVarAssign(iterable_var.ref(), iterable,
position, info)
Expand All @@ -718,7 +718,7 @@ def translate_stmt_For(self, node: ast.For, ctx: Context) -> List[Stmt]:
raise UnsupportedException(node, 'unknown.iterable')

# Create artificial new variable to store current iteration content.
target_var = ctx.actual_function.create_variable('loop_target',
target_var = ctx.current_function.create_variable('loop_target',
target_type,
self.translator)

Expand Down

0 comments on commit 19a5fbb

Please sign in to comment.