Skip to content

Commit

Permalink
[Codegen] Allow multiple reduction dimensions in VectorDistribution (#…
Browse files Browse the repository at this point in the history
…18868)

This PR adds support for multiple k dimensions in VectorDistribution
contract codegen.
  • Loading branch information
Groverkss authored Oct 22, 2024
1 parent b922a70 commit 81c8b25
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1280,9 +1280,6 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo,
llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
llvm::errs() << "For schedule: " << *this << "\n";
});
if (opInfo.getKDims().size() != 1) {
return contractOp->emitError("Unimplemented: > 1 k dims");
}

int64_t rank = contractOp.getIteratorTypesArray().size();
auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
Expand Down Expand Up @@ -1450,6 +1447,10 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo,
aSubgroupSizes[dim] = subgroupMBasis[i];
aSubgroupStrides[dim] = subgroupMStrides[i];
}
for (auto [kDim, lhsKDim] :
llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
aBatchSizes[lhsKDim] = bounds[kDim];
}
aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;

auto aLayout = createNestedLayout(context, aRank, afm, afk,
Expand All @@ -1470,6 +1471,10 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo,
bSubgroupSizes[dim] = subgroupNBasis[i];
bSubgroupStrides[dim] = subgroupNStrides[i];
}
for (auto [kDim, rhsKDim] :
llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
bBatchSizes[rhsKDim] = bounds[kDim];
}
bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;

auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,55 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

hal.executable @matmul_multiple_k {
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @matmul_multiple_k layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @matmul_multiple_k() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x128x64x2048xf16>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<10x128x64x2048xf16>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 128, 64, 2048], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128x64x2048xf16>> -> tensor<2x128x64x2048xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [10, 128, 64, 2048], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<10x128x64x2048xf16>> -> tensor<10x128x64x2048xf16>
%5 = tensor.empty() : tensor<2x10x64x64xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : tensor<2x128x64x2048xf16>, tensor<10x128x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 1, 128], workgroup = [1, 1, 64, 64, 0, 0]}>} {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %8, %out : f16
linalg.yield %9 : f16
} -> tensor<2x10x64x64xf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 64, 64], strides = [1, 1, 1, 1] : tensor<2x10x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
return
}
}
}
}

// Check if we can handle multiple reduction dimensions and that they generate
// one coalesced loop.

// CHECK-LABEL: func.func @matmul_multiple_k
// CHECK: scf.for %[[IV:.+]] = %c0 to %c2048 step %c1
// CHECK: affine.delinearize_index %[[IV]] into (%c128, %c16)
// CHECK-COUNT-32: amdgpu.mfma
// CHECK: scf.yield
// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<2x10x64x64xf16, #hal.descriptor_type<storage_buffer>>

// -----

// Basic f8, f8 -> f32 matmul.

#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}>
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std::pair<int, int> VectorContractOpInfo::getOperandMNIndex() const {

// Returns the (LHS K, RHS K) dimension index pair.
std::pair<int, int> VectorContractOpInfo::getOperandKIndex() const {
return std::make_pair(lhsKDim, rhsKDim);
return std::make_pair(lhsKDim.back(), rhsKDim.back());
}

// Returns the result (M, N) dimension index pair.
Expand Down Expand Up @@ -55,9 +55,12 @@ VectorContractOpInfo::inferFromIndexingMaps(ArrayRef<AffineMap> maps) {
opInfo.outNDims.push_back(
*maps[2].getResultPosition(getAffineDimExpr(n, ctx)));
}
int64_t k = contractionDims.k.back();
opInfo.lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx));
opInfo.rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx));
for (auto k : contractionDims.k) {
opInfo.lhsKDim.push_back(
*maps[0].getResultPosition(getAffineDimExpr(k, ctx)));
opInfo.rhsKDim.push_back(
*maps[1].getResultPosition(getAffineDimExpr(k, ctx)));
}

opInfo.lhsUnitDims = maps[0].getBroadcastDims();
opInfo.rhsUnitDims = maps[1].getBroadcastDims();
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ class VectorContractOpInfo {
int64_t getBatchCount() const { return contractionDims.batch.size(); }

SmallVector<int64_t> lhsMDims;
int64_t lhsKDim;
SmallVector<int64_t> lhsKDim;
SmallVector<int64_t> rhsNDims;
int64_t rhsKDim;
SmallVector<int64_t> rhsKDim;
SmallVector<int64_t> outMDims;
SmallVector<int64_t> outNDims;

Expand Down

0 comments on commit 81c8b25

Please sign in to comment.