Skip to content

Commit

Permalink
Avoid layer name conflicts in aten::index (#1377)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeliz-cruise authored Oct 4, 2022
1 parent 9d89f6c commit dd88afc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
23 changes: 12 additions & 11 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ auto select_registrations TORCHTRT_UNUSED =
nvinfer1::ElementWiseOperation::kPROD,
d0,
dim_tensor,
std::string("compute_dim0_") + std::to_string(i))
util::node_info(n) + std::string("_compute_dim0_") + std::to_string(i))
->getOutput(0);
}

Expand All @@ -378,7 +378,7 @@ auto select_registrations TORCHTRT_UNUSED =
nvinfer1::ElementWiseOperation::kPROD,
d1,
dim_tensor,
std::string("compute_dim1_") + std::to_string(i))
util::node_info(n) + std::string("_compute_dim1_") + std::to_string(i))
->getOutput(0);
}

Expand All @@ -398,26 +398,27 @@ auto select_registrations TORCHTRT_UNUSED =
nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1]];
nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1];
for (int i = adv_idx_count - 2; i >= 0; i--) {
nvinfer1::ITensor* adv_index = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
tensors[i],
multiplier,
std::string("adv_index_") + std::to_string(i))
->getOutput(0);
nvinfer1::ITensor* adv_index =
add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
tensors[i],
multiplier,
util::node_info(n) + std::string("_adv_index_") + std::to_string(i))
->getOutput(0);
cum_adv_index = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
cum_adv_index,
adv_index,
std::string("cum_adv_index_") + std::to_string(i))
util::node_info(n) + std::string("_cum_adv_index_") + std::to_string(i))
->getOutput(0);
multiplier = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
multiplier,
dim_tensor_list[adv_idx_indices[i]],
std::string("multiplier_") + std::to_string(i))
util::node_info(n) + std::string("_multiplier_") + std::to_string(i))
->getOutput(0);
}

Expand Down
32 changes: 32 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,38 @@ TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) {
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenIndexTensorRepeatedFullIndicesConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
%index0 : Tensor,
%index1 : Tensor,
%index2 : Tensor):
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2)
%19 : Tensor = aten::index(%x.1, %18)
%20 : Tensor = aten::index(%x.1, %18)
return (%19, %20))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA});
auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong);
auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong);
auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong);
auto index0_trt = index0.to(torch::kInt32);
auto index1_trt = index1.to(torch::kInt32);
auto index2_trt = index2.to(torch::kInt32);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6));
}

TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
Expand Down

0 comments on commit dd88afc

Please sign in to comment.