Skip to content

Commit

Permalink
[MLIR][TORCH] Fix Onnx.ReduceSum lowering for failing e2e tests (llvm…
Browse files Browse the repository at this point in the history
…#3095)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 authored Apr 3, 2024
1 parent f97cd48 commit ce7d4f1
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 137 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Value createConstantIntList(OpBinder binder,

Type getQTorchTypeFromTorchIntType(Type ty);

bool areAllElementsDistinct(SmallVector<int64_t> array);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
179 changes: 99 additions & 80 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,87 +759,118 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});
patterns.onOp(
"ReduceSum", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
"ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
Value axes;
int64_t keepDims;
int64_t noop_with_empty_axes;
if (binder.tensorOperands(data, axes) ||
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
0))
return failure();
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<Value> dimList;
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = axesType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
auto sizes =
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
// Deal with case when axes is empty
if (sizes.size() == 1 && sizes[0] == 0) {
if (noop_with_empty_axes == 0) {
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(

SmallVector<Value> axesList;

Value axesVal;
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
if (!inputType.hasSizes() || !resultType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: expected input and result to have shapes");
}

// If the input shape and result shape is statically known then the
// list of dims to be squeezed can be derived from those shapes. As a
// result, we don't have to wait for the dim values to be known at
// runtime which is also expected by the downstream pipeline.
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
SmallVector<int64_t> inputShape{inputType.getSizes()};
SmallVector<int64_t> resultShape{resultType.getSizes()};
if (llvm::equal(inputShape, resultShape)) {
// Case: none of the dimension is reduced.
rewriter.replaceOp(binder.op, data);
return success();
}
if (areAllElementsDistinct(inputShape)) {
// The check for the input shape elements to be distinct is added
// for the cases like:
// Input: [3, 2, 2] -> Output: [3, 2]
// For the above case, from the input and output shape it can't be
// inferred whether the dim:1 is reduced or dim:2. To avoid these
// type of cases, the check has been placed.
SmallVector<int64_t> reduceDims;
unsigned resultShapeCounter = 0;
for (unsigned i = 0; i < inputShape.size(); i++) {
if (resultShapeCounter < resultShape.size() &&
inputShape[i] == resultShape[resultShapeCounter]) {
resultShapeCounter++;
} else {
reduceDims.push_back(i);
if (resultShapeCounter < resultShape.size() &&
resultShape[resultShapeCounter] == 1)
resultShapeCounter++;
}
}
for (auto i : reduceDims) {
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
}
}

if (axesList.empty()) {
Torch::BaseTensorType axesType =
axesVal.getType().cast<Torch::BaseTensorType>();
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
auto axesShape = axesTy.getSizes();
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
return failure();

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), keepDimsConstInt);
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
binder.op, resultType, data, /*dim=*/noneVal,
/*keepdim=*/keepDimsBool, /*dtype=*/noneVal);
} else {
rewriter.replaceOp(binder.op, data);
rewriter.getI64IntegerAttr(0));
SmallVector<int64_t> selectSizes{1};
auto selType = rewriter.getType<Torch::ValueTensorType>(
selectSizes, axesType.getOptionalDtype());
int64_t numAxes = axesShape[0];
for (int64_t i = 0; i < numAxes; ++i) {
Value iv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selType, axesVal, zero, iv);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
axesList.push_back(dim);
}
}
return success();
}
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
int64_t adjustmentInt =
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adjustmentInt));
// convert axes (tensor) into torch int list while dealing with neg axis
for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
// deal with neg axis: if (axis < 0) axis += rank
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, adjustment);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), dim, finalOffset);
dimList.push_back(finalDim);

SmallVector<int64_t> axesInts;
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
Value iv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(axesInts[i]));
axesList.push_back(iv);
}
}

// deal with case when axes is empty
if (axesList.empty() && noop_with_empty_axes) {
rewriter.replaceOp(binder.op, data);
return success();
}

Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
dimList);
Value keepDimBool;
if (keepDims == 1) {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
} else {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
}
axesList);
Value keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
binder.op, resultType, data, dimValueList, keepDimBool,
/*dtype=*/noneVal);
Expand Down Expand Up @@ -869,18 +900,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"unimplemented: expected input and result to have shapes");
}

auto areDistinct = ([](SmallVector<int64_t> array) -> bool {
int n = array.size();
llvm::SetVector<int64_t> set;
for (int i = 0; i < n; i++) {
set.insert(array[i]);
}

// If all elements are distinct, then the size of set should be same
// as array's size.
return (set.size() == array.size());
});

// If the input shape and result shape is statically known then the
// list of dims to be squeezed can be derived from those shapes. As a
// result, we don't have to wait for the dim values to be known at
Expand All @@ -893,7 +912,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.replaceOp(binder.op, data);
return success();
}
if (areDistinct(inputShape)) {
if (areAllElementsDistinct(inputShape)) {
// The check for the input shape elements to be distinct is added
// for the cases like:
// Input: [3, 2, 2] -> Output: [3, 2]
Expand Down
12 changes: 12 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,15 @@ Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
return nullptr;
return Torch::ValueTensorType::get(ctx, tty.getOptionalSizes(), dty);
}

bool mlir::torch::onnx_c::areAllElementsDistinct(SmallVector<int64_t> array) {
int n = array.size();
llvm::SetVector<int64_t> set;
for (int i = 0; i < n; i++) {
set.insert(array[i]);
}

// If all elements are distinct, then the size of set should be same
// as array's size.
return (set.size() == array.size());
}
10 changes: 0 additions & 10 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,7 +1983,6 @@
"ReduceL2NormComplexModule_basic",

# Failure - onnx_lowering: onnx.ReduceL3
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ReduceL3NormKeepDimComplexModule_basic",

Expand All @@ -1998,15 +1997,6 @@
"StdCorrectionLargeInputModule_basic",
"VarCorrectionLargeInputModule_basic",

# Failure - onnx_lowering: onnx.ReduceSum
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",

# Failure - onnx_lowering: onnx.Resize
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticSize_basic",
Expand Down
Loading

0 comments on commit ce7d4f1

Please sign in to comment.