Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mv. #1444

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3387,6 +3387,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
}];
}

def Torch_AtenMvOp : Torch_Op<"aten.mv", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$vec
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenMvOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -4528,31 +4552,6 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [
}];
}

def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFrobeniusNormDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenFrobeniusNormDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
38 changes: 38 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,42 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
};
} // namespace

namespace {
class ConvertAtenMvOp : public OpConversionPattern<AtenMvOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make more sense to add a decomposition of this op into AtenMatmulOp, since that op performs this same handling:

// Third Case: Matrix-Vec Multiplication.

The decomposition would happen in the DecomposeComplexOps.cpp file

public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMvOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = adaptor.self();
Value rhs = adaptor.vec();

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = newResultType.cast<RankedTensorType>();
Type elementType = resultType.getElementType();

Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);

Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
Value matmul =
rewriter
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
};
} // namespace

namespace {
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
public:
Expand Down Expand Up @@ -839,6 +875,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
target.addIllegalOp<AtenMatmulOp>();
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenMvOp>();
patterns.add<ConvertAtenMvOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>();
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
Expand Down
28 changes: 3 additions & 25 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
elementType.getIntOrFloatBitWidth())));
}

if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
if (isa<AtenLinalgVectorNormOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

op->emitError("unimplemented lowering in createInitElementForReduceOp");
Expand Down Expand Up @@ -244,16 +244,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenFrobeniusNormDimOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
auto abs = b.create<math::AbsFOp>(loc, self);
Attribute twoAttr = b.getFloatAttr(resultElementType, 2.0);
auto ord = b.create<arith::ConstantOp>(loc, twoAttr);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
}
}
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
return nullptr;
}
Expand Down Expand Up @@ -330,9 +321,6 @@ class ConvertReductionOp : public ConversionPattern {
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op))
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);

if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op))
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);

return rewriter.notifyMatchFailure(op, "not a supported reduce op");
}

Expand Down Expand Up @@ -417,7 +405,7 @@ class ConvertReductionOp : public ConversionPattern {
LogicalResult
validateReductionElementType(Operation *op, Type elemType,
ConversionPatternRewriter &rewriter) const {
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op)) &&
if ((isa<AtenLinalgVectorNormOp>(op)) &&
!elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
Expand Down Expand Up @@ -468,15 +456,6 @@ class ConvertReductionOp : public ConversionPattern {
reduceOp = *secondReduceOp;
}

// If it is aten.frobenius_norm.dim op, take the square root of reduceOp as
// the final result
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op)) {
auto halfAttr = rewriter.getFloatAttr(elemType, 0.5);
auto exp = rewriter.create<arith::ConstantOp>(loc, halfAttr);
reduceOp =
createElementwiseExp(loc, elemType, exp, reduceOp, *opInfo, rewriter);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, reduceOp);
return success();
}
Expand All @@ -493,6 +472,5 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);
}
110 changes: 1 addition & 109 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using namespace mlir::torch::torch_to_mhlo;
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy);
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp>(op)) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
Expand Down Expand Up @@ -571,113 +571,6 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
} // namespace

// AtenFrobeniusNormDimOp
// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims)
// + mhlo.sqrt
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
const TorchToMhloOptions &options = getOptions();

Value input = adaptor.self();
auto inputType = input.getType().dyn_cast<RankedTensorType>();
if (!inputType) {
return op.emitError(
"only ranked tensor input supported in AtenFrobeniusNormDimOp");
}
auto inputRank = inputType.getRank();
auto inputElemType = inputType.getElementType();
if (!inputElemType.isa<mlir::FloatType>()) {
return op.emitError(
"only float dtype allowed in input tensor of AtenFrobeniusNormDimOp");
}

SmallVector<int64_t> dims;
if (!matchPattern(op.dim(), m_TorchConstantIntList(dims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
for (auto &dim : dims) {
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op,
"invalid dimension detected in `dim`");
}
}

// Sort the dims in ascending order, making the conversion
// stable with unordered dims.
std::sort(dims.begin(), dims.end());

bool keepDim = false;
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(
op, "non-const bool `keepdim` is not supported");
}

auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
if (!initValue) {
return failure();
}

auto squareSumReduceOp = rewriter.create<mhlo::ReduceOp>(
op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Region &region = squareSumReduceOp.body();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemType);

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);

auto constantOrd2 = rewriter.create<mhlo::ConstantOp>(
op->getLoc(), blockArgumentTy,
DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef<float>{2.0}));
auto abs = rewriter.create<mhlo::AbsOp>(op->getLoc(), *secondArgument);
auto squareResult = rewriter.create<mhlo::PowOp>(
op->getLoc(), abs, constantOrd2);
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), squareResult,
*firstArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
}

auto output = rewriter.create<mhlo::SqrtOp>(op->getLoc(),
squareSumReduceOp.getResult(0));

if (keepDim) {
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), output,
outShapeTensor);
return success();
}
rewriter.replaceOp(op, output.getResult());
return success();
}
} // namespace

void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
Expand All @@ -690,6 +583,5 @@ void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
}
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Dtype is always float32, except for bfloat16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;
Expand Down Expand Up @@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op,

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Expand Down
Loading