diff --git a/core/compiler.cpp b/core/compiler.cpp index 0ef53f3105..92809affc8 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph( return partitioning::stitch(&partitioning_ctx, block); } -void MapInputsAndDetermineDTypes( +ir::TypeMap MapInputsAndDetermineDTypes( CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, @@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes( cfg.partitioning_info.collection_input_spec_map = ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map); + ir::TypeMap inferred_dtypes; auto collection_inputs = ir::get_collection_inputs(g, static_params); LOG_DEBUG( "In MapInputsAndDetermineDTypes, the g->inputs() size is " @@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes( LOG_INFO( "Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input " << in->debugName() << " has type " << est_type_opt[i].value()); - spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value()); + spec[i].dtype = est_type_opt[i].value(); } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) { // If we cannot calculate the type and the user did not define the type, then default to FP32 LOG_WARNING( "Cannot infer input type from calcuations in graph for input " << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); - spec[i].dtype = nvinfer1::DataType::kFLOAT; + spec[i].dtype = at::kFloat; } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) { if (!est_type_opt[i]) { LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting"); @@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes( auto warn_str = ss.str(); LOG_WARNING(warn_str); // Overwrite type map with user settings - first_use_type_map[in][i] = { - util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; - - } else { - if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) != - est_type_opt[i].value()) { - std::stringstream ss; - ss << "For input " << in->debugName() << ", found user specified input dtype as "; - ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; - ss << ", however when inspecting the graph, the input type expected was inferred to be "; - ss << est_type_opt[i].value() << std::endl; - ss << "The compiler is going to use the user setting " - << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; - ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; - ss << "compatibility with PyTorch's data type convention is required.\n"; - ss << "If you do indeed see errors at runtime either:\n"; - ss << "- Remove the dtype spec for " << in->debugName() << std::endl; - ss << "- Disable partial compilation by setting require_full_compilation to True"; - auto warn_str = ss.str(); - LOG_WARNING(warn_str); - // Overwrite type map with user settings - first_use_type_map[in][i] = { - util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; - } + first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype}; + + } else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) { + std::stringstream ss; + ss << "For input " << in->debugName() << ", found user specified input dtype as "; + ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << ", however when inspecting the graph, the input type expected was inferred to be "; + ss << est_type_opt[i].value() << std::endl; + ss << "The compiler is going to use the user setting " + << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; + ss << "compatibility with PyTorch's data type convention is required.\n"; + ss << "If you do indeed see errors at runtime either:\n"; + ss << "- Remove the dtype spec for " << in->debugName() << std::endl; + ss << "- Disable partial compilation by setting require_full_compilation to True"; + auto warn_str = ss.str(); + LOG_WARNING(warn_str); + // Overwrite type map with user settings + first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype}; } } else { // The user defined the type so no changes are necessary } + + // Insert entry for Value pointer and determined ScalarType + inferred_dtypes.insert({in, {spec[i].dtype}}); } } - // } + return inferred_dtypes; } std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) { @@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); + // Ensure none of the specified types are of acceptable input types incompatible with TRT + // Currently, only at::kLong is an acceptable, though TRT-incompatible type + for (auto value_to_dtypes : first_use_types) { + for (auto dtype : value_to_dtypes.second) { + TORCHTRT_CHECK( + !dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT"); + } + } + auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); return engine; @@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) // Infer the type of an input from the weights of the calculation auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); - MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); + // Extract map of IValue to DType + auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); + + // Check whether any of the input types are Long + bool user_requested_long = false; + for (auto dtype : type_map) { + user_requested_long |= dtype.second && (dtype.second.value() == at::kLong); + } + + // Use dtype map to autocast Tensor-type inputs to Long dtype as necessary + if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) { + auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString()); + user_requested_long &= (casts_inserted > 0); + } + auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); auto outputIsCollection = conversion::OutputIsCollection(g->block()); - if (cfg.partitioning_info.enabled && + if (cfg.partitioning_info.enabled && !user_requested_long && (cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) && !outputIsCollection) { @@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) if (cfg.partitioning_info.enabled && (!(cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || - outputIsCollection)) { + outputIsCollection || user_requested_long)) { auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types); new_g = graph_and_mapping.first; // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 5f4b20e1b3..940e178850 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -183,7 +183,7 @@ void AddInputs(ConversionCtx* ctx, c10::ArrayRef input "Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)"); - auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape); + auto trt_in = ctx->net->addInput(name.c_str(), util::ScalarTypeToTRTDataType(spec.dtype), spec.input_shape); TORCHTRT_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)"); trt_in->setAllowedFormats(1U << static_cast(spec.format)); diff --git a/core/ir/Input.cpp b/core/ir/Input.cpp index 852453574e..8c0ccbe90a 100644 --- a/core/ir/Input.cpp +++ b/core/ir/Input.cpp @@ -71,7 +71,7 @@ bool valid_input_dtype(nvinfer1::DataType dtype) { Input::Input( std::vector shape, - nvinfer1::DataType dtype, + at::ScalarType dtype, nvinfer1::TensorFormat format, bool dtype_is_user_defined) { if (shape.size() > 5) { @@ -84,10 +84,10 @@ Input::Input( input_shape = util::toDims(shape); input_is_dynamic = false; - TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype); + TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype); this->dtype = dtype; TORCHTRT_CHECK( - valid_dtype_format_combo(dtype, format), + valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported"); @@ -99,7 +99,7 @@ Input::Input( std::vector min_shape, std::vector opt_shape, std::vector max_shape, - nvinfer1::DataType dtype, + at::ScalarType dtype, nvinfer1::TensorFormat format, bool dtype_is_user_defined) { if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) { @@ -137,10 +137,10 @@ Input::Input( input_shape = util::toDims(dyn_shape); - TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype); + TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype); this->dtype = dtype; TORCHTRT_CHECK( - valid_dtype_format_combo(dtype, format), + valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported"); diff --git a/core/ir/ir.h b/core/ir/ir.h index 141dd24aa0..cb5a157a87 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -29,16 +29,17 @@ struct Input : torch::CustomClassHolder { Input(){}; Input( std::vector shape, - nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT, + at::ScalarType dtype = at::kFloat, nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR, bool dtype_is_user_defined = false); Input( std::vector min_shape, std::vector opt_shape, std::vector max_shape, - nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT, + at::ScalarType dtype = at::kFloat, nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR, bool dtype_is_used_defined = false); + friend std::ostream& operator<<(std::ostream& os, const Input& input); bool input_is_dynamic = false; @@ -47,7 +48,7 @@ struct Input : torch::CustomClassHolder { nvinfer1::Dims min; nvinfer1::Dims max; nvinfer1::Dims opt; - nvinfer1::DataType dtype; + at::ScalarType dtype; nvinfer1::TensorFormat format; int id; }; diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 4d665b390a..cf57e7c83c 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -26,6 +26,75 @@ void LowerBlock(torch::jit::Block* b) { DropUnusedNodes(b); } +int AutocastLongInputs( + std::shared_ptr& g, + ir::TypeMap input_type_map, + std::string target_device_name) { + int num_autocasts = 0; + // For each graph input, determine if it can be autocasted + for (int i = 0; i < g->inputs().size(); i++) { + auto input = g->inputs()[i]; + + // Autocasted inputs must be Tensor-type + if (input->type()->isSubtypeOf(c10::TensorType::get())) { + auto dtype_input = input_type_map.find(input); + + // Ensure the data type to be casted to exists in the type map + if (dtype_input == input_type_map.end() || !dtype_input->second) { + LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast"); + continue; + } + + auto dtype = dtype_input->second.value(); + // Currently, we do not autocast inputs for which the determined type is not long + if (dtype != at::kLong) { + LOG_DEBUG( + "Skipping autocast for tensor " << input->debugName() << ", since its dtype is " << dtype + << " and not at::kLong"); + continue; + } + + LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype); + + // Generate cast node sending input tensors to the inferred or specified datatype (long) + torch::jit::Value *const_false, *cuda, *none_val; + if (num_autocasts == 0) { + // Only generate constants once and reuse for all autocasts + const_false = g->insertConstant(0); + const_false->setType(torch::jit::BoolType::get()); + cuda = g->insertConstant(target_device_name); + cuda->setType(torch::jit::DeviceObjType::get()); + none_val = g->insertNode(g->createNone())->output(); + } + + auto const_type = g->insertConstant(dtype); + auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val}); + + // Replace all uses of the original tensor with that of the casted tensor + g->prependNode(cast_node); + input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); + + // Mark the cast node to run in PyTorch for ease of casting + LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch"); + cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false); + num_autocasts++; + } + } + + LOG_GRAPH("Inserted " << num_autocasts << " autocasts"); + + if (num_autocasts > 0) { + LOG_WARNING( + "Data types for input tensors have been modified by inserting " + << "aten::to operations which cast INT64 inputs to INT32. " + << "To disable this, please recompile using INT32 inputs"); + + LOG_GRAPH("Graph after Autocast: " << *g); + } + + return num_autocasts; +} + void LowerGraph(std::shared_ptr& g, std::vector& params, LowerInfo lower_info) { torch::jit::EliminateRedundantGuards(g); torch::jit::RemoveListMutation(g); diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index ed448b1bbc..d89c1651a3 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -27,6 +27,10 @@ struct LowerInfo { void LowerBlock(torch::jit::Block* b); void LowerGraph(std::shared_ptr& g, LowerInfo lower_info); +int AutocastLongInputs( + std::shared_ptr& g, + ir::TypeMap input_type_map, + std::string target_device_name); torch::jit::Module LowerModule( const torch::jit::Module& mod, std::string method_name, diff --git a/core/partitioning/segmentedblock/SegmentedBlock.cpp b/core/partitioning/segmentedblock/SegmentedBlock.cpp index 583e67ca4d..249a293bc3 100644 --- a/core/partitioning/segmentedblock/SegmentedBlock.cpp +++ b/core/partitioning/segmentedblock/SegmentedBlock.cpp @@ -62,13 +62,13 @@ std::vector SegmentedBlock::construct_inputs_spec() const { if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) { for (uint64_t i = 0; i < opt_shapes_.size(); i++) { auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]); - in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]); + in.dtype = in_types_[i]; inputs.push_back(in); } } else { for (uint64_t i = 0; i < opt_shapes_.size(); i++) { auto in = ir::Input(opt_shapes_[i]); - in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]); + in.dtype = in_types_[i]; inputs.push_back(in); } } diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 4220764dd6..80c609f7b7 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -266,10 +266,10 @@ void getSegmentsOutputByRunning( "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); } else if (partitioning_info.truncate_long_and_double && t == at::kLong) { cur_ivalue = cur_ivalue.toTensor().to(at::kInt); - LOG_WARNING("Truncating graph input type from at::kLong to at::kInt"); + LOG_WARNING("Truncating intermediate graph input type from at::kLong to at::kInt"); } else if (partitioning_info.truncate_long_and_double && t == at::kDouble) { cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); - LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); + LOG_WARNING("Truncating intermediate graph input type from at::kDouble to at::kFloat"); } c10::optional dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); if (dtype == c10::nullopt) { diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index d320992a70..b97eb91184 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -251,6 +251,7 @@ const std::unordered_map& get_at_trt_type_ma {at::kFloat, nvinfer1::DataType::kFLOAT}, {at::kHalf, nvinfer1::DataType::kHALF}, {at::kInt, nvinfer1::DataType::kINT32}, + {at::kLong, nvinfer1::DataType::kINT32}, {at::kChar, nvinfer1::DataType::kINT8}, {at::kBool, nvinfer1::DataType::kBOOL}}; return at_trt_type_map; diff --git a/cpp/include/torch_tensorrt/torch_tensorrt.h b/cpp/include/torch_tensorrt/torch_tensorrt.h index ddc29f8a07..fb0b945012 100644 --- a/cpp/include/torch_tensorrt/torch_tensorrt.h +++ b/cpp/include/torch_tensorrt/torch_tensorrt.h @@ -58,6 +58,8 @@ class DataType { * ex. torch_tensorrt::DataType type = DataType::kFloat; */ enum Value : int8_t { + /// INT64 + kLong, /// FP32 kFloat, /// FP16 diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp index 45ae34c3da..2d3c271694 100644 --- a/cpp/src/types.cpp +++ b/cpp/src/types.cpp @@ -87,6 +87,25 @@ nvinfer1::DataType toTRTDataType(DataType value) { } } +at::ScalarType toAtenDataType(DataType value) { + switch (value) { + case DataType::kChar: + return at::kChar; + case DataType::kHalf: + return at::kHalf; + case DataType::kInt: + return at::kInt; + case DataType::kLong: + return at::kLong; + case DataType::kBool: + return at::kBool; + case DataType::kFloat: + case DataType::kUnknown: + default: + return at::kFloat; + } +} + nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) { TORCHTRT_CHECK(!(value == TensorFormat::kUnknown), "Tensor format is unknown"); switch (value) { @@ -100,7 +119,7 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) { DataType::DataType(c10::ScalarType t) { TORCHTRT_CHECK( - t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool, + t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool, "Data type is unsupported (" << t << ")"); switch (t) { case at::kHalf: @@ -112,6 +131,9 @@ DataType::DataType(c10::ScalarType t) { case at::kInt: value = DataType::kInt; break; + case at::kLong: + value = DataType::kLong; + break; case at::kBool: value = DataType::kBool; break; @@ -267,7 +289,7 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) { i.min_shape, i.opt_shape, i.max_shape, - toTRTDataType(i.dtype), + toAtenDataType(i.dtype), toTRTTensorFormat(i.format), !(i.dtype == DataType::kUnknown)); } diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 4062edb6e6..8780d4db91 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -219,7 +219,9 @@ def _supported_input_size_type(input_size: Any) -> bool: @staticmethod def _parse_dtype(dtype: Any) -> _enums.dtype: if isinstance(dtype, torch.dtype): - if dtype == torch.int32: + if dtype == torch.long: + return _enums.dtype.long + elif dtype == torch.int32: return _enums.dtype.int32 elif dtype == torch.half: return _enums.dtype.half @@ -229,7 +231,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: return _enums.dtype.bool else: raise TypeError( - "Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " + "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " + str(dtype) ) @@ -242,6 +244,9 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: + str(type(dtype)) ) + def is_trt_dtype(self) -> bool: + return self.dtype != _enums.dtype.long + @staticmethod def _parse_format(format: Any) -> _enums.TensorFormat: if isinstance(format, torch.memory_format): diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 489da576e2..6fca17fdd1 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -16,6 +16,8 @@ std::string to_str(DataType value) { return "Bool"; case DataType::kFloat: return "Float"; + case DataType::kLong: + return "Long"; default: return "Unknown data type"; } @@ -29,6 +31,8 @@ nvinfer1::DataType toTRTDataType(DataType value) { return nvinfer1::DataType::kHALF; case DataType::kInt32: return nvinfer1::DataType::kINT32; + case DataType::kLong: + return nvinfer1::DataType::kINT32; case DataType::kBool: return nvinfer1::DataType::kBOOL; case DataType::kFloat: @@ -40,6 +44,27 @@ nvinfer1::DataType toTRTDataType(DataType value) { } } +at::ScalarType toAtenDataType(DataType value) { + switch (value) { + case DataType::kChar: + return at::kChar; + case DataType::kHalf: + return at::kHalf; + case DataType::kInt32: + return at::kInt; + case DataType::kLong: + return at::kLong; + case DataType::kBool: + return at::kBool; + case DataType::kFloat: + return at::kFloat; + case DataType::kUnknown: + return at::kFloat; + default: + TORCHTRT_THROW_ERROR("Unknown data type: " << to_str(value)); + } +} + Device::Device(const core::runtime::RTDevice& internal_dev) { device_type = DeviceType::kGPU; gpu_id = internal_dev.id; @@ -70,9 +95,9 @@ std::string to_str(TensorFormat value) { core::ir::Input Input::toInternalInput() { if (!input_is_dynamic) { - return core::ir::Input(opt, toTRTDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input(opt, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } else { - return core::ir::Input(min, opt, max, toTRTDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input(min, opt, max, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } } diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 6762d078a1..3470944c72 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -27,9 +27,10 @@ namespace pyapi { return static_cast(field_name); \ } -enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; +enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; std::string to_str(DataType value); nvinfer1::DataType toTRTDataType(DataType value); +at::ScalarType toAtenDataType(DataType value); enum class TensorFormat : int8_t { kContiguous, kChannelsLast }; std::string to_str(TensorFormat value); diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 868dbb21fa..7341ba9281 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -242,6 +242,8 @@ PYBIND11_MODULE(_C, m) { .value("float16", DataType::kHalf, "16 bit floating point number") .value("int8", DataType::kChar, "8 bit integer number") .value("int32", DataType::kInt32, "32 bit integer number") + .value("long", DataType::kLong, "64 bit integer number") + .value("int64", DataType::kLong, "64 bit integer number") .value("bool", DataType::kBool, "Boolean value") .value("unknown", DataType::kUnknown, "Unknown data type") .export_values(); diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 5ffe0471f4..d76d259e29 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -214,6 +214,14 @@ def _parse_input_signature(input_signature: Any): if isinstance(input_signature, torch.Tensor) else input_signature ) + + if not i.is_trt_dtype(): + raise TypeError( + "Using non-TRT input types with input_signature is not currently " + + "supported. Please specify inputs individually to use " + + "non-TRT types." + ) + clone = _internal_input_to_torch_class_input(i._to_internal()) return clone else: diff --git a/tests/cpp/test_collections.cpp b/tests/cpp/test_collections.cpp index 7fcc006980..982562923d 100644 --- a/tests/cpp/test_collections.cpp +++ b/tests/cpp/test_collections.cpp @@ -45,6 +45,50 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) { ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor())); } +TEST(CppAPITests, TestCollectionStandardTensorInputLongDtype) { + std::string path = "tests/modules/standard_tensor_input_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kLong); + std::vector inputs; + inputs.push_back(in0); + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + auto out = mod.forward(inputs_); + + std::vector input_range; + + // Specify Long input tensor type + input_range.push_back({in0.sizes(), torch::kLong}); + input_range.push_back({in0.sizes(), torch::kLong}); + torch_tensorrt::ts::CompileSpec compile_settings(input_range); + compile_settings.min_block_size = 1; + + // // FP32 execution with long and double truncation + compile_settings.enabled_precisions = {torch::kFloat}; + compile_settings.truncate_long_and_double = true; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(inputs_); + + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + out.toTensor().to(torch::kFloat), trt_out.toTensor().to(torch::kFloat))); +} + TEST(CppAPITests, TestCollectionTupleInput) { std::string path = "tests/modules/tuple_input_scripted.jit.pt"; torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py index 936a4d5c73..33b4608194 100644 --- a/tests/py/api/test_collections.py +++ b/tests/py/api/test_collections.py @@ -50,6 +50,36 @@ def test_compile(self): ) +class TestStandardTensorInputLong(unittest.TestCase): + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "inputs": [ + torchtrt.Input(self.input.shape, dtype=torch.long), + torchtrt.Input(self.input.shape, dtype=torch.long), + ], + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "truncate_long_and_double": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + cos_sim = cosine_similarity( + self.model(self.input, self.input), trt_mod(self.input, self.input) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"standard_tensor_input_long_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + class TestTupleInput(unittest.TestCase): def test_compile(self): diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index ce8435a34c..8358a3c570 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -20,7 +20,7 @@ namespace util { std::vector toInputs(std::vector ten) { std::vector a; for (auto i : ten) { - a.push_back(core::ir::Input(core::util::toVec(i.sizes()), core::util::ScalarTypeToTRTDataType(i.scalar_type()))); + a.push_back(core::ir::Input(core::util::toVec(i.sizes()), i.scalar_type())); } return a; } @@ -30,7 +30,7 @@ std::vector toInputsDynamic(std::vector ten, bool d for (auto i : ten) { auto opt = core::util::toVec(i.sizes()); - auto dtype = core::util::ScalarTypeToTRTDataType(i.scalar_type()); + auto dtype = i.scalar_type(); if (dynamic_batch) { std::vector min_range(opt);