Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Util] Fix an assert getting reached for certain nested loops in HoistIntoGlobals #19576

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class ConstExprHoistingPolicy {
}
void enableHoist() {
assert(outcome == UNDECIDED &&
"can only disable hoisting of an undecided decision");
"can only enable hoisting of an undecided decision");
outcome = ENABLE_HOIST;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ static void populateEscapingProducers(Operation *parentOp,
ConstExprOpInfo &info) {
SmallPtrSet<Operation *, 8> containedOps;
parentOp->walk<WalkOrder::PreOrder>([&](Operation *itOp) {
containedOps.insert(parentOp);

// For the outer-most op, consider that all operands escape.
if (itOp == parentOp) {
info.producers.insert(itOp->getOperands().begin(),
Expand All @@ -33,8 +31,9 @@ static void populateEscapingProducers(Operation *parentOp,
: WalkResult::advance();
}

// For nested operations, only consider that they escape if they are
// defined outside of the parent.
containedOps.insert(itOp->getParentOp());
// A nested operation escapes if every operand is defined outside contained
// ops.
for (Value operand : itOp->getOperands()) {
Block *block = operand.getParentBlock();
if (!containedOps.contains(block->getParentOp())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,15 @@ class HoistIntoGlobalsPass
continue;
auto walkRes = funcOp.walk<WalkOrder::PreOrder>([&](Operation *iterOp) {
// We only want to look at const-expr ops (non roots) since they may
// have interesting escapes. Early exit here for efficiency.
// have interesting escapes. Early exit here if the op has no
// ConstValueInfo or its first result cannot be hoisted.
auto *iterInfo = constExprs.lookup(iterOp);
if (!iterInfo)
if (!iterInfo || policy.getDecision(iterInfo)->getOutcome() !=
ConstExprHoistingPolicy::ENABLE_HOIST)
return WalkResult::advance();
for (Value constExprResult : iterOp->getResults()) {
auto *resultInfo = constExprs.lookup(constExprResult);
assert(resultInfo && "must have const-expr info");
if (policy.getDecision(resultInfo)->getOutcome() !=
ConstExprHoistingPolicy::ENABLE_HOIST) {
continue;
}
if (failed(hoistConstExpr(constExprResult, hoistedMap, moduleSymbols,
constExprs))) {
return WalkResult::interrupt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,30 @@ module @nested_program_const_expr {
}
}
}

// -----

// Prior to this patch, a bug caused %3#0 to be considered an escaping producer for %1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use issue numbers when referencing things - "this patch" won't mean much after this lands :)

the description here is a bit too much for a test - it encodes assumptions/situations that are very specific to the current implementation of the code and will get out of date fast - it's useful to include comments as to what a test is verifying but it doesn't need the whole back story and prior behavior, just what's expected - if the test ever starts failing as someone is changing code they care about what situation they are trying to make work, not how it did/didn't work previously

we generally want to avoid "it doesn't crash" tests as it doesn't help anyone coming along working on the code - "not crashing" is a weak test that doesn't prove the behavior does anything but not crash - those are better for large bulk test corpuses or indirectly via e2e tests - if adding a test and fixing code then a test should be added for the behavior being modified/fixed/etc. here, for example, whatever ops or attributes caused the crash need to be CHECKed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Ben. I'll create an issue to link with and look for a different way to test these changes. I found it difficult to add a test since I'm essentially only changing the analysis phase of this pass: the actual behavior on IR should not be affected.

Testing a batched aten.multinomial op e2e would be sufficient to cover this, but that might also need some changes from #19563 and #19556 to pass.

// This means that %3#0 got assigned a ConstValueInfo when expanding, but %3#1 did not (it is unused).
// HoistIntoGlobalsPass includes an assertion that, for each *op* succeeding ConstValueInfo lookup,
// every result of the op also has ConstValueInfo. This assert caused the compiler to abort when checking %3#1.
// This lit test simply verifies that the pass does not crash.

// CHECK-LABEL: @nested_bodies_unused_result_no_crash
module @nested_bodies_unused_result_no_crash {
util.func public @main() -> tensor<i32> {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c10_i32 = arith.constant 10 : i32
%0 = tensor.empty() : tensor<i32>
%1 = scf.for %arg0 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg1 = %0) -> (tensor<i32>) : i32 {
%2 = scf.for %arg2 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg3 = %arg1) -> (tensor<i32>) : i32 {
%3:2 = "iree_unregistered.const_expr"(%arg0, %arg2) : (i32, i32) -> (i32, i32)
%inserted = tensor.insert %3#0 into %arg3[] : tensor<i32>
scf.yield %inserted : tensor<i32>
}
scf.yield %2 : tensor<i32>
}
util.return %1 : tensor<i32>
}
}
Loading