Skip to content

Commit

Permalink
[onnx] Fix ReduceMean lowering to torch (llvm#2956)
Browse files Browse the repository at this point in the history
Torch lowering only supported the most recent version. Refactored the
lowering so more easily handle default values and optional operands /
attributes.
  • Loading branch information
rsuderman authored Feb 28, 2024
1 parent d541779 commit 4a7a7d7
Show file tree
Hide file tree
Showing 6 changed files with 608 additions and 449 deletions.
228 changes: 122 additions & 106 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1104,129 +1104,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value axes;
int64_t keepDims;
int64_t noop_with_empty_axes;
// Deal with case when no axes arg is passed
if (binder.op->getNumOperands() == 1) {
if (binder.tensorOperand(data) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
if (noop_with_empty_axes == 0) {
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), keepDimsConstInt);
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
.getSizes()
.size();
SmallVector<Value> axesList;
for (int i = 0; i < numDims; i++) {
Value curr = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
axesList.push_back(curr);
}
Value axesValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
axesList);
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
binder.op, resultType, data, axesValueList, keepDimsBool);
} else {
rewriter.replaceOp(binder.op, data);
}
return success();
}
if (binder.tensorOperands(data, axes) ||
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
0))
return failure();
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<Value> dimList;
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = axesType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
auto sizes =
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
// deal with case when axes is empty
if (sizes.size() == 1 && sizes[0] == 0) {
if (noop_with_empty_axes == 0) {
// create dims list with all dims [0, data.getSizes().size())
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), keepDimsConstInt);
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
.getSizes()
.size();
for (int i = 0; i < numDims; i++) {
Value curr = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
dimList.push_back(curr);

auto dataTy = cast<Torch::BaseTensorType>(data.getType());
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();

// If any of the input dims are 0 we set to the upper limit:
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
(llvm::any_of(dataTy.getSizes(),
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
keepDims)) {
auto dty = dataTy.getDtype();
Value scalar;
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
scalar = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(),
inf.convertToDouble()));
}

if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
auto mx =
intTy.isSigned()
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
scalar = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
mx.getSExtValue()));
}

llvm::SmallVector<Value> fillDims;
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
auto staticDim = resultType.getSizes()[i];
if (staticDim != Torch::kUnknownSize) {
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getI64IntegerAttr(staticDim)));
continue;
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
dimList);
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
binder.op, resultType, data, dimValueList, keepDimsBool);
} else {
rewriter.replaceOp(binder.op, data);

Value iv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), torchIntTy, data, iv));
}

Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
binder.op, resultType, fillDimsList, scalar, none, none, none,
none);
return success();
}

// Previous version of the operation had the axes as an attribute:
SmallVector<Value> axesList;
llvm::SmallVector<int64_t> axesAttr;
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getI64IntegerAttr(axesAttr[i])));
}
}

// Extract the axes values from the axes operand:
if (!binder.tensorOperandAtIndex(axes, 1)) {
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<int64_t> selectSizes{1};
Type selectResultType = axesType.getWithSizesAndDtype(
selectSizes, axesType.getOptionalDtype());
auto sizes = axesType.getSizes();

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

// Extract the value of each axes:
for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
axesList.push_back(dim);
}
}

// Handle the noop case:
if (axesList.empty() && noop_with_empty_axes) {
rewriter.replaceOp(binder.op, data);
return success();
}

// Deal with case when no axes arg is passed but not a noop:
if (axesList.empty()) {
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
.getSizes()
.size();
for (int i = 0; i < numDims; i++) {
Value curr = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
axesList.push_back(curr);
}
}

// Handle negative axis:
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
torchIntTy, data);
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
int64_t adjustmentInt =
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adjustmentInt));
// convert axes (tensor) into torch int list while dealing with neg axis
for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
// deal with neg axis: if (axis < 0) axis += rank
rewriter.getI64IntegerAttr(0));
for (Value &axes : axesList) {
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, adjustment);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), dim, finalOffset);
dimList.push_back(finalDim);
binder.getLoc(), isNegative, rankVal);
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
finalOffset);
}

Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
dimList);
Value keepDimBool;
if (keepDims == 1) {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
} else {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
}
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
Value keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
binder.op, resultType, data, dimValueList, keepDimBool);
return success();
Expand Down
97 changes: 67 additions & 30 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,15 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {

Location loc = op.getLoc();
Value input = adaptor.getSelf();
RankedTensorType valResultType =
getTypeConverter()
->convertType(op.getResult(0).getType())
.template cast<RankedTensorType>();

RankedTensorType idxResultType =
this->getTypeConverter()
->convertType(op.getResult(1).getType())
.template cast<RankedTensorType>();
auto typec = this->getTypeConverter();
auto valResultType =
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
auto idxResultType =
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
RankedTensorType inputType =
input.getType().template cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType();
Type idxElementType =
getElementTypeOrSelf(typec->convertType(idxResultType));
if (!idxElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
op, opName + " to linalg.* requires integer-like result type");
Expand Down Expand Up @@ -109,14 +106,12 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
}

// Constant op to account for the reduction along dim.
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim != i) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
resultShape.push_back(currentDimSize);
} else if (keepDim)
resultShape.push_back(c1);
}
}
// First fill the output buffer for the index.
Value filledTensorIdx =
Expand Down Expand Up @@ -146,27 +141,23 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
Value filledTensorVal =
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result();

SmallVector<utils::IteratorType> iteratorTypes(
inputType.getRank(), utils::IteratorType::parallel);
iteratorTypes[dim] = utils::IteratorType::reduction;

// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.

SmallVector<AffineExpr> exprs;
SmallVector<utils::IteratorType> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size :
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
exprs.push_back(rewriter.getAffineDimExpr(size.index()));

if (unsigned(dim) == size.index()) {
iteratorTypes.push_back(utils::IteratorType::reduction);
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
iteratorTypes.push_back(utils::IteratorType::parallel);
if (unsigned(dim) != size.index())
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
}
}

auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs},
rewriter.getContext());
auto linalgOp = rewriter.create<linalg::GenericOp>(
Expand Down Expand Up @@ -219,12 +210,58 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
nestedLoc, ValueRange({resultVal, resultIndex}));
});

// This cast is required to fix the shape in the case of keepDim=True
Value valuesCast = rewriter.create<tensor::CastOp>(loc, valResultType,
linalgOp.getResult(0));
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1));
rewriter.replaceOp(op, {valuesCast, idxCast});
if (!keepDim) {
Value rVal = rewriter.create<tensor::CastOp>(loc, valResultType,
linalgOp.getResult(0));
Value rIdx = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1));
llvm::SmallVector<Value> res{rVal, rIdx};
rewriter.replaceOp(op, res);
return success();
}

llvm::SmallVector<int64_t> valShape(valResultType.getShape());
llvm::SmallVector<int64_t> idxShape(idxResultType.getShape());
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
valShape[i] = valShape[i + 1];
idxShape[i] = idxShape[i + 1];
}

valShape.resize(valShape.size() - 1);
idxShape.resize(idxShape.size() - 1);

Value rVal = rewriter.create<tensor::CastOp>(
loc, valResultType.clone(valShape), linalgOp.getResult(0));
Value rIdx = rewriter.create<tensor::CastOp>(
loc, idxResultType.clone(idxShape), linalgOp.getResult(1));

SmallVector<ReassociationIndices> reassociation(valShape.size());
if (reassociation.size() > 0) {
for (int i = 0; i < dim; ++i)
reassociation[i].push_back(i);
reassociation[std::max<int64_t>(0, dim - 1)].push_back(dim);
for (int i = dim, s = reassociation.size(); i < s; ++i)
reassociation[i].push_back(i + 1);
}

valShape.push_back(0);
idxShape.push_back(0);
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
valShape[i + 1] = valShape[i];
idxShape[i + 1] = idxShape[i];
}

valShape[dim] = 1;
idxShape[dim] = 1;

Value unsqueezeVal = rewriter.create<tensor::ExpandShapeOp>(
loc, valResultType, rVal, reassociation);

Value unsqueezeIdx = rewriter.create<tensor::ExpandShapeOp>(
loc, idxResultType, rIdx, reassociation);

llvm::SmallVector<Value> unsqueezes = {unsqueezeVal, unsqueezeIdx};
rewriter.replaceOp(op, unsqueezes);
return success();
}
};
Expand Down
Loading

0 comments on commit 4a7a7d7

Please sign in to comment.