From 45e3bd473b9a963bfd4fa0e8d5f8b2c235269320 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 11 Aug 2022 14:29:37 -0700 Subject: [PATCH 1/3] feat(aten::__derive_index): Implement derive index evaluator Fixes: #834 Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/evaluators/aten.cpp | 11 ++++++++++- .../evaluators/test_aten_evaluators.cpp | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 518219b361..8136cd66c9 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -806,7 +806,16 @@ auto aten_registrations TORCHTRT_UNUSED = return 0; } }, - EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})}); + EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})}) + .evaluator( + {c10::Symbol::fromQualString("aten::__derive_index"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto idx = args.at(n->input(0)).unwrapToInt(); + auto start = args.at(n->input(1)).unwrapToInt(); + auto step = args.at(n->input(2)).unwrapToInt(); + return start + idx * step; + }}); + } // namespace } // namespace evaluators } // namespace conversion diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index a379f15060..aef7fa5729 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -797,3 +797,21 @@ TEST(Evaluators, PowFloatIntEvaluatesCorrectly) { ASSERT_TRUE(jit_results[0] == trt_results[0]); } + +TEST(Evaluators, DeriveIndexEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=9]() + %2 : int = prim::Constant[value=4]() + %3 : int = prim::Constant[value=2]() + %4 : int = aten::__derive_index(%1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} From 371f247ce923a5ab3a11d1394d850b3df0935688 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 11 Aug 2022 21:35:13 -0700 Subject: [PATCH 2/3] feat: Allow tuples to carry ITensors Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/converters/impl/unary.cpp | 9 ++-- core/conversion/evaluators/aten.cpp | 9 ++-- core/conversion/evaluators/prim.cpp | 37 ++++---------- .../evaluators/test_aten_evaluators.cpp | 51 +++++++++++++++++++ 4 files changed, 72 insertions(+), 34 deletions(-) diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index 6b0ee2bd6a..fa4e88fa5e 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -11,8 +11,8 @@ namespace impl { namespace { auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( - {"aten::abs (Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + {"aten::abs(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT || in->getType() == nvinfer1::DataType::kHALF || in->getType() == nvinfer1::DataType::kINT8; if (unary_supported_input) { @@ -23,6 +23,9 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; } else { + LOG_GRAPH( + "Tensor is of unsupported type " + << in->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)"); // For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x) at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType())); auto neg_one_const = tensor_to_const(ctx, neg_one); @@ -50,7 +53,7 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \ {"aten::" #unary "(Tensor self) -> Tensor", \ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { \ - auto in = args[0].ITensor(); \ + auto in = args[0].ITensorOrFreeze(ctx); \ auto unary = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \ \ TORCHTRT_CHECK(unary, "Unable to create " #unary " layer from node: " << *n); \ diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 8136cd66c9..b24222be26 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -516,7 +516,7 @@ auto aten_registrations TORCHTRT_UNUSED = auto self = args.at(n->input(0)).IValue(); auto obj = args.at(n->input(1)).IValue(); - return self->isSameIdentity(*obj); + return self->is(*obj); }, EvalOptions().validSchemas({ "aten::__is__(t1 self, t2 obj) -> bool", @@ -527,7 +527,7 @@ auto aten_registrations TORCHTRT_UNUSED = auto self = args.at(n->input(0)).IValue(); auto obj = args.at(n->input(1)).IValue(); - return !self->isSameIdentity(*obj); + return !self->is(*obj); }, EvalOptions().validSchemas({ "aten::__isnot__(t1 self, t2 obj) -> bool", @@ -814,10 +814,11 @@ auto aten_registrations TORCHTRT_UNUSED = auto start = args.at(n->input(1)).unwrapToInt(); auto step = args.at(n->input(2)).unwrapToInt(); return start + idx * step; - }}); + }, + EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})}); } // namespace } // namespace evaluators } // namespace conversion } // namespace core -} // namespace torch_tensorrt +} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 59984edacd..2245ca05dc 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -270,36 +270,19 @@ auto prim_registrations = .evaluator( {torch::jit::prim::TupleConstruct, [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto num_inputs = n->inputs().size(); c10::IValue tuple = c10::ivalue::Tuple::create(); - switch (num_inputs) { - case 0: - tuple = c10::ivalue::Tuple::create(); - break; - case 1: - tuple = c10::ivalue::Tuple::create(std::move((*args.at(n->input(0)).IValue()))); - break; - case 2: { - tuple = c10::ivalue::Tuple::create( - std::move(*(args.at(n->input(0)).IValue())), std::move(*(args.at(n->input(1)).IValue()))); - break; - } - case 3: { - tuple = c10::ivalue::Tuple::create( - std::move(*(args.at(n->input(0)).IValue())), - std::move(*(args.at(n->input(1)).IValue())), - std::move(*(args.at(n->input(2)).IValue()))); - break; - } - default: { - std::vector elems; - for (size_t i = 0; i < num_inputs; i++) { - elems.push_back(*(args.at(n->input(i)).IValue())); - } - tuple = c10::ivalue::Tuple::create(std::move(elems)); - break; + std::vector elems; + for (auto in : n->inputs()) { + if (args.at(in).isITensor()) { + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(args.at(in).ITensor()); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + elems.push_back(std::move(ival)); + } else { + elems.push_back(*(args.at(in).IValue())); } } + tuple = c10::ivalue::Tuple::create(std::move(elems)); return c10::optional(std::move(tuple)); }}) .evaluator( diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index aef7fa5729..78c60db859 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -815,3 +815,54 @@ TEST(Evaluators, DeriveIndexEvaluatesCorrectly) { ASSERT_TRUE(jit_results[0] == trt_results[0]); } + +TEST(Evaluators, IsTrueEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=1]() + %4 : bool = aten::__is__(%1, %2) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, IsFalseEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=9]() + %2 : None = prim::Constant() + %4 : bool = aten::__is__(%1, %2) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, IsNotTrueEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=1]() + %2 : None = prim::Constant() + %4 : bool = aten::__isnot__(%1, %2) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} From 460fc9b686eeafbded7fe49a1319f4b62453feaf Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 12 Aug 2022 09:23:24 -0700 Subject: [PATCH 3/3] chore: linting Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/converters/impl/element_wise.cpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 core/conversion/converters/impl/element_wise.cpp diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp old mode 100755 new mode 100644