Skip to content

Commit

Permalink
[Torch] Add decomposition for 1d torch.nonzero (#3876)
Browse files Browse the repository at this point in the history
2d static nonzero also work. But 2d dynamic need to be fixed next.
  • Loading branch information
AmosLewis authored Dec 19, 2024
1 parent 061bbc5 commit 51da49c
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 1 deletion.
235 changes: 235 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5705,6 +5705,240 @@ class DecomposeAtenConvolutionBackwardOp
};
} // namespace

/**
* # one dim input
* t = torch.tensor([0, 0, 1, 1, 0, 0]
* # t_flat:[0, 0, 1, 1, 0, 0]
* t_flat = t.flatten(0, 0)
* nonzero_mask = t_flat != 0
* # nonzero_mask:[0, 0, 1, 1, 0, 0]
* nonzero_mask = nonzero_mask.long()
* # destination_indices:[-1, -1, 0, 1, 1, 1]
* destination_indices = torch.cumsum(nonzero_mask, 0) - 1
* # destination_indices_clamp:[0, 0, 0, 1, 1, 1]
* destination_indices_clamp = torch.clamp(destination_indices, min=0)
* # iota:[0, 0, 2, 3, 0, 0]
* iota = torch.arange(t_flat.size(0)) * nonzero_mask
* # scatter_self:[0, 0, 0, 0, 0, 0]
* scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
* # compacted:[2, 3, 0, 0, 0, 0]
* compacted = torch.scatter_add(
* scatter_self, dim=0, index=destination_indices_clamp, src=iota
* )
* # result_flat:[2, 3]
* result_flat = compacted[: torch.sum(nonzero_mask)]
*
* # multi dim support
* original_shape = t.shape
* # input_shape_tensor:[6]
* input_shape_tensor = torch.tensor(original_shape)
* strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
*
* one = torch.tensor([1])
* if(t.dim() > 1):
* slicedStrides = strides[1:-1]
* strides = torch.cat([slicedStrides, one])
* else:
* strides = one
* # a: tensor([[2], [3]]) torch.Size([2, 1])
* a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
* # b: tensor([[1]]) torch.Size([1, 1])
* b = strides.unsqueeze(0)
* # c: tensor([[2], [3]]) torch.Size([2, 1])
* c = a // b
* # result: tensor([[2], [3]]) torch.Size([2, 1])
* result = c % input_shape_tensor
*/
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNonzeroOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resultType = cast<BaseTensorType>(op.getType());
auto intType = resultType.getDtype();
Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType);
auto constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
auto constantOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
std::function<Value(Value)> makeOneElementList = [&](Value element) {
auto listType = Torch::ListType::get(element.getType());
return rewriter.create<PrimListConstructOp>(loc, listType,
ArrayRef<Value>{element});
};

Value input = op.getSelf();
auto inputType = dyn_cast<BaseTensorType>(input.getType());
int64_t inputRank = inputType.getSizes().size();

// t_flat = t.flatten() # torch.flatten(t, 0, 0)
int64_t flattenedSize = 1;
if (inputType.hasSizes()) {
for (auto size : inputType.getSizes()) {
flattenedSize *= size;
}
} else {
flattenedSize = kUnknownSize;
}

auto flattendInputShape = SmallVector<int64_t>{flattenedSize};
auto flattenedInputType = rewriter.getType<Torch::ValueTensorType>(
flattendInputShape, inputType.getOptionalDtype());

// %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 :
auto inputDimsEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));
Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
loc, flattenedInputType, input, constantZero /*inputDimsStart*/,
inputDimsEnd /*inputDimsEnd*/);

// nonzero_mask = (t_flat != 0)
auto boolMaskType = inputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
Value boolMask = rewriter.create<AtenNeScalarOp>(
loc, boolMaskType, flattenedInput, constantZero);

// nonzero_mask = nonzero_mask.int()
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), intType);
Value intMask = rewriter.create<AtenToDtypeOp>(
loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst);

// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
Value cumulativeSum = rewriter.create<AtenCumsumOp>(
loc, intMaskType, intMask, constantZero, noneCst);
Value subtracted = rewriter.create<AtenSubScalarOp>(
loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne);

// destination_indices = torch.clamp(destination_indices, min=0)
Value indices = rewriter.create<AtenClampMinOp>(loc, intMaskType,
subtracted, constantZero);

// iota = torch.arange(len(t_flat)) * nonzero_mask
Value end = rewriter.create<AtenSizeIntOp>(loc, flattenedInput,
/*dim=*/constantZero);
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, intMaskType, /*start*/ constantZero, /*end*/ end,
/*step*/ constantOne, noneCst, noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, intMaskType,
rangeTensor, intMask);

// scatter_self = torch.zeros_like(t, dtype=torch.int64)
// AtenFullLike doesn't support index type so we have to use int.
Value zerosTensor = rewriter.create<AtenZerosLikeOp>(
loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst,
noneCst, noneCst);

// compacted = torch.scatter_add(
// scatter_self, dim=0, index=destination_indices_clamp, src=iota)
Value scatteredTensor = rewriter.create<AtenScatterAddOp>(
loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero,
/*index=*/indices, /*src=*/multiplied);

// result_flat = compacted[:torch.sum(nonzero_mask)]
auto scalarType = ValueTensorType::get(rewriter.getContext(),
ArrayRef<int64_t>{}, intType);
Value sumMask =
rewriter.create<AtenSumOp>(loc, scalarType, intMask, noneCst);
Value numNonzero = rewriter.create<AtenIntTensorOp>(loc, sumMask);

auto slicedResultType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, intType);
Value slicedResult =
rewriter.create<AtenSliceTensorOp>(loc, slicedResultType,
/*self=*/scatteredTensor,
/*dim=*/constantZero,
/*start=*/noneCst,
/*end=*/numNonzero,
/*step=*/constantOne);

// TODO fix multidim dynamic support. The following code only work for
// static multidim. Convert flattened indices back to multi-dimensional
// indices original_shape = t.shape input_shape_tensor =
// torch.tensor(original_shape)
auto shapeType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank}, intType);
SmallVector<Value> shapeValues;
for (int i = 0; i < inputRank; i++) {
auto constantI =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
Value shape = rewriter.create<AtenSizeIntOp>(loc, input,
/*dim=*/constantI);
shapeValues.push_back(shape);
}
Value shapeTensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues);
Value inputShapeTensor = rewriter.create<Torch::AtenTensorOp>(
loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst);

// strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0)
Value flippedShape = rewriter.create<AtenFlipOp>(
loc, shapeType, inputShapeTensor, makeOneElementList(constantZero));
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
loc, shapeType, flippedShape, constantZero, noneCst);
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));

// strides = torch.cat([strides[1:-1], torch.tensor([1])])
auto oneTensorType = ValueTensorType::get(rewriter.getContext(),
SmallVector<int64_t>{1}, intType);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst,
noneCst);

Value strides;
if (inputRank > 1) {
// strides[1:-1]
auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
intType);
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank));
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
/*dim*/ constantZero,
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);
// torch.cat
auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, intType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(tensorListElementType),
SmallVector<Value>{slicedStrides, oneTensor});
strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList,
constantZero);
} else {
// strides[1:-1] is empty
strides = oneTensor;
}

// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
// input_shape_tensor
auto unsqueezedResultType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, intType);
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedResultType, slicedResult, constantOne);

auto unsqueezedStridesType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, intType);
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedStridesType, strides, constantZero);

auto dividedBroadcastType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
intType);
Value divided = rewriter.create<AtenFloorDivideOp>(
loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);

Value modded = rewriter.create<AtenRemainderTensorOp>(
loc, resultType, divided, inputShapeTensor);

rewriter.replaceOp(op, modded);
return success();
}
};

// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
namespace {
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
Expand Down Expand Up @@ -11263,6 +11497,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
Expand Down
3 changes: 2 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@
"AtenIntBoolOpModule_basic",
"AtenIntMM_basic",
"AtenItemFpOpModule_basic",
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"QuantizedReluInt32_basic",
Expand Down Expand Up @@ -628,6 +629,7 @@
"AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"AtenNonzero1DDynamicModule_basic",
"AtenRealView128Module_basic",
"AtenRealView64Module_basic",
"AtenTopKModule_basic",
Expand Down Expand Up @@ -3018,7 +3020,6 @@
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6430,3 +6430,26 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64)
)


# ==============================================================================


class AtenNonzero1DDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
)
def forward(self, x):
return torch.ops.aten.nonzero(x)


@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule())
def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))

0 comments on commit 51da49c

Please sign in to comment.