Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add option to specify int64 as an Input dtype #1551

Merged
merged 2 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 53 additions & 31 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
return partitioning::stitch(&partitioning_ctx, block);
}

void MapInputsAndDetermineDTypes(
ir::TypeMap MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
Expand All @@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
cfg.partitioning_info.collection_input_spec_map =
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);

ir::TypeMap inferred_dtypes;
auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG(
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
Expand All @@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes(
LOG_INFO(
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
<< in->debugName() << " has type " << est_type_opt[i].value());
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
spec[i].dtype = est_type_opt[i].value();
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = nvinfer1::DataType::kFLOAT;
spec[i].dtype = at::kFloat;
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
Expand All @@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes(
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};

} else {
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
}
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};

} else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
}
} else {
// The user defined the type so no changes are necessary
}

// Insert entry for Value pointer and determined ScalarType
inferred_dtypes.insert({in, {spec[i].dtype}});
}
}
// }
return inferred_dtypes;
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
Expand All @@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

// Ensure none of the specified types are of acceptable input types incompatible with TRT
// Currently, only at::kLong is an acceptable, though TRT-incompatible type
for (auto value_to_dtypes : first_use_types) {
for (auto dtype : value_to_dtypes.second) {
TORCHTRT_CHECK(
!dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT");
}
}

auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);

return engine;
Expand All @@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
// Extract map of IValue to DType
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

// Check whether any of the input types are Long
bool user_requested_long = false;
for (auto dtype : type_map) {
user_requested_long |= dtype.second && (dtype.second.value() == at::kLong);
}

// Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) {
auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
user_requested_long &= (casts_inserted > 0);
}

auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partitioning_info.enabled &&
if (cfg.partitioning_info.enabled && !user_requested_long &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
!outputIsCollection) {
Expand All @@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
if (cfg.partitioning_info.enabled &&
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection)) {
outputIsCollection || user_requested_long)) {
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void AddInputs(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> input
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec
<< " in engine (conversion.AddInputs)");

auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
auto trt_in = ctx->net->addInput(name.c_str(), util::ScalarTypeToTRTDataType(spec.dtype), spec.input_shape);
TORCHTRT_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));

Expand Down
12 changes: 6 additions & 6 deletions core/ir/Input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {

Input::Input(
std::vector<int64_t> shape,
nvinfer1::DataType dtype,
at::ScalarType dtype,
nvinfer1::TensorFormat format,
bool dtype_is_user_defined) {
if (shape.size() > 5) {
Expand All @@ -84,10 +84,10 @@ Input::Input(
input_shape = util::toDims(shape);
input_is_dynamic = false;

TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
this->dtype = dtype;
TORCHTRT_CHECK(
valid_dtype_format_combo(dtype, format),
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
"Unsupported combination of dtype and tensor format: ("
<< dtype << ", " << format
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
Expand All @@ -99,7 +99,7 @@ Input::Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
nvinfer1::DataType dtype,
at::ScalarType dtype,
nvinfer1::TensorFormat format,
bool dtype_is_user_defined) {
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
Expand Down Expand Up @@ -137,10 +137,10 @@ Input::Input(

input_shape = util::toDims(dyn_shape);

TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
this->dtype = dtype;
TORCHTRT_CHECK(
valid_dtype_format_combo(dtype, format),
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
"Unsupported combination of dtype and tensor format: ("
<< dtype << ", " << format
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
Expand Down
7 changes: 4 additions & 3 deletions core/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ struct Input : torch::CustomClassHolder {
Input(){};
Input(
std::vector<int64_t> shape,
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
at::ScalarType dtype = at::kFloat,
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
bool dtype_is_user_defined = false);
Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
at::ScalarType dtype = at::kFloat,
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
bool dtype_is_used_defined = false);

friend std::ostream& operator<<(std::ostream& os, const Input& input);

bool input_is_dynamic = false;
Expand All @@ -47,7 +48,7 @@ struct Input : torch::CustomClassHolder {
nvinfer1::Dims min;
nvinfer1::Dims max;
nvinfer1::Dims opt;
nvinfer1::DataType dtype;
at::ScalarType dtype;
nvinfer1::TensorFormat format;
int id;
};
Expand Down
69 changes: 69 additions & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,75 @@ void LowerBlock(torch::jit::Block* b) {
DropUnusedNodes(b);
}

int AutocastLongInputs(
std::shared_ptr<torch::jit::Graph>& g,
ir::TypeMap input_type_map,
std::string target_device_name) {
int num_autocasts = 0;
// For each graph input, determine if it can be autocasted
for (int i = 0; i < g->inputs().size(); i++) {
auto input = g->inputs()[i];

// Autocasted inputs must be Tensor-type
if (input->type()->isSubtypeOf(c10::TensorType::get())) {
auto dtype_input = input_type_map.find(input);

// Ensure the data type to be casted to exists in the type map
if (dtype_input == input_type_map.end() || !dtype_input->second) {
LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast");
continue;
}

auto dtype = dtype_input->second.value();
// Currently, we do not autocast inputs for which the determined type is not long
if (dtype != at::kLong) {
LOG_DEBUG(
"Skipping autocast for tensor " << input->debugName() << ", since its dtype is " << dtype
<< " and not at::kLong");
continue;
}

LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);

// Generate cast node sending input tensors to the inferred or specified datatype (long)
torch::jit::Value *const_false, *cuda, *none_val;
if (num_autocasts == 0) {
// Only generate constants once and reuse for all autocasts
const_false = g->insertConstant(0);
const_false->setType(torch::jit::BoolType::get());
cuda = g->insertConstant(target_device_name);
cuda->setType(torch::jit::DeviceObjType::get());
none_val = g->insertNode(g->createNone())->output();
}

auto const_type = g->insertConstant(dtype);
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});

// Replace all uses of the original tensor with that of the casted tensor
g->prependNode(cast_node);
input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);

// Mark the cast node to run in PyTorch for ease of casting
LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch");
cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
num_autocasts++;
}
}

LOG_GRAPH("Inserted " << num_autocasts << " autocasts");

if (num_autocasts > 0) {
LOG_WARNING(
"Data types for input tensors have been modified by inserting "
<< "aten::to operations which cast INT64 inputs to INT32. "
<< "To disable this, please recompile using INT32 inputs");

LOG_GRAPH("Graph after Autocast: " << *g);
}

return num_autocasts;
}

void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
torch::jit::EliminateRedundantGuards(g);
torch::jit::RemoveListMutation(g);
Expand Down
4 changes: 4 additions & 0 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ struct LowerInfo {

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
int AutocastLongInputs(
std::shared_ptr<torch::jit::Graph>& g,
ir::TypeMap input_type_map,
std::string target_device_name);
torch::jit::Module LowerModule(
const torch::jit::Module& mod,
std::string method_name,
Expand Down
4 changes: 2 additions & 2 deletions core/partitioning/segmentedblock/SegmentedBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ std::vector<ir::Input> SegmentedBlock::construct_inputs_spec() const {
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) {
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]);
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
in.dtype = in_types_[i];
inputs.push_back(in);
}
} else {
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
auto in = ir::Input(opt_shapes_[i]);
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
in.dtype = in_types_[i];
inputs.push_back(in);
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ void getSegmentsOutputByRunning(
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
} else if (partitioning_info.truncate_long_and_double && t == at::kLong) {
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
LOG_WARNING("Truncating intermediate graph input type from at::kLong to at::kInt");
} else if (partitioning_info.truncate_long_and_double && t == at::kDouble) {
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
LOG_WARNING("Truncating intermediate graph input type from at::kDouble to at::kFloat");
}
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
if (dtype == c10::nullopt) {
Expand Down
1 change: 1 addition & 0 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kFloat, nvinfer1::DataType::kFLOAT},
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT32},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
return at_trt_type_map;
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/torch_tensorrt/torch_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DataType {
* ex. torch_tensorrt::DataType type = DataType::kFloat;
*/
enum Value : int8_t {
/// INT64
kLong,
/// FP32
kFloat,
/// FP16
Expand Down
Loading