diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 102f0d294f..79f061ad5d 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -17,17 +17,23 @@ namespace { bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) { auto in = args[0].ITensor(); - auto numOutputs = 1, numRemainder = 0, axis = 0; + auto numOutputs = 1, numRemainder = 0; std::vector sizes; + // Precompute axis along which to apply split, ensuring negative dimensions are re-indexed + auto maxDim = static_cast(in->getDimensions().nbDims); + auto input_axis = unbind ? args[1].unwrapToInt() : args[2].unwrapToInt(); + auto axis = input_axis < 0 ? input_axis + maxDim : input_axis; + + // Ensure input axis is valid for input tensor + TORCHTRT_CHECK( + (axis >= 0) && (axis < maxDim), + "Expected input axis to fall in range [-" << maxDim << ", " << (maxDim - 1) << "], got " << input_axis); + if (unbind) { - axis = args[1].unwrapToInt(); - auto maxDim = static_cast(in->getDimensions().nbDims); - axis = axis < 0 ? axis + maxDim : axis; numOutputs = in->getDimensions().d[axis]; sizes.insert(sizes.end(), numOutputs, 1); } else { - axis = args[2].unwrapToInt(); auto inDimSize = in->getDimensions().d[axis]; if (split_list) { sizes = args[1].unwrapToIntList().vec(); @@ -274,7 +280,8 @@ auto select_registrations TORCHTRT_UNUSED = .pattern( {"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // refer to https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4627 + // refer to + // https://github.com/pytorch/pytorch/blob/974ad8fa6cc63b89234beb5ebff54c2d42711932/torch/onnx/symbolic_opset9.py#L4627 auto in = args[0].ITensorOrFreeze(ctx); auto ts = args[1].IValue()->toListRef(); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index d6f4996580..c04036e9ba 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -739,6 +739,34 @@ TEST(Converters, ATenSplitAndAddConvertsCorrectly) { } } +TEST(Converters, ATenSplitNegativeDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %n1 : int = prim::Constant[value=-1]() + %3 : Tensor[] = aten::split(%x.1, %2, %n1) + %4 : Tensor, %5 : Tensor, %6 : Tensor, %7 : Tensor = prim::ListUnpack(%3) + return (%4, %5, %6, %7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): @@ -1109,4 +1137,4 @@ TEST(Converters, ScatterSrcConvertsCorrectly) { auto trt = trt_results[i].reshape(jit_results[i].sizes()); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } -} \ No newline at end of file +}