Skip to content

Commit

Permalink
feat: Automatically cast user inputs to inferred data type
Browse files Browse the repository at this point in the history
- Add post-lowering pass to insert `aten::to` operators for Tensor
inputs determined to require float or int inputs
- Specifically, if the user provides an non-float input to a
float-dtype input field and has `truncate_long_and_double=True`, a
Torch-executed graph block will be inserted which casts that input to a
float in-place.
- This operation modifies user-provided tensors and provides a warning
as such
- Currently, the feature is only functional for Tensor inputs (not input
signatures) and only casts to int and float types - if the input is
specified as any other type, a cast will not be inserted
- Modify compiler to extract inferred data types for each input
- Add testing to ensure casts are inserted correctly and run in Torch
  • Loading branch information
gs-olive committed Dec 15, 2022
1 parent 2ef6c3a commit 3677fd3
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 24 deletions.
58 changes: 34 additions & 24 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
return partitioning::stitch(&partitioning_ctx, block);
}

void MapInputsAndDetermineDTypes(
ir::TypeMap MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
Expand All @@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
cfg.partitioning_info.collection_input_spec_map =
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);

ir::TypeMap inferred_dtypes;
auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG(
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
Expand Down Expand Up @@ -239,34 +240,36 @@ void MapInputsAndDetermineDTypes(
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};

} else {
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
}
} else if (
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
}
} else {
// The user defined the type so no changes are necessary
}

// Insert entry for Value pointer and determined ScalarType
inferred_dtypes.insert({in, c10::optional<c10::ScalarType>(util::TRTDataTypeToScalarType(spec[i].dtype))});
}
}
// }
return inferred_dtypes;
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
Expand Down Expand Up @@ -307,7 +310,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
// Extract map of IValue to DType
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

// Use dtype map to autocast inputs to the correct type
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double) {
lowering::AutocastInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
}

auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partitioning_info.enabled &&
Expand Down
50 changes: 50 additions & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,56 @@ void LowerBlock(torch::jit::Block* b) {
DropUnusedNodes(b);
}

void AutocastInputs(std::shared_ptr<torch::jit::Graph>& g, ir::TypeMap input_type_map, std::string target_device_name) {
// For each graph input, determine if it can be autocasted
for (int i = 0; i < g->inputs().size(); i++) {
auto input = g->inputs()[i];

// Autocasted inputs must be Tensor-type
if (input->type()->isSubtypeOf(c10::TensorType::get())) {
auto dtype_input = input_type_map.find(input);

// Ensure the data type to be casted to exists in the type map
if (dtype_input == input_type_map.end() || !dtype_input->second) {
LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast");
continue;
}

auto dtype = dtype_input->second.value();
// Currently, we do not autocast inputs for which the determined type is not int or float
if (!(dtype == at::kFloat || dtype == at::kInt)) {
continue;
}

LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);

// Generate cast node sending input tensors to the inferred or specified datatype (float or int)
auto const_type = g->insertConstant(dtype);
auto const_false = g->insertConstant(0);
const_false->setType(torch::jit::BoolType::get());
auto cuda = g->insertConstant(target_device_name);
cuda->setType(torch::jit::DeviceObjType::get());
auto none_val = g->insertNode(g->createNone())->output();
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});

// Replace all uses of the original tensor with that of the casted tensor
g->prependNode(cast_node);
input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);

// Mark the cast node to run in PyTorch for ease of casting
LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch");
cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
}
}

LOG_WARNING(
"Input tensors to this Torch-TRT engine may have their data types in-place modified "
<< "if the type does not match the determined required type for the model. To disable this "
<< "automatic casting, specify truncate_long_and_double=False");

LOG_GRAPH("Graph after Autocast: " << *g);
}

void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
torch::jit::EliminateRedundantGuards(g);
torch::jit::RemoveListMutation(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct LowerInfo {

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
void AutocastInputs(std::shared_ptr<torch::jit::Graph>& g, ir::TypeMap input_type_map, std::string target_device_name);
torch::jit::Module LowerModule(
const torch::jit::Module& mod,
std::string method_name,
Expand Down
59 changes: 59 additions & 0 deletions tests/core/partitioning/test_type_auto_conversion.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <string>
#include "core/ir/ir.h"
#include "core/lowering/lowering.h"
#include "core/partitioning/partitioning.h"
#include "core/util/trt_util.h"
#include "gtest/gtest.h"
Expand Down Expand Up @@ -107,3 +109,60 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) {
}
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
}

TEST(Partitioning, AutoCastingInputIntsFloatsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
%y.1 : Tensor):
%k.1 : int = prim::Constant[value=1]() # examples/custom_converters/toy_model.py:38:12
%3 : int = prim::Constant[value=2]() # examples/custom_converters/toy_model.py:40:13
%x.5 : Tensor = aten::add_(%x.1, %y.1, %k.1) # examples/custom_converters/toy_model.py:39:8
%23 : Tensor = aten::mul(%y.1, %3) # <string>:3:9
%x.9 : Tensor = aten::add(%x.5, %23, %k.1) # examples/custom_converters/toy_model.py:40:8
%x.13 : Tensor = aten::add(%x.9, %k.1, %k.1) # examples/custom_converters/toy_model.py:41:8
%x.17 : Tensor = aten::sub(%x.13, %k.1, %k.1) # examples/custom_converters/toy_model.py:42:8
%x.21 : Tensor = aten::add(%x.17, %k.1, %k.1) # examples/custom_converters/toy_model.py:43:8
%x.25 : Tensor = aten::sub(%x.21, %k.1, %k.1) # examples/custom_converters/toy_model.py:44:8
return (%x.25))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
partitioning_info.forced_fallback_operators = {"aten::expand"};
partitioning_info.truncate_long_and_double = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;

inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));

std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
inputs_map.insert({g->inputs()[0], {inputs[0]}});
input_types.insert({g->inputs()[0], {{at::kFloat}}});
inputs_map.insert({g->inputs()[1], {inputs[1]}});
input_types.insert({g->inputs()[1], {{at::kInt}}});

partitioning_info.collection_input_spec_map = inputs_map;
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
ctx.input_types_map = input_types;

// Generate map of input Value * to dtype
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
torch_tensorrt::core::ir::TypeMap dtype_map;
dtype_map.insert({g->inputs()[0], c10::optional<c10::ScalarType>(at::kFloat)});
dtype_map.insert({g->inputs()[1], c10::optional<c10::ScalarType>(at::kInt)});

torch_tensorrt::core::lowering::AutocastInputs(g, dtype_map, "cuda");
torch_tensorrt::core::partitioning::partition(&ctx);
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;

for (auto& seg_block : segmented_blocks) {
LOG_DEBUG(seg_block << " cur seg block");
}

// Ensure the first segmented block is a Torch block containing 2 casts
ASSERT_TRUE(segmented_blocks[0].target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTorch);
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[0], 2));
}

0 comments on commit 3677fd3

Please sign in to comment.