Skip to content

Commit

Permalink
feat: Add option to specify int64 as an Input dtype
Browse files Browse the repository at this point in the history
- Rework `Input` paradigm to be based on `at::ScalarType` as opposed to
the previous `nvinfer1::DataType`, allowing a larger representation
space of data types
- When paired with `truncate_long_and_double`, insert casts to ensure
Torch engines using Int64 tensors receive the correct types, and
TensorRT engines operating on those tensors receive downcasted Int32
versions thereof
- Add Torch block at the beginning of model graph to prepare types
of input tensors for forthcoming engines in sequence
- Automatically follow internal tensor types to abstract away the
different internal engines used (Torch/TensorRT) from the user
- Provide a framework for streamlined addition of other data types,
including `torch.double` as valid input types
- Improve error checking to ensure model compilation and behavior is as
documented. For example, disallow specification of Long type input if
the engine is required to be converted entirely to TRT

- Known Limitations:
- Specifying `dtype=torch.long` on an `Input` in an `input_signature` is
not supported currently and will throw an error before model compilation
when used with the Python API
- While Torch may output Int64 tensors from the overall model, Torch-TRT
currently can only output Int32 tensors for models using TRT, as there
is not a mechanism in place for differentiating intermediate blocks from
final/beginning blocks in the graph
- Torch-TRT will almost definitely alter the data type of the input
tensor, in-place, if `dtype=torch.long` is specified, and the returned
result will be of type `torch.int32`
  • Loading branch information
gs-olive committed Dec 21, 2022
1 parent f43be5b commit 14ed6dd
Show file tree
Hide file tree
Showing 17 changed files with 234 additions and 53 deletions.
84 changes: 53 additions & 31 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
return partitioning::stitch(&partitioning_ctx, block);
}

void MapInputsAndDetermineDTypes(
ir::TypeMap MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
Expand All @@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
cfg.partitioning_info.collection_input_spec_map =
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);

ir::TypeMap inferred_dtypes;
auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG(
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
Expand All @@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes(
LOG_INFO(
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
<< in->debugName() << " has type " << est_type_opt[i].value());
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
spec[i].dtype = est_type_opt[i].value();
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = nvinfer1::DataType::kFLOAT;
spec[i].dtype = at::kFloat;
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
Expand All @@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes(
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};

} else {
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
}
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};

} else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
}
} else {
// The user defined the type so no changes are necessary
}

// Insert entry for Value pointer and determined ScalarType
inferred_dtypes.insert({in, {spec[i].dtype}});
}
}
// }
return inferred_dtypes;
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
Expand All @@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

// Ensure none of the specified types are of acceptable input types incompatible with TRT
// Currently, only at::kLong is an acceptable, though TRT-incompatible type
for (auto value_to_dtypes : first_use_types) {
for (auto dtype : value_to_dtypes.second) {
TORCHTRT_CHECK(
!dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT");
}
}

auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);

return engine;
Expand All @@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
// Extract map of IValue to DType
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

// Check whether any of the input types are Long
bool user_requested_long = false;
for (auto dtype : type_map) {
user_requested_long |= dtype.second && (dtype.second.value() == at::kLong);
}

// Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) {
auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
user_requested_long &= (casts_inserted > 0);
}

auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partitioning_info.enabled &&
if (cfg.partitioning_info.enabled && !user_requested_long &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
!outputIsCollection) {
Expand All @@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
if (cfg.partitioning_info.enabled &&
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection)) {
outputIsCollection || user_requested_long)) {
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void AddInputs(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> input
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec
<< " in engine (conversion.AddInputs)");

auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
auto trt_in = ctx->net->addInput(name.c_str(), util::ScalarTypeToTRTDataType(spec.dtype), spec.input_shape);
TORCHTRT_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));

Expand Down
12 changes: 6 additions & 6 deletions core/ir/Input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {

Input::Input(
std::vector<int64_t> shape,
nvinfer1::DataType dtype,
at::ScalarType dtype,
nvinfer1::TensorFormat format,
bool dtype_is_user_defined) {
if (shape.size() > 5) {
Expand All @@ -84,10 +84,10 @@ Input::Input(
input_shape = util::toDims(shape);
input_is_dynamic = false;

TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
this->dtype = dtype;
TORCHTRT_CHECK(
valid_dtype_format_combo(dtype, format),
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
"Unsupported combination of dtype and tensor format: ("
<< dtype << ", " << format
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
Expand All @@ -99,7 +99,7 @@ Input::Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
nvinfer1::DataType dtype,
at::ScalarType dtype,
nvinfer1::TensorFormat format,
bool dtype_is_user_defined) {
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
Expand Down Expand Up @@ -137,10 +137,10 @@ Input::Input(

input_shape = util::toDims(dyn_shape);

TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
this->dtype = dtype;
TORCHTRT_CHECK(
valid_dtype_format_combo(dtype, format),
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
"Unsupported combination of dtype and tensor format: ("
<< dtype << ", " << format
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
Expand Down
7 changes: 4 additions & 3 deletions core/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ struct Input : torch::CustomClassHolder {
Input(){};
Input(
std::vector<int64_t> shape,
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
at::ScalarType dtype = at::kFloat,
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
bool dtype_is_user_defined = false);
Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
at::ScalarType dtype = at::kFloat,
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
bool dtype_is_used_defined = false);

friend std::ostream& operator<<(std::ostream& os, const Input& input);

bool input_is_dynamic = false;
Expand All @@ -47,7 +48,7 @@ struct Input : torch::CustomClassHolder {
nvinfer1::Dims min;
nvinfer1::Dims max;
nvinfer1::Dims opt;
nvinfer1::DataType dtype;
at::ScalarType dtype;
nvinfer1::TensorFormat format;
int id;
};
Expand Down
57 changes: 57 additions & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,63 @@ void LowerBlock(torch::jit::Block* b) {
DropUnusedNodes(b);
}

int AutocastLongInputs(
std::shared_ptr<torch::jit::Graph>& g,
ir::TypeMap input_type_map,
std::string target_device_name) {
int num_autocasts = 0;
// For each graph input, determine if it can be autocasted
for (int i = 0; i < g->inputs().size(); i++) {
auto input = g->inputs()[i];

// Autocasted inputs must be Tensor-type
if (input->type()->isSubtypeOf(c10::TensorType::get())) {
auto dtype_input = input_type_map.find(input);

// Ensure the data type to be casted to exists in the type map
if (dtype_input == input_type_map.end() || !dtype_input->second) {
LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast");
continue;
}

auto dtype = dtype_input->second.value();
// Currently, we do not autocast inputs for which the determined type is not long
if (dtype != at::kLong) {
continue;
}

LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);

// Generate cast node sending input tensors to the inferred or specified datatype (long)
auto const_type = g->insertConstant(dtype);
auto const_false = g->insertConstant(0);
const_false->setType(torch::jit::BoolType::get());
auto cuda = g->insertConstant(target_device_name);
cuda->setType(torch::jit::DeviceObjType::get());
auto none_val = g->insertNode(g->createNone())->output();
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});

// Replace all uses of the original tensor with that of the casted tensor
g->prependNode(cast_node);
input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);

// Mark the cast node to run in PyTorch for ease of casting
LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch");
cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
num_autocasts++;
}
}

LOG_WARNING(
"Input tensors to this Torch-TRT engine may have their data types in-place modified "
<< "if the type does not match the determined required type for TRT. To disable this "
<< "automatic casting, specify an Input dtype other than Long");

LOG_GRAPH("Graph after Autocast: " << *g);

return num_autocasts;
}

void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
torch::jit::EliminateRedundantGuards(g);
torch::jit::RemoveListMutation(g);
Expand Down
4 changes: 4 additions & 0 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ struct LowerInfo {

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
int AutocastLongInputs(
std::shared_ptr<torch::jit::Graph>& g,
ir::TypeMap input_type_map,
std::string target_device_name);
torch::jit::Module LowerModule(
const torch::jit::Module& mod,
std::string method_name,
Expand Down
4 changes: 2 additions & 2 deletions core/partitioning/segmentedblock/SegmentedBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ std::vector<ir::Input> SegmentedBlock::construct_inputs_spec() const {
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) {
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]);
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
in.dtype = in_types_[i];
inputs.push_back(in);
}
} else {
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
auto in = ir::Input(opt_shapes_[i]);
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
in.dtype = in_types_[i];
inputs.push_back(in);
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ void getSegmentsOutputByRunning(
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
} else if (partitioning_info.truncate_long_and_double && t == at::kLong) {
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
LOG_WARNING("Truncating intermediate graph input type from at::kLong to at::kInt");
} else if (partitioning_info.truncate_long_and_double && t == at::kDouble) {
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
LOG_WARNING("Truncating intermediate graph input type from at::kDouble to at::kFloat");
}
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
if (dtype == c10::nullopt) {
Expand Down
1 change: 1 addition & 0 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kFloat, nvinfer1::DataType::kFLOAT},
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT32},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
return at_trt_type_map;
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/torch_tensorrt/torch_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DataType {
* ex. torch_tensorrt::DataType type = DataType::kFloat;
*/
enum Value : int8_t {
/// INT64
kLong,
/// FP32
kFloat,
/// FP16
Expand Down
21 changes: 20 additions & 1 deletion cpp/src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,25 @@ nvinfer1::DataType toTRTDataType(DataType value) {
}
}

at::ScalarType toAtDataType(DataType value) {
switch (value) {
case DataType::kChar:
return at::kChar;
case DataType::kHalf:
return at::kHalf;
case DataType::kInt:
return at::kInt;
case DataType::kLong:
return at::kLong;
case DataType::kBool:
return at::kBool;
case DataType::kFloat:
case DataType::kUnknown:
default:
return at::kFloat;
}
}

nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
TORCHTRT_CHECK(!(value == TensorFormat::kUnknown), "Tensor format is unknown");
switch (value) {
Expand Down Expand Up @@ -267,7 +286,7 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) {
i.min_shape,
i.opt_shape,
i.max_shape,
toTRTDataType(i.dtype),
toAtDataType(i.dtype),
toTRTTensorFormat(i.format),
!(i.dtype == DataType::kUnknown));
}
Expand Down
Loading

0 comments on commit 14ed6dd

Please sign in to comment.