From 6fd0fd07c0c5a4ad432969f08c3b16ce51ad9c37 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 14 Jan 2025 16:29:04 +0000 Subject: [PATCH] [LinalgExt] Implement PartialReductionOpInterface for OnlineAttentionOp (#19684) --- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 7 +- .../LinalgExt/IR/TilingInterfaceImpl.cpp | 346 +++++++++++++++++- .../LinalgExt/Transforms/test/tiling.mlir | 124 +++++++ 3 files changed, 462 insertions(+), 15 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 067c54d412e8..7d43c09137c2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -649,7 +649,12 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", ["getIterationDomain", "getLoopIteratorTypes", "getResultTilePosition", - "getTiledImplementation"]>]> { + "getTiledImplementation"]>, + DeclareOpInterfaceMethods]> { let summary = "Online Attention operator"; let description = [{ Traditional scaled dot product attention computes: diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp index 4b21177e3287..ee7d053267a2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp @@ -1844,7 +1844,7 @@ getAttentionIteratorTypes(int64_t domainRank, AffineMap qMap, AffineMap kMap, return iteratorTypes; } -static SmallVector getPermutedSlice(AffineMap permutation, +static SmallVector getPermutedRange(AffineMap permutation, ArrayRef offsets, ArrayRef sizes) { auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1); @@ -1862,6 +1862,15 @@ static SmallVector getPermutedSlice(AffineMap permutation, return output; } +static Operation *getPermutedSlice(OpBuilder &b, Location loc, Value val, + AffineMap permutation, + ArrayRef offsets, + ArrayRef sizes) { + SmallVector slice = getPermutedRange(permutation, offsets, sizes); + Operation *querySliceOp = getSlice(b, loc, val, slice); + return querySliceOp; +} + //===----------------------------------------------------------------------===// // AttentionOp //===----------------------------------------------------------------------===// @@ -1890,12 +1899,12 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, Location loc = getLoc(); SmallVector querySlice = - getPermutedSlice(getQueryMap(), offsets, sizes); - SmallVector keySlice = getPermutedSlice(getKeyMap(), offsets, sizes); + getPermutedRange(getQueryMap(), offsets, sizes); + SmallVector keySlice = getPermutedRange(getKeyMap(), offsets, sizes); SmallVector valueSlice = - getPermutedSlice(getValueMap(), offsets, sizes); + getPermutedRange(getValueMap(), offsets, sizes); SmallVector outputSlice = - getPermutedSlice(getOutputMap(), offsets, sizes); + getPermutedRange(getOutputMap(), offsets, sizes); Value scale = getScale(); @@ -1939,7 +1948,7 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, Value attnMask = getMask(); if (attnMask) { SmallVector maskSlice = - getPermutedSlice(*getMaskMap(), offsets, sizes); + getPermutedRange(*getMaskMap(), offsets, sizes); Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice); tiledOperands.emplace_back(maskSliceOp->getResult(0)); slices.push_back(maskSliceOp); @@ -2043,19 +2052,19 @@ OnlineAttentionOp::getTiledImplementation(OpBuilder &builder, Location loc = getLoc(); SmallVector querySlice = - getPermutedSlice(getQueryMap(), offsets, sizes); - SmallVector keySlice = getPermutedSlice(getKeyMap(), offsets, sizes); + getPermutedRange(getQueryMap(), offsets, sizes); + SmallVector keySlice = getPermutedRange(getKeyMap(), offsets, sizes); SmallVector valueSlice = - getPermutedSlice(getValueMap(), offsets, sizes); + getPermutedRange(getValueMap(), offsets, sizes); std::optional> maskSlice; if (auto maskMap = getMaskMap()) { - maskSlice = getPermutedSlice(*maskMap, offsets, sizes); + maskSlice = getPermutedRange(*maskMap, offsets, sizes); } SmallVector outputSlice = - getPermutedSlice(getOutputMap(), offsets, sizes); - SmallVector maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes); - SmallVector sumSlice = getPermutedSlice(getSumMap(), offsets, sizes); + getPermutedRange(getOutputMap(), offsets, sizes); + SmallVector maxSlice = getPermutedRange(getMaxMap(), offsets, sizes); + SmallVector sumSlice = getPermutedRange(getSumMap(), offsets, sizes); Value scale = getScale(); @@ -2097,7 +2106,7 @@ OnlineAttentionOp::getTiledImplementation(OpBuilder &builder, Value attnMask = getMask(); if (attnMask) { SmallVector maskSlice = - getPermutedSlice(*getMaskMap(), offsets, sizes); + getPermutedRange(*getMaskMap(), offsets, sizes); Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice); tiledOperands.emplace_back(maskSliceOp->getResult(0)); slices.push_back(maskSliceOp); @@ -2175,6 +2184,315 @@ LogicalResult OnlineAttentionOp::getResultTilePosition( return success(); } +static AffineMap getPartialResultMap(AffineMap map, AttentionOpDetail &opInfo) { + // Append K2 dimensions at end. + for (int dim : opInfo.getK2Dims()) { + map = map.insertResult(getAffineDimExpr(dim, map.getContext()), + map.getNumResults()); + } + return map; +} + +FailureOr> +OnlineAttentionOp::generateInitialTensorForPartialReduction( + OpBuilder &b, Location loc, ArrayRef sizes, + ArrayRef reductionDim) { + FailureOr maybeOpInfo = AttentionOpDetail::get( + getQueryMap(), getKeyMap(), getValueMap(), getOutputMap()); + if (failed(maybeOpInfo)) { + return emitOpError("failed to verify op's indexing maps"); + } + AttentionOpDetail &opInfo = maybeOpInfo.value(); + + SmallVector shape = llvm::map_to_vector( + getIterationDomain(b), [](Range x) { return x.size; }); + + SmallVector tiledShape; + for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { + if (isZeroIndex(tileSize)) { + tiledShape.push_back(dimSize); + } else { + tiledShape.push_back(tileSize); + } + } + + SmallVector accSize = applyPermutationMap( + getPartialResultMap(getOutputMap(), opInfo), tiledShape); + SmallVector maxSize = applyPermutationMap( + getPartialResultMap(getMaxMap(), opInfo), tiledShape); + SmallVector sumSize = applyPermutationMap( + getPartialResultMap(getSumMap(), opInfo), tiledShape); + + Type accElTy = getElementTypeOrSelf(getOutput().getType()); + Type maxElTy = getElementTypeOrSelf(getMax().getType()); + Type sumElTy = getElementTypeOrSelf(getSum().getType()); + + Value partialAcc = b.create(loc, accSize, accElTy); + Value partialMax = b.create(loc, maxSize, maxElTy); + Value partialSum = b.create(loc, sumSize, sumElTy); + + Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, accElTy, + b, loc, /*useOnlyFiniteValue=*/true); + Value maxInit = + arith::getIdentityValue(arith::AtomicRMWKind::maximumf, maxElTy, b, loc, + /*useOnlyFiniteValue=*/true); + Value sumInit = + arith::getIdentityValue(arith::AtomicRMWKind::addf, sumElTy, b, loc); + + Value accFill = b.create(loc, ValueRange{accInit}, partialAcc) + .getResult(0); + Value maxFill = b.create(loc, ValueRange{maxInit}, partialMax) + .getResult(0); + Value sumFill = b.create(loc, ValueRange{sumInit}, partialSum) + .getResult(0); + + return SmallVector{accFill, maxFill, sumFill}; +} + +FailureOr OnlineAttentionOp::tileToPartialReduction( + OpBuilder &b, Location loc, ValueRange init, ArrayRef offsets, + ArrayRef sizes, ArrayRef reductionDims) { + FailureOr maybeOpInfo = AttentionOpDetail::get( + getQueryMap(), getKeyMap(), getValueMap(), getOutputMap()); + if (failed(maybeOpInfo)) { + return emitOpError("failed to verify op's indexing maps"); + } + AttentionOpDetail &opInfo = maybeOpInfo.value(); + + // Extend result maps, keeping everything else the same. + AffineMap partialAccMap = getPartialResultMap(getOutputMap(), opInfo); + AffineMap partialMaxMap = getPartialResultMap(getMaxMap(), opInfo); + AffineMap partialSumMap = getPartialResultMap(getSumMap(), opInfo); + + SmallVector indexingMaps = getIndexingMapsArray(); + indexingMaps[getNumDpsInputs()] = partialAccMap; + indexingMaps[getNumDpsInputs() + 1] = partialMaxMap; + indexingMaps[getNumDpsInputs() + 2] = partialSumMap; + + SmallVector tiledOperands; + SmallVector slices; + + auto appendSlice = [&](Value val, AffineMap map, + ArrayRef offsets) -> LogicalResult { + Operation *sliceOp = getPermutedSlice(b, loc, val, map, offsets, sizes); + if (!sliceOp) { + return emitOpError("failed to get slice"); + } + tiledOperands.emplace_back(sliceOp->getResult(0)); + slices.push_back(sliceOp); + return success(); + }; + + if (failed(appendSlice(getQuery(), getQueryMap(), offsets))) { + return failure(); + } + if (failed(appendSlice(getKey(), getKeyMap(), offsets))) { + return failure(); + } + if (failed(appendSlice(getValue(), getValueMap(), offsets))) { + return failure(); + } + + tiledOperands.emplace_back(getScale()); + + if (Value mask = getMask()) { + if (failed(appendSlice(mask, *getMaskMap(), offsets))) { + return failure(); + } + } + + // For results, we set offset of the iterated reduction dims to 0. + SmallVector initOffsets(offsets); + for (int dim : opInfo.getK2Dims()) { + initOffsets[dim] = b.getIndexAttr(0); + } + + if (failed(appendSlice(init[0], partialAccMap, initOffsets))) { + return failure(); + } + if (failed(appendSlice(init[1], partialMaxMap, initOffsets))) { + return failure(); + } + if (failed(appendSlice(init[2], partialSumMap, initOffsets))) { + return failure(); + } + + // Get the initial values. + ValueRange slicedInits = ArrayRef(tiledOperands).take_back(3); + + auto tiledOp = cast( + mlir::clone(b, getOperation(), slicedInits.getTypes(), tiledOperands)); + tiledOp.setIndexingMapsAttr(b.getAffineMapArrayAttr(indexingMaps)); + + return TilingResult{ + {tiledOp}, SmallVector(tiledOp->getResults()), slices}; +} + +template +static linalg::ReduceOp reduceOnK2(OnlineAttentionOp attn, AffineMap partialMap, + AttentionOpDetail &opInfo, OpBuilder &b, + Location loc, Value partialResult, + Value init) { + // linalg.reduce's iteration space is the result's iteration space (and + // not the operations iteration space). To account for this, permute the + // reduction dimensions based on the partial result map. + SmallVector partialReductionDims; + for (auto [resultNum, dimExpr] : llvm::enumerate(partialMap.getResults())) { + unsigned dim = cast(dimExpr).getPosition(); + if (llvm::find(opInfo.getK2Dims(), dim) != opInfo.getK2Dims().end()) { + partialReductionDims.push_back(resultNum); + } + } + + return b.create( + loc, partialResult, init, partialReductionDims, + [&](OpBuilder &b, Location loc, ValueRange inputs) { + Value reduced = b.create(loc, inputs[0], inputs[1]); + b.create(loc, reduced); + }); +}; + +template +static Value elementwiseValueInPlace(OpBuilder &builder, Location loc, + AffineMap inputMap, AffineMap scaleMap, + Value value, Value scale) { + SmallVector compressedMaps = + compressUnusedDims(SmallVector{inputMap, scaleMap}); + inputMap = compressedMaps[0]; + scaleMap = compressedMaps[1]; + + SmallVector iteratorTypes(inputMap.getNumDims(), + utils::IteratorType::parallel); + + auto genericOp = builder.create( + loc, value.getType(), scale, value, + SmallVector{scaleMap, inputMap}, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // Convert scale to the same datatype as input. + Value scale = convertScalarToDtype(b, loc, args[0], args[1].getType(), + /*isUnsignedCast=*/false); + Value result = b.create(loc, scale, args[1]); + b.create(loc, result); + }); + return genericOp.getResult(0); +} + +// Compute output = exp2(output - input) +static Value computeSubAndExp2(OpBuilder &builder, Location loc, + AffineMap inputMap, AffineMap outputMap, + Value input, Value output) { + SmallVector compressedMaps = + compressUnusedDims(SmallVector{inputMap, outputMap}); + inputMap = compressedMaps[0]; + outputMap = compressedMaps[1]; + + SmallVector iteratorTypes(inputMap.getNumDims(), + utils::IteratorType::parallel); + auto genericOp = builder.create( + loc, output.getType(), input, output, + SmallVector{inputMap, outputMap}, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // Convert input to the same datatype as output. + Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), + /*isUnsignedCast=*/false); + Value diff = b.create(loc, args[1], in); + Value weight = b.create(loc, diff); + b.create(loc, weight); + }); + return genericOp.getResult(0); +} + +FailureOr +OnlineAttentionOp::mergeReductions(OpBuilder &b, Location loc, + ValueRange partialReduce, + ArrayRef reductionDim) { + FailureOr maybeOpInfo = AttentionOpDetail::get( + getQueryMap(), getKeyMap(), getValueMap(), getOutputMap()); + if (failed(maybeOpInfo)) { + return emitOpError("failed to verify op's indexing maps"); + } + AttentionOpDetail &opInfo = maybeOpInfo.value(); + + AffineMap partialAccMap = getPartialResultMap(getOutputMap(), opInfo); + AffineMap partialMaxMap = getPartialResultMap(getMaxMap(), opInfo); + AffineMap partialSumMap = getPartialResultMap(getSumMap(), opInfo); + + // newMax = max(maxInit, rowMax(partialMax)) + linalg::ReduceOp reducedMax = reduceOnK2( + *this, partialMaxMap, opInfo, b, loc, partialReduce[1], getMax()); + + // norm = exp2(partialMax - newMax) + Value norm = computeSubAndExp2(b, loc, getMaxMap(), partialMaxMap, + reducedMax.getResult(0), partialReduce[1]); + + // normSum = norm * partialSum + Value normSum = elementwiseValueInPlace( + b, loc, partialSumMap, partialMaxMap, partialReduce[2], norm); + + // newSum = sumInit + rowSum(partialSum) + linalg::ReduceOp reducedSum = reduceOnK2( + *this, partialSumMap, opInfo, b, loc, normSum, getSum()); + + // normAcc = norm * partialAcc + Value normAcc = elementwiseValueInPlace( + b, loc, partialAccMap, partialMaxMap, partialReduce[0], norm); + + // newAcc = accInit + rowMax(partialAcc) + linalg::ReduceOp reducedAcc = reduceOnK2( + *this, partialAccMap, opInfo, b, loc, normAcc, getOutput()); + + return MergeResult{{reducedAcc, reducedMax, reducedSum}, + {reducedAcc.getResult(0), reducedMax.getResult(0), + reducedSum.getResult(0)}}; +} + +LogicalResult OnlineAttentionOp::getPartialResultTilePosition( + OpBuilder &b, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes, ArrayRef reductionDims) { + + FailureOr maybeOpInfo = AttentionOpDetail::get( + getQueryMap(), getKeyMap(), getValueMap(), getOutputMap()); + if (failed(maybeOpInfo)) { + return emitOpError("failed to verify op's indexing maps"); + } + AttentionOpDetail &opInfo = maybeOpInfo.value(); + + resultOffsets.clear(); + resultSizes.clear(); + + AffineMap resultIndexingMap; + switch (resultNumber) { + case 0: + resultIndexingMap = getOutputMap(); + break; + case 1: + resultIndexingMap = getMaxMap(); + break; + case 2: + resultIndexingMap = getSumMap(); + break; + default: + return failure(); + } + + AffineMap partialMap = getPartialResultMap(resultIndexingMap, opInfo); + + for (AffineExpr dimExpr : partialMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + resultSizes.push_back(sizes[dim]); + + if (llvm::find(opInfo.getK2Dims(), dim) != opInfo.getK2Dims().end()) { + // Reduction dims are reduced, and are always outputed in the same + // place. So use offset 0 for them. + resultOffsets.push_back(b.getIndexAttr(0)); + } else { + resultOffsets.push_back(offsets[dim]); + } + } + return success(); +} + //===---------------------------------------------------------------------===// // CustomOp //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index a6baca15a877..6d2f05793302 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -2057,6 +2057,130 @@ module attributes { transform.with_named_sequence } { } } + +// ----- + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_partial_reduction(%query: tensor<192x?x64xf32>, %key: tensor<192x?x64xf32>, %value: tensor<192x?x64xf32>) -> (tensor<192x?x64xf32>, tensor<192x?xf32>) { + %scale = arith.constant 1.0 : f32 + + %c1 = arith.constant 1 : index + + %m = tensor.dim %query, %c1 : tensor<192x?x64xf32> + %k2 = tensor.dim %key, %c1 : tensor<192x?x64xf32> + + %output_empty = tensor.empty(%m) : tensor<192x?x64xf32> + %row_red_empty = tensor.empty(%m) : tensor<192x?xf32> + + %sum_ident = arith.constant 0.000000e+00 : f32 + %max_ident = arith.constant -3.40282347E+38 : f32 + + %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x?x64xf32>) -> tensor<192x?x64xf32> + %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x?xf32>) -> tensor<192x?xf32> + %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x?xf32>) -> tensor<192x?xf32> + + %out:3 = iree_linalg_ext.online_attention + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale : tensor<192x?x64xf32>, tensor<192x?x64xf32>, tensor<192x?x64xf32>, f32) + outs(%output_fill, %acc_fill, %sum_fill : tensor<192x?x64xf32>, tensor<192x?xf32>, tensor<192x?xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x?x64xf32>, tensor<192x?xf32>, tensor<192x?xf32> + + return %out#0, %out#2 : tensor<192x?x64xf32>, tensor<192x?xf32> +} + +// CHECK-LABEL: func.func @online_attention_partial_reduction +// CHECK-SAME: (%[[Q:.+]]: tensor<192x?x64xf32>, %[[K:.+]]: tensor<192x?x64xf32>, %[[V:.+]]: tensor<192x?x64xf32>) + +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[Q]], %c1 : tensor<192x?x64xf32> +// CHECK-DAG: %[[K2:.+]] = tensor.dim %[[K]], %c1 : tensor<192x?x64xf32> + +// CHECK-DAG: %[[OUT_E:.+]] = tensor.empty(%[[M]]) : tensor<192x?x64xf32> +// CHECK-DAG: %[[RED_E:.+]] = tensor.empty(%[[M]]) : tensor<192x?xf32> + +// CHECK-DAG: %[[SUM_INIT:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[MAX_INIT:.+]] = arith.constant -3.40282347E+38 : f32 + +// CHECK-DAG: %[[OUT:.+]] = linalg.fill ins(%[[SUM_INIT]] : f32) outs(%[[OUT_E]] : tensor<192x?x64xf32>) -> tensor<192x?x64xf32> +// CHECK-DAG: %[[MAX:.+]] = linalg.fill ins(%[[MAX_INIT]] : f32) outs(%[[RED_E]] : tensor<192x?xf32>) -> tensor<192x?xf32> +// CHECK-DAG: %[[SUM:.+]] = linalg.fill ins(%[[SUM_INIT]] : f32) outs(%[[RED_E]] : tensor<192x?xf32>) -> tensor<192x?xf32> + +// CHECK-DAG: %[[OUT_PART_E:.+]] = tensor.empty(%[[M]]) : tensor<192x?x64x32xf32> +// CHECK-DAG: %[[RED_PART_E:.+]] = tensor.empty(%[[M]]) : tensor<192x?x32xf32> + +// CHECK-DAG: %[[OUT_PART:.+]] = linalg.fill ins(%[[SUM_INIT]] : f32) outs(%[[OUT_PART_E]] : tensor<192x?x64x32xf32>) -> tensor<192x?x64x32xf32> +// CHECK-DAG: %[[MAX_PART:.+]] = linalg.fill ins(%[[MAX_INIT]] : f32) outs(%[[RED_PART_E]] : tensor<192x?x32xf32>) -> tensor<192x?x32xf32> +// CHECK-DAG: %[[SUM_PART:.+]] = linalg.fill ins(%[[SUM_INIT]] : f32) outs(%[[RED_PART_E]] : tensor<192x?x32xf32>) -> tensor<192x?x32xf32> + +// CHECK: %[[ITER:.+]]:3 = scf.for %[[IV:.+]] = %c0 to %[[K2]] step %c32 +// CHECK-SAME: iter_args(%[[OUT_ITER:.+]] = %[[OUT_PART]], %[[MAX_ITER:.+]] = %[[MAX_PART]], %[[SUM_ITER:.+]] = %[[SUM_PART]]) +// CHECK: %[[MIN:.+]] = affine.min +// CHECK: %[[Q_SLICE:.+]] = tensor.extract_slice %[[Q]][0, 0, 0] [192, %[[M]], 64] [1, 1, 1] : tensor<192x?x64xf32> to tensor<192x?x64xf32> +// CHECK: %[[K_SLICE:.+]] = tensor.extract_slice %[[K]][0, %[[IV]], 0] [192, %[[MIN]], 64] [1, 1, 1] : tensor<192x?x64xf32> to tensor<192x?x64xf32> +// CHECK: %[[V_SLICE:.+]] = tensor.extract_slice %[[V]][0, %[[IV]], 0] [192, %[[MIN]], 64] [1, 1, 1] : tensor<192x?x64xf32> to tensor<192x?x64xf32> + +// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT_ITER]][0, 0, 0, 0] [192, %[[M]], 64, %[[MIN]]] [1, 1, 1, 1] : tensor<192x?x64x32xf32> to tensor<192x?x64x?xf32> +// CHECK: %[[MAX_SLICE:.+]] = tensor.extract_slice %[[MAX_ITER]][0, 0, 0] [192, %[[M]], %[[MIN]]] [1, 1, 1] : tensor<192x?x32xf32> to tensor<192x?x?xf32> +// CHECK: %[[SUM_SLICE:.+]] = tensor.extract_slice %[[SUM_ITER]][0, 0, 0] [192, %[[M]], %[[MIN]]] [1, 1, 1] : tensor<192x?x32xf32> to tensor<192x?x?xf32> + +// CHECK: %[[OATT:.+]]:3 = iree_linalg_ext.online_attention +// CHECK-SAME: ins(%[[Q_SLICE]], %[[K_SLICE]], %[[V_SLICE]] +// CHECK-SAME: outs(%[[OUT_SLICE]], %[[MAX_SLICE]], %[[SUM_SLICE]] + +// CHECK: %[[OUT_NEXT:.+]] = tensor.insert_slice %[[OATT]]#0 into %[[OUT_ITER]][0, 0, 0, 0] [192, %[[M]], 64, %[[MIN]]] [1, 1, 1, 1] : tensor<192x?x64x?xf32> into tensor<192x?x64x32xf32> +// CHECK: %[[MAX_NEXT:.+]] = tensor.insert_slice %[[OATT]]#1 into %[[MAX_ITER]][0, 0, 0] [192, %[[M]], %[[MIN]]] [1, 1, 1] : tensor<192x?x?xf32> into tensor<192x?x32xf32> +// CHECK: %[[SUM_NEXT:.+]] = tensor.insert_slice %[[OATT]]#2 into %[[SUM_ITER]][0, 0, 0] [192, %[[M]], %[[MIN]]] [1, 1, 1] : tensor<192x?x?xf32> into tensor<192x?x32xf32> + +// CHECK: scf.yield %[[OUT_NEXT]], %[[MAX_NEXT]], %[[SUM_NEXT]] + +// CHECK: %[[MAX_RED:.+]] = linalg.reduce ins(%[[ITER]]#1 +// CHECK-SAME: dimensions = [2] +// CHECK: arith.maximumf +// CHECK: linalg.yield + +// CHECK: %[[NORM:.+]] = linalg.generic +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield + +// CHECK: %[[NORM_SUM:.+]] = linalg.generic +// CHECK-SAME: ins(%[[NORM]] +// CHECK-SAME: outs(%[[ITER]]#2 +// CHECK: arith.mulf +// CHECK: linalg.yield + +// CHECK: %[[SUM_RED:.+]] = linalg.reduce ins(%[[NORM_SUM]] +// CHECK-SAME: dimensions = [2] +// CHECK: arith.addf +// CHECK: linalg.yield + +// CHECK: %[[NORM_ACC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[NORM]] +// CHECK-SAME: outs(%[[ITER]]#0 +// CHECK: arith.mulf +// CHECK: linalg.yield + +// CHECK: %[[ACC_RED:.+]] = linalg.reduce ins(%[[NORM_ACC]] +// CHECK-SAME: dimensions = [3] +// CHECK: arith.addf +// CHECK: linalg.yield + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + %fill_op:3, %split, %merge:3, %forop = transform.structured.tile_reduction_using_for %0 by tile_sizes = [0, 0, 0, 32, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + // ----- #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>