Skip to content

Commit

Permalink
Cleanup of krnl iterate loops (#2953)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
Co-authored-by: Tung D. Le <tung@jp.ibm.com>
  • Loading branch information
AlexandreEichenberger and tungld authored Sep 25, 2024
1 parent bb179d7 commit 40b607d
Show file tree
Hide file tree
Showing 65 changed files with 291 additions and 316 deletions.
10 changes: 5 additions & 5 deletions docs/LoweringCode.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct KrnlBuilder : public DialectBuilder {

void iterate(ValueRange originalLoops, ValueRange optimizedLoops,
ValueRange lbs, ValueRange ubs,
function_ref<void(KrnlBuilder &createKrnl, ValueRange indices)>
function_ref<void(const KrnlBuilder &createKrnl, ValueRange indices)>
bodyBuilderFn);
};
```
Expand All @@ -128,7 +128,7 @@ ValueRange loopDef = createKrnl.defineLoops(2);
// Create the loop.
createKrnl.iterate(loopDef, loopDef, {zero, zero}, {ub0, ub1},
[&](KrnlBuilder &createKrnl, ValueRange loopInd){
[&](const KrnlBuilder &createKrnl, ValueRange loopInd){
// Loop body.
createKrnl.store(zero, array, loopInd);
});
Expand Down Expand Up @@ -183,7 +183,7 @@ ValueRange loopBlockDef = createKrnl.block(loopDef, 4);
createKrnl.permute({loopBlockDef[0], loopBlockDef[1], {0,1});
// Create the loop iterating over the blocks.
createKrnl.iterate(loopDef, {loopBlockDef[0], loopBlockDef[0]}, {zero}, {ub0},
[&](KrnlBuilder &createKrnl, ValueRange blockLoopInd){
[&](const KrnlBuilder &createKrnl, ValueRange blockLoopInd){
// Loop body.
createKrnl.store(zero, array, loopInd);
});
Expand All @@ -209,10 +209,10 @@ We now consider tiling our original 2-dimensional example below.
// Create the loop iterating over the blocks.
createKrnl.iterate(loopDef, {outerLoopBlockDef[0], innerLoopBlockDef[0]},
{zero, zero}, {ub0, ub1},
[&](KrnlBuilder &createKrnl, ValueRange blockLoopInd){
[&](const KrnlBuilder &createKrnl, ValueRange blockLoopInd){
// Create the loop iterating inside the blocks.
createKrnl.iterate({}, {outerLoopBlockDef[1], innerLoopBlockDef[1]},
{}, {}, [&](KrnlBuilder &createKrnl, ValueRange loopInd) {
{}, {}, [&](const KrnlBuilder &createKrnl, ValueRange loopInd) {
// Loop body.
createKrnl.store(zero, array, loopInd);
});
Expand Down
23 changes: 10 additions & 13 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern {
Value iZero = create.math.constantIndex(0);
ValueRange batchLoop = create.krnl.defineLoops(1);
create.krnl.iterate(batchLoop, batchLoop, {iZero}, {create.mem.dim(Y, 2)},
[&](KrnlBuilder &createKrnl, ValueRange batchIndices) {
[&](const KrnlBuilder &createKrnl, ValueRange batchIndices) {
MathBuilder createMath(createKrnl);
IndexExprScope ieScope(createKrnl);
Value bs = batchIndices[0];
Expand All @@ -1338,7 +1338,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&regionOp.getBodyRegion().front());
ValueRange loops = create.krnl.defineLoops(yRank - 1);
create.krnl.iterate(loops, loops, yLbs, yUbs,
[&](KrnlBuilder &createKrnl, ValueRange indices) {
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
Value sequenceIV(indices[0]);
Value directionIV(indices[1]);
Value hs(indices[2]);
Expand Down Expand Up @@ -1366,7 +1366,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern {

ValueRange loops = create.krnl.defineLoops(yRank);
create.krnl.iterate(loops, loops, yLbs, yUbs,
[&](KrnlBuilder &createKrnl, ValueRange indices) {
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
MathBuilder createMath(createKrnl);
IndexExprScope ieScope(createKrnl);
Value sequenceIV(indices[0]);
Expand Down Expand Up @@ -1435,7 +1435,7 @@ struct ZHighToZLowFixGRUYhOpLowering : public ConversionPattern {
Value seqSize = create.mem.dim(Y, 0);
ValueRange loops = create.krnl.defineLoops(htRank);
create.krnl.iterate(loops, loops, htLbs, htUbs,
[&](KrnlBuilder &createKrnl, ValueRange indices) {
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
MathBuilder createMath(createKrnl);
IndexExprScope ieScope(createKrnl);
Value bs(indices[1]), hs(indices[2]);
Expand Down Expand Up @@ -1612,7 +1612,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering
SmallVector<IndexExpr, 4> lbs(rank, LitIE(0));
SmallVector<IndexExpr, 4> ubs = shapeHelper.getOutputDims();
create.krnl.iterateIE(loopDef, loopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange indices) {
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
// Keep this load inside the loop to tweak LLVM.
Value valueF16 = createKrnl.load(memrefF16);
createKrnl.store(valueF16, res, indices);
Expand Down Expand Up @@ -1701,13 +1701,10 @@ struct ZHighToZLowDataConversionLowering
SmallVector<IndexExpr, 4> flattenedOutputDims;
Value flatOutput = create.mem.reshapeToFlatInnermost(
alloc, outputDims, flattenedOutputDims, collapsedInnermostLoops);
DimsExpr lbs(1, LitIE(0));

// Create loop iteration (flattened to 1D) and block it by totVL.
ValueRange loopDef = create.krnl.defineLoops(1);
ValueRange blockedLoopDef = create.krnl.block(loopDef[0], totVL);
SmallVector<Value, 1> optimizedLoopDef(1, blockedLoopDef[0]);

DimsExpr lbs = {LitIE(0)};
bool useParallel = false;
if (enableParallel) {
int64_t parId;
int64_t tripCount = flattenedOutputDims[0].isLiteral()
Expand All @@ -1716,7 +1713,7 @@ struct ZHighToZLowDataConversionLowering
: -1;
if (findSuitableParallelDimension(lbs, flattenedOutputDims, 0, 1, parId,
/*min iter for going parallel*/ 1024)) {
create.krnl.parallel(blockedLoopDef[0]);
useParallel = true;
onnxToKrnlParallelReport(op, /*successful*/ true, 0, tripCount,
"dlf16-f32 conversion fully parallelized");
} else {
Expand All @@ -1729,8 +1726,8 @@ struct ZHighToZLowDataConversionLowering
: -1,
"dlf16-f32 conversion fully flattened");

create.krnl.iterateIE(loopDef, optimizedLoopDef, lbs, flattenedOutputDims,
[&](KrnlBuilder &b, ValueRange loopInd) {
create.krnl.forLoopIE(lbs[0], flattenedOutputDims[0], totVL, useParallel,
[&](const KrnlBuilder &b, ValueRange loopInd) {
MDBuilder create(b);
// Manually unrolled loop, add archVL offset at each iterations.
for (int64_t u = 0; u < unrollVL; ++u) {
Expand Down
14 changes: 7 additions & 7 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims);

// Outer loop (E4, E3, E2, E1 iterates over tiles of 64 elements)
create.krnl.iterateIE(
loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) {
create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs,
[&](const KrnlBuilder &b, ValueRange loopInd) {
MDBuilder create(b);
IndexExprScope outerScope(create.krnl, &allocScope);
DimsExpr outerIndices = DimListIE(loopInd);
Expand Down Expand Up @@ -192,14 +192,14 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
// Condition
isFullLogical.getValue(),
// Then (is full).
[&](SCFBuilder b) {
[&](const SCFBuilder b) {
MDBuilder create(b);
// Loop (tried unroll of 2 and 8, 4 was best).
const int64_t unrollVL = 4;
const int64_t totVL = unrollVL * archVL;
assert(totVL <= 64 && "bad unroll");
create.scf.forLoop(litZero.getValue(), lit64.getValue(), totVL,
[&](SCFBuilder b, ValueRange loopInd) {
[&](const SCFBuilder b, ValueRange loopInd) {
MDBuilder create(b);
IndexExprScope innerScope(b, &outerScope);
Value loopIndex = loopInd[0];
Expand Down Expand Up @@ -430,8 +430,8 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
create.mem.reinterpretCast(alloc, litZero.getValue(), reallocTileDims);

// Outer loop (E1 iterates over tiles of 64 elements).
create.krnl.iterateIE(
loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) {
create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs,
[&](const KrnlBuilder &b, ValueRange loopInd) {
MDBuilder create(b);
IndexExprScope outerScope(create.krnl, &allocScope);
DimsExpr outerIndices;
Expand All @@ -458,7 +458,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
#endif

create.affine.forLoopIE(litZero, simdLoopUB, totVL,
[&](AffineBuilder &b, ValueRange loopInd) {
[&](const AffineBuilder &b, ValueRange loopInd) {
MDBuilder create(b);
DimsExpr inputAF;
IndexExprScope innerScope(create.krnl, &outerScope);
Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern {
return success();
}

void genCopyLoops(AffineBuilderKrnlMem &createAffine,
void genCopyLoops(const AffineBuilderKrnlMem &createAffine,
IndexExprScope *enclosingScope, Value buffMemref, Value destMemref,
IndexExpr zeroIE, SmallVectorImpl<IndexExpr> &starts,
SmallVectorImpl<IndexExpr> &writeUBs, SmallVectorImpl<Value> &loopIndices,
Expand Down Expand Up @@ -125,7 +125,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern {
} else {
// Loop to copy the data.
createAffine.forLoopIE(zeroIE, writeUBs[i], 1, false /*parallel*/,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
loopIndices.emplace_back(loopInd[0]);
genCopyLoops(createAffine, enclosingScope, buffMemref, destMemref,
zeroIE, starts, writeUBs, loopIndices, i + 1, buffRank);
Expand Down
6 changes: 3 additions & 3 deletions src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern {
return success();
}

void genCopyLoops(AffineBuilderKrnlMem &createAffine,
void genCopyLoops(const AffineBuilderKrnlMem &createAffine,
IndexExprScope *enclosingScope, Value buffMemref, Value sourceMemref,
SmallVectorImpl<int64_t> &srcLoopMap, Value padVal, IndexExpr zeroIE,
SmallVectorImpl<IndexExpr> &starts, SmallVectorImpl<IndexExpr> &readUBs,
Expand Down Expand Up @@ -169,7 +169,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern {
// Nothing to read, skip.
} else {
createAffine.forLoopIE(zeroIE, readUBs[i], 1,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
loopIndices.emplace_back(loopInd[0]);
genCopyLoops(createAffine, enclosingScope, buffMemref,
sourceMemref, srcLoopMap, padVal, zeroIE, starts, readUBs,
Expand All @@ -182,7 +182,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern {
// No padding needed.
} else {
createAffine.forLoopIE(readUBs[i], padUBs[i], 1,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
loopIndices.emplace_back(loopInd[0]);
genCopyLoops(createAffine, enclosingScope, buffMemref,
sourceMemref, srcLoopMap, padVal, zeroIE, starts, readUBs,
Expand Down
38 changes: 21 additions & 17 deletions src/Conversion/KrnlToAffine/KrnlMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,30 +223,32 @@ class KrnlMatmulLowering : public ConversionPattern {
if (matVectorProduct) {
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) {
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatVect(createAffine, matmulOp, elementType, aStart, bStart,
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
}, /* else has partial tiles */ [&](AffineBuilderKrnlMem &createAffine) {
}, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
iTrip, jTrip, kTrip, /*unroll*/ false);
});
// clang-format on
} else {
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) {
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart,
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
}, /* has some partial tiles */ [&](AffineBuilderKrnlMem &createAffine) {
},
/* Else has some partial tiles */
[&](const AffineBuilderKrnlMem &createAffine) {
// Trip regardless of full/partial for N & K
// Test if SIMD dim (M) is full.
createAffine.ifThenElseIE(indexScope, jFullTiles,
/* full SIMD */ [&](AffineBuilderKrnlMem &createAffine) {
/* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart,
cStart, iTrip, jComputeTileSize, kTrip, vectorLen, /*unroll*/ false);
}, /* else partial SIMD */ [&](AffineBuilderKrnlMem &createAffine) {
}, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
// TODO: evaluate if get performance from partial SIMD
if (false && jPartialTrip.isLiteral() && jPartialTrip.getLiteral() >=2) {
// has a known trip count along the simd dimension of at least 2
Expand All @@ -265,11 +267,11 @@ class KrnlMatmulLowering : public ConversionPattern {
// Scalar code generator.
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full */ [&](AffineBuilderKrnlMem &createAffine) {
/* then full */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
iComputeTileSize, jComputeTileSize, kComputeTileSize,
fullUnrollAndJam);
}, /* else partial */ [&](AffineBuilderKrnlMem &createAffine) {
}, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
iTrip, jTrip, kTrip, false);
});
Expand All @@ -280,7 +282,7 @@ class KrnlMatmulLowering : public ConversionPattern {
}

private:
void genScalar(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
bool unrollJam) const {
Expand All @@ -300,10 +302,11 @@ class KrnlMatmulLowering : public ConversionPattern {
LiteralIndexExpr zeroIE(0);
Value jSaved;
createAffine.forLoopIE(zeroIE, I, 1,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
Value i = loopInd[0];
createAffine.forLoopIE(zeroIE, J, 1,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine,
ValueRange loopInd) {
MathBuilder createMath(createAffine);
Value j = loopInd[0];
// Defines induction variables, and possibly initialize C.
Expand All @@ -315,7 +318,7 @@ class KrnlMatmulLowering : public ConversionPattern {
createAffine.store(initVal, TmpC, tmpCAccess);
// Sum over k.
createAffine.forLoopIE(zeroIE, K, 1,
[&](AffineBuilderKrnlMem &createAffine,
[&](const AffineBuilderKrnlMem &createAffine,
ValueRange loopInd) {
MathBuilder createMath(createAffine);
Value k = loopInd[0];
Expand All @@ -340,7 +343,7 @@ class KrnlMatmulLowering : public ConversionPattern {
}

// Initially, simdize with full K vector length.
void genSimdMatVect(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
void genSimdMatVect(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
Expand Down Expand Up @@ -384,7 +387,7 @@ class KrnlMatmulLowering : public ConversionPattern {
Value iZero = create.math.constantIndex(0);

create.affineKMem.forLoopIE(zeroIE, K, VL,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
MultiDialectBuilder<MathBuilder, VectorBuilder> create(createAffine);
Value k = loopInd[0];
// Iterates over the I indices (K is SIMD dim).
Expand Down Expand Up @@ -431,7 +434,7 @@ class KrnlMatmulLowering : public ConversionPattern {
}

// Simdize along J / memory rows in B and C.
void genSimdMatMat(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
void genSimdMatMat(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
Expand Down Expand Up @@ -466,7 +469,7 @@ class KrnlMatmulLowering : public ConversionPattern {
Value iZero = create.math.constantIndex(0);

createAffine.forLoopIE(zeroIE, I, 1,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
MultiDialectBuilder<MathBuilder, VectorBuilder> create(createAffine);
Value i = loopInd[0];
iSaved = i; // Saved for unroll and jam.
Expand All @@ -476,7 +479,8 @@ class KrnlMatmulLowering : public ConversionPattern {
createAffine.store(initVal, TmpC, tmpCAccess);
// Sum over k.
createAffine.forLoopIE(zeroIE, K, 1,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
[&](const AffineBuilderKrnlMem &createAffine,
ValueRange loopInd) {
MultiDialectBuilder<MathBuilder, VectorBuilder> create(
createAffine);
Value k = loopInd[0];
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/KrnlToAffine/KrnlMemset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class KrnlMemsetLowering : public ConversionPattern {
SmallVector<int64_t, 4> steps(rank, 1);
// Copy data,
create.affineKMem.forLoopsIE(lbs, ubs, steps,
[&](AffineBuilderKrnlMem &createAffine, ValueRange indices) {
[&](const AffineBuilderKrnlMem &createAffine, ValueRange indices) {
createAffine.store(destVal, destMemRef, indices);
});
rewriter.eraseOp(op);
Expand Down
Loading

0 comments on commit 40b607d

Please sign in to comment.