From 70600a20241e749004df653762fe618ef3aace7d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:30:26 -0700 Subject: [PATCH 1/9] feat: Cudagraphs integration for Torch-TRT - Add option to enable Cudagraphs in the runtime for additional acceleration of TRT engines - Add C++ and Python toggles and well as full integration for C++ and Python runtimes - Add support for dynamic shape cases via shape keys with cache invalidation - Add test cases for cudagraphs support --- core/runtime/TRTEngine.cpp | 2 + core/runtime/TRTEngine.h | 8 + core/runtime/execute_engine.cpp | 278 +++++++++++------ core/runtime/register_jit_hooks.cpp | 2 + core/runtime/runtime.cpp | 9 + core/runtime/runtime.h | 6 + docsrc/index.rst | 1 + docsrc/user_guide/runtime.rst | 24 ++ examples/dynamo/README.rst | 1 + examples/dynamo/torch_export_cudagraphs.py | 69 ++++ .../runtime/_PythonTorchTensorRTModule.py | 294 ++++++++++++------ py/torch_tensorrt/runtime/__init__.py | 1 + py/torch_tensorrt/runtime/cudagraphs.py | 55 ++++ tests/py/dynamo/runtime/test_cudagraphs.py | 112 +++++++ 14 files changed, 662 insertions(+), 200 deletions(-) create mode 100644 examples/dynamo/torch_export_cudagraphs.py create mode 100644 py/torch_tensorrt/runtime/cudagraphs.py create mode 100644 tests/py/dynamo/runtime/test_cudagraphs.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 6e6080a353..8f63563c58 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -124,6 +124,7 @@ TRTEngine::TRTEngine( } else { uint64_t inputs_size = _in_binding_names.size(); in_binding_names.resize(inputs_size); + input_buffers.resize(inputs_size); for (uint64_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) { auto binding_name = _in_binding_names[pyt_idx]; // Check if the binding name provided is in the list of engine's bindings @@ -153,6 +154,7 @@ TRTEngine::TRTEngine( uint64_t outputs = _out_binding_names.size(); out_binding_names.resize(outputs); + output_buffers.resize(outputs); for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) { auto binding_name = _out_binding_names[pyt_idx]; // Check if the binding name provided is in the list of engine's bindings diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index af6bdcec6f..0e76b63179 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -7,6 +7,7 @@ #include #include "ATen/core/function_schema.h" +#include "ATen/cuda/CUDAGraph.h" #include "NvInfer.h" #include "torch/custom_class.h" @@ -65,6 +66,13 @@ struct TRTEngine : torch::CustomClassHolder { void dump_engine_layer_info(); friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); static const char BINDING_DELIM = '%'; + + // CUDAGraph-Related Functionality + at::cuda::CUDAGraph cudagraph = {}; + std::vector input_buffers = {}; + std::vector output_buffers = {}; + std::string shape_key; + // TODO: Implement a call method // c10::List Run(c10::List inputs); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 9e240fa60b..d5bb1388fa 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -1,3 +1,4 @@ +#include "c10/cuda/CUDAGuard.h" #include "c10/cuda/CUDAStream.h" #include "torch/csrc/jit/runtime/custom_operator.h" @@ -58,6 +59,29 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de return new_target_device_opt.value(); } +bool _cudagraphs_validate_shapes(std::vector inputs, c10::intrusive_ptr compiled_engine) { + // Validate whether the current input shapes to the engine + // invalidate the existing cudagraphs object + + // Populate the shape key for the inputs + std::stringstream new_shape_key_ss; + for (auto input : inputs) { + new_shape_key_ss << input.sizes(); + } + + auto new_shape_key = new_shape_key_ss.str(); + + // Compare the shape key to the original key and invalidate shapes if they do not match + if (new_shape_key != compiled_engine->shape_key) { + LOG_DEBUG("Resetting Cudagraph on New Shape Key " << new_shape_key); + compiled_engine->shape_key = new_shape_key; + compiled_engine->cudagraph.reset(); + return false; + } + + return true; +} + std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { LOG_DEBUG( "Attempting to run engine (ID: " << compiled_engine->name @@ -76,107 +100,124 @@ std::vector execute_engine(std::vector inputs, c10::intr LOG_INFO("" << log_info); } - if (MULTI_DEVICE_SAFE_MODE) { - std::unique_ptr device_profiler_guard; - if (compiled_engine->profile_execution) { - device_profiler_guard = - std::make_unique(compiled_engine->device_profile_path); - } + // Whether cudagraphs needs to record the graph on this pass + bool need_cudagraphs_record = (CUDAGRAPHS_MODE && !_cudagraphs_validate_shapes(inputs, compiled_engine)); + + // Intialize outputs to be available throughout the succeeding scopes + std::vector outputs(compiled_engine->num_io.second); + + // If not in cudagraphs mode or a new cudagraphs recording is needed + // proceed with input validation and assignment of new I/O pointers for TRT + if (!CUDAGRAPHS_MODE || need_cudagraphs_record) { + if (MULTI_DEVICE_SAFE_MODE) { + std::unique_ptr device_profiler_guard; + if (compiled_engine->profile_execution) { + device_profiler_guard = + std::make_unique(compiled_engine->device_profile_path); + } - RTDevice curr_device = get_current_device(); - LOG_DEBUG("Current Device: " << curr_device); + RTDevice curr_device = get_current_device(); + LOG_DEBUG("Current Device: " << curr_device); - // Generic Target Device Prefix - std::string target_device = "cuda:"; + // Generic Target Device Prefix + std::string target_device = "cuda:"; - if (is_switch_required(curr_device, compiled_engine->device_info)) { - // Scan through available CUDA devices and set the CUDA device context correctly - RTDevice device = - select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); - set_rt_device(device); + if (is_switch_required(curr_device, compiled_engine->device_info)) { + // Scan through available CUDA devices and set the CUDA device context correctly + RTDevice device = + select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); + set_rt_device(device); - // Target device is new device - target_device += std::to_string(device.id); + // Target device is new device + target_device += std::to_string(device.id); - for (auto& in : inputs) { - in = in.to(torch::Device(target_device)); + for (auto& in : inputs) { + in = in.to(torch::Device(target_device)); + } + } else { + // Target device is current device + target_device += std::to_string(curr_device.id); } - } else { - // Target device is current device - target_device += std::to_string(curr_device.id); - } - // For each input, ensure its current device is the desired target device - for (size_t i = 0; i < inputs.size(); i++) { - at::Tensor* in = &inputs[i]; - std::string current_tensor_device = in->device().str(); - - // If current device string does not match target device, display warning and move tensor accordingly - if (current_tensor_device != target_device) { - LOG_WARNING( - "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device - << " but should be on " << target_device << ". This tensor is being moved by the runtime but " - << "for performance considerations, ensure your inputs are all on GPU " - << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " - << "warning persists."); - *in = in->to(torch::Device(target_device)); + // For each input, ensure its current device is the desired target device + for (size_t i = 0; i < inputs.size(); i++) { + at::Tensor* in = &inputs[i]; + std::string current_tensor_device = in->device().str(); + + // If current device string does not match target device, display warning and move tensor accordingly + if (current_tensor_device != target_device) { + LOG_WARNING( + "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device + << " but should be on " << target_device << ". This tensor is being moved by the runtime but " + << "for performance considerations, ensure your inputs are all on GPU " + << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " + << "warning persists."); + *in = in->to(torch::Device(target_device)); + } } } - } - // this is a buffer to store shape tensor input addresses throughout the runtime scope - std::list> inputShapeTensorValues; - { - std::unique_ptr input_profiler_guard; - if (compiled_engine->profile_execution) { - input_profiler_guard = - std::make_unique(compiled_engine->input_profile_path); - } - for (size_t i = 0; i < inputs.size(); i++) { - std::string name = compiled_engine->in_binding_names[i]; - TORCHTRT_CHECK( - inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); - auto expected_type = - util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); - TORCHTRT_CHECK( - inputs[i].dtype() == expected_type, - "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); - auto dims = core::util::toDims(inputs[i].sizes()); - auto shape = core::util::toVec(dims); - LOG_DEBUG("Input Name: " << name << " Shape: " << dims); - if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { - // Shape tensor inputs are casted to int32 explicitly. - // Refer to - // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 - auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32); - std::vector inputs_cpu_vec( - input_cpu.data_ptr(), input_cpu.data_ptr() + input_cpu.numel()); - inputShapeTensorValues.emplace_back(inputs_cpu_vec); - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), - "Error while setting the tensor address for shape inputs"); - } else { + { + std::unique_ptr input_profiler_guard; + if (compiled_engine->profile_execution) { + input_profiler_guard = + std::make_unique(compiled_engine->input_profile_path); + } + for (size_t i = 0; i < inputs.size(); i++) { + std::string name = compiled_engine->in_binding_names[i]; TORCHTRT_CHECK( - compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); + inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); + auto expected_type = + util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr()), - "Error while setting the input tensor address for inputs"); + inputs[i].dtype() == expected_type, + "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); + auto dims = core::util::toDims(inputs[i].sizes()); + auto shape = core::util::toVec(dims); + LOG_DEBUG("Input Name: " << name << " Shape: " << dims); + at::Tensor contig_input; + + if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { + // Shape tensor inputs are casted to int32 explicitly. + // Refer to + // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 + auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32); + std::vector inputs_cpu_vec( + input_cpu.data_ptr(), input_cpu.data_ptr() + input_cpu.numel()); + inputShapeTensorValues.emplace_back(inputs_cpu_vec); + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), + "Error while setting the tensor address for shape inputs"); + compiled_engine->input_buffers[i] = input_cpu; + } else { + // If in cudagraphs mode, the inputs must be cloned since the memory will be reused + // in subsequent replays of the graph + if (CUDAGRAPHS_MODE) { + contig_input = inputs[i].view(shape).contiguous().clone(); + compiled_engine->input_buffers[i] = contig_input; + } else { + contig_input = inputs[i].view(shape).contiguous(); + } + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), contig_input.data_ptr()), + "Error while setting the input tensor address for inputs"); + compiled_engine->input_buffers[i] = contig_input; + } } - } - // Check if input shapes can be inferred. - int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; - std::vector names(io_size); - int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); - TORCHTRT_CHECK( - nbNames == 0, - "The shapes of the inputs: " - << names - << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); - } + // Check if input shapes can be inferred. + int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; + std::vector names(io_size); + int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); + TORCHTRT_CHECK( + nbNames == 0, + "The shapes of the inputs: " + << names + << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); + } - std::vector outputs(compiled_engine->num_io.second); - { std::unique_ptr output_profiler_guard; if (compiled_engine->profile_execution) { output_profiler_guard = @@ -193,29 +234,68 @@ std::vector execute_engine(std::vector inputs, c10::intr auto dims = core::util::toVec(out_shape); auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); + + // In cudagraphs mode, the allocated output buffers are stored for reuse + if (CUDAGRAPHS_MODE) { + compiled_engine->output_buffers[pyt_idx] = outputs[pyt_idx]; + } TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), "Error while setting the output tensor address"); } } - { - std::unique_ptr enqueue_profiler_guard; - if (compiled_engine->profile_execution) { - enqueue_profiler_guard = - std::make_unique(compiled_engine->enqueue_profile_path); - } + std::unique_ptr enqueue_profiler_guard; + + // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it. + std::unique_lock lock(compiled_engine->mu); + + if (!CUDAGRAPHS_MODE) { + // If not in cudagraphs mode, proceed with enqueueV3 as normal c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index()); - // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it. - std::unique_lock lock(compiled_engine->mu); compiled_engine->exec_ctx->enqueueV3(stream); - if (compiled_engine->profile_execution) { - LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); - dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); - compiled_engine->dump_engine_layer_info(); + } else if (need_cudagraphs_record) { + // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph + + // Cudagraphs cannot record on the default stream, so use an alternate + c10::cuda::CUDAStream stream = c10::cuda::getStreamFromPool(true, inputs[0].device().index()); + c10::cuda::CUDAStreamGuard guard(stream); + compiled_engine->exec_ctx->enqueueV3(stream); + + compiled_engine->cudagraph.capture_begin(); + compiled_engine->exec_ctx->enqueueV3(stream); + compiled_engine->cudagraph.capture_end(); + + // Reset the stream to its original setting + guard.reset_stream(guard.original_stream()); + + } else { + // If the cudagraph has already been recorded, copy the input buffers and replay it + for (auto i = 0; i < inputs.size(); i++) { + compiled_engine->input_buffers[i].copy_(inputs[i], true); } + compiled_engine->cudagraph.replay(); } - return outputs; + + std::vector model_outputs(compiled_engine->num_io.second); + + // In cudagraphs mode, the output buffers can be reused, so they must + // be cloned before providing them to the user to avoid data corruption + if (CUDAGRAPHS_MODE) { + for (auto i = 0; i < compiled_engine->output_buffers.size(); i++) { + model_outputs[i] = compiled_engine->output_buffers[i].clone(); + } + } else { + model_outputs = outputs; + } + + if (compiled_engine->profile_execution) { + LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); + dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); + compiled_engine->dump_engine_layer_info(); + } + + return model_outputs; } } // namespace runtime diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 9ac5af5d05..483f7f3a90 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -122,6 +122,8 @@ TORCH_LIBRARY(tensorrt, m) { m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; }); + m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; }); + m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; }); m.def("set_logging_level", [](int64_t level) -> void { util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level)); }); diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index 24eba16cc3..b933e081c7 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -8,6 +8,7 @@ namespace core { namespace runtime { bool MULTI_DEVICE_SAFE_MODE = false; +bool CUDAGRAPHS_MODE = false; c10::optional get_most_compatible_device( const RTDevice& target_device, @@ -129,6 +130,14 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; } +bool get_cudagraphs_mode() { + return CUDAGRAPHS_MODE; +} + +void set_cudagraphs_mode(bool cudagraphs_mode) { + CUDAGRAPHS_MODE = cudagraphs_mode; +} + namespace { static DeviceList cuda_device_list; } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index e48357503d..3e21b249a8 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -17,6 +17,8 @@ namespace runtime { using EngineID = int64_t; const std::string ABI_VERSION = "5"; extern bool MULTI_DEVICE_SAFE_MODE; +extern bool CUDAGRAPHS_MODE; + typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, @@ -43,6 +45,10 @@ bool get_multi_device_safe_mode(); void set_multi_device_safe_mode(bool multi_device_safe_mode); +bool get_cudagraphs_mode(); + +void set_cudagraphs_mode(bool multi_device_safe_mode); + class DeviceList { using DeviceMap = std::unordered_map; DeviceMap device_list; diff --git a/docsrc/index.rst b/docsrc/index.rst index 0eadef1d14..0e13ef8abb 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -113,6 +113,7 @@ Tutorials tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion + tutorials/_rendered_examples/dynamo/torch_export_cudagraphs tutorials/_rendered_examples/dynamo/custom_kernel_plugins tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion diff --git a/docsrc/user_guide/runtime.rst b/docsrc/user_guide/runtime.rst index 4b9f3688a3..c897ea1f78 100644 --- a/docsrc/user_guide/runtime.rst +++ b/docsrc/user_guide/runtime.rst @@ -68,3 +68,27 @@ multi-device safe mode is to use Python threads. Each thread is responsible for on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is responsible (can be set via ``torch.cuda.set_device(...)``). In this way, multiple threads can be used in the same Python script without needing to switch CUDA contexts and incur performance overhead. + +Cudagraphs Mode +--------------- + +Cudagraphs mode is a setting in Torch-TensorRT which allows the user to determine whether +the runtime uses cudagraphs to accelerate inference in certain cases. + +Cudagraphs can accelerate certain models by reducing kernel overheads, as documented further [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/). + +.. code-block:: python + + # Enables Cudagraphs Mode + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + # Disables Cudagraphs Mode [Default Behavior] + torch_tensorrt.runtime.set_cudagraphs_mode(False) + + # Enables Cudagraphs Mode, then resets the mode to its prior setting + with torch_tensorrt.runtime.enable_cudagraphs(): + ... + +In the current implementation, use of a new input shape (for instance in dynamic shape +cases), will cause the cudagraph to be re-recorded. Cudagraph recording is generally +not latency intensive, and future improvements include caching cudagraphs for multiple input shapes. diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index 89c997abdb..dac0cf1161 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -10,6 +10,7 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API * :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` +* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"` * :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines * :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights * :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` diff --git a/examples/dynamo/torch_export_cudagraphs.py b/examples/dynamo/torch_export_cudagraphs.py new file mode 100644 index 0000000000..db7041b94d --- /dev/null +++ b/examples/dynamo/torch_export_cudagraphs.py @@ -0,0 +1,69 @@ +""" +.. _torch_export_cudagraphs: + +Torch Export with Cudagraphs +====================================================== + +This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torchvision.models as models + +import torch_tensorrt + +# %% +# Compilation with `torch_tensorrt.compile` Using Default Settings +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# We begin by defining and initializing a model +model = models.resnet18(pretrained=True).eval().to("cuda") + +# Define sample inputs +inputs = torch.randn((16, 3, 224, 224)).cuda() + +# %% + +# Next, we compile the model using torch_tensorrt.compile +# We use the `ir="dynamo"` flag here, and `ir="torch_compile"` should +# work with cudagraphs as well. +opt = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=torch_tensorrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(16, 3, 224, 224), + dtype=torch.float, + name="x", + ), +) + +# %% +# Inference using the Cudagraphs Integration +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# We can enable the cudagraphs API with a context manager +with torch_tensorrt.runtime.enable_cudagraphs(): + out_trt = opt(inputs) + +# Alternatively, we can set the cudagraphs mode for the session +torch_tensorrt.runtime.set_cudagraphs_mode(True) +out_trt = opt(inputs) + +# We can also turn off cudagraphs mode and perform inference as normal +torch_tensorrt.runtime.set_cudagraphs_mode(False) +out_trt = opt(inputs) + +# %% + +# If we provide new input shapes, cudagraphs will re-record the graph +inputs_2 = torch.randn((8, 3, 224, 224)).cuda() +inputs_3 = torch.randn((4, 3, 224, 224)).cuda() + +with torch_tensorrt.runtime.enable_cudagraphs(): + out_trt_2 = opt(inputs_2) + out_trt_3 = opt(inputs_3) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 78395b8943..f570d44a00 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -2,10 +2,10 @@ import logging from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch -import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -18,7 +18,7 @@ from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt +import torch_tensorrt logger = logging.getLogger(__name__) @@ -43,6 +43,17 @@ def __init__( # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() + self.input_buffers: List[torch.Tensor] = [] + self.output_buffers: List[torch.Tensor] = [] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + # {shape: cudagraph} + # limitation on CG + self.shape_key: Optional[str] = None + + # See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98 + # Unused currently - to be used by Dynamic Shape support implementation + self.memory_pool = None + self.engine = engine self.input_names = input_names if input_names is not None else [] self.output_names = output_names if output_names is not None else [] @@ -85,6 +96,10 @@ def _initialize(self) -> None: for output_name in self.output_names ] + if torch_tensorrt.runtime.get_cudagraphs_mode(): + self.cudagraph = torch.cuda.CUDAGraph() + self.graph_capturer = torch.cuda.graphs.graph(self.cudagraph) + def _check_initialized(self) -> None: if not self.initialized: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") @@ -151,116 +166,146 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ): self._check_initialized() - # If in safe mode, check at each iteration for for whether a switch is required - if ( - torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE - ): - curr_device_id = torch.cuda.current_device() - curr_device_properties = torch.cuda.get_device_properties( - curr_device_id - ) - logger.debug(f"Current Device: cuda:{curr_device_id}") - - # If a switch is required, move all inputs to new device and set as active device - if _is_switch_required( - curr_device_id, - self.target_device_id, - curr_device_properties, - self.target_device_properties, + cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + need_cudagraphs_record = ( + cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs) + ) + + # If cudagraphs is not enabled or the recorded graph shapes are either uninitialized or invalid + if not cudagraphs_enabled or need_cudagraphs_record: + # If in safe mode, check at each iteration for for whether a switch is required + if ( + torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE ): - device_id, _ = _select_rt_device( + curr_device_id = torch.cuda.current_device() + curr_device_properties = torch.cuda.get_device_properties( + curr_device_id + ) + logger.debug(f"Current Device: cuda:{curr_device_id}") + + # If a switch is required, move all inputs to new device and set as active device + if _is_switch_required( curr_device_id, self.target_device_id, + curr_device_properties, self.target_device_properties, - ) - device = torch.device(device_id) - torch.cuda.set_device(device_id) - - inputs = tuple([tensor.to(device) for tensor in inputs]) - logger.warning(f"Moved all input Tensors to cuda:{device_id}") - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessInputs" - ) - if self.profiling_enabled - else nullcontext() - ): - assert len(inputs) == len( - self.input_names - ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." - - for i, input_name in enumerate(self.input_names): - if not contiguous_inputs[i].is_cuda: - logger.warning( - f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " - "This tensor is being moved by the runtime but for performance considerations, " - "ensure your inputs are all on GPU and open an issue here " - "(https://github.com/pytorch/TensorRT/issues) if this warning persists." - ) - contiguous_inputs = ( - contiguous_inputs[:i] - + [contiguous_inputs[i].cuda()] - + contiguous_inputs[i + 1 :] + ): + device_id, _ = _select_rt_device( + curr_device_id, + self.target_device_id, + self.target_device_properties, ) + device = torch.device(device_id) + torch.cuda.set_device(device_id) - assert ( - contiguous_inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - - # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers - # as per TensorRT requirements - if self.engine.is_shape_inference_io(input_name): - # Shape tensor inputs are casted to int64 explicitly - # Currently Torch CPU pointers are not working; numpy pointers are used instead - # to refer to underlying memory - inputs_cpu = ( - contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() - ) - self.context.set_tensor_address( - input_name, inputs_cpu.ctypes.data - ) + contiguous_inputs = [ + tensor.to(device) for tensor in contiguous_inputs + ] + logger.warning(f"Moved all input Tensors to cuda:{device_id}") + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): + assert len(contiguous_inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." + + if cudagraphs_enabled: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + contiguous_inputs = [ + i.contiguous().clone() for i in contiguous_inputs + ] else: - self.context.set_input_shape( - input_name, tuple(contiguous_inputs[i].shape) - ) - self.context.set_tensor_address( - input_name, contiguous_inputs[i].data_ptr() + contiguous_inputs = [i.contiguous() for i in contiguous_inputs] + bindings = [] + for i, input_name in enumerate(self.input_names): + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] + ) + + assert ( + contiguous_inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." + + # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers + # as per TensorRT requirements + if self.engine.is_shape_inference_io(input_name): + # Shape tensor inputs are casted to int64 explicitly + # Currently Torch CPU pointers are not working; numpy pointers are used instead + # to refer to underlying memory + inputs_cpu = ( + contiguous_inputs[i] + .cpu() + .to(torch.int64) + .numpy() + .copy() + ) + self.context.set_tensor_address( + input_name, inputs_cpu.ctypes.data + ) + bindings.append(inputs_cpu.ctypes.data) + else: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) + self.context.set_tensor_address( + input_name, contiguous_inputs[i].data_ptr() + ) + bindings.append(contiguous_inputs[i].data_ptr()) + + # Check if input shapes can be inferred. + uninferred_input_names = self.context.infer_shapes() + if uninferred_input_names: + logger.warning( + f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ + This could happen if the input tensor addresses/shapes haven't been configured correctly" ) - # Check if input shapes can be inferred. - uninferred_input_names = self.context.infer_shapes() - if uninferred_input_names: - logger.warning( - f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ - This could happen if the input tensor addresses/shapes haven't been configured correctly" + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" ) + if self.profiling_enabled + else nullcontext() + ): + # create output tensors + outputs: List[torch.Tensor] = [] - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessOutputs" - ) - if self.profiling_enabled - else nullcontext() - ): - # create output tensors - outputs: List[torch.Tensor] = [] + for i, output_name in enumerate(self.output_names): + shape = tuple(self.context.get_tensor_shape(output_name)) - for i, output_name in enumerate(self.output_names): - shape = tuple(self.context.get_tensor_shape(output_name)) + if DYNAMIC_DIM in shape: + raise ValueError( + "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." + ) - if DYNAMIC_DIM in shape: - raise ValueError( - "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." + output = torch.empty( + size=shape, + dtype=self.output_dtypes[i].to(torch.dtype), + device=torch.cuda.current_device(), ) + bindings.append(output.data_ptr()) + outputs.append(output) - output = torch.empty( - size=shape, - dtype=self.output_dtypes[i].to(torch.dtype), - device=torch.cuda.current_device(), + # Assign tensor address appropriately + for idx in range(self.engine.num_io_tensors): + self.context.set_tensor_address( + self.engine.get_tensor_name(idx), bindings[idx] ) - self.context.set_tensor_address(output_name, output.data_ptr()) - outputs.append(output) with ( torch.autograd.profiler.record_function( @@ -269,12 +314,39 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.profiling_enabled else nullcontext() ): - self.context.execute_async_v3(torch.cuda.current_stream().cuda_stream) - if len(outputs) == 1: - return outputs[0] + if not cudagraphs_enabled: + self.context.execute_async_v3( + torch.cuda.current_stream().cuda_stream + ) + + elif need_cudagraphs_record: + self.input_buffers = list(contiguous_inputs) + self.output_buffers = list(outputs) + + current_stream = self.graph_capturer.capture_stream + + self.context.execute_async_v3(current_stream.cuda_stream) + current_stream.synchronize() + + with self.graph_capturer: + self.context.execute_async_v3(current_stream.cuda_stream) + + else: + for idx, input_tensor in enumerate(inputs): + self.input_buffers[idx].copy_(input_tensor, non_blocking=True) - return tuple(outputs) + self.cudagraph.replay() # type: ignore + + if cudagraphs_enabled: + model_outputs = tuple(output.clone() for output in self.output_buffers) + else: + model_outputs = tuple(outputs) + + if len(model_outputs) == 1: + return model_outputs[0] + + return model_outputs def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: """ @@ -307,3 +379,23 @@ def get_layer_info(self) -> str: trt.LayerInformationFormat.JSON ) return engine_json + + def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: + """ + Validates the input shapes of the forward function + versus the version currently active for the + """ + # Representation of input shapes to a given model + # Shapes are concatenated as so: + # x: (3, 4), y: (4, 5) --> Key: (3, 4)(4, 5) + new_shape_key = "".join(str(tuple(t.shape)) for t in inputs) + + # If the new shape key differs from the existing one, + # invalidate the old shape key and remove the CUDAGraph + if new_shape_key != self.shape_key: + logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}") + self.shape_key = new_shape_key + self.cudagraph.reset() # type: ignore + return False + + return True diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index d202c897f6..29b1dad637 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -3,4 +3,5 @@ TorchTensorRTModule, ) +from .cudagraphs import enable_cudagraphs, get_cudagraphs_mode, set_cudagraphs_mode from .multi_device_safe_mode import set_multi_device_safe_mode diff --git a/py/torch_tensorrt/runtime/cudagraphs.py b/py/torch_tensorrt/runtime/cudagraphs.py new file mode 100644 index 0000000000..95c73a885d --- /dev/null +++ b/py/torch_tensorrt/runtime/cudagraphs.py @@ -0,0 +1,55 @@ +import logging +from importlib.util import find_spec +from typing import Any + +import torch + +if find_spec("torch_tensorrt._C") is not None: + _PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode() +else: + _PY_RT_CUDAGRAPHS = False + + +logger = logging.getLogger(__name__) + + +def set_cudagraphs_mode(mode: bool) -> None: + # Set new cudagraphs mode for Python + global _PY_RT_CUDAGRAPHS + _PY_RT_CUDAGRAPHS = mode + + # Set new mode for C++ + if find_spec("torch_tensorrt._C") is not None: + torch.ops.tensorrt.set_cudagraphs_mode(mode) + + logger.info(f"Set Cudagraphs usage to {mode}") + + +def get_cudagraphs_mode() -> bool: + # Get cudagraphs mode for Python + global _PY_RT_CUDAGRAPHS + return _PY_RT_CUDAGRAPHS # type: ignore + + +class _CudagraphsContextManager(object): + """Helper class used in conjunction with `enable_cudagraphs` + + Used to enable cudagraphs as a context manager + """ + + def __init__(self) -> None: + global _PY_RT_CUDAGRAPHS + self.old_mode = _PY_RT_CUDAGRAPHS + + def __enter__(self) -> "_CudagraphsContextManager": + # Enable cudagraphs + set_cudagraphs_mode(True) + return self + + def __exit__(self, *args: Any) -> None: + # Set cudagraphs back to old mode + set_cudagraphs_mode(self.old_mode) + + +def enable_cudagraphs() -> _CudagraphsContextManager: + return _CudagraphsContextManager() diff --git a/tests/py/dynamo/runtime/test_cudagraphs.py b/tests/py/dynamo/runtime/test_cudagraphs.py new file mode 100644 index 0000000000..9b5e64ab9a --- /dev/null +++ b/tests/py/dynamo/runtime/test_cudagraphs.py @@ -0,0 +1,112 @@ +import unittest + +import torch +from torch.testing._internal.common_utils import TestCase, run_tests + +import torch_tensorrt + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + + +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", +) +class TestCudagraphs(TestCase): + def test_cudagraphs_on(self): + torch_tensorrt.runtime.set_cudagraphs_mode(True) + self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode()) + + def test_cudagraphs_off(self): + torch_tensorrt.runtime.set_cudagraphs_mode(False) + self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) + + def test_cudagraphs_context(self): + with torch_tensorrt.runtime.enable_cudagraphs(): + self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode()) + self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) + + def test_cudagraphs_enabled_inference_python(self): + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode Python TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_cudagraphs_enabled_inference_cpp(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=False, + ) + + with torch_tensorrt.runtime.enable_cudagraphs(): + optimized_model_results = optimized_model(*inputs).detach().cpu() + + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode C++ TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests() From 9e67471c13921da6c2fee5db48727a3cb9835df9 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 28 Jun 2024 00:03:56 -0700 Subject: [PATCH 2/9] feat: Utilize non-default stream for runtimes - Add support for non-default streams --- core/runtime/TRTEngine.cpp | 5 +++ core/runtime/TRTEngine.h | 2 ++ core/runtime/execute_engine.cpp | 31 +++++++++++++------ .../runtime/_PythonTorchTensorRTModule.py | 31 ++++++++++++------- py/torch_tensorrt/runtime/cudagraphs.py | 4 ++- tests/py/dynamo/runtime/test_cudagraphs.py | 8 ++--- 6 files changed, 55 insertions(+), 26 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 8f63563c58..4e042c1ca1 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -2,6 +2,7 @@ #include #include "NvInfer.h" +#include "c10/cuda/CUDAStream.h" #include "torch/csrc/jit/frontend/function_schema_parser.h" #include "torch/cuda.h" @@ -70,6 +71,10 @@ TRTEngine::TRTEngine( multi_gpu_device_check(); set_rt_device(device_info); + // Set active stream to high-priority, non-default stream + active_stream = c10::cuda::getStreamFromPool(true, device_info.id); + c10::cuda::setCurrentCUDAStream(active_stream); + rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); name = slugify(mod_name); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 0e76b63179..1c900e3f34 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -9,6 +9,7 @@ #include "ATen/core/function_schema.h" #include "ATen/cuda/CUDAGraph.h" #include "NvInfer.h" +#include "c10/cuda/CUDAStream.h" #include "torch/custom_class.h" #include "core/runtime/TRTEngineProfiler.h" @@ -69,6 +70,7 @@ struct TRTEngine : torch::CustomClassHolder { // CUDAGraph-Related Functionality at::cuda::CUDAGraph cudagraph = {}; + at::cuda::CUDAStream active_stream = c10::cuda::getDefaultCUDAStream(); std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key; diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index d5bb1388fa..d28c5930da 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -1,4 +1,3 @@ -#include "c10/cuda/CUDAGuard.h" #include "c10/cuda/CUDAStream.h" #include "torch/csrc/jit/runtime/custom_operator.h" @@ -64,9 +63,20 @@ bool _cudagraphs_validate_shapes(std::vector inputs, c10::intrusive_ // invalidate the existing cudagraphs object // Populate the shape key for the inputs + // x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) std::stringstream new_shape_key_ss; for (auto input : inputs) { - new_shape_key_ss << input.sizes(); + new_shape_key_ss << "("; + auto sizes = input.sizes(); + auto rank = input.sizes().size(); + for (auto i = 0; i < rank; i++) { + new_shape_key_ss << sizes[i]; + // For all but the final dimension in the shape key, add comma separator + if (i < rank - 1) { + new_shape_key_ss << ","; + } + } + new_shape_key_ss << ")"; } auto new_shape_key = new_shape_key_ss.str(); @@ -128,6 +138,10 @@ std::vector execute_engine(std::vector inputs, c10::intr select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); set_rt_device(device); + // Update active stream based on new device + compiled_engine->active_stream = c10::cuda::getStreamFromPool(true, device.id); + c10::cuda::setCurrentCUDAStream(compiled_engine->active_stream); + // Target device is new device target_device += std::to_string(device.id); @@ -157,6 +171,8 @@ std::vector execute_engine(std::vector inputs, c10::intr } } + // this is a buffer to store shape tensor input addresses throughout the runtime scope + std::list> inputShapeTensorValues; { std::unique_ptr input_profiler_guard; if (compiled_engine->profile_execution) { @@ -252,23 +268,18 @@ std::vector execute_engine(std::vector inputs, c10::intr if (!CUDAGRAPHS_MODE) { // If not in cudagraphs mode, proceed with enqueueV3 as normal - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index()); - compiled_engine->exec_ctx->enqueueV3(stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); } else if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph // Cudagraphs cannot record on the default stream, so use an alternate c10::cuda::CUDAStream stream = c10::cuda::getStreamFromPool(true, inputs[0].device().index()); - c10::cuda::CUDAStreamGuard guard(stream); - compiled_engine->exec_ctx->enqueueV3(stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); compiled_engine->cudagraph.capture_begin(); - compiled_engine->exec_ctx->enqueueV3(stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); compiled_engine->cudagraph.capture_end(); - // Reset the stream to its original setting - guard.reset_stream(guard.original_stream()); - } else { // If the cudagraph has already been recorded, copy the input buffers and replay it for (auto i = 0; i < inputs.size(); i++) { diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f570d44a00..cf8c5405c3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -46,8 +46,9 @@ def __init__( self.input_buffers: List[torch.Tensor] = [] self.output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None - # {shape: cudagraph} - # limitation on CG + self.active_stream: Optional[torch.cuda.Stream] = None + + # TODO: Make the below a Dictionary {shape: cudagraph} self.shape_key: Optional[str] = None # See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98 @@ -100,6 +101,10 @@ def _initialize(self) -> None: self.cudagraph = torch.cuda.CUDAGraph() self.graph_capturer = torch.cuda.graphs.graph(self.cudagraph) + # Set the active stream using the current device, with a high priority flag + self.active_stream = torch.cuda.Stream(torch.cuda.current_device(), priority=-1) + torch.cuda.set_stream(self.active_stream) + def _check_initialized(self) -> None: if not self.initialized: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") @@ -195,9 +200,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self.target_device_id, self.target_device_properties, ) + + # Update current device device = torch.device(device_id) torch.cuda.set_device(device_id) + # Update current stream + self.active_stream = torch.cuda.Stream(device, priority=-1) + torch.cuda.set_stream(self.active_stream) + contiguous_inputs = [ tensor.to(device) for tensor in contiguous_inputs ] @@ -316,21 +327,19 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ): if not cudagraphs_enabled: - self.context.execute_async_v3( - torch.cuda.current_stream().cuda_stream - ) + self.context.execute_async_v3(self.active_stream) elif need_cudagraphs_record: self.input_buffers = list(contiguous_inputs) self.output_buffers = list(outputs) - current_stream = self.graph_capturer.capture_stream + graph_capturer_stream = self.graph_capturer.capture_stream - self.context.execute_async_v3(current_stream.cuda_stream) - current_stream.synchronize() + self.context.execute_async_v3(graph_capturer_stream.cuda_stream) + graph_capturer_stream.synchronize() with self.graph_capturer: - self.context.execute_async_v3(current_stream.cuda_stream) + self.context.execute_async_v3(graph_capturer_stream.cuda_stream) else: for idx, input_tensor in enumerate(inputs): @@ -387,8 +396,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: """ # Representation of input shapes to a given model # Shapes are concatenated as so: - # x: (3, 4), y: (4, 5) --> Key: (3, 4)(4, 5) - new_shape_key = "".join(str(tuple(t.shape)) for t in inputs) + # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) + new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs) # If the new shape key differs from the existing one, # invalidate the old shape key and remove the CUDAGraph diff --git a/py/torch_tensorrt/runtime/cudagraphs.py b/py/torch_tensorrt/runtime/cudagraphs.py index 95c73a885d..56f8b82a73 100644 --- a/py/torch_tensorrt/runtime/cudagraphs.py +++ b/py/torch_tensorrt/runtime/cudagraphs.py @@ -4,7 +4,9 @@ import torch -if find_spec("torch_tensorrt._C") is not None: +import torch_tensorrt + +if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: _PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode() else: _PY_RT_CUDAGRAPHS = False diff --git a/tests/py/dynamo/runtime/test_cudagraphs.py b/tests/py/dynamo/runtime/test_cudagraphs.py index 9b5e64ab9a..07e1bcf615 100644 --- a/tests/py/dynamo/runtime/test_cudagraphs.py +++ b/tests/py/dynamo/runtime/test_cudagraphs.py @@ -8,10 +8,6 @@ from ..testing_utilities import DECIMALS_OF_AGREEMENT -@unittest.skipIf( - not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, - "Torch-TensorRT runtime is not available", -) class TestCudagraphs(TestCase): def test_cudagraphs_on(self): torch_tensorrt.runtime.set_cudagraphs_mode(True) @@ -66,6 +62,10 @@ def forward(self, x): ) torch._dynamo.reset() + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", + ) def test_cudagraphs_enabled_inference_cpp(self): class SampleModel(torch.nn.Module): def forward(self, x): From df2a0b98dda9b2ba34cd38733ef5ac5a362e0e01 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 28 Jun 2024 10:21:51 -0700 Subject: [PATCH 3/9] fix: Input stream number --- WORKSPACE | 2 -- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 9 +++------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 734ce8c85f..225bc0688e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -101,8 +101,6 @@ http_archive( ], ) - - #################################################################################### # Locally installed dependencies (use in cases of custom dependencies or aarch64) #################################################################################### diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index cf8c5405c3..af81fccbee 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -228,11 +228,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if cudagraphs_enabled: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory - contiguous_inputs = [ - i.contiguous().clone() for i in contiguous_inputs - ] - else: - contiguous_inputs = [i.contiguous() for i in contiguous_inputs] + contiguous_inputs = [i.clone() for i in contiguous_inputs] + bindings = [] for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: @@ -327,7 +324,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ): if not cudagraphs_enabled: - self.context.execute_async_v3(self.active_stream) + self.context.execute_async_v3(self.active_stream.cuda_stream) # type: ignore elif need_cudagraphs_record: self.input_buffers = list(contiguous_inputs) From ea98b93b2d0e83bb06de5c25d51f1bfa6f583678 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 28 Jun 2024 10:38:33 -0700 Subject: [PATCH 4/9] fix: Remove unnecessary stream generation --- core/runtime/execute_engine.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index d28c5930da..c22126e40f 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -271,9 +271,6 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); } else if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph - - // Cudagraphs cannot record on the default stream, so use an alternate - c10::cuda::CUDAStream stream = c10::cuda::getStreamFromPool(true, inputs[0].device().index()); compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); compiled_engine->cudagraph.capture_begin(); From 39906b0348942648a211ffe773b9fde58628f789 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 28 Jun 2024 13:08:50 -0700 Subject: [PATCH 5/9] fix: Test failure + new test cases --- core/runtime/execute_engine.cpp | 14 ++- .../dynamo/backend/test_backend_compiler.py | 9 +- tests/py/dynamo/runtime/test_cudagraphs.py | 87 +++++++++++++++++++ 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index c22126e40f..1609af050b 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -1,3 +1,4 @@ +#include "c10/cuda/CUDAGuard.h" #include "c10/cuda/CUDAStream.h" #include "torch/csrc/jit/runtime/custom_operator.h" @@ -271,12 +272,21 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); } else if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph - compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); + + // Cudagraphs cannot record on the current stream, so use an alternate + c10::cuda::CUDAStream recording_stream = c10::cuda::getStreamFromPool(true, inputs[0].device().index()); + c10::cuda::CUDAStreamGuard guard(recording_stream); + + compiled_engine->exec_ctx->enqueueV3(recording_stream); + recording_stream.synchronize(); compiled_engine->cudagraph.capture_begin(); - compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); + compiled_engine->exec_ctx->enqueueV3(recording_stream); compiled_engine->cudagraph.capture_end(); + // Reset the stream to its original setting + guard.reset_stream(guard.original_stream()); + } else { // If the cudagraph has already been recorded, copy the input buffers and replay it for (auto i = 0; i < inputs.size(); i++) { diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 0f138c7100..4c65800f05 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -2,10 +2,11 @@ from copy import deepcopy import torch -import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition +import torch_tensorrt + from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -120,9 +121,11 @@ def forward(self, x, y): torch._dynamo.reset() + model = PartiallySupportedMultiOp().eval().cuda() + # Validate that the results between Torch and Torch-TRT are similar optimized_model = torch_tensorrt.compile( - fx_graph, + model, "torch_compile", inputs, min_block_size=1, @@ -132,7 +135,7 @@ def forward(self, x, y): debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() - torch_model_results = fx_graph(*inputs).detach().cpu() + torch_model_results = model(*inputs).detach().cpu() max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) diff --git a/tests/py/dynamo/runtime/test_cudagraphs.py b/tests/py/dynamo/runtime/test_cudagraphs.py index 07e1bcf615..448e5981c8 100644 --- a/tests/py/dynamo/runtime/test_cudagraphs.py +++ b/tests/py/dynamo/runtime/test_cudagraphs.py @@ -107,6 +107,93 @@ def forward(self, x): ) torch._dynamo.reset() + def test_cudagraphs_enabled_fallback_inference_python(self): + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode Python TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", + ) + def test_cudagraphs_enabled_fallback_inference_cpp(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=False, + ) + + with torch_tensorrt.runtime.enable_cudagraphs(): + optimized_model_results = optimized_model(*inputs).detach().cpu() + + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode C++ TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + if __name__ == "__main__": run_tests() From 85e2eaaa61481848d128cd3862a22b4687a8a382 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 28 Jun 2024 14:38:22 -0700 Subject: [PATCH 6/9] fix: Only set current stream if default --- core/runtime/TRTEngine.cpp | 11 ++++++++--- core/runtime/execute_engine.cpp | 11 ++++++++--- .../runtime/_PythonTorchTensorRTModule.py | 18 +++++++++++++----- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 4e042c1ca1..341d8a10e0 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -71,9 +71,14 @@ TRTEngine::TRTEngine( multi_gpu_device_check(); set_rt_device(device_info); - // Set active stream to high-priority, non-default stream - active_stream = c10::cuda::getStreamFromPool(true, device_info.id); - c10::cuda::setCurrentCUDAStream(active_stream); + // Set active stream to non-default stream + auto current_stream = c10::cuda::getCurrentCUDAStream(device_info.id); + if (current_stream == c10::cuda::getDefaultCUDAStream(device_info.id)) { + active_stream = c10::cuda::getStreamFromPool(false, device_info.id); + c10::cuda::setCurrentCUDAStream(active_stream); + } else { + active_stream = current_stream; + } rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 1609af050b..e44bdf7bd1 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -140,8 +140,13 @@ std::vector execute_engine(std::vector inputs, c10::intr set_rt_device(device); // Update active stream based on new device - compiled_engine->active_stream = c10::cuda::getStreamFromPool(true, device.id); - c10::cuda::setCurrentCUDAStream(compiled_engine->active_stream); + auto current_stream = c10::cuda::getCurrentCUDAStream(device.id); + if (current_stream == c10::cuda::getDefaultCUDAStream(device.id)) { + compiled_engine->active_stream = c10::cuda::getStreamFromPool(false, device.id); + c10::cuda::setCurrentCUDAStream(compiled_engine->active_stream); + } else { + compiled_engine->active_stream = current_stream; + } // Target device is new device target_device += std::to_string(device.id); @@ -274,7 +279,7 @@ std::vector execute_engine(std::vector inputs, c10::intr // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph // Cudagraphs cannot record on the current stream, so use an alternate - c10::cuda::CUDAStream recording_stream = c10::cuda::getStreamFromPool(true, inputs[0].device().index()); + c10::cuda::CUDAStream recording_stream = c10::cuda::getStreamFromPool(false, inputs[0].device().index()); c10::cuda::CUDAStreamGuard guard(recording_stream); compiled_engine->exec_ctx->enqueueV3(recording_stream); diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index af81fccbee..6c94b112a7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -101,9 +101,13 @@ def _initialize(self) -> None: self.cudagraph = torch.cuda.CUDAGraph() self.graph_capturer = torch.cuda.graphs.graph(self.cudagraph) - # Set the active stream using the current device, with a high priority flag - self.active_stream = torch.cuda.Stream(torch.cuda.current_device(), priority=-1) - torch.cuda.set_stream(self.active_stream) + # Set the active stream using the current device + current_stream = torch.cuda.current_stream() + if current_stream == torch.cuda.default_stream(): + self.active_stream = torch.cuda.Stream() + torch.cuda.set_stream(self.active_stream) + else: + self.active_stream = current_stream def _check_initialized(self) -> None: if not self.initialized: @@ -206,8 +210,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . torch.cuda.set_device(device_id) # Update current stream - self.active_stream = torch.cuda.Stream(device, priority=-1) - torch.cuda.set_stream(self.active_stream) + current_stream = torch.cuda.current_stream(device) + if current_stream == torch.cuda.default_stream(device): + self.active_stream = torch.cuda.Stream(device) + torch.cuda.set_stream(self.active_stream) + else: + self.active_stream = current_stream contiguous_inputs = [ tensor.to(device) for tensor in contiguous_inputs From 658daab89f5111311429bda16115dfb08df3e71c Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 28 Jun 2024 17:10:57 -0700 Subject: [PATCH 7/9] fix: Extra buffer assignment --- core/runtime/execute_engine.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index e44bdf7bd1..6868bc47ce 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -225,7 +225,6 @@ std::vector execute_engine(std::vector inputs, c10::intr TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress(name.c_str(), contig_input.data_ptr()), "Error while setting the input tensor address for inputs"); - compiled_engine->input_buffers[i] = contig_input; } } From 3f96963bfacf5b719dbe0eeea6e42f5247b88781 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 28 Jun 2024 17:23:33 -0700 Subject: [PATCH 8/9] fix: Enable cudagraphs for TS path --- core/runtime/TRTEngine.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 341d8a10e0..27880ed302 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -122,7 +122,9 @@ TRTEngine::TRTEngine( num_io = std::make_pair(inputs, outputs); in_binding_names.resize(inputs); + input_buffers.resize(inputs); out_binding_names.resize(outputs); + output_buffers.resize(outputs); for (int64_t x = 0; x < cuda_engine->getNbIOTensors(); x++) { std::string bind_name = cuda_engine->getIOTensorName(x); if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) { From b01a3790a83541a6a9e4840f97e80d8b273ef6d1 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:05:58 -0700 Subject: [PATCH 9/9] Enable CG tests Python + Fallback --- tests/py/dynamo/runtime/test_cudagraphs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/runtime/test_cudagraphs.py b/tests/py/dynamo/runtime/test_cudagraphs.py index 448e5981c8..4d922629c1 100644 --- a/tests/py/dynamo/runtime/test_cudagraphs.py +++ b/tests/py/dynamo/runtime/test_cudagraphs.py @@ -134,7 +134,10 @@ def forward(self, x): torch_executed_ops={"torch.ops.aten.mul.Tensor"}, use_python_runtime=True, ) - optimized_model_results = optimized_model(*inputs).detach().cpu() + + with torch_tensorrt.runtime.enable_cudagraphs(): + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() max_diff = float(