Skip to content

Commit

Permalink
Merge pull request #1692 from mfeliz-cruise/michael.feliz/index_selec…
Browse files Browse the repository at this point in the history
…t_converter

[feat] Add converter support for index_select
  • Loading branch information
peri044 authored Feb 28, 2023
2 parents 617469f + 69b3d79 commit b388010
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
23 changes: 23 additions & 0 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,29 @@ auto select_registrations TORCHTRT_UNUSED =
return true;
}})
.pattern(
{"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
auto dim = args[1].unwrapToInt();
// Handle negative axis by refering to nbDims of input Tensor
dim = dim < 0 ? dim + maxDim : dim;
auto index = args[2].ITensorOrFreeze(ctx);

LOG_DEBUG("Gather input dimensions: " << in->getDimensions());
LOG_DEBUG("Dimension to select: " << dim);
LOG_DEBUG("Index dimensions: " << index->getDimensions());

auto gather_layer = ctx->net->addGather(*in, *index, dim);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto out = gather_layer->getOutput(0);
LOG_DEBUG("Gather tensor shape: " << out->getDimensions());

out = ctx->AssociateValueAndTensor(n->outputs()[0], out);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
{"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
Expand Down
54 changes: 54 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,60 @@ TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::sameShape(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenIndexSelectConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %index : Int (2)):
%2 : int = prim::Constant[value=0]()
%3 : Tensor = aten::index_select(%0, %2, %index)
return (%3))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
auto index = at::randint(0, 4, {2}, {at::kCUDA}).to(torch::kI32);

auto jit_in = at::clone(in);
auto jit_index = at::clone(index);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_index = at::clone(index);
auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenIndexSelectNegativeDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %index : Int (5)):
%2 : int = prim::Constant[value=-1]()
%3 : Tensor = aten::index_select(%0, %2, %index)
return (%3))IR";
auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {5, 3, 9}, {at::kCUDA});
auto index = at::randint(0, 9, {5}, {at::kCUDA}).to(torch::kI32);

auto jit_in = at::clone(in);
auto jit_index = at::clone(index);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_index = at::clone(index);
auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down

0 comments on commit b388010

Please sign in to comment.