Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saturation in compiler generated Stickify #2877

Merged
merged 11 commits into from
Jul 16, 2024
68 changes: 51 additions & 17 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
Expand All @@ -43,6 +44,9 @@
#define PREFETCH_CSU_DIST 0
#define PREFETCH_CSU 1

// TODO, integrate.
#define SATURATION_ON 0

using namespace mlir;

namespace onnx_mlir {
Expand Down Expand Up @@ -71,14 +75,14 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
layout.getValue().equals_insensitive("3D") ||
layout.getValue().equals_insensitive("2D") ||
layout.getValue().equals_insensitive("3DS")) {
return generateUnstickCodeNoBuffer(rewriter, unstickOp, layout);
return generateUnstickCodeNoBuffer(rewriter, unstickOp);
}
// Otherwise, we don't replace and keep the zdnn call.
return failure();
}

LogicalResult generateUnstickCodeNoBuffer(PatternRewriter &rewriter,
ZLowUnstickOp unstickOp, StringAttr layout) const {
LogicalResult generateUnstickCodeNoBuffer(
PatternRewriter &rewriter, ZLowUnstickOp unstickOp) const {
Operation *op = unstickOp.getOperation();
Location loc = unstickOp.getLoc();
MDBuilder create(rewriter, loc);
Expand Down Expand Up @@ -187,7 +191,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
// Then (is full).
[&](SCFBuilder b) {
MDBuilder create(b);
// Loop
// Loop (tried unroll of 2 and 8, 4 was best).
const int64_t U = 4;
assert(U * VL <= 64 && "bad unroll");
create.scf.forLoop(litZero.getValue(), lit64.getValue(), U * VL,
Expand Down Expand Up @@ -309,15 +313,15 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
layout.getValue().equals_insensitive("3D") ||
layout.getValue().equals_insensitive("2D") ||
layout.getValue().equals_insensitive("3DS")) {
return generateStickCodeNoBuffer(rewriter, stickOp, layout);
return generateStickCodeNoBuffer(rewriter, stickOp);
}
// Otherwise, we don't replace and keep the zdnn call.
return failure();
}

/* Version without buffer, more like zdnn */
LogicalResult generateStickCodeNoBuffer(
PatternRewriter &rewriter, ZLowStickOp stickOp, StringAttr layout) const {
PatternRewriter &rewriter, ZLowStickOp stickOp) const {
Operation *op = stickOp.getOperation();
Location loc = stickOp.getLoc();
MDBuilder create(rewriter, loc);
Expand All @@ -327,6 +331,12 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
Value input = stickOp.getX();
Value alloc = stickOp.getOut();

bool saturation = false;
#if SATURATION_ON
// TODO: hook to operation's attribute.
saturation = true;
#endif

DimsExpr outputDims;
create.krnlIE.getShapeAsSymbols(alloc, outputDims);
int64_t rank = outputDims.size();
Expand All @@ -344,6 +354,15 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
IndexExpr litVLHalf = LiteralIndexExpr(VLHalf);
IndexExpr lit64 = LiteralIndexExpr(64);

// Values for saturation.
Value vecDlf16Min, vecDlf16Max;
if (saturation) {
Value dlf16Min = create.math.constant(f32Type, DLF16_MIN);
vecDlf16Min = create.vec.splat(vecF32Type, dlf16Min);
Value dlf16Max = create.math.constant(f32Type, DLF16_MAX);
vecDlf16Max = create.vec.splat(vecF32Type, dlf16Max);
}

// Useful references for indexing dimensions (neg val are not used).
int64_t E1 = rank - 1;

Expand Down Expand Up @@ -406,7 +425,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
#endif
#endif

const int64_t U = 4;
const int64_t U = 4; // Tried 2 and 8, 4 was best.
assert(U * VL <= 64 && "bad unroll");
create.affine.forIE(litZero, lit64, U * VL,
[&](AffineBuilder &b, ValueRange loopInd) {
Expand All @@ -417,21 +436,36 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
getIndexExprList<SymbolIndexExpr>(memAF, inputAF);
// E1: add the "l" local E1 offset.
inputAF[E1] = inputAF[E1] + l;
// Load the f32.
Value vecF32H[U], vecF32L[U], vecF16[U];
for (int64_t i = 0; i < U; ++i) {
LiteralIndexExpr iH(i * VL), iL(i * VL + VL / 2);
vecF32H[i] = create.vec.loadIE(
for (int64_t u = 0; u < U; ++u) {
LiteralIndexExpr iH(u * VL), iL(u * VL + VL / 2);
vecF32H[u] = create.vec.loadIE(
vecF32Type, input, inputAF, {iH.getValue()});
vecF32L[i] = create.vec.loadIE(
vecF32L[u] = create.vec.loadIE(
vecF32Type, input, inputAF, {iL.getValue()});
}
for (int64_t i = 0; i < U; ++i) {
vecF16[i] = rewriter.create<ZLowConvertF32ToDLF16VectorOp>(
loc, vecF32H[i], vecF32L[i]);
if (saturation) {
// Get rid of too-high values.
for (int64_t u = 0; u < U; ++u) {
vecF32H[u] = create.math.min(vecF32H[u], vecDlf16Max);
vecF32L[u] = create.math.min(vecF32L[u], vecDlf16Max);
}
// Get rid of too-low values.
for (int64_t u = 0; u < U; ++u) {
vecF32H[u] = create.math.max(vecF32H[u], vecDlf16Min);
vecF32L[u] = create.math.max(vecF32L[u], vecDlf16Min);
}
}
// Convert f32 to dlfloat16.
for (int64_t u = 0; u < U; ++u) {
vecF16[u] = rewriter.create<ZLowConvertF32ToDLF16VectorOp>(
loc, vecF32H[u], vecF32L[u]);
}
for (int64_t i = 0; i < U; ++i) {
create.vec.storeIE(vecF16[i], allocAsTx64,
{SymIE(allocTileIndex), l + (i * VL)}, {});
// Store the dlfloat16.
for (int64_t u = 0; u < U; ++u) {
create.vec.storeIE(vecF16[u], allocAsTx64,
{SymIE(allocTileIndex), l + (u * VL)}, {});
}
});
});
Expand Down
6 changes: 3 additions & 3 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,9 @@ Value MathBuilder::neq(Value lhs, Value rhs) const {
llvm_unreachable("expected int or float");
}

Value MathBuilder::select(Value cmp, Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
return b().create<arith::SelectOp>(loc(), cmp, lhs, rhs);
Value MathBuilder::select(Value cmp, Value trueVal, Value falseVal) const {
assert(trueVal.getType() == falseVal.getType() && "expected same type");
return b().create<arith::SelectOp>(loc(), cmp, trueVal, falseVal);
}

Value MathBuilder::constant(Type type, double val) const {
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/Mlir/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ struct MathBuilder final : DialectBuilder {
mlir::Value tanh(mlir::Value val) const; // Float only.
mlir::Value xori(mlir::Value lhs, mlir::Value rhs) const; // Int only.

mlir::Value select(mlir::Value cmp, mlir::Value lhs, mlir::Value rhs) const;
mlir::Value select(
mlir::Value cmp, mlir::Value trueVal, mlir::Value valseVal) const;
mlir::Value gt(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value ge(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value lt(mlir::Value lhs, mlir::Value rhs) const;
Expand Down
Loading