Skip to content

Commit

Permalink
bug fixes and support cpp runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Feb 4, 2025
1 parent e7a0faf commit 605bba0
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 58 deletions.
23 changes: 21 additions & 2 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ std::vector<std::string> split(const std::string& str, char delim) {
return strings;
}

DynamicOutputAllocator::DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes)
: dtypes(output_dtypes) {}

void* DynamicOutputAllocator::reallocateOutputAsync(
char const* tensorName,
void* currentMemory,
uint64_t size,
uint64_t alignment,
cudaStream_t stream) {
std::vector<int64_t> shape = {static_cast<int64_t>(size)};
auto it = buffers.find(tensorName);
if (it == buffers.end() || it->second.sizes() != shape) {
buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(c10::kCUDA));
}
return buffers[tensorName].data_ptr();
}

void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept {
shapes[tensorName] = dims;
}

TRTEngine::TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
Expand Down Expand Up @@ -137,7 +158,6 @@ TRTEngine::TRTEngine(
in_binding_names.resize(inputs);
input_buffers.resize(inputs);
out_binding_names.resize(outputs);
output_buffers.resize(outputs);
for (int64_t x = 0; x < cuda_engine->getNbIOTensors(); x++) {
std::string bind_name = cuda_engine->getIOTensorName(x);
if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
Expand Down Expand Up @@ -179,7 +199,6 @@ TRTEngine::TRTEngine(

uint64_t outputs = _out_binding_names.size();
out_binding_names.resize(outputs);
output_buffers.resize(outputs);
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
auto binding_name = _out_binding_names[pyt_idx];
// Check if the binding name provided is in the list of engine's bindings
Expand Down
29 changes: 28 additions & 1 deletion core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,39 @@ struct TorchTRTRuntimeStates {
}
};

class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
public:
DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes);

void* reallocateOutputAsync(
char const* tensorName,
void* currentMemory,
uint64_t size,
uint64_t alignment,
cudaStream_t stream) override;

void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override;

const std::unordered_map<std::string, at::Tensor>& getBuffers() const {
return buffers;
}

const std::unordered_map<std::string, nvinfer1::Dims>& getShapes() const {
return shapes;
}

private:
std::unordered_map<std::string, at::ScalarType> dtypes;
std::unordered_map<std::string, at::Tensor> buffers;
std::unordered_map<std::string, nvinfer1::Dims> shapes;
};

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
std::shared_ptr<DynamicOutputAllocator> output_allocator;
std::pair<uint64_t, uint64_t> num_io;
std::string name;
RTDevice device_info;
Expand Down Expand Up @@ -141,7 +169,6 @@ struct TRTEngine : torch::CustomClassHolder {
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key = "None";
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;
Expand Down
95 changes: 41 additions & 54 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,23 @@ void setup_input_tensors(
}
}
}
std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> compiled_engine) {
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
for (auto output_indices : compiled_engine->out_binding_map) {
// out_binding_map stores TRT_IDX: PYT_IDX
auto pyt_idx = output_indices.second;

std::string name = compiled_engine->out_binding_names[pyt_idx];
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);

auto dims = core::util::toVec(out_shape);
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());

void setup_output_allocator(c10::intrusive_ptr<TRTEngine> compiled_engine) {
if (compiled_engine->output_allocator == nullptr) {
std::unordered_map<std::string, at::ScalarType> output_dtypes_dict;
for (size_t o = 0; o < compiled_engine->out_binding_names.size(); ++o) {
auto name = compiled_engine->out_binding_names[o];
output_dtypes_dict[name] =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
}
compiled_engine->output_allocator = std::make_shared<DynamicOutputAllocator>(output_dtypes_dict);
}

return outputs;
for (const auto& output_name : compiled_engine->out_binding_names) {
if (!compiled_engine->exec_ctx->setOutputAllocator(output_name.c_str(), compiled_engine->output_allocator.get())) {
throw std::runtime_error("Failed to set output allocator for " + output_name);
}
}
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
Expand Down Expand Up @@ -218,7 +219,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}

// Intialize inputs and outputs to be available throughout the succeeding scopes
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);

if (MULTI_DEVICE_SAFE_MODE) {
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
Expand Down Expand Up @@ -287,44 +287,20 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
<< " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly");
}

{ // Output Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
{ // OutputAllocator Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_allocator_profiler_guard;
if (compiled_engine->profile_execution) {
output_profiler_guard =
output_allocator_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
}
if (can_use_pre_allocated_outputs) {
outputs = compiled_engine->pre_allocated_outputs;
} else {
outputs = create_output_tensors(compiled_engine);
}

for (auto output_indices : compiled_engine->out_binding_map) {
auto pyt_idx = output_indices.second;
std::string name = compiled_engine->out_binding_names[pyt_idx];
if (need_cudagraphs_record) {
// If we are recording the cuda graph then we need to update the persistent output buffer
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
}

if (cudagraphs_enabled) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
} else {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
}
}
setup_output_allocator(compiled_engine);
}

auto current_device_id = -1;
if (inputs.size() > 0) {
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
} else if (outputs.size() > 0) {
current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart
} else {
current_device_id = c10::cuda::current_device();
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
Expand Down Expand Up @@ -368,21 +344,32 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
} // End engine exeuction (resets to caller stream)

// Create output buffer for next execution of graph or trt context.
if (compiled_engine->use_pre_allocated_outputs) {
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
}

// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);

if (cudagraphs_enabled) {
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
outputs[o].copy_(compiled_engine->output_buffers[o], false);
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
if (compiled_engine->profile_execution) {
output_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
}
std::vector<at::Tensor> outputs;
for (size_t i = 0; i < compiled_engine->out_binding_names.size(); i++) {
auto name = compiled_engine->out_binding_names[i];
auto dims = compiled_engine->output_allocator->getShapes().at(name);
auto dtype = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
at::Tensor output = compiled_engine->output_allocator->getBuffers().at(name).clone().detach();
int64_t prod = 1;
for (int i = 0; i < dims.nbDims; ++i) {
prod *= dims.d[i];
}
std::vector<int64_t> dims_vec(dims.nbDims);
for (int i = 0; i < dims.nbDims; ++i) {
dims_vec[i] = dims.d[i];
}
output = output.reshape(-1).view(dtype).slice(0, 0, prod).reshape(dims_vec);
outputs.push_back(output);
}

if (compiled_engine->profile_execution) {
Expand Down
44 changes: 43 additions & 1 deletion tests/py/dynamo/conversion/test_nonzero_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -16,7 +17,7 @@ class TestNonZeroConverter(DispatchTestCase):
((2, 3, 4, 5), torch.float),
]
)
def test_non_zero_float(self, input_shape, dtype):
def test_non_zero(self, input_shape, dtype):
class NonZero(nn.Module):
def forward(self, input):
return torch.ops.aten.nonzero.default(input)
Expand All @@ -27,6 +28,47 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"1d",
(1,),
(10,),
(100,),
torch.int32,
),
(
"2d",
(1, 2),
(5, 10),
(20, 40),
torch.float16,
),
(
"3d",
(1, 2, 3),
(5, 10, 20),
(30, 40, 50),
torch.float,
),
]
)
def test_nonzero_dynamic_shape(self, _, min_shape, opt_shape, max_shape, dtype):
class NonZero(nn.Module):
def forward(self, input):
return torch.ops.aten.nonzero.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=dtype,
),
]

self.run_test_with_dynamic_shape(NonZero(), input_specs)


if __name__ == "__main__":
run_tests()

0 comments on commit 605bba0

Please sign in to comment.