Skip to content

Commit

Permalink
Implement lowering of torch.aten.lerp.Scalar (llvm#2773)
Browse files Browse the repository at this point in the history
  • Loading branch information
ikalinic authored Jan 31, 2024
1 parent 7301aa8 commit 54ef18c
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 0 deletions.
49 changes: 49 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,55 @@ def Torch_AtenLerp_TensorOp : Torch_Op<"aten.lerp_.Tensor", [
}];
}

def Torch_AtenLerpScalarOp : Torch_Op<"aten.lerp.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$end,
AnyTorchScalarType:$weight
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLerpScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenLerpScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenLerp_ScalarOp : Torch_Op<"aten.lerp_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::lerp_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$end,
AnyTorchScalarType:$weight
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLerp_ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenLerp_ScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8438,6 +8438,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.lerp.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.addcmul\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
Expand Down Expand Up @@ -11198,6 +11202,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,35 @@ class DecomposeAtenLeakyReluBackwardOp
};
} // namespace

namespace {
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLerpScalarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto start = op.getSelf();
auto inputType = start.getType().cast<BaseTensorType>();

auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
start, cstOne);

auto weightedDelta =
rewriter.create<AtenMulScalarOp>(loc, inputType, delta, op.getWeight());
auto lerp = rewriter.create<AtenAddTensorOp>(loc, inputType, start,
weightedDelta, cstOne);
rewriter.replaceOp(op, lerp);
return success();
}
};
} // namespace

// Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1)
namespace {
class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
Expand Down Expand Up @@ -6763,6 +6792,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNarrowTensorOp>();
target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenLerpScalarOp>();
target.addIllegalOp<AtenIndexTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,8 @@
"ElementwiseLeakyReluModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLerpScalarIntModule_basic",
"ElementwiseLerpScalarFloatModule_basic",
"ElementwiseLog2Module_basic",
"ElementwiseLogModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
Expand Down Expand Up @@ -1496,6 +1498,8 @@
"ElementwiseLogitModule_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseLerpScalarIntModule_basic",
"ElementwiseLerpScalarFloatModule_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"UpSampleNearest2dBackwardVec_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,9 @@ def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posi
def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight))

def aten〇lerp〇Scalar〡shape(self: List[int], end: List[int], weight: float) -> List[int]:
return upstream_shape_functions.broadcast(self, end)

def aten〇addcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]:
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2))

Expand Down Expand Up @@ -3313,6 +3316,27 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp
dtypes = [self_dtype, end_dtype, weight_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5) +
# Different width
[Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float64),
weight=0.5),
# Different type
Invocation(TensorOfShape(4, 3, dtype=torch.int32),
TensorOfShape(4, 3, dtype=torch.float32),
weight=0.5),
Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float32),
weight=2)])
def aten〇lerp〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
end_rank, end_dtype = end_rank_dtype

ranks: List[Optional[int]] = [self_rank, end_rank, None]
dtypes = [self_dtype, end_dtype, get_dtype_of_scalar(weight)]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) +
# Different width
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::logical_xor : (Tensor, Tensor) -> (Tensor)",
"aten::logical_not : (Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)",
Expand Down
42 changes: 42 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,48 @@ def forward(self, x):
def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, low=-1))


# ==============================================================================


class ElementwiseLerpScalarIntModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.lerp(a, b, weight=2)

@register_test_case(module_factory=lambda: ElementwiseLerpScalarIntModule())
def ElementwiseLerpScalarIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5,3), tu.rand(5,3))


class ElementwiseLerpScalarFloatModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.lerp(a, b, weight=0.5)

@register_test_case(module_factory=lambda: ElementwiseLerpScalarFloatModule())
def ElementwiseLerpScalarFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5,3), tu.rand(5,3))


# ==============================================================================


Expand Down

0 comments on commit 54ef18c

Please sign in to comment.