Skip to content

Commit

Permalink
[CPU] Improve vector tile sizes for sub-byte matmuls on Aarch64 (#16143)
Browse files Browse the repository at this point in the history
This PR introduces a simple heuristic to make sure that we at least fill
one vector register for the smallest data type used in the matmul. For
example, given a 128-bit vector and a `i32 <- i4, i4` matmul, we used 16
tile size for the main vector dimension (16x4 = 64 bits, half vector).
With this PR we use 32 (32x4 = 128 bits, full vector).
  • Loading branch information
dcaballe authored Jan 19, 2024
1 parent f707927 commit 04d5849
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 10 deletions.
82 changes: 72 additions & 10 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,9 +944,10 @@ setMatmulRootConfig(func::FuncOp entryPointFn,

/// Returns default hard-coded vector sizes for a give target. No smartness
/// should be introduced in this utility.
static void getDefaultMatmulVectorSizes(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &sizes,
SmallVectorImpl<bool> &scalableSizeFlags, int64_t vectorSize) {
static void
getDefaultMatmulVectorSizes(linalg::LinalgOp op, int64_t vectorSize,
SmallVectorImpl<int64_t> &sizes,
SmallVectorImpl<bool> &scalableSizeFlags) {
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
if (isX86(targetAttr)) {
sizes.append({8, 32, 16});
Expand Down Expand Up @@ -994,6 +995,57 @@ static FailureOr<Type> nonWideningLinalgElementType(linalg::LinalgOp op) {
return inputAndOutputElementTypes[0];
}

static void getFullRegisterHeuristicsMatmulVectorSizes(
func::FuncOp entryPointFn, linalg::LinalgOp op, int64_t vectorSize,
SmallVectorImpl<int64_t> &sizes, SmallVectorImpl<bool> &scalableSizeFlags) {
}

/// Compute or adjust existing vector sizes using a generic heuristic that will
/// aim to fill at least one full vector register for all the element types of
/// the matmul. For now, the current heuristics only look at the N dimension but
/// we would introduce logic to also consider unrolling trade-offs between the
/// M, N and K.
///
/// Example: for an (i32 <- i8, i8) matmul and a 128-bit vector register, vector
/// size N would be at least 128/8=16.
///
/// NOTE: This function should not contain target-specific conditional code.
/// TODO: Currently it's only use on Aarch64. We should generalize it to other
/// targets.
static void getMatmulVectorSizesUsingFullVectorHeuristics(
func::FuncOp entryPointFn, linalg::LinalgOp op, int64_t vectorSize,
SmallVectorImpl<int64_t> &sizes, SmallVectorImpl<bool> &scalableSizeFlags) {
if (sizes.empty())
getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags);

// Find the smallest type size in the matmul.
SmallVector<Type> matmulTypes;
auto operandTypes = op->getOperandTypes();
matmulTypes.append(operandTypes.begin(), operandTypes.end());
auto resultTypes = op->getResultTypes();
matmulTypes.append(resultTypes.begin(), resultTypes.end());

int64_t minSize = std::numeric_limits<int64_t>::max();
for (Type mmType : matmulTypes) {
if (auto shType = dyn_cast<ShapedType>(mmType))
mmType = shType.getElementType();

if (mmType.isSignlessIntOrFloat())
minSize = std::min<int64_t>(minSize, mmType.getIntOrFloatBitWidth());
}

LLVM_DEBUG(KD_DBGS() << "Smallest type found: " << minSize << " bits\n");
assert(minSize > 0 && minSize < std::numeric_limits<int64_t>::max() &&
"Min size couldn't be computed");

// Make sure that the smallest type can at least fill a full vector register
// given the tile size of the main vector dimension (N).
constexpr int64_t byteSizeInBits = 8;
int64_t minNumElements =
(getNativeVectorSizeInBytes(entryPointFn) * byteSizeInBits) / minSize;
sizes[1] = std::max<int64_t>(sizes[1], minNumElements);
}

/// Utility to compute the tile sizes for AArch64 SME. Unlike other targets, the
/// tile sizes picked here must exactly match the SME hardware virtual tiles, as
/// there is currently no support for lowering non-standard shapes.
Expand Down Expand Up @@ -1034,16 +1086,26 @@ static SizesAndScalableFlags getMatmulVectorSizes(func::FuncOp entryPointFn,

// TODO: Compute vector tile sizes using heuristics.

if (isAArch64(targetAttr) && hasSMEFeature(targetAttr)) {
// Note: This may not pick any sizes (which will fallback to the default
// SVE) sizes below.
getMatmulAArch64SMEVectorSizes(op, matmulTileSizes, matmulScalableFlags);
if (isAArch64(targetAttr)) {
if (hasSMEFeature(targetAttr)) {
// Note: This may not pick any sizes (which will fallback to the SVE
// heuristics below).
getMatmulAArch64SMEVectorSizes(op, matmulTileSizes, matmulScalableFlags);
}

// Try to maximize the vector register utilization for all the matmul
// element types.
if (matmulTileSizes.empty()) {
getMatmulVectorSizesUsingFullVectorHeuristics(
entryPointFn, op, vectorSize, matmulTileSizes, matmulScalableFlags);
}
}

// Get default hard-coded tile sizes if we couldn't compute anything better.
// If tile sizes were not computed by previous heuristics, use default
// hard-coded tile sizes.
if (matmulTileSizes.empty()) {
getDefaultMatmulVectorSizes(op, matmulTileSizes, matmulScalableFlags,
vectorSize);
getDefaultMatmulVectorSizes(op, vectorSize, matmulTileSizes,
matmulScalableFlags);
}
// Pad the scalable flags with false to match the tile sizes.
matmulScalableFlags.resize(matmulTileSizes.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,60 @@ hal.executable private @matmul_tensors_default {

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>,
#hal.descriptor_set.binding<3, storage_buffer>
]>
]>
hal.executable private @i4_i4_i32_matmul {
hal.executable.variant @llvm target(<"llvm-cpu", "embedded-elf-arm_64", {
data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
native_vector_size = 16 : index,
target_triple = "aarch64-none-elf"
}>) {
hal.executable.export @i4_i4_i32_matmul layout(#pipeline_layout)
builtin.module {
func.func @i4_i4_i32_matmul() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%M = hal.interface.constant.load[0] : index
%N = hal.interface.constant.load[1] : index
%K = hal.interface.constant.load[2] : index
%lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
: !flow.dispatch.tensor<readonly:tensor<?x?xi4>>{%M, %K}
%rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
: !flow.dispatch.tensor<readonly:tensor<?x?xi4>>{%K, %N}
%init_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
: !flow.dispatch.tensor<readonly:tensor<?x?xi32>>{%M, %N}
%result_binding = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer)
: !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%M, %N}
%lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xi4>>{%M, %K} -> tensor<?x?xi4>
%rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xi4>>{%K, %N} -> tensor<?x?xi4>
%init = flow.dispatch.tensor.load %init_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xi32>>{%M, %N} -> tensor<?x?xi32>
%gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi4>, tensor<?x?xi4>) outs(%init : tensor<?x?xi32>) -> tensor<?x?xi32>
flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
: tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%M, %N}
return
}
}
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [64, 64, 0], [0, 0, 0], [8, 32, 0], [0, 0, 1], [0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingPeelingExpert>
// CHECK: hal.executable.export public @i4_i4_i32_matmul
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
Expand Down

0 comments on commit 04d5849

Please sign in to comment.