Skip to content

Commit

Permalink
[LINALG] Add complex tensor support for create[Zero|One]InitTensor
Browse files Browse the repository at this point in the history
…utility (#3777)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 authored Oct 9, 2024
1 parent d49eabb commit 94f5410
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
18 changes: 12 additions & 6 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,25 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));

Type fillValElemTy = elemTy;
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());

Value c0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(fillValElemTy));
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
}

Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c1 =
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));

Type fillValElemTy = elemTy;
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());

Value c1 = b.create<arith::ConstantOp>(loc, b.getOneAttr(fillValElemTy));
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
}

Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,7 @@
"SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic",
"SliceStaticModule_basic",
"SliceStaticComplexInputModule_basic",
"SliceWholeTensorModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
Expand Down Expand Up @@ -2618,6 +2619,7 @@
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
"SliceCopy_Module_basic",
"SliceStaticComplexInputModule_basic",
"StdCorrectionLargeInputModule_basic",
"TupleModule_basic",
"VarCorrectionLargeInputModule_basic",
Expand Down Expand Up @@ -3778,6 +3780,7 @@
"SignAndLogarithmOfDeterminantModule_F32",
"SignAndLogarithmOfDeterminantBatchedModule_F32",
"SignAndLogarithmOfDeterminantDynamicModule_F32",
"SliceStaticComplexInputModule_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
Expand Down Expand Up @@ -4714,6 +4717,7 @@
"SliceCopy_Module_basic",
"SliceEndSleStartModule_basic",
"SliceModule_basic",
"SliceStaticComplexInputModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils):
# ==============================================================================


class SliceStaticComplexInputModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([6, 4, 7], torch.complex64, True),
]
)
def forward(self, x):
return x[0:5:1, 1:3:1, 2:4:1]


@register_test_case(module_factory=lambda: SliceStaticComplexInputModule())
def SliceStaticComplexInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7).to(torch.complex64))


# ==============================================================================


class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 94f5410

Please sign in to comment.