Skip to content

Commit

Permalink
Align softmax accumulation types with Torch's CUDA implementation (ll…
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah authored Mar 12, 2024
1 parent ad6159c commit 5ecc1d5
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
9 changes: 9 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA,
int64_t dimB, Type &transposedType);

// Approximates the heuristic in the torch `acc_type` template for kernels
// that are defined in terms of it. For now, this just returns accumulators
// as if for CUDA from that implementation. In the future, this could be
// extended to look at hints on the `forOp` or its container to better
// control the behavior. Such support would be done in coordination with
// the fx_importer and APIs, which could add hints to the IR (based on
// Torch flags, user options, etc).
Type getDefaultAccType(PatternRewriter &rewriter, Type inputType);

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
31 changes: 24 additions & 7 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1237,22 +1237,32 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
PatternRewriter &rewriter) {
Type accumulatorType, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.getDim();
if (resultType != accumulatorType)
self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
Value xMax =
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);

if (!xMax)
return nullptr;
Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax);
Value unNormalized =
createTensorSub(rewriter, loc, self.getType(), self, xMax);
Value unNormalizedExp =
rewriter.create<AtenExpOp>(loc, resultType, unNormalized);
rewriter.create<AtenExpOp>(loc, self.getType(), unNormalized);
Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim,
/*keepDim=*/true);
if (!sum)
return nullptr;
return rewriter.create<AtenDivTensorOp>(loc, resultType, unNormalizedExp,
sum);

Value result = rewriter.create<AtenDivTensorOp>(loc, self.getType(),
unNormalizedExp, sum);
if (resultType != accumulatorType)
result = convertTensorToDtype(rewriter, loc, result,
resultType.cast<BaseTensorType>().getDtype());

return result;
}

// Decompose softmax into: exp(x) / sum(exp(x))
Expand Down Expand Up @@ -1284,7 +1294,10 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
}

Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);

Value result = getSoftmaxResult(op, self, resultTensorType,
accumulatorTensorType, rewriter);
if (!result)
return failure();
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
Expand Down Expand Up @@ -1329,7 +1342,11 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
}
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);

Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);

Value result = getSoftmaxResult(op, self, resultTensorType,
accumulatorTensorType, rewriter);
if (!result)
return op.emitError("failed to get softmax result");
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,33 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
inType.getOptionalDtype());
return success();
}

Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
if (inputType.isF16())
return rewriter.getF32Type();
if (inputType.isBF16())
return rewriter.getF32Type();
if (inputType.isa<Float32Type>())
return rewriter.getF32Type();
if (inputType.isa<Float64Type>())
return rewriter.getF64Type();
if (inputType.isFloat8E5M2())
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FN())
return rewriter.getF32Type();
if (inputType.isFloat8E5M2FNUZ())
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FNUZ())
return rewriter.getF32Type();
if (inputType.isSignedInteger(8))
return rewriter.getI64Type();
if (inputType.isUnsignedInteger(8))
return rewriter.getI64Type();
if (inputType.isSignedInteger(16))
return rewriter.getI64Type();
if (inputType.isSignedInteger(32))
return rewriter.getI64Type();
if (inputType.isSignedInteger(64))
return rewriter.getI64Type();
llvm::report_fatal_error("unhandled type for getDefaultAccType");
}

0 comments on commit 5ecc1d5

Please sign in to comment.