Skip to content

Commit

Permalink
Match TileAndFuse Matmul Heuristics to VectorDistibute and raise limi…
Browse files Browse the repository at this point in the history
…t of TileLargeTensorPass

Signed-off-by: Nirvedh Meshram <[email protected]>
  • Loading branch information
nirvedhmeshram committed Jan 10, 2025
1 parent 982856b commit c2326e1
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 35 deletions.
22 changes: 22 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem,
bool transposedLhs, bool transposedRhs) {
bool isAligned = isScheduleAligned(problem, schedule, mustBeAligned);

LLVM_DEBUG({
llvm::dbgs() << "while checkging valid sched\n";
llvm::dbgs() << schedule << "\n";
llvm::dbgs() << mustBeAligned << "\n";
llvm::dbgs() << subgroupSize << "\n";
llvm::dbgs() << transposedLhs << "\n";
llvm::dbgs() << transposedRhs << "\n";
});

// Constraint to ensure wgTileSize is distributable by wgSize.
// such that we can distribute to it's corresponding vector.transfer_read.
const int64_t kMaxVectorLoadBitWidth = 128;
Expand Down Expand Up @@ -388,6 +397,14 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
int64_t subgroupSize, bool transposedLhs, bool transposedRhs,
bool canUpcastAcc, bool mustBeAligned, bool doCPromotion) {

LLVM_DEBUG({
llvm::dbgs() << "into deduced mma sced\n";
llvm::dbgs() << mustBeAligned << "\n";
llvm::dbgs() << subgroupSize << "\n";
llvm::dbgs() << transposedLhs << "\n";
llvm::dbgs() << transposedRhs << "\n";
});
for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
if (failed(canTargetIntrinsic(problem, intrinsic, canUpcastAcc,
mustBeAligned))) {
Expand Down Expand Up @@ -426,6 +443,11 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
llvm::dbgs() << sharedMemoryUsed << " bytes\n";
});

LLVM_DEBUG({
llvm::dbgs() << "Aligned: ";
llvm::dbgs() << isAligned << "\n";
});

return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes;
};
return fitScheduleInSharedMemory(intrinsic, schedule, isValidSchedule);
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def TileLargeTensorsPass :
];
let options = [
Option<"maxVectorSize", "max-vector-size", "int64_t",
/*default=*/"64",
/*default=*/"256",
"Maximum static size to tile to (i.e. all remaining ops will be smaller)">,
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
// RUN: FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> {
func.func @simple_generic(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>, %5: tensor<64x512xf32>) -> tensor<64x512xf32> {
%6 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]
} ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) {
} ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%5 : tensor<64x512xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.addf %in, %in_0 : f32
linalg.yield %7 : f32
} -> tensor<64x256xf32>
return %6 : tensor<64x256xf32>
} -> tensor<64x512xf32>
return %6 : tensor<64x512xf32>
}

// CHECK-LABEL: func.func @simple_generic
// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>)
// CHECK: scf.for %{{.*}} = %c0 to %c512 step %c256
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x256xf32>)

// -----

Expand Down Expand Up @@ -65,21 +65,21 @@ func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: te

// -----

func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>) {
%add_empty = tensor.empty() : tensor<64x256xf32>
func.func @multiple_use_tilable_op(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>) -> (tensor<64x512xf32>, tensor<512x64xf32>) {
%add_empty = tensor.empty() : tensor<64x512xf32>
%6 = linalg.add
ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>)
outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32>
%transpose_empty = tensor.empty() : tensor<256x64xf32>
ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>)
outs(%add_empty : tensor<64x512xf32>) -> tensor<64x512xf32>
%transpose_empty = tensor.empty() : tensor<512x64xf32>
%7 = linalg.transpose
ins(%6 : tensor<64x256xf32>)
outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0]
return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32>
ins(%6 : tensor<64x512xf32>)
outs(%transpose_empty : tensor<512x64xf32>) permutation = [1, 0]
return %6, %7 : tensor<64x512xf32>, tensor<512x64xf32>
}

// CHECK-LABEL: func.func @multiple_use_tilable_op
// CHECK: %[[ADD_TILING:.+]] = scf.for
// CHECK: linalg.add {{.*}} -> tensor<1x64xf32>
// CHECK: linalg.add {{.*}} -> tensor<1x256xf32>
// CHECK: %[[T_TILING:.+]] = scf.for
// CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32>
// CHECK: linalg.transpose ins(%[[FUSED_ADD]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
GPUMMAHeuristicSeeds seeds;
assert(problem.aType == problem.bType &&
"expected the same aType and bType.");
int64_t inBitWidth = problem.aType.getIntOrFloatBitWidth();

// Note that the following heuristic seeds are just placeholder values.
// We need to clean it up and make it adjusting to different targets.
Expand All @@ -148,26 +147,20 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
// and a larger bestKTileCountPerSubgroup.
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/4,
/*bestKTileCountPerSubgroup=*/8,
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
/*bestKTileCountPerSubgroup=*/8};
} else {
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/16,
/*bestKTileCountPerSubgroup=*/4,
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
inBitWidth};
/*bestMNTileCountPerSubgroup=*/8,
/*bestKTileCountPerSubgroup=*/4};
}

int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();

// First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
/*transposedLhs=*/false, /*transposedRhs=*/false, /*canUpcastAcc=*/false,
/*mustBeAligned*/ mustBeAligned, doCPromotion);
// TODO (nirvedhmeshram) : Add support for upcasting accumulator schedule.
// Currently we dont have this for TileAndFuse path, see
// https://github.com/iree-org/iree/issues/19532
return schedule;
}

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
/*canUpcastAcc=*/true);
}

LDBG("transposedLhs: " << transposedLhs);
LDBG("transposedRhs: " << transposedRhs);

// Only batch_matmul is supported in the LLVMGPUPadAndVectorDistribute
// pipeline.
// TODO(hanchung): Support cases that there are fused producers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [1, 1, 4, 1, 0]
// CHECK-SAME: workgroup = [1, 1, 64, 64, 0]

Expand Down Expand Up @@ -74,7 +74,7 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8, 1]
// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0]
// CHECK-SAME: workgroup = [2, 2, 32, 32, 0, 0]

Expand Down Expand Up @@ -136,9 +136,9 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<
// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 2]
// CHECK-SAME: subgroup = [4, 4, 0]
// CHECK-SAME: workgroup = [128, 128, 0]
// CHECK-SAME: reduction = [0, 0, 4]
// CHECK-SAME: subgroup = [2, 4, 0]
// CHECK-SAME: workgroup = [64, 128, 0]

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,9 +1013,8 @@ hal.executable public @main {
// CHECK: scf.yield %[[REDUCE]]

// CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1
// CHECK: scf.for
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>

// -----

Expand Down

0 comments on commit c2326e1

Please sign in to comment.