Skip to content

Commit

Permalink
Merge branch 'main' into broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed May 1, 2023
2 parents 3cac6b7 + c5cc6e3 commit ed7f63d
Show file tree
Hide file tree
Showing 215 changed files with 4,556 additions and 982 deletions.
244 changes: 211 additions & 33 deletions .circleci/config.yml

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import torch_tensorrt
...
trt_ts_module = torch_tensorrt.compile(torch_script_module,
# If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
inputs = [example_tensor, # Provide example tensor for input shape or...
torch_tensorrt.Input( # Specify input object with shape and dtype
min_shape=[1, 3, 224, 224],
Expand All @@ -81,6 +82,12 @@ trt_ts_module = torch_tensorrt.compile(torch_script_module,
# For static size shape=[1, 3, 224, 224]
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
],
# For inputs containing tuples or lists of tensors, use the `input_signature` argument:
# Below, we have an input consisting of a Tuple of two Tensors (Tuple[Tensor, Tensor])
# input_signature = ( (torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half),
# torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half)), ),
enabled_precisions = {torch.half}, # Run with FP16
)
Expand Down Expand Up @@ -114,7 +121,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.

- Bazel 5.2.0
- Libtorch 2.0.0.dev20230103 (built with CUDA 11.7)
- Libtorch 2.1.0.dev20230314 (built with CUDA 11.7)
- CUDA 11.7
- cuDNN 8.5.0
- TensorRT 8.5.1.7
Expand All @@ -124,7 +131,7 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
Releases: https://github.com/pytorch/TensorRT/releases

```
pip install torch-tensorrt==1.2.0 --find-links https://github.com/pytorch/TensorRT/releases/expanded_assets/v1.2.0
pip install torch-tensorrt
```

## Compiling Torch-TensorRT
Expand Down
8 changes: 4 additions & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ new_local_repository(
http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "59b8b5e1954a86d50b79c13f06398d385b200da13e37a08ecf31d3c62e5ca127",
sha256 = "7c4b8754830fef23ec19c5eaf414794cee9597b435df055f5c1d0471d3e81568",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
)

http_archive(
name = "libtorch_pre_cxx11_abi",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "e260fc7476be89d1650953e8643e9f7363845f5a52de4bab87ac0e619c1f6ad4",
sha256 = "f1e64a75dd12d0ba4c8c1f61947299e0a9c50684dff64f0cfbf355aa7a13e8cf",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
)

# Download these tarballs manually from the NVIDIA website
Expand Down
84 changes: 68 additions & 16 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ partitioning::GraphAndMapping BuildHybridGraph(
torch::jit::Block* block,
CompileSpec cfg,
ir::StaticParams static_params,
ir::CollectionTypeMap first_use_types) {
ir::CollectionTypeMap first_use_types,
bool expect_full_compilation = false) {
auto convert_info = cfg.convert_info;
auto partitioning_info = cfg.partitioning_info;

Expand All @@ -149,17 +150,20 @@ partitioning::GraphAndMapping BuildHybridGraph(
// TODO: Combine this within partition call
partitioning::populateInputIValues(&partitioning_ctx);

partitioning::partition(&partitioning_ctx);
partitioning::partition(&partitioning_ctx, expect_full_compilation);

for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
int num_torch_segments = 0;
int num_trt_segments = 0;

for (auto& seg_block : segmented_blocks) {
LOG_INFO("Block segment:" << seg_block);
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
num_trt_segments++;
auto inputs = seg_block.construct_inputs_spec();
// update the input ranges for each segments
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
Expand All @@ -180,8 +184,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
true);

seg_block.update_graph(temp_g);
} else {
num_torch_segments++;

// If full compilation is expected, ensure that all operators in Torch blocks are
// for collections processing
if (expect_full_compilation) {
for (auto torch_node : seg_block.block()->nodes()) {
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
TORCHTRT_THROW_ERROR(
"Full compilation specified but node "
<< *torch_node
<< " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
<< " Try recompiling with require_full_compilation=False.");
}
}
}
}
}

// If full compilation is expected, cannot have more than 2 Torch segments
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
TORCHTRT_THROW_ERROR(
"Full compilation was requested but unable to convert all operations to TensorRT."
<< " Try recompiling with require_full_compilation=False.");
}
}

return partitioning::stitch(&partitioning_ctx, block);
Expand All @@ -191,7 +219,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
ir::CollectionTypeMap& first_use_type_map) {
ir::CollectionTypeMap& first_use_type_map,
bool requires_collection_handling = false) {
cfg.convert_info.collection_input_spec_map =
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
cfg.partitioning_info.collection_input_spec_map =
Expand Down Expand Up @@ -226,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = at::kFloat;
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
std::stringstream ss;
Expand Down Expand Up @@ -297,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
return engine;
}

bool userRequestedFallback(CompileSpec& cfg) {
return cfg.lower_info.forced_fallback_modules.size() != 0 ||
cfg.partitioning_info.forced_fallback_operators.size() != 0;
}

torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");

Expand All @@ -315,8 +349,18 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

// Determine if the block is convertible/has collection output, and based on the result,
// whether full compilation can be expected
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto inputIsCollection = conversion::InputIsCollection(g->block());
auto outputIsCollection = conversion::OutputIsCollection(g->block());
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));

// Determine whether user specifications necessitate partitioning
auto isFallbackRequested = userRequestedFallback(cfg);

// Extract map of IValue to DType
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling);

// Check whether any of the input types are Long
bool user_requested_long = false;
Expand All @@ -330,20 +374,28 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
user_requested_long &= (casts_inserted > 0);
}

auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partitioning_info.enabled && !user_requested_long &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
!outputIsCollection) {
// Partitioning is required if:
// 1. User requested some modules/operators fallback
// 2. The block (graph) cannot be converted due to operator coverage
// 3. The output of the graph is a collection
// 4. The user requested a non-TRT data type input
auto isPartitioningRequired =
(isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);

// The user did not require full compilation, but the model can be fully compiled
if (cfg.partitioning_info.enabled && !isPartitioningRequired) {
LOG_INFO("Skipping partitioning since model is fully supported");
}

if (cfg.partitioning_info.enabled &&
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection || user_requested_long)) {
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
// The user did not require full compilation, and the model can be fully compiled
// or, the user required full compilation but the I/O of the graph use collections
if ((cfg.partitioning_info.enabled && isPartitioningRequired) || requires_collection_handling) {
// If the model is fully-compilable and the user has specified full compilation, run partitioning
// to generate collection-processing code in Torch
auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled);

auto graph_and_mapping =
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
Expand Down
14 changes: 12 additions & 2 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
return {};
}
}
auto eval = evaluators::EvalNode(n, eval_args);
auto eval = evaluators::EvalNode(ctx, n, eval_args);
return eval;
}

Expand Down Expand Up @@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
return convertable_ops;
}

bool InputIsCollection(const torch::jit::Block* b) {
for (auto in : b->inputs()) {
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
return true;
}
}
return false;
}

bool OutputIsCollection(const torch::jit::Block* b) {
for (auto out : b->outputs()) {
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
out->type()->kind() == torch::jit::TypeKind::ListType) {
out->type()->kind() == torch::jit::TypeKind::ListType ||
out->type()->kind() == torch::jit::TypeKind::DictType) {
return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(

bool OpSupported(const torch::jit::Node* n);

bool InputIsCollection(const torch::jit::Block* b);

bool OutputIsCollection(const torch::jit::Block* b);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);
Expand Down
23 changes: 18 additions & 5 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,28 @@ void _batch_norm(
const torch::Tensor& mean,
const torch::Tensor& var,
const float eps) {
auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;
auto orig_dtype = var.dtype();
// perform compile-time weight calculations in float to improve accuracy
// resulting weights will be embedded as the original dtype
auto calculation_gamma = gamma;
auto calculation_beta = beta;
auto calculation_mean = mean;
auto calculation_var = var;
if (orig_dtype == torch::kHalf) {
calculation_gamma = calculation_gamma.to(torch::kFloat);
calculation_beta = calculation_beta.to(torch::kFloat);
calculation_mean = calculation_mean.to(torch::kFloat);
calculation_var = calculation_var.to(torch::kFloat);
}
auto scale = calculation_gamma / torch::sqrt(calculation_var + eps);
auto bias = calculation_beta - calculation_mean * scale;
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);
auto scale_weights = Weights(ctx, scale.to(orig_dtype));
auto bias_weights = Weights(ctx, bias.to(orig_dtype));

auto power = Weights(ctx, at::ones_like(scale));
auto power = Weights(ctx, at::ones_like(scale).to(orig_dtype));
auto bn =
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
bn->setName(util::node_info(n).c_str());
Expand Down
23 changes: 23 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,29 @@ auto element_wise_registrations TORCHTRT_UNUSED =
return true;
}})
.pattern(
{"aten::logical_and(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// torch.logical_and autocasts inputs to bool
auto input_as_bool = [&](int idx) {
auto x = args[idx].ITensorOrFreeze(ctx);
if (x->getType() != nvinfer1::DataType::kBOOL) {
x = castITensor(
ctx, x, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_bool_" + str(idx)).c_str());
}
return x;
};
auto self = input_as_bool(0);
auto other = input_as_bool(1);

auto and_layer =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kAND, self, other, util::node_info(n) + "_and");
TORCHTRT_CHECK(and_layer, "Unable to create and layer from node: " << *n);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], and_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
{"aten::atan2(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Element-wise divide input Tensors, apply atan unary, apply quadrant correction
Expand Down
Loading

0 comments on commit ed7f63d

Please sign in to comment.