Skip to content

Commit

Permalink
[torch] torch.aten.complex operation with lowering (#3738)
Browse files Browse the repository at this point in the history
Add the operation with lowering to linalg. Includes a test for
end-to-end correctness.
  • Loading branch information
rsuderman authored Oct 3, 2024
1 parent f0b7ca7 commit 9ab0db5
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 17 deletions.
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5122,6 +5122,30 @@ def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [
}];
}

def Torch_AtenComplexOp : Torch_Op<"aten.complex", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$real,
AnyTorchTensorType:$imag
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenComplexOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenRealOp : Torch_Op<"aten.real", [
AllowsTypeRefinement,
ReadOnly
Expand Down
44 changes: 27 additions & 17 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero);
}
if (auto complex = dyn_cast<AtenComplexOp>(op)) {
auto ctype = cast<ComplexType>(
cast<RankedTensorType>(converter->convertType(complex.getType()))
.getElementType());
Type stype = ctype.getElementType();

Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype);
return b.create<complex::CreateOp>(loc, ctype, lhs, rhs);
}
if (isa<AtenAbsOp>(op)) {
if (isa<IntegerType>(payloadArgs[0].getType()))
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
Expand Down Expand Up @@ -1590,22 +1600,22 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

Expand Down Expand Up @@ -3351,7 +3361,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2751,6 +2751,7 @@
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseCreateComplexModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseOrTensorModule_basic",
Expand Down Expand Up @@ -3165,6 +3166,14 @@
"AtenIntMM_basic",
}

if torch_version_for_comparison() > version.parse("2.4.0.dev"):
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
"ElementwiseCreateComplexModule_basic",
}
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
"ElementwiseCreateComplexModule_basic",
}


ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
"FakeQuantizePerTensorAffineModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
emit("aten::rad2deg : (Tensor) -> (Tensor)")
emit("aten::complex : (Tensor, Tensor) -> (Tensor)")
emit("aten::real : (Tensor) -> (Tensor)")
emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseCreateComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, a, b):
return torch.complex(a, b)


@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule())
def ElementwiseCreateComplexModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.float32),
tu.randint(4, high=10).type(torch.float32),
)


# ==============================================================================


class ElementwiseMulTensorComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 9ab0db5

Please sign in to comment.