Skip to content

Commit

Permalink
Merge pull request #215 from marcoeilers/inline
Browse files Browse the repository at this point in the history
@inline decorator
  • Loading branch information
marcoeilers authored Dec 1, 2024
2 parents 17047a7 + 540c1ee commit f0add81
Show file tree
Hide file tree
Showing 27 changed files with 330 additions and 41 deletions.
8 changes: 8 additions & 0 deletions src/nagini_contracts/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,13 @@ def Predicate(func: T) -> T:
return func


def Inline(func: T) -> T:
"""
Decorator to mark functions that should be inlined during verification, and not treated modularly. It's a no-op.
"""
return func


def Ghost(func: T) -> T:
"""
Decorator for ghost functions. It's a no-op.
Expand Down Expand Up @@ -555,6 +562,7 @@ def isNaN(f: float) -> bool:
'Unfolding',
'Pure',
'Predicate',
'Inline',
'Ghost',
'ContractOnly',
'GhostReturns',
Expand Down
20 changes: 20 additions & 0 deletions src/nagini_translation/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
func.method_type = MethodType.class_method
self.current_class._has_classmethod = True
func.predicate = self.is_predicate(node)
if self.is_inline_method(node):
if name == '__init__':
raise UnsupportedException(node, 'Inlining constructors is currently not supported.')
decorators = {d.id for d in node.decorator_list if isinstance(d, ast.Name)}
if len(decorators) > 1:
raise UnsupportedException(node, 'Unsupported decorator for inline function.')
func.inline = True
if func.predicate:
func.contract_only = self.is_declared_contract_only(node)
if self.is_all_low(node):
Expand Down Expand Up @@ -909,13 +916,21 @@ def track_access(self, node: ast.AST, var: Union[PythonVar, PythonField]) -> Non
elif isinstance(node.ctx, ast.Store):
var.writes.append(node)

def _check_not_in_inline_method(self, node: ast.Call) -> None:
if isinstance(self.stmt_container, PythonMethod) and self.stmt_container.inline:
raise InvalidProgramException(node, 'contract.in.inline.method',
'Inline methods must not have specifications.')

def visit_Call(self, node: ast.Call) -> None:
"""
Collects preconditions, postconditions, raised exceptions and
invariants.
"""
if (isinstance(node.func, ast.Name) and
node.func.id in CONTRACT_WRAPPER_FUNCS):
if node.func.id != 'Invariant':
self._check_not_in_inline_method(node)

if node.func.id == 'Requires':
self.stmt_container.precondition.append(
(node.args[0], self._aliases.copy()))
Expand Down Expand Up @@ -1469,6 +1484,8 @@ def visit_Try(self, node: ast.Try) -> None:

def _incompatible_decorators(self, decorators) -> bool:
return ((('Predicate' in decorators) and ('Pure' in decorators)) or
(('Predicate' in decorators) and ('Inline' in decorators)) or
(('Inline' in decorators) and ('Pure' in decorators)) or
(('IOOperation' in decorators) and (len(decorators) != 1)) or
(('property' in decorators) and (len(decorators) != 1)) or
(('AllLow' in decorators) and ('PreservesLow' in decorators)) or
Expand Down Expand Up @@ -1521,6 +1538,9 @@ def is_pure(self, func: ast.FunctionDef) -> bool:
def is_predicate(self, func: ast.FunctionDef) -> bool:
return self.has_decorator(func, 'Predicate')

def is_inline_method(self, func: ast.FunctionDef) -> bool:
return self.has_decorator(func, 'Inline')

def is_static_method(self, func: ast.FunctionDef) -> bool:
return self.has_decorator(func, 'staticmethod')

Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/lib/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def set_alias(self, name: str, var: PythonVar,
self.old_aliases[name] = []
self.old_aliases[name].append(self.var_aliases[name])
if replaces:
if replaces.alt_types:
if hasattr(replaces, 'alt_types') and replaces.alt_types:
assert not var.alt_types
var.alt_types = replaces.alt_types
self.var_aliases[name] = var
Expand Down
8 changes: 8 additions & 0 deletions src/nagini_translation/lib/program_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def __init__(self, name: str, node: ast.AST, cls: PythonClass,
self.declared_exceptions = OrderedDict() # direct
self.pure = pure
self.predicate = False
self.inline = False
self.all_low = False
self.preserves_low = False
self.contract_only = contract_only
Expand Down Expand Up @@ -1080,6 +1081,13 @@ def process(self, sil_name: str, translator: 'Translator') -> None:
# Could be overridden by anything, so we have to check if there's
# anything with the same name.
self.overrides = self.cls.superclass.get_contents(False)[self.name]
if self.overrides is not None:
if self.overrides.inline:
raise InvalidProgramException(self.node, 'overriding.inline.method',
'Functions marked to be inlined cannot be overridden.')
if self.inline:
raise InvalidProgramException(self.node, 'overriding.inline.method',
'Functions marked to be inlined cannot override other methods.')
except KeyError:
pass
for local in self.locals:
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/sif/translators/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _method_body_postamble(self, method: PythonMethod, ctx: Context) -> List[Stm
def _create_method_epilog(self, method: PythonMethod, ctx: Context) -> List[Stmt]:
# With the extended AST we don't need a label at the end of the method.
# Check that no undeclared exceptions are raised. (but not for main method)
if not method.declared_exceptions:
if not (method.declared_exceptions or method.inline):
no_info = self.no_info(ctx)
error_string = '"method raises no exceptions"'
error_pos = self.to_position(method.node, ctx, error_string)
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/sif/translators/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _create_With_exception_handler(self, try_block: PythonTryBlock,
err_type_arg = self.type_factory.typeof(error_var, ctx)

tb_class = ctx.module.global_module.classes['traceback']
traceback_var = ctx.actual_function.create_variable('tb', tb_class,
traceback_var = ctx.current_function.create_variable('tb', tb_class,
self.translator)
tb_type = self.type_check(traceback_var.ref(), tb_class, no_pos, ctx)
inhale_types = self.viper.Inhale(tb_type, no_pos, no_info)
Expand Down
29 changes: 20 additions & 9 deletions src/nagini_translation/translators/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _translate_list(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
if contents:
sil_ref_seq = self.viper.SeqType(self.viper.Ref)
ref_seq = SilverType(sil_ref_seq, ctx.module)
havoc_var = ctx.actual_function.create_variable('havoc_seq', ref_seq,
havoc_var = ctx.current_function.create_variable('havoc_seq', ref_seq,
self.translator)
seq_field = self.viper.Field('list_acc', sil_ref_seq, position, info)
content_field = self.viper.FieldAccess(result_var, seq_field, position, info)
Expand Down Expand Up @@ -391,7 +391,7 @@ def _translate_set(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
if contents:
sil_ref_set = self.viper.SetType(self.viper.Ref)
ref_set = SilverType(sil_ref_set, ctx.module)
havoc_var = ctx.actual_function.create_variable('havoc_set', ref_set,
havoc_var = ctx.current_function.create_variable('havoc_set', ref_set,
self.translator)
set_field = self.viper.Field('set_acc', sil_ref_set, position, info)
content_field = self.viper.FieldAccess(result_var, set_field, position, info)
Expand Down Expand Up @@ -455,7 +455,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
arg_type = self.get_type(node.args[0], ctx)
arg_stmt, arg = self.translate_expr(node.args[0], ctx)
arg_contents = self.get_sequence(arg_type, arg, None, node.args[0], ctx)
new_list = ctx.actual_function.create_variable('enumerate_res', result_type,
new_list = ctx.current_function.create_variable('enumerate_res', result_type,
self.translator)
sil_ref_seq = self.viper.SeqType(self.viper.Ref)
seq_field = self.viper.Field('list_acc', sil_ref_seq, pos, info)
Expand All @@ -470,7 +470,7 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
type_inhale = self.viper.Inhale(self.viper.And(list_type_info, list_len_info,
pos, info),
pos, info)
i_var = ctx.actual_function.create_variable('i', prim_int_type, self.translator,
i_var = ctx.current_function.create_variable('i', prim_int_type, self.translator,
False)
orig_seq_i = self.viper.SeqIndex(arg_contents, i_var.ref(), pos, info)
content_type = result_type.type_args[0].type_args[1]
Expand Down Expand Up @@ -921,9 +921,12 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar],
# Create local var aliases
locals_to_copy = method.locals.copy()
for local_name, local in locals_to_copy.items():
if type(local).__name__ == 'SilverVar':
continue
local_var = ctx.current_function.create_variable(local_name,
local.type,
self.translator)
local_var.writes = local.writes
ctx.set_alias(local_name, local_var, local)

# Create label aliases
Expand All @@ -939,6 +942,11 @@ def inline_method(self, method: PythonMethod, args: List[PythonVar],
start, end = get_body_indices(method.node.body)
stmts = []

if error_var:
pos = self.no_position(ctx)
info = self.no_info(ctx)
stmts.append(self.viper.LocalVarAssign(error_var.ref(), self.viper.NullLit(pos, info), pos, info))

for stmt in method.node.body[start:end]:
stmts += self.translate_stmt(stmt, ctx)

Expand Down Expand Up @@ -986,7 +994,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool,
self.translator)
optional_error_var = None
error_var = self.get_error_var(node, ctx)
if method.declared_exceptions:
if method.declared_exceptions or method.inline:
var = PythonVar(ERROR_NAME, None,
ctx.module.global_module.classes['Exception'])
var._ref = error_var
Expand All @@ -999,7 +1007,7 @@ def _inline_call(self, method: PythonMethod, node: ast.Call, is_super: bool,
stmts += inline_stmts
if end_lbl:
stmts.append(end_lbl)
if method.declared_exceptions:
if method.declared_exceptions or method.inline:
stmts += self.create_exception_catchers(error_var,
ctx.actual_function.try_blocks, node, ctx)
# Return result
Expand Down Expand Up @@ -1201,6 +1209,9 @@ def translate_normal_call(self, target: PythonMethod, arg_stmts: List[Stmt],
return self._translate_function_call(target, args, formal_args,
arg_stmts, position, node, ctx)
else:
if target.inline:
return self._inline_call(target, node, False, 'inlined call',
ctx)
return self._translate_method_call(target, args, arg_stmts,
position, node, ctx)

Expand Down Expand Up @@ -1355,7 +1366,7 @@ def _translate_thread_creation(self, node: ast.Call,

# Create thread object
thread_class = ctx.module.global_module.classes['Thread']
thread_var = ctx.actual_function.create_variable('threadingVar', thread_class,
thread_var = ctx.current_function.create_variable('threadingVar', thread_class,
self.translator)
thread = thread_var.ref(node, ctx)
newstmt = self.viper.NewStmt(thread, [], pos, info)
Expand Down Expand Up @@ -1501,7 +1512,7 @@ def _translate_thread_join(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
ctx.perm_factor = post_perm

object_class = ctx.module.global_module.classes[OBJECT_TYPE]
res_var = ctx.actual_function.create_variable('join_result', object_class,
res_var = ctx.current_function.create_variable('join_result', object_class,
self.translator)

# Resolve list of possible thread target methods.
Expand Down Expand Up @@ -1544,7 +1555,7 @@ def _inhale_possible_thread_post(self, method: PythonMethod, thread: Expr,

# Set arg aliases with types
for index, arg in enumerate(method._args.values()):
arg_var = ctx.actual_function.create_variable('thread_arg', arg.type,
arg_var = ctx.current_function.create_variable('thread_arg', arg.type,
self.translator)
ctx.set_alias(arg.name, arg_var)
id = self.viper.IntLit(index, pos, info)
Expand Down
4 changes: 2 additions & 2 deletions src/nagini_translation/translators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def get_func_or_method_call(self, receiver: PythonType, func_name: str,
method = receiver.get_method(func_name)
if method:
assert method.type
target_var = ctx.actual_function.create_variable('target', method.type,
target_var = ctx.current_function.create_variable('target', method.type,
self.translator)
val = target_var.ref(node, ctx)
call = self.get_method_call(receiver, func_name, args, arg_types, [val], node,
Expand Down Expand Up @@ -815,7 +815,7 @@ def get_error_var(self, stmt: ast.AST,
if err_var.sil_name in ctx.var_aliases:
err_var = ctx.var_aliases[err_var.sil_name]
return err_var.ref()
if ctx.actual_function.declared_exceptions:
if ctx.actual_function.declared_exceptions or ctx.actual_function.inline:
return ctx.error_var.ref()
else:
new_var = ctx.current_function.create_variable('error',
Expand Down
10 changes: 5 additions & 5 deletions src/nagini_translation/translators/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def translate_ListComp(self, node: ast.ListComp, ctx: Context) -> StmtsAndExpr:
if body_stmt:
raise InvalidProgramException(node, 'impure.list.comprehension.body')
ctx.remove_alias(target.id)
result_var = ctx.actual_function.create_variable('listcomp', list_type,
result_var = ctx.current_function.create_variable('listcomp', list_type,
self.translator)
stmt = self._create_list_comp_inhale(result_var, list_type, element_var,
body, node, ctx)
Expand Down Expand Up @@ -213,7 +213,7 @@ def _create_list_comp_inhale(self, result_var: PythonVar, list_type: PythonType,
len_equal = self.viper.EqCmp(self.to_int(result_len, ctx), seq_len, position,
info)
int_class = ctx.module.global_module.classes[PRIMITIVE_INT_TYPE]
index_var = ctx.actual_function.create_variable('i', int_class, self.translator,
index_var = ctx.current_function.create_variable('i', int_class, self.translator,
False)
index_positive = self.viper.GeCmp(index_var.ref(),
self.viper.IntLit(0, position, info),
Expand Down Expand Up @@ -491,7 +491,7 @@ def _translate_slice_subscript(self, node: ast.Subscript, target: Expr,
stmt = target_stmt + start_stmt + stop_stmt
getitem = target_type.get_func_or_method('__getitem_slice__')
if not getitem.pure:
result_var = ctx.actual_function.create_variable(
result_var = ctx.current_function.create_variable(
'slice_res', target_type, self.translator)
call = self.get_method_call(target_type, '__getitem_slice__',
args, [None, None],
Expand Down Expand Up @@ -523,7 +523,7 @@ def create_exception_catchers(self, var: PythonVar,
else:
end_label = ctx.get_label_name(END_LABEL)
goto_end = self.viper.Goto(end_label, position, self.no_info(ctx))
if ctx.actual_function.declared_exceptions:
if ctx.actual_function.declared_exceptions or ctx.actual_function.inline:
assignerror = self.viper.LocalVarAssign(err_var, var, position,
self.no_info(ctx))
uncaught_option = self.translate_block([assignerror, goto_end],
Expand Down Expand Up @@ -665,7 +665,7 @@ def translate_Name(self, node: ast.Name, ctx: Context) -> StmtsAndExpr:
self.to_position(node, ctx), ctx)
if node.id == '_':
object_type = ctx.module.global_module.classes[OBJECT_TYPE]
temp_var = ctx.actual_function.create_variable('wildcard', object_type, self.translator)
temp_var = ctx.current_function.create_variable('wildcard', object_type, self.translator)
return [], temp_var.ref(node, ctx)
if node.id in ctx.var_aliases:
var = ctx.var_aliases[node.id]
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def translate_finally(self, block: PythonTryBlock,
goto_continue = self.viper.Goto(loop.end_label, pos, info)
break_block = [goto_break]
continue_block = [goto_continue]
if ctx.actual_function.declared_exceptions:
if ctx.actual_function.declared_exceptions or ctx.actual_function.inline:
# Assign error to error output var
error_var = ctx.error_var.ref()
block_error_var = block.get_error_var(self.translator)
Expand Down
4 changes: 2 additions & 2 deletions src/nagini_translation/translators/obligation/fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _set_parameter_aliases(self, method: PythonMethod) -> List[Stmt]:
arg_vars = []
stmts = []
for index, arg in enumerate(method._args.values()):
arg_var = self._ctx.actual_function.create_variable(
arg_var = self._ctx.current_function.create_variable(
'thread_arg', arg.type, self._translator.translator)
arg_vars.append(arg_var)
index_lit = self.viper.IntLit(index, self._position, self._info)
Expand Down Expand Up @@ -219,7 +219,7 @@ def _add_waitlevel(self) -> None:
def _create_level_below(
self, expr: sil.PermExpression,
ctx: Context) -> sil.BoolExpression:
residue_level_var = sil.PermVar(ctx.actual_function.obligation_info.residue_level)
residue_level_var = sil.PermVar(ctx.current_function.obligation_info.residue_level)
obligation = self._obligation_manager.must_release_obligation
fields = obligation.create_fields_untranslated()
var = ctx.current_function.create_variable(
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/obligation/inexhale.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_use_method(
self, ctx: Context) -> List[Tuple[sil.Expression, Rules]]:
"""Default implementation for obligation use in method contract."""
inexhale = self._get_inexhale(ctx)
obligation_info = ctx.actual_function.obligation_info
obligation_info = ctx.current_function.obligation_info
if self.is_fresh():
return [(inexhale.construct_use_method_unbounded(), None)]
else:
Expand Down
2 changes: 1 addition & 1 deletion src/nagini_translation/translators/obligation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def translate_must_invoke_ctoken(
return self._loop_translator.translate_may_invoke(
node, ctx)
elif ctx.obligation_context.is_translating_posts:
if ctx.actual_function.name != 'Gap':
if ctx.current_function.name != 'Gap':
raise InvalidProgramException(
node, 'invalid.postcondition.ctoken_not_allowed')
else:
Expand Down
Loading

0 comments on commit f0add81

Please sign in to comment.