From 9dd5ae8239e2d45fd8959b8b014bf0c76be16d31 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 30 Sep 2022 20:03:41 +0530 Subject: [PATCH] [tosa] Add TorchToTosa lowering for aten.arange.start_step op (#1442) --- e2e_testing/xfail_sets.py | 7 +++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 51 ++++++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 19 ++++++++ 3 files changed, 77 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index eee0c011837e..231b3cbe6f20 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -456,6 +456,13 @@ "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", "SliceStaticModule_basic", + "ArangeStartStepIntModule_basic", + "ArangeDtypeFloatModule_basic", + "ArangeIntModule_basic", + "ArangeNegativeStartIntModule_basic", + "ArangeStartIntModule_basic", + "ArangeStartNegativeStepIntModule_basic", + "ArangeZeroElementOutputModule_basic", } LTC_XFAIL_SET = { diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1285cc98650a..1e7407da3d95 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3005,6 +3005,56 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "unimplemented: broadcasts other than same rank or zero ranked tensor."); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenArangeStartStepOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = + typeConverter->convertType(op->getResult(0).getType()) + .cast(); + + // At this point all tensors should have value semantics, and hence the + // `layout` check can be ignored. + + // TODO: Add support for pin_memory features. + // The pin_memory should be either `False` or `none`. + bool pinMemory; + if (!op.pin_memory().getType().isa() && + (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: pin_memory must be either None or false"); + } + + int64_t start, step, end; + if (!matchPattern(op.start(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `start` should be a torch constant int"); + + if (!matchPattern(op.end(), m_TorchConstantInt(&end))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `end` should be a torch constant int"); + + if (!matchPattern(op.step(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `step` should be a torch constant int"); + + // The result will always be a 1-d tensor. + // The size of the result is calculated as follows: + // ceil((end - start)/step) + int64_t resultShape = ceil((float)(end - start) / (float)step); + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += i * step; + Value result = + tosa::getConstTensor(rewriter, op, values, resultShape).value(); + + rewriter.replaceOpWithNewOp(op, resultType, result); + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -3653,6 +3703,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenMaxDimOp); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index ea84779980a1..2edcb9b82d07 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -828,3 +828,22 @@ func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> { %0 = torch.vtensor.literal(dense<-1> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64> return %0 : !torch.vtensor<[1,512],si64> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[CST5:.*]] = torch.constant.int 5 +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>} : () -> tensor<5xi64> +// CHECK: %[[VAL_1:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<5xi64>) -> tensor<5xi64> +// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %1 : tensor<5xi64> -> !torch.vtensor<[5],si64> +// CHECK: return %[[VAL_2]] : !torch.vtensor<[5],si64> +func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int5 = torch.constant.int 5 + %int1 = torch.constant.int 1 + %0 = torch.aten.arange.start_step %int0, %int5, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64> + return %0 : !torch.vtensor<[5],si64> +}