Skip to content

Commit

Permalink
[MLIR][TORCH] Extend aten.div.Tensor support for integer dtype
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 committed Feb 20, 2024
1 parent 135c81a commit 124bd23
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
14 changes: 9 additions & 5 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(div.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
div.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::DivFOp>(loc, lhs, rhs);
if (dtype.isa<mlir::FloatType>())
return b.create<arith::DivFOp>(loc, lhs, rhs);
else if (dtype.isa<mlir::IntegerType>()) {
if (dtype.isUnsignedInteger())
return b.create<arith::DivUIOp>(loc, lhs, rhs);
return b.create<arith::DivSIOp>(loc, lhs, rhs);
}
div.emitError("unimplemented: non-floating point and non-integer dtype");
return nullptr;
}
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
AtenDivTensorModeOp::Adaptor adaptor(operands);
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,8 @@
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorUnsignedIntegerModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseEqBoolScalarModule_basic",
Expand Down
46 changes: 46 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,52 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseDivTensorIntegerModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int32, True),
])
def forward(self, a, b):
return torch.div(a, b)


@register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule())
def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32))


# ==============================================================================


class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.uint8, True),
([-1, -1], torch.uint8, True),
])
def forward(self, a, b):
return torch.div(a, b)


@register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule())
def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8))


# ==============================================================================


class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):

def __init__(self):
Expand Down
9 changes: 9 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,15 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3

// -----

// CHECK-LABEL: @test_div_int32
func.func @test_div_int32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],si32>
%0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32>
return %0 : !torch.vtensor<[3,4,5],si32>
}

// -----

// CHECK-LABEL: @test_div_uint8
func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8>
Expand Down

0 comments on commit 124bd23

Please sign in to comment.