Skip to content

Commit

Permalink
[MLIR][TORCH] Fix OnnxToLinalg lowering issue for sub and sum op
Browse files Browse the repository at this point in the history
This commit adds the support for scalar conversion to byte.
This commit also fixes the OnnxToLinalg lowering issue for
Onnx.Sub and Onnx.Sum op.
Fixes nod-ai/SHARK-ModelDev#466
Fixes nod-ai/SHARK-ModelDev#467

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 committed Feb 29, 2024
1 parent 76b81e0 commit 9fe3986
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 14 deletions.
3 changes: 2 additions & 1 deletion include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt,
std::optional<Type> dstOriginalDtype = std::nullopt);
std::optional<Type> dstOriginalDtype = std::nullopt,
std::optional<Value> originalScalar = std::nullopt);

Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Expand Down
10 changes: 8 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
}
// When binder.op->getNumOperands() > 2
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
binder.op->getContext());
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, valList[0], valList[1], const1);
for (int i = 2; i < numOperands; i++) {
if (i == numOperands - 1) {
curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, curr, valList[i], const1);
} else {
SmallVector<int64_t> resultBroadcastShapeInt;
SmallVector<Value> resultBroadcastShapeValue;
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
valList[i], resultBroadcastShapeInt,
resultBroadcastShapeValue);
auto baseType = Torch::ValueTensorType::get(
binder.op->getContext(), resultBroadcastShapeInt,
resultType.getOptionalDtype());
curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), baseType, curr, valList[i], const1);
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
/*dstOriginalDtype=*/resultElementType);
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
/*dstOriginalDtype=*/resultElementType,
/*originalScalar=*/sub.getAlpha());
if (dtype.isa<mlir::FloatType>()) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled);
Expand Down
33 changes: 27 additions & 6 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,20 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
elementType, encoding);
}

static std::optional<int64_t> getIntegerValue(Value scalar) {
if (auto constOp = scalar.getDefiningOp<Torch::ConstantIntOp>()) {
return std::optional<int64_t>(constOp.getValue());
}
return std::optional<int64_t>();
}

// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype,
std::optional<Type> dstOriginalDtype) {
std::optional<Type> dstOriginalDtype,
std::optional<Value> originalScalar) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
Expand All @@ -262,7 +270,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return false;
};

// We don't support conversion to Byte dtype.
// We support conversion to Byte dtype only if the original scalar is an
// integer constant with value lying between 0 - 63.
if (isByteOrChar(dtype)) {
if (!dstOriginalDtype.has_value()) {
mlir::emitError(loc)
Expand All @@ -271,10 +280,22 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return nullptr;
}
if (dstOriginalDtype->isUnsignedInteger()) {
mlir::emitError(loc)
<< "unsupported: conversion to byte type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
if (originalScalar.has_value()) {
std::optional<int64_t> optConstVal =
getIntegerValue(originalScalar.value());
if (optConstVal.has_value()) {
int64_t constVal = optConstVal.value();
if (constVal < 0 || constVal > 63) {
// Do the conversion only if the original integer value is between
// 0 - 63.
mlir::emitError(loc)
<< "unsupported: conversion to byte type for "
"convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
}
}
}
}

Expand Down
Loading

0 comments on commit 9fe3986

Please sign in to comment.