diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index c569a6088e..4fead67513 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -180,6 +180,29 @@ auto select_registrations TORCHTRT_UNUSED = return true; }}) .pattern( + {"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto maxDim = static_cast(in->getDimensions().nbDims); + auto dim = args[1].unwrapToInt(); + // Handle negative axis by refering to nbDims of input Tensor + dim = dim < 0 ? dim + maxDim : dim; + auto index = args[2].ITensorOrFreeze(ctx); + + LOG_DEBUG("Gather input dimensions: " << in->getDimensions()); + LOG_DEBUG("Dimension to select: " << dim); + LOG_DEBUG("Index dimensions: " << index->getDimensions()); + + auto gather_layer = ctx->net->addGather(*in, *index, dim); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto out = gather_layer->getOutput(0); + LOG_DEBUG("Gather tensor shape: " << out->getDimensions()); + + out = ctx->AssociateValueAndTensor(n->outputs()[0], out); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + }}) + .pattern( {"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensor(); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 991a1b792c..a1f7485825 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -165,6 +165,60 @@ TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::sameShape(jit_results[0], trt_results[0])); } +TEST(Converters, ATenIndexSelectConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %index : Int (2)): + %2 : int = prim::Constant[value=0]() + %3 : Tensor = aten::index_select(%0, %2, %index) + return (%3))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); + auto index = at::randint(0, 4, {2}, {at::kCUDA}).to(torch::kI32); + + auto jit_in = at::clone(in); + auto jit_index = at::clone(index); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_index = at::clone(index); + auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenIndexSelectNegativeDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %index : Int (5)): + %2 : int = prim::Constant[value=-1]() + %3 : Tensor = aten::index_select(%0, %2, %index) + return (%3))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {5, 3, 9}, {at::kCUDA}); + auto index = at::randint(0, 9, {5}, {at::kCUDA}).to(torch::kI32); + + auto jit_in = at::clone(in); + auto jit_index = at::clone(index); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_index = at::clone(index); + auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):