Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chance of reverse chain assignment in perm_temp_for_expr #138

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions src/randomizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

# Instead of emitting an assignment statement, assign the temporary within the
# first expression it's used in with this probability.
PROB_TEMP_ASSIGN_AT_FIRST_USE = 0.1
PROB_TEMP_ASSIGN_AT_FIRST_USE = 0.2

# When creating a temporary for an expression, use the temporary for all equal
# expressions with this probability.
Expand Down Expand Up @@ -385,6 +385,27 @@ def expr_filter(node: ca.Node, is_expr: bool) -> Any:
visit_replace(top_node, expr_filter)


def find_assignment_stmt_by_rvalue(
block: Block, expr: Expression
) -> Optional[Statement]:

ret_stmt: Optional[Statement] = None

def rec(block: Block) -> None:
nonlocal ret_stmt
statements = ast_util.get_block_stmts(block, False)
for stmt in statements:
if isinstance(stmt, ca.Assignment):
if stmt.rvalue is expr:
ret_stmt = stmt
return None
ast_util.for_nested_blocks(stmt, rec)
return None

rec(block)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simpler to use a visitor for this rather than explicit recursion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think it was not possible to use a visitor here. I'll explain what the issue was when I get back to this

return ret_stmt


def replace_node(top_node: ca.Node, old: ca.Node, new: ca.Node) -> None:
visit_replace(top_node, lambda node, _: new if node is old else None)

Expand Down Expand Up @@ -661,6 +682,17 @@ def visitor(expr: Expression) -> None:
expr = ca.UnaryOp("&", expr)
type = decayed_expr_type(expr, typemap)

reverse_chain_case = False

# if ASSIGN_AT_FIRST_USE and expr is rvalue of an assignment
# then decide whether to flip the order of assignments in the chain assignment statement
if place is None and not should_make_ptr:
stmt = find_assignment_stmt_by_rvalue(fn.body, expr)
if stmt and random_bool(random, 0.5):
assert isinstance(stmt, ca.Assignment)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assert shouldn't be needed

expr = stmt.lvalue
reverse_chain_case = True

if should_make_ptr:
assert isinstance(expr, ca.UnaryOp)
assert not isinstance(expr.expr, ca.Typename)
Expand Down Expand Up @@ -712,18 +744,19 @@ def find_duplicates(e: Expression) -> None:
else:
replace_subexprs(fn.body, find_duplicates)

assert orig_expr in replace_cands
replace_cand_set: Set[Expression] = set()
if random_bool(random, PROB_TEMP_REPLACE_ALL):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I changed this part of the code in a recent commit, will need to be updated)

replace_cand_set.update(replace_cands)
elif random_bool(random, PROB_TEMP_REPLACE_MOST):
index = replace_cands.index(orig_expr)
if random_bool(random, 0.5):
replace_cand_set.update(replace_cands[: index + 1])
elif not reverse_chain_case:
if random_bool(random, PROB_TEMP_REPLACE_MOST):
assert orig_expr in replace_cands
index = replace_cands.index(orig_expr)
if random_bool(random, 0.5):
replace_cand_set.update(replace_cands[: index + 1])
else:
replace_cand_set.update(replace_cands[index:])
else:
replace_cand_set.update(replace_cands[index:])
else:
replace_cand_set.add(orig_expr)
replace_cand_set.add(orig_expr)

if random_bool(random, 0.5):
for cand in replace_cands:
Expand All @@ -743,6 +776,11 @@ def replacer(e: Expression) -> Optional[Expression]:

replace_subexprs(fn.body, replacer)

if reverse_chain_case:
assert isinstance(stmt, ca.Assignment)
stmt.rvalue = ca.Assignment("=", stmt.lvalue, stmt.rvalue)
stmt.lvalue = ca.ID(var)

# Step 6: insert the assignment and any new variable declaration
if place is not None:
block, index, _ = place
Expand Down