Skip to content

Commit

Permalink
try to favor vec_min and vec_max
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Jul 12, 2024
1 parent c131337 commit ef26878
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,35 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
// Make the compares with min/max dlfloat16 (in fp32 format).
Value vecF32LeMinH[U], vecF32LeMinL[U];
Value vecF32GeMaxH[U], vecF32GeMaxL[U];
#if 1
// This solution aims to use a pattern that can be more easily
// picked up as vec_min and vec_max.

// Compare for mins.
for (int64_t u = 0; u < U; ++u) {
vecF32LeMinH[u] = create.math.le(vecF32H[u], vecDlf16Min);
vecF32LeMinL[u] = create.math.le(vecF32L[u], vecDlf16Min);
}
// Select for mins.
for (int64_t u = 0; u < U; ++u) {
vecF32H[u] = create.math.select(
vecF32LeMinH[u], vecDlf16Min, vecF32H[u]);
vecF32L[u] = create.math.select(
vecF32LeMinL[u], vecDlf16Min, vecF32L[u]);
}
// Compare for maxes (using the previous results).
for (int64_t u = 0; u < U; ++u) {
vecF32GeMaxH[u] = create.math.ge(vecF32H[u], vecDlf16Max);
vecF32GeMaxL[u] = create.math.ge(vecF32L[u], vecDlf16Max);
}
// Select for maxes.
for (int64_t u = 0; u < U; ++u) {
vecF32H[u] = create.math.select(
vecF32GeMaxH[u], vecDlf16Max, vecF32H[u]);
vecF32L[u] = create.math.select(
vecF32GeMaxL[u], vecDlf16Max, vecF32L[u]);
}
#else
for (int64_t u = 0; u < U; ++u) {
vecF32LeMinH[u] = create.math.le(vecF32H[u], vecDlf16Min);
vecF32LeMinL[u] = create.math.le(vecF32L[u], vecDlf16Min);
Expand All @@ -469,6 +498,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
vecF32L[u] = create.math.select(
vecF32GeMaxL[u], vecDlf16Max, vecF32L[u]);
}
#endif
} // End saturation special case.
// Convert f32 to dlfloat16.
for (int64_t u = 0; u < U; ++u) {
Expand Down

0 comments on commit ef26878

Please sign in to comment.