diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 1e81c0a4b5..d66c4514e9 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -156,6 +156,38 @@ nvinfer1::ILayer* add_elementwise( return ele; } +nvinfer1::ITensor* add_abs( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* self, + const std::string& name) { + nvinfer1::ILayer* absolute_value_layer; + + // Check if TRT Unary ops support the input type + bool unary_supported_input = (self->getType() == nvinfer1::DataType::kFLOAT) || + (self->getType() == nvinfer1::DataType::kHALF) || (self->getType() == nvinfer1::DataType::kINT8); + if (unary_supported_input) { + absolute_value_layer = ctx->net->addUnary(*self, nvinfer1::UnaryOperation::kABS); + TORCHTRT_CHECK(absolute_value_layer, "Unable to create abs layer from node: " << *n); + absolute_value_layer->setName(name.c_str()); + } else { + LOG_GRAPH( + "Tensor is of unsupported type " + << self->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)"); + // For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x) + at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(self->getType())); + auto neg_one_const = tensor_to_const(ctx, neg_one); + auto neg_layer = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kPROD, self, neg_one_const, util::node_info(n) + std::string("_Negation")); + TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n); + absolute_value_layer = + add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, self, neg_layer->getOutput(0), name); + TORCHTRT_CHECK(absolute_value_layer, "Unable to create max layer from node: " << *n); + } + + return absolute_value_layer->getOutput(0); +} + nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& tensor_name) { auto id_layer = ctx->net->addIdentity(*tensor); auto id_out_tensor = id_layer->getOutput(0); diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index 2f5d4b25a9..6c4e9d53f0 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -35,6 +35,8 @@ nvinfer1::ITensor* addUnpadding( bool trailing = true, bool use_zeros = true); +// TODO: Change add_elementwise schema to output nvinfer1::ITensor* instead of nvinfer1::ILayer*, +// for consistency with other utils. Need to change schema and usage in all calling contexts nvinfer1::ILayer* add_elementwise( ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, @@ -42,6 +44,12 @@ nvinfer1::ILayer* add_elementwise( nvinfer1::ITensor* other, const std::string& name); +nvinfer1::ITensor* add_abs( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* self, + const std::string& name); + // Apply an identity operation on a tensor. Used in the case where an input is an output to a network. nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& name); diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 0b347c6647..4e1fab4929 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -326,15 +326,27 @@ auto element_wise_registrations TORCHTRT_UNUSED = } else if (rounding_mode == "trunc") { // trunc = floor(abs(div)) * sign(div) auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div"); - auto abs = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kABS); - auto floor = ctx->net->addUnary(*abs->getOutput(0), nvinfer1::UnaryOperation::kFLOOR); + auto abs = add_abs(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val"); + + // In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this + // specific function. Floor applied to non-float types equates to identity + nvinfer1::ITensor* floor; + + if ((abs->getType() == nvinfer1::DataType::kINT32) || (abs->getType() == nvinfer1::DataType::kBOOL)) { + LOG_DEBUG( + "Tensor is of unsupported type " << abs->getType() + << " for IUnaryLayer::kFLOOR. Using identity instead."); + floor = abs; + } else { + auto floor_layer = ctx->net->addUnary(*abs, nvinfer1::UnaryOperation::kFLOOR); + TORCHTRT_CHECK(floor_layer, "Unable to create floor layer from node: " << *n); + floor_layer->setName((util::node_info(n) + "_floor").c_str()); + floor = floor_layer->getOutput(0); + } + auto sign = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kSIGN); div = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kPROD, - floor->getOutput(0), - sign->getOutput(0), - util::node_info(n)); + ctx, nvinfer1::ElementWiseOperation::kPROD, floor, sign->getOutput(0), util::node_info(n)); } else { div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); } diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index acac34cd7f..90fb8ef624 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -13,40 +13,10 @@ namespace { auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( {"aten::abs(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensorOrFreeze(ctx); - bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT || - in->getType() == nvinfer1::DataType::kHALF || in->getType() == nvinfer1::DataType::kINT8; - if (unary_supported_input) { - auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kABS); - TORCHTRT_CHECK(unary_layer, "Unable to create abs layer from node: " << *n); - unary_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - return true; - } else { - LOG_GRAPH( - "Tensor is of unsupported type " - << in->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)"); - // For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x) - at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType())); - auto neg_one_const = tensor_to_const(ctx, neg_one); - auto neg_layer = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kPROD, - in, - neg_one_const, - util::node_info(n) + std::string("_Negation")); - TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n); - auto max_layer = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kMAX, - in, - neg_layer->getOutput(0), - util::node_info(n) + std::string("_Max")); - TORCHTRT_CHECK(max_layer, "Unable to create max layer from node: " << *n); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], max_layer->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - return true; - } + auto abs_tensor = add_abs(ctx, n, in, util::node_info(n)); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], abs_tensor); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; }}); auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( diff --git a/tests/core/conversion/converters/test_element_wise.cpp b/tests/core/conversion/converters/test_element_wise.cpp index c4b170ed42..82e959dfcd 100644 --- a/tests/core/conversion/converters/test_element_wise.cpp +++ b/tests/core/conversion/converters/test_element_wise.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" +#include "torch/torch.h" void pointwise_test_helper( std::string graph_ir, @@ -235,6 +236,29 @@ TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) { pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); } +TEST(Converters, ATenDivRoundingTruncWithIntsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %trunc : str = prim::Constant[value="trunc"]() + %out : Tensor = aten::div(%0, %1, %trunc) + return (%out))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Avoid divide-by-zero issues by making denominator >= 1 + auto in_0 = at::randint(-5, 5, {4, 1, 7, 8}, {at::kCUDA}).to(torch::kInt32); + auto in_1 = at::randint(1, 10, {4, 1, 7, 8}, {at::kCUDA}).to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]))); +} + TEST(Converters, ATenPowTensorConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor, %x2.1 : Tensor):