Skip to content

Commit

Permalink
use IOutputAllocator
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Feb 4, 2025
1 parent 54e36db commit e7a0faf
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 48 deletions.
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3582,3 +3582,20 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
def aten_ops_nonzero(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.nonzero(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,18 @@ def native_dropout(
mask = np.ones(input_val.shape, dtype=bool)
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
return identity_layer.get_output(0), mask


def nonzero(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
non_zero_layer = ctx.net.add_non_zero(input_val)
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
shuffle_layer.first_transpose = trt.Permutation([1, 0])
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
return shuffle_layer.get_output(0)
118 changes: 70 additions & 48 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import Platform, dtype
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER
from torch_tensorrt.runtime._utils import (
_is_switch_required,
Expand All @@ -23,6 +22,41 @@
logger = logging.getLogger(__name__)


class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc]
def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None:
trt.IOutputAllocator.__init__(self)
self.buffers: Dict[str, torch.Tensor] = {}
self.shapes: Dict[str, Tuple[int, ...]] = {}
self.dtypes: Dict[str, torch.dtype] = output_dtypes

def reallocate_output_async(
self,
tensor_name: str,
memory: int,
size: int,
alignment: int,
stream: torch.cuda.Stream,
) -> Any:
shape = (size,)
if tensor_name not in self.buffers:
self.buffers[tensor_name] = torch.empty(
shape,
dtype=self.dtypes[tensor_name],
device=torch.cuda.current_device(),
)
else:
if self.buffers[tensor_name].shape != shape:
self.buffers[tensor_name] = torch.empty(
shape,
dtype=self.dtypes[tensor_name],
device=torch.cuda.current_device(),
)
return self.buffers[tensor_name].data_ptr()

def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None:
self.shapes[tensor_name] = tuple(shape)


class TorchTRTRuntimeStates:
def __init__(self, new_cudagraphs: bool):
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
Expand Down Expand Up @@ -128,7 +162,6 @@ def __init__(

self.name = name
self._input_buffers: List[torch.Tensor] = []
self._output_buffers: List[torch.Tensor] = []
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None
Expand All @@ -147,6 +180,8 @@ def __init__(
self.output_names = (
output_binding_names if output_binding_names is not None else []
)
self.output_allocator: Optional[DynamicOutputAllocator] = None

self.initialized = False
self.target_device_id = (
settings.device.gpu_id
Expand Down Expand Up @@ -320,7 +355,7 @@ def setup_input_tensors(
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# For shape tensors, we use CPU pointers; for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
# Shape tensor inputs are casted to int64 explicitly
Expand All @@ -342,18 +377,18 @@ def setup_input_tensors(
input_name, contiguous_inputs[i].data_ptr()
)

def create_output_tensors(self) -> List[torch.Tensor]:
# create output tensors
outputs: List[torch.Tensor] = []
def setup_output_allocator(self) -> None:
if self.output_allocator is None:
output_dtypes_dict = {}
for o, output_name in enumerate(self.output_names):
output_dtypes_dict[output_name] = self.output_dtypes[o]
self.output_allocator = DynamicOutputAllocator(output_dtypes_dict)

for o, _ in enumerate(self.output_names):
output = torch.empty(
size=self.output_shapes[o],
dtype=self.output_dtypes[o],
device=torch.cuda.current_device(),
)
outputs.append(output)
return outputs
for output_name in self.output_names:
if not self.context.set_output_allocator(
output_name, self.output_allocator
):
raise RuntimeError(f"Failed to set output allocator for {output_name}")

def set_pre_allocated_outputs(self, enable: bool) -> None:
self.use_pre_allocated_outputs = enable
Expand Down Expand Up @@ -387,7 +422,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

if need_cudagraphs_record:
self._input_buffers = [None] * len(self.input_names)
self._output_buffers = [None] * len(self.output_names)

# If in safe mode, check at each iteration for whether a switch is required
if (
Expand Down Expand Up @@ -447,36 +481,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
"PythonTorchTensorRTModule:ProcessOutputAllocators"
)
if self.profiling_enabled
else nullcontext()
):
if can_use_pre_allocated_outputs:
outputs = self.pre_allocated_outputs
else:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]
if DYNAMIC_DIM in self.output_shapes:
raise ValueError(
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
)
outputs = self.create_output_tensors()

for o, output_name in enumerate(self.output_names):
if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()

if cudagraphs_enabled:
self.context.set_tensor_address(
output_name, self._output_buffers[o].data_ptr()
)
else:
self.context.set_tensor_address(
output_name, outputs[o].data_ptr()
)
self.setup_output_allocator()

with (
torch.autograd.profiler.record_function(
Expand All @@ -495,6 +505,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self._engine_stream.wait_stream(self._caller_stream)

with torch.cuda.stream(self._engine_stream):

if cudagraphs_enabled:
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()
Expand All @@ -507,7 +518,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)
) # The OutputAllocator is called by execute_async_v3()

if self.profiling_enabled:
import tempfile
Expand All @@ -524,12 +535,23 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

self._caller_stream.wait_stream(self._engine_stream)

if self.use_pre_allocated_outputs:
self.pre_allocated_outputs = self.create_output_tensors()

if cudagraphs_enabled:
for idx, o in enumerate(outputs):
o.copy_(self._output_buffers[idx])
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
outputs = []
for o, output_name in enumerate(self.output_names):
assert self.output_allocator is not None
shape = self.output_allocator.shapes.get(output_name, None)
self.output_shapes[o] = shape
dtype = self.output_dtypes[o]
output = self.output_allocator.buffers.get(output_name, None).clone().detach()
prod = int(torch.prod(torch.tensor(shape)))
output = output.reshape(-1).view(dtype)[:prod].reshape(shape)
outputs.append(output)

if len(outputs) == 1:
return outputs[0]
Expand Down
32 changes: 32 additions & 0 deletions tests/py/dynamo/conversion/test_nonzero_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestNonZeroConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.int),
((1, 20), torch.int32),
((2, 3), torch.int64),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_non_zero_float(self, input_shape, dtype):
class NonZero(nn.Module):
def forward(self, input):
return torch.ops.aten.nonzero.default(input)

inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
self.run_test(
NonZero(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit e7a0faf

Please sign in to comment.