Skip to content

Commit

Permalink
[TorchToStablehlo] support l1_loss, deg2rad, logit (#3865)
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 authored Nov 18, 2024
1 parent 896f66c commit bdbc64a
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 1 deletion.
48 changes: 48 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9383,6 +9383,31 @@ def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [
}];
}

def Torch_AtenL1LossOp : Torch_Op<"aten.l1_loss", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenL1LossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenL1LossOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -16923,6 +16948,29 @@ def Torch_AtenTrilIndicesOp : Torch_Op<"aten.tril_indices", [
let hasVerifier = 1;
}

def Torch_AtenDeg2radOp : Torch_Op<"aten.deg2rad", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::deg2rad : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDeg2radOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenDeg2radOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
44 changes: 44 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,49 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
return success();
}

// AtenLogitOp
template <>
LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
AtenLogitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();

Value self = adaptor.getSelf();
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
if (!selfTy) {
return op.emitError("only ranked tensor type is supported.");
}

auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
self = hlo::promoteType(rewriter, op.getLoc(), self, outTy.getElementType());

selfTy = dyn_cast<RankedTensorType>(self.getType());

Value eps = adaptor.getEps();
auto epsTy = eps.getType();
Value newSelf;
if (!isa<Torch::NoneType>(epsTy)) {
auto epsTensor = hlo::scalarToStablehloTensor(rewriter, op, eps,
selfTy.getElementType());
Value oneEpsTensor = hlo::getConstantLike(rewriter, loc, 1.0, epsTensor);
auto max =
rewriter.create<stablehlo::SubtractOp>(loc, oneEpsTensor, epsTensor);
newSelf = rewriter.create<stablehlo::ClampOp>(loc, epsTensor, self, max);
} else {
newSelf = self;
}

Value one = hlo::getConstantLike(rewriter, loc, 1.0, self);
Value zi1 = rewriter.create<stablehlo::SubtractOp>(loc, one, newSelf);
Value newZi = rewriter.create<stablehlo::DivOp>(loc, newSelf, zi1);

Value log = rewriter.create<stablehlo::LogOp>(loc, outTy, newZi);

rewriter.replaceOp(op, log);

return success();
}

// AtenErfOp
template <>
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
Expand Down Expand Up @@ -2248,6 +2291,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenLog10Op);
INSERT_ATENOP_PATTERN(AtenLogitOp);
INSERT_ATENOP_PATTERN(AtenErfOp);
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);

Expand Down
39 changes: 39 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10465,6 +10465,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.deg2rad\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
Expand All @@ -10485,6 +10489,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.l1_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
" %2 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" torch.prim.If.yield %2 : !torch.list<int>\n"
" } else {\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.If.yield %2 : !torch.list<int>\n"
" }\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.cross_entropy_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -13864,6 +13880,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.l1_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down Expand Up @@ -15918,6 +15952,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.deg2rad\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int3 = torch.constant.int 3\n"
" %int1 = torch.constant.int 1\n"
Expand Down
105 changes: 105 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,44 @@ class DecomposeAtenTrilIndicesOp : public OpRewritePattern<AtenTrilIndicesOp> {
};
} // namespace

namespace {
class DecomposeAtenDeg2radOp : public OpRewritePattern<AtenDeg2radOp> {
public:
using OpRewritePattern<AtenDeg2radOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenDeg2radOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
auto selfTy = dyn_cast<BaseTensorType>(self.getType());
if (!selfTy || !selfTy.getDtype()) {
return rewriter.notifyMatchFailure(op, "requires tensor types input.");
}

auto outTy = dyn_cast<BaseTensorType>(op.getType());
if (!outTy || !outTy.getDtype()) {
return rewriter.notifyMatchFailure(
op, "requires output is a tensor with dtype.");
}

if (selfTy.getDtype() != outTy.getDtype()) {
self = convertTensorToDtype(rewriter, loc, self, outTy.getDtype());
}

Value pi =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(M_PI));
Value basic =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(180.0));
Value rad =
rewriter.create<AtenDivScalarOp>(loc, op.getType(), self, basic);
Value result = rewriter.create<AtenMulScalarOp>(loc, op.getType(), rad, pi);

rewriter.replaceOp(op, result);

return success();
}
};
} // namespace

namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
Expand Down Expand Up @@ -8640,6 +8678,71 @@ class DecomposeAtenMseLossOp : public OpRewritePattern<AtenMseLossOp> {
};
} // namespace

namespace {
class DecomposeAtenL1LossOp : public OpRewritePattern<AtenL1LossOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenL1LossOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
auto selfTy = dyn_cast<BaseTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes() || !selfTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "Expected self to be a tensor with sizes and a dtype");
}

Value target = op.getTarget();
auto targetTy = dyn_cast<BaseTensorType>(target.getType());
if (!targetTy || !targetTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "Expected target to be a tensor with sizes and a dtype");
}

auto outTy = dyn_cast<BaseTensorType>(op.getType());
if (!outTy || !outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "Expected output type to be a tensor with a dtype");
}

auto outDtype = outTy.getDtype();
if (selfTy.getDtype() != outDtype) {
self = convertTensorToDtype(rewriter, loc, self, outDtype);
}
if (targetTy.getDtype() != outDtype) {
target = convertTensorToDtype(rewriter, loc, target, outDtype);
}

Value reduction = op.getReduction();
int64_t reductionInt;
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
return rewriter.notifyMatchFailure(
op, "Expected reduction to be a constant int");
}

auto subTy = outTy.getWithSizesAndDtype(selfTy.getSizes(), outDtype);
Value sub = createTensorSub(rewriter, loc, subTy, self, target);
Value abs = rewriter.create<AtenAbsOp>(loc, subTy, sub);

if (reductionInt == 0) {
rewriter.replaceOp(op, abs);
} else if (reductionInt == 1) {
Value none = rewriter.create<ConstantNoneOp>(loc);
Value sum = rewriter.create<AtenSumOp>(loc, outTy, abs, none);
Value numel = rewriter.create<AtenNumelOp>(loc, abs);
Value mean = rewriter.create<AtenDivScalarOp>(loc, outTy, sum, numel);
rewriter.replaceOp(op, mean);
} else {
Value none = rewriter.create<ConstantNoneOp>(loc);
Value sum = rewriter.create<AtenSumOp>(loc, outTy, abs, none);
rewriter.replaceOp(op, sum);
}

return success();
}
};
} // namespace

namespace {
// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op
class DecomposeAtenNormScalarOptDimOp
Expand Down Expand Up @@ -10776,6 +10879,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenL1LossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
Expand Down Expand Up @@ -10821,6 +10925,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuIndicesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTrilIndicesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenDeg2radOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LinalgDetOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgSlogdetOp>(patterns);
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLerpScalarOp>();
target.addIllegalOp<AtenLerpTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenL1LossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>();
target.addIllegalOp<AtenVarMeanCorrectionOp>();
Expand Down Expand Up @@ -564,6 +565,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTriuOp>();
target.addIllegalOp<AtenTriuIndicesOp>();
target.addIllegalOp<AtenTrilIndicesOp>();
target.addIllegalOp<AtenDeg2radOp>();
target.addIllegalOp<AtenLinalgNormOp>();
target.addIllegalOp<AtenFminOp>();
target.addIllegalOp<AtenFmaxOp>();
Expand Down
5 changes: 4 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,6 @@
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
Expand Down Expand Up @@ -2899,6 +2898,7 @@
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"Deg2radModule_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"ElementwiseAcoshIntModule_basic",
Expand Down Expand Up @@ -2986,6 +2986,9 @@
"IsFloatingPointInt_False",
"IscloseStaticModuleTrue_basic",
"IscloseStaticModule_basic",
"L1LossNoReductionModule_basic",
"L1LossMeanReductionModule_basic",
"L1LossSumReductionModule_basic",
"LeakyReluBackwardModule_basic",
"LeakyReluBackwardStaticModule_basic",
"LenStrModule_basic",
Expand Down
Loading

0 comments on commit bdbc64a

Please sign in to comment.