From 1e343324e7a8be56afc2d31fcacbc1895cfc80ad Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 20 Dec 2022 20:01:54 -0800 Subject: [PATCH 1/2] feat: Add option to specify int64 as an Input dtype - Rework `Input` paradigm to be based on `at::ScalarType` as opposed to the previous `nvinfer1::DataType`, allowing a larger representation space of data types - When paired with `truncate_long_and_double`, insert casts to ensure Torch engines using Int64 tensors receive the correct types, and TensorRT engines operating on those tensors receive downcasted Int32 versions thereof - Add Torch block at the beginning of model graph to prepare types of input tensors for forthcoming engines in sequence - Automatically follow internal tensor types to abstract away the different internal engines used (Torch/TensorRT) from the user - Provide a framework for streamlined addition of other data types, including `torch.double` as valid input types - Improve error checking to ensure model compilation and behavior is as documented. For example, disallow specification of Long type input if the engine is required to be converted entirely to TRT - Known Limitations: - Specifying `dtype=torch.long` on an `Input` in an `input_signature` is not supported currently and will throw an error before model compilation when used with the Python API - While Torch may output Int64 tensors from the overall model, Torch-TRT currently can only output Int32 tensors for models using TRT, as there is not a mechanism in place for differentiating intermediate blocks from final/beginning blocks in the graph - Torch-TRT will almost definitely alter the data type of the input tensor, in-place, if `dtype=torch.long` is specified, and the returned result will be of type `torch.int32` --- core/compiler.cpp | 84 ++++++++++++------- core/conversion/conversion.cpp | 2 +- core/ir/Input.cpp | 12 +-- core/ir/ir.h | 7 +- core/lowering/lowering.cpp | 57 +++++++++++++ core/lowering/lowering.h | 4 + .../segmentedblock/SegmentedBlock.cpp | 4 +- core/partitioning/shape_analysis.cpp | 4 +- core/util/trt_util.cpp | 1 + cpp/include/torch_tensorrt/torch_tensorrt.h | 2 + cpp/src/types.cpp | 21 ++++- py/torch_tensorrt/_Input.py | 9 +- py/torch_tensorrt/csrc/tensorrt_classes.cpp | 29 ++++++- py/torch_tensorrt/csrc/tensorrt_classes.h | 3 +- py/torch_tensorrt/csrc/torch_tensorrt_py.cpp | 2 + py/torch_tensorrt/ts/_compile_spec.py | 8 ++ tests/py/api/test_collections.py | 30 +++++++ tests/util/run_graph_engine.cpp | 4 +- 18 files changed, 230 insertions(+), 53 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 0ef53f3105..92809affc8 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph( return partitioning::stitch(&partitioning_ctx, block); } -void MapInputsAndDetermineDTypes( +ir::TypeMap MapInputsAndDetermineDTypes( CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, @@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes( cfg.partitioning_info.collection_input_spec_map = ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map); + ir::TypeMap inferred_dtypes; auto collection_inputs = ir::get_collection_inputs(g, static_params); LOG_DEBUG( "In MapInputsAndDetermineDTypes, the g->inputs() size is " @@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes( LOG_INFO( "Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input " << in->debugName() << " has type " << est_type_opt[i].value()); - spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value()); + spec[i].dtype = est_type_opt[i].value(); } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) { // If we cannot calculate the type and the user did not define the type, then default to FP32 LOG_WARNING( "Cannot infer input type from calcuations in graph for input " << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); - spec[i].dtype = nvinfer1::DataType::kFLOAT; + spec[i].dtype = at::kFloat; } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) { if (!est_type_opt[i]) { LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting"); @@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes( auto warn_str = ss.str(); LOG_WARNING(warn_str); // Overwrite type map with user settings - first_use_type_map[in][i] = { - util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; - - } else { - if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) != - est_type_opt[i].value()) { - std::stringstream ss; - ss << "For input " << in->debugName() << ", found user specified input dtype as "; - ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; - ss << ", however when inspecting the graph, the input type expected was inferred to be "; - ss << est_type_opt[i].value() << std::endl; - ss << "The compiler is going to use the user setting " - << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; - ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; - ss << "compatibility with PyTorch's data type convention is required.\n"; - ss << "If you do indeed see errors at runtime either:\n"; - ss << "- Remove the dtype spec for " << in->debugName() << std::endl; - ss << "- Disable partial compilation by setting require_full_compilation to True"; - auto warn_str = ss.str(); - LOG_WARNING(warn_str); - // Overwrite type map with user settings - first_use_type_map[in][i] = { - util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; - } + first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype}; + + } else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) { + std::stringstream ss; + ss << "For input " << in->debugName() << ", found user specified input dtype as "; + ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << ", however when inspecting the graph, the input type expected was inferred to be "; + ss << est_type_opt[i].value() << std::endl; + ss << "The compiler is going to use the user setting " + << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; + ss << "compatibility with PyTorch's data type convention is required.\n"; + ss << "If you do indeed see errors at runtime either:\n"; + ss << "- Remove the dtype spec for " << in->debugName() << std::endl; + ss << "- Disable partial compilation by setting require_full_compilation to True"; + auto warn_str = ss.str(); + LOG_WARNING(warn_str); + // Overwrite type map with user settings + first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype}; } } else { // The user defined the type so no changes are necessary } + + // Insert entry for Value pointer and determined ScalarType + inferred_dtypes.insert({in, {spec[i].dtype}}); } } - // } + return inferred_dtypes; } std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) { @@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); + // Ensure none of the specified types are of acceptable input types incompatible with TRT + // Currently, only at::kLong is an acceptable, though TRT-incompatible type + for (auto value_to_dtypes : first_use_types) { + for (auto dtype : value_to_dtypes.second) { + TORCHTRT_CHECK( + !dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT"); + } + } + auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); return engine; @@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) // Infer the type of an input from the weights of the calculation auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); - MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); + // Extract map of IValue to DType + auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); + + // Check whether any of the input types are Long + bool user_requested_long = false; + for (auto dtype : type_map) { + user_requested_long |= dtype.second && (dtype.second.value() == at::kLong); + } + + // Use dtype map to autocast Tensor-type inputs to Long dtype as necessary + if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) { + auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString()); + user_requested_long &= (casts_inserted > 0); + } + auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); auto outputIsCollection = conversion::OutputIsCollection(g->block()); - if (cfg.partitioning_info.enabled && + if (cfg.partitioning_info.enabled && !user_requested_long && (cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) && !outputIsCollection) { @@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) if (cfg.partitioning_info.enabled && (!(cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || - outputIsCollection)) { + outputIsCollection || user_requested_long)) { auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types); new_g = graph_and_mapping.first; // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 5f4b20e1b3..940e178850 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -183,7 +183,7 @@ void AddInputs(ConversionCtx* ctx, c10::ArrayRef input "Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)"); - auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape); + auto trt_in = ctx->net->addInput(name.c_str(), util::ScalarTypeToTRTDataType(spec.dtype), spec.input_shape); TORCHTRT_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)"); trt_in->setAllowedFormats(1U << static_cast(spec.format)); diff --git a/core/ir/Input.cpp b/core/ir/Input.cpp index 852453574e..8c0ccbe90a 100644 --- a/core/ir/Input.cpp +++ b/core/ir/Input.cpp @@ -71,7 +71,7 @@ bool valid_input_dtype(nvinfer1::DataType dtype) { Input::Input( std::vector shape, - nvinfer1::DataType dtype, + at::ScalarType dtype, nvinfer1::TensorFormat format, bool dtype_is_user_defined) { if (shape.size() > 5) { @@ -84,10 +84,10 @@ Input::Input( input_shape = util::toDims(shape); input_is_dynamic = false; - TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype); + TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype); this->dtype = dtype; TORCHTRT_CHECK( - valid_dtype_format_combo(dtype, format), + valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported"); @@ -99,7 +99,7 @@ Input::Input( std::vector min_shape, std::vector opt_shape, std::vector max_shape, - nvinfer1::DataType dtype, + at::ScalarType dtype, nvinfer1::TensorFormat format, bool dtype_is_user_defined) { if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) { @@ -137,10 +137,10 @@ Input::Input( input_shape = util::toDims(dyn_shape); - TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype); + TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype); this->dtype = dtype; TORCHTRT_CHECK( - valid_dtype_format_combo(dtype, format), + valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported"); diff --git a/core/ir/ir.h b/core/ir/ir.h index 141dd24aa0..cb5a157a87 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -29,16 +29,17 @@ struct Input : torch::CustomClassHolder { Input(){}; Input( std::vector shape, - nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT, + at::ScalarType dtype = at::kFloat, nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR, bool dtype_is_user_defined = false); Input( std::vector min_shape, std::vector opt_shape, std::vector max_shape, - nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT, + at::ScalarType dtype = at::kFloat, nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR, bool dtype_is_used_defined = false); + friend std::ostream& operator<<(std::ostream& os, const Input& input); bool input_is_dynamic = false; @@ -47,7 +48,7 @@ struct Input : torch::CustomClassHolder { nvinfer1::Dims min; nvinfer1::Dims max; nvinfer1::Dims opt; - nvinfer1::DataType dtype; + at::ScalarType dtype; nvinfer1::TensorFormat format; int id; }; diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 4d665b390a..e88b1c7f57 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -26,6 +26,63 @@ void LowerBlock(torch::jit::Block* b) { DropUnusedNodes(b); } +int AutocastLongInputs( + std::shared_ptr& g, + ir::TypeMap input_type_map, + std::string target_device_name) { + int num_autocasts = 0; + // For each graph input, determine if it can be autocasted + for (int i = 0; i < g->inputs().size(); i++) { + auto input = g->inputs()[i]; + + // Autocasted inputs must be Tensor-type + if (input->type()->isSubtypeOf(c10::TensorType::get())) { + auto dtype_input = input_type_map.find(input); + + // Ensure the data type to be casted to exists in the type map + if (dtype_input == input_type_map.end() || !dtype_input->second) { + LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast"); + continue; + } + + auto dtype = dtype_input->second.value(); + // Currently, we do not autocast inputs for which the determined type is not long + if (dtype != at::kLong) { + continue; + } + + LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype); + + // Generate cast node sending input tensors to the inferred or specified datatype (long) + auto const_type = g->insertConstant(dtype); + auto const_false = g->insertConstant(0); + const_false->setType(torch::jit::BoolType::get()); + auto cuda = g->insertConstant(target_device_name); + cuda->setType(torch::jit::DeviceObjType::get()); + auto none_val = g->insertNode(g->createNone())->output(); + auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val}); + + // Replace all uses of the original tensor with that of the casted tensor + g->prependNode(cast_node); + input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); + + // Mark the cast node to run in PyTorch for ease of casting + LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch"); + cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false); + num_autocasts++; + } + } + + LOG_WARNING( + "Input tensors to this Torch-TRT engine may have their data types in-place modified " + << "if the type does not match the determined required type for TRT. To disable this " + << "automatic casting, specify an Input dtype other than Long"); + + LOG_GRAPH("Graph after Autocast: " << *g); + + return num_autocasts; +} + void LowerGraph(std::shared_ptr& g, std::vector& params, LowerInfo lower_info) { torch::jit::EliminateRedundantGuards(g); torch::jit::RemoveListMutation(g); diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index ed448b1bbc..d89c1651a3 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -27,6 +27,10 @@ struct LowerInfo { void LowerBlock(torch::jit::Block* b); void LowerGraph(std::shared_ptr& g, LowerInfo lower_info); +int AutocastLongInputs( + std::shared_ptr& g, + ir::TypeMap input_type_map, + std::string target_device_name); torch::jit::Module LowerModule( const torch::jit::Module& mod, std::string method_name, diff --git a/core/partitioning/segmentedblock/SegmentedBlock.cpp b/core/partitioning/segmentedblock/SegmentedBlock.cpp index 583e67ca4d..249a293bc3 100644 --- a/core/partitioning/segmentedblock/SegmentedBlock.cpp +++ b/core/partitioning/segmentedblock/SegmentedBlock.cpp @@ -62,13 +62,13 @@ std::vector SegmentedBlock::construct_inputs_spec() const { if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) { for (uint64_t i = 0; i < opt_shapes_.size(); i++) { auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]); - in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]); + in.dtype = in_types_[i]; inputs.push_back(in); } } else { for (uint64_t i = 0; i < opt_shapes_.size(); i++) { auto in = ir::Input(opt_shapes_[i]); - in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]); + in.dtype = in_types_[i]; inputs.push_back(in); } } diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 4220764dd6..80c609f7b7 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -266,10 +266,10 @@ void getSegmentsOutputByRunning( "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); } else if (partitioning_info.truncate_long_and_double && t == at::kLong) { cur_ivalue = cur_ivalue.toTensor().to(at::kInt); - LOG_WARNING("Truncating graph input type from at::kLong to at::kInt"); + LOG_WARNING("Truncating intermediate graph input type from at::kLong to at::kInt"); } else if (partitioning_info.truncate_long_and_double && t == at::kDouble) { cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); - LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); + LOG_WARNING("Truncating intermediate graph input type from at::kDouble to at::kFloat"); } c10::optional dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); if (dtype == c10::nullopt) { diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index d320992a70..b97eb91184 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -251,6 +251,7 @@ const std::unordered_map& get_at_trt_type_ma {at::kFloat, nvinfer1::DataType::kFLOAT}, {at::kHalf, nvinfer1::DataType::kHALF}, {at::kInt, nvinfer1::DataType::kINT32}, + {at::kLong, nvinfer1::DataType::kINT32}, {at::kChar, nvinfer1::DataType::kINT8}, {at::kBool, nvinfer1::DataType::kBOOL}}; return at_trt_type_map; diff --git a/cpp/include/torch_tensorrt/torch_tensorrt.h b/cpp/include/torch_tensorrt/torch_tensorrt.h index ddc29f8a07..fb0b945012 100644 --- a/cpp/include/torch_tensorrt/torch_tensorrt.h +++ b/cpp/include/torch_tensorrt/torch_tensorrt.h @@ -58,6 +58,8 @@ class DataType { * ex. torch_tensorrt::DataType type = DataType::kFloat; */ enum Value : int8_t { + /// INT64 + kLong, /// FP32 kFloat, /// FP16 diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp index 45ae34c3da..7a5e203836 100644 --- a/cpp/src/types.cpp +++ b/cpp/src/types.cpp @@ -87,6 +87,25 @@ nvinfer1::DataType toTRTDataType(DataType value) { } } +at::ScalarType toAtDataType(DataType value) { + switch (value) { + case DataType::kChar: + return at::kChar; + case DataType::kHalf: + return at::kHalf; + case DataType::kInt: + return at::kInt; + case DataType::kLong: + return at::kLong; + case DataType::kBool: + return at::kBool; + case DataType::kFloat: + case DataType::kUnknown: + default: + return at::kFloat; + } +} + nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) { TORCHTRT_CHECK(!(value == TensorFormat::kUnknown), "Tensor format is unknown"); switch (value) { @@ -267,7 +286,7 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) { i.min_shape, i.opt_shape, i.max_shape, - toTRTDataType(i.dtype), + toAtDataType(i.dtype), toTRTTensorFormat(i.format), !(i.dtype == DataType::kUnknown)); } diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 4062edb6e6..e9b5e392df 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -219,7 +219,9 @@ def _supported_input_size_type(input_size: Any) -> bool: @staticmethod def _parse_dtype(dtype: Any) -> _enums.dtype: if isinstance(dtype, torch.dtype): - if dtype == torch.int32: + if dtype == torch.long: + return _enums.dtype.long + elif dtype == torch.int32: return _enums.dtype.int32 elif dtype == torch.half: return _enums.dtype.half @@ -229,7 +231,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: return _enums.dtype.bool else: raise TypeError( - "Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " + "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " + str(dtype) ) @@ -242,6 +244,9 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: + str(type(dtype)) ) + def is_TRT_dtype(self) -> bool: + return self.dtype != _enums.dtype.long + @staticmethod def _parse_format(format: Any) -> _enums.TensorFormat: if isinstance(format, torch.memory_format): diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 489da576e2..ec14aba316 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -16,6 +16,8 @@ std::string to_str(DataType value) { return "Bool"; case DataType::kFloat: return "Float"; + case DataType::kLong: + return "Long"; default: return "Unknown data type"; } @@ -29,6 +31,8 @@ nvinfer1::DataType toTRTDataType(DataType value) { return nvinfer1::DataType::kHALF; case DataType::kInt32: return nvinfer1::DataType::kINT32; + case DataType::kLong: + return nvinfer1::DataType::kINT32; case DataType::kBool: return nvinfer1::DataType::kBOOL; case DataType::kFloat: @@ -40,6 +44,27 @@ nvinfer1::DataType toTRTDataType(DataType value) { } } +at::ScalarType toAtDataType(DataType value) { + switch (value) { + case DataType::kChar: + return at::kChar; + case DataType::kHalf: + return at::kHalf; + case DataType::kInt32: + return at::kInt; + case DataType::kLong: + return at::kLong; + case DataType::kBool: + return at::kBool; + case DataType::kFloat: + return at::kFloat; + case DataType::kUnknown: + return at::kFloat; + default: + TORCHTRT_THROW_ERROR("Unknown data type: " << to_str(value)); + } +} + Device::Device(const core::runtime::RTDevice& internal_dev) { device_type = DeviceType::kGPU; gpu_id = internal_dev.id; @@ -70,9 +95,9 @@ std::string to_str(TensorFormat value) { core::ir::Input Input::toInternalInput() { if (!input_is_dynamic) { - return core::ir::Input(opt, toTRTDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input(opt, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } else { - return core::ir::Input(min, opt, max, toTRTDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input(min, opt, max, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } } diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 6762d078a1..4507bb3285 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -27,9 +27,10 @@ namespace pyapi { return static_cast(field_name); \ } -enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; +enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; std::string to_str(DataType value); nvinfer1::DataType toTRTDataType(DataType value); +at::ScalarType toAtDataType(DataType value); enum class TensorFormat : int8_t { kContiguous, kChannelsLast }; std::string to_str(TensorFormat value); diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 868dbb21fa..7341ba9281 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -242,6 +242,8 @@ PYBIND11_MODULE(_C, m) { .value("float16", DataType::kHalf, "16 bit floating point number") .value("int8", DataType::kChar, "8 bit integer number") .value("int32", DataType::kInt32, "32 bit integer number") + .value("long", DataType::kLong, "64 bit integer number") + .value("int64", DataType::kLong, "64 bit integer number") .value("bool", DataType::kBool, "Boolean value") .value("unknown", DataType::kUnknown, "Unknown data type") .export_values(); diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 5ffe0471f4..9be1bcab0e 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -214,6 +214,14 @@ def _parse_input_signature(input_signature: Any): if isinstance(input_signature, torch.Tensor) else input_signature ) + + if not i.is_TRT_dtype(): + raise TypeError( + "Using non-TRT input types with input_signature is not currently " + + "supported. Please specify inputs individually to use " + + "non-TRT types." + ) + clone = _internal_input_to_torch_class_input(i._to_internal()) return clone else: diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py index 936a4d5c73..33b4608194 100644 --- a/tests/py/api/test_collections.py +++ b/tests/py/api/test_collections.py @@ -50,6 +50,36 @@ def test_compile(self): ) +class TestStandardTensorInputLong(unittest.TestCase): + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "inputs": [ + torchtrt.Input(self.input.shape, dtype=torch.long), + torchtrt.Input(self.input.shape, dtype=torch.long), + ], + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "truncate_long_and_double": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + cos_sim = cosine_similarity( + self.model(self.input, self.input), trt_mod(self.input, self.input) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"standard_tensor_input_long_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + class TestTupleInput(unittest.TestCase): def test_compile(self): diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index ce8435a34c..8358a3c570 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -20,7 +20,7 @@ namespace util { std::vector toInputs(std::vector ten) { std::vector a; for (auto i : ten) { - a.push_back(core::ir::Input(core::util::toVec(i.sizes()), core::util::ScalarTypeToTRTDataType(i.scalar_type()))); + a.push_back(core::ir::Input(core::util::toVec(i.sizes()), i.scalar_type())); } return a; } @@ -30,7 +30,7 @@ std::vector toInputsDynamic(std::vector ten, bool d for (auto i : ten) { auto opt = core::util::toVec(i.sizes()); - auto dtype = core::util::ScalarTypeToTRTDataType(i.scalar_type()); + auto dtype = i.scalar_type(); if (dynamic_batch) { std::vector min_range(opt); From 4282c06f717fb34afb2b20eb46246151bbcd66c0 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 21 Dec 2022 18:14:18 -0800 Subject: [PATCH 2/2] fix: Improve autocast graph, provide CPP support - Address review comments - Add cpp API testing and support - Improve length and efficiency of autocast graph - Improve messages displayed to user --- core/lowering/lowering.cpp | 32 ++++++++++----- cpp/src/types.cpp | 9 +++-- py/torch_tensorrt/_Input.py | 2 +- py/torch_tensorrt/csrc/tensorrt_classes.cpp | 6 +-- py/torch_tensorrt/csrc/tensorrt_classes.h | 2 +- py/torch_tensorrt/ts/_compile_spec.py | 2 +- tests/cpp/test_collections.cpp | 44 +++++++++++++++++++++ 7 files changed, 78 insertions(+), 19 deletions(-) diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index e88b1c7f57..cf57e7c83c 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -48,18 +48,26 @@ int AutocastLongInputs( auto dtype = dtype_input->second.value(); // Currently, we do not autocast inputs for which the determined type is not long if (dtype != at::kLong) { + LOG_DEBUG( + "Skipping autocast for tensor " << input->debugName() << ", since its dtype is " << dtype + << " and not at::kLong"); continue; } LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype); // Generate cast node sending input tensors to the inferred or specified datatype (long) + torch::jit::Value *const_false, *cuda, *none_val; + if (num_autocasts == 0) { + // Only generate constants once and reuse for all autocasts + const_false = g->insertConstant(0); + const_false->setType(torch::jit::BoolType::get()); + cuda = g->insertConstant(target_device_name); + cuda->setType(torch::jit::DeviceObjType::get()); + none_val = g->insertNode(g->createNone())->output(); + } + auto const_type = g->insertConstant(dtype); - auto const_false = g->insertConstant(0); - const_false->setType(torch::jit::BoolType::get()); - auto cuda = g->insertConstant(target_device_name); - cuda->setType(torch::jit::DeviceObjType::get()); - auto none_val = g->insertNode(g->createNone())->output(); auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val}); // Replace all uses of the original tensor with that of the casted tensor @@ -73,12 +81,16 @@ int AutocastLongInputs( } } - LOG_WARNING( - "Input tensors to this Torch-TRT engine may have their data types in-place modified " - << "if the type does not match the determined required type for TRT. To disable this " - << "automatic casting, specify an Input dtype other than Long"); + LOG_GRAPH("Inserted " << num_autocasts << " autocasts"); - LOG_GRAPH("Graph after Autocast: " << *g); + if (num_autocasts > 0) { + LOG_WARNING( + "Data types for input tensors have been modified by inserting " + << "aten::to operations which cast INT64 inputs to INT32. " + << "To disable this, please recompile using INT32 inputs"); + + LOG_GRAPH("Graph after Autocast: " << *g); + } return num_autocasts; } diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp index 7a5e203836..2d3c271694 100644 --- a/cpp/src/types.cpp +++ b/cpp/src/types.cpp @@ -87,7 +87,7 @@ nvinfer1::DataType toTRTDataType(DataType value) { } } -at::ScalarType toAtDataType(DataType value) { +at::ScalarType toAtenDataType(DataType value) { switch (value) { case DataType::kChar: return at::kChar; @@ -119,7 +119,7 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) { DataType::DataType(c10::ScalarType t) { TORCHTRT_CHECK( - t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool, + t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool, "Data type is unsupported (" << t << ")"); switch (t) { case at::kHalf: @@ -131,6 +131,9 @@ DataType::DataType(c10::ScalarType t) { case at::kInt: value = DataType::kInt; break; + case at::kLong: + value = DataType::kLong; + break; case at::kBool: value = DataType::kBool; break; @@ -286,7 +289,7 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) { i.min_shape, i.opt_shape, i.max_shape, - toAtDataType(i.dtype), + toAtenDataType(i.dtype), toTRTTensorFormat(i.format), !(i.dtype == DataType::kUnknown)); } diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index e9b5e392df..8780d4db91 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -244,7 +244,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: + str(type(dtype)) ) - def is_TRT_dtype(self) -> bool: + def is_trt_dtype(self) -> bool: return self.dtype != _enums.dtype.long @staticmethod diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index ec14aba316..6fca17fdd1 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -44,7 +44,7 @@ nvinfer1::DataType toTRTDataType(DataType value) { } } -at::ScalarType toAtDataType(DataType value) { +at::ScalarType toAtenDataType(DataType value) { switch (value) { case DataType::kChar: return at::kChar; @@ -95,9 +95,9 @@ std::string to_str(TensorFormat value) { core::ir::Input Input::toInternalInput() { if (!input_is_dynamic) { - return core::ir::Input(opt, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input(opt, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } else { - return core::ir::Input(min, opt, max, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input(min, opt, max, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } } diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 4507bb3285..3470944c72 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -30,7 +30,7 @@ namespace pyapi { enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; std::string to_str(DataType value); nvinfer1::DataType toTRTDataType(DataType value); -at::ScalarType toAtDataType(DataType value); +at::ScalarType toAtenDataType(DataType value); enum class TensorFormat : int8_t { kContiguous, kChannelsLast }; std::string to_str(TensorFormat value); diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 9be1bcab0e..d76d259e29 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -215,7 +215,7 @@ def _parse_input_signature(input_signature: Any): else input_signature ) - if not i.is_TRT_dtype(): + if not i.is_trt_dtype(): raise TypeError( "Using non-TRT input types with input_signature is not currently " + "supported. Please specify inputs individually to use " diff --git a/tests/cpp/test_collections.cpp b/tests/cpp/test_collections.cpp index 7fcc006980..982562923d 100644 --- a/tests/cpp/test_collections.cpp +++ b/tests/cpp/test_collections.cpp @@ -45,6 +45,50 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) { ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor())); } +TEST(CppAPITests, TestCollectionStandardTensorInputLongDtype) { + std::string path = "tests/modules/standard_tensor_input_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kLong); + std::vector inputs; + inputs.push_back(in0); + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + auto out = mod.forward(inputs_); + + std::vector input_range; + + // Specify Long input tensor type + input_range.push_back({in0.sizes(), torch::kLong}); + input_range.push_back({in0.sizes(), torch::kLong}); + torch_tensorrt::ts::CompileSpec compile_settings(input_range); + compile_settings.min_block_size = 1; + + // // FP32 execution with long and double truncation + compile_settings.enabled_precisions = {torch::kFloat}; + compile_settings.truncate_long_and_double = true; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(inputs_); + + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + out.toTensor().to(torch::kFloat), trt_out.toTensor().to(torch::kFloat))); +} + TEST(CppAPITests, TestCollectionTupleInput) { std::string path = "tests/modules/tuple_input_scripted.jit.pt"; torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);