Skip to content

Commit

Permalink
[WIP] Replace batch matmul codegenration for unaligned shapes to use …
Browse files Browse the repository at this point in the history
…tile and fuse
  • Loading branch information
nirvedhmeshram committed Oct 16, 2024
1 parent 029fd7b commit 03deea8
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 4 deletions.
17 changes: 14 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ static Value skipAffineMaxZero(Value dim) {
return *affineMax.getSymbolOperands().begin();
}

static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp) {
template <typename OpTy>
static LogicalResult padAlloc(MLIRContext *context, OpTy allocOp) {
IRRewriter rewriter(context);
rewriter.setInsertionPoint(allocOp);
SmallVector<int64_t> shape = llvm::to_vector(allocOp.getType().getShape());
Expand Down Expand Up @@ -66,7 +67,7 @@ static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp) {
MemRefType allocType = MemRefType::get(shape, elType, AffineMap(),
allocOp.getType().getMemorySpace());
Location loc = allocOp.getLoc();
Value paddedAlloc = rewriter.create<memref::AllocOp>(loc, allocType);
Value paddedAlloc = rewriter.create<OpTy>(loc, allocType);
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
Value subview = rewriter.create<memref::SubViewOp>(loc, paddedAlloc, offsets,
Expand All @@ -88,7 +89,17 @@ struct PadDynamicAllocPass final
funcOp.walk(
[&](memref::AllocOp allocOp) { sharedMemAllocs.push_back(allocOp); });
for (memref::AllocOp alloc : sharedMemAllocs) {
if (failed(padAlloc(context, alloc)))
if (failed(padAlloc<memref::AllocOp>(context, alloc)))
return signalPassFailure();
}

SmallVector<memref::AllocaOp> privateMemAllocas;
// Collect all the alloc operations.
funcOp.walk([&](memref::AllocaOp allocaOp) {
privateMemAllocas.push_back(allocaOp);
});
for (memref::AllocaOp alloca : privateMemAllocas) {
if (failed(padAlloc<memref::AllocaOp>(context, alloca)))
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
transposedLhs, transposedRhs, /*canUpcastAcc=*/true);
}

// Only batch_matmul is supported in the LLVMGPUPadAndVectorDistribute
// pipeline.
// TODO(hanchung): Support cases that there are fused producers.
if (!schedule && !contractionDims.batch.empty() &&
!hasFusedLeadingOp(linalgOp)) {
LDBG("Tile and Fuse with pack/unpack");
bool mustBeAligned = false;
schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
targetSubgroupSize, transposedLhs, transposedRhs,
/*canUpcastAcc=*/false, mustBeAligned);
if (!schedule) {
// Then try again by allowing upcasting accumulator.
schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
targetSubgroupSize, transposedLhs, transposedRhs,
/*canUpcastAcc=*/true, mustBeAligned);
}
}

if (!schedule) {
LDBG("Failed to deduce TileAndFuse MMA schedule");
return failure();
Expand Down
9 changes: 8 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ llvm::cl::opt<bool> clGPUEnableVectorDistribution(
llvm::cl::desc("enable the usage of the vector distribution pipeline"),
llvm::cl::init(true));

llvm::cl::opt<bool> clGPUUnalignedGEMMVectorDistribution(
"iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution",
llvm::cl::desc("enable the usage of the vector distribution pipeline for "
"unaligned GEMMs when supported"),
llvm::cl::init(false));

llvm::cl::opt<bool> clGPUEnableTransformDialectJit(
"iree-codegen-llvmgpu-enable-transform-dialect-jit",
llvm::cl::desc("enable the usage of the transform dialect JIT"),
Expand Down Expand Up @@ -562,7 +568,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
// Only batch_matmul is supported in the LLVMGPUPadAndVectorDistribute
// pipeline.
// TODO(hanchung): Support cases that there are fused producers.
if (!schedule && !contractionDims->batch.empty() && !hasFusedLeadingOp(op)) {
if (!schedule && !contractionDims->batch.empty() && !hasFusedLeadingOp(op) &&
clGPUUnalignedGEMMVectorDistribution) {
LDBG("Matmul Pad and Vector Distribute");
pipeline = CodeGenPipeline::LLVMGPUPadAndVectorDistribute;
bool mustBeAligned = false;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution \
// RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s

// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
Expand Down
19 changes: 19 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,25 @@ replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use,
});
return llvm::to_vector_of<Value>(newExpandOp->getResults());
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(user)) {
auto newSourceType = llvm::cast<MemRefType>(replacement.getType());

FailureOr<MemRefType> newResultType =
memref::CollapseShapeOp::computeCollapsedType(
newSourceType, collapseOp.getReassociationIndices());
if (failed(newResultType)) {
return std::nullopt;
}

auto newCollapseOp = rewriter.create<memref::CollapseShapeOp>(
loc, *newResultType, replacement, collapseOp.getReassociation());
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return llvm::to_vector_of<Value>(newCollapseOp->getResults());
}
return std::nullopt;
}

Expand Down

0 comments on commit 03deea8

Please sign in to comment.