Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: second attempt to support DDS and NonZero op #3388

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Sequence[Input],
CompilationSettings,
Optional[Dict[str, Any]],
bool,
]


Expand Down Expand Up @@ -106,6 +107,7 @@ def pack(
input_specs: Sequence[Input],
compilation_settings: CompilationSettings,
weight_name_map: Optional[Dict[Any, Any]],
engine_is_dds: bool,
) -> bytes:
"""Pack serialized engine, input names, output names, and weight map into a single blob

Expand All @@ -116,7 +118,7 @@ def pack(
input_specs (Sequence[Input]): input specs of TRT engine
compilation_settings (CompilationSettings): compilation settings of TRT engine
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting

engine_is_dds (bool): whether the engine is data-dependent shape
Returns:
bytes: packed blob
"""
Expand All @@ -130,6 +132,7 @@ def pack(
"input_specs": input_specs,
"compilation_settings": settings,
"weight_name_map": weight_name_map,
"engine_is_dds": engine_is_dds,
}
)

Expand All @@ -151,6 +154,7 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit:
unpacked["input_specs"],
unpacked["compilation_settings"],
unpacked["weight_name_map"],
unpacked["engine_is_dds"],
)

def insert(
Expand Down
24 changes: 23 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
CallingConvention,
)
from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_node_io,
Expand Down Expand Up @@ -62,6 +64,7 @@ class TRTInterpreterResult(NamedTuple):
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
engine_is_dds: bool


class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
Expand Down Expand Up @@ -136,6 +139,9 @@ def __init__(
# Engine cache for storing and reusing TRT engines
self.engine_cache = engine_cache

# Whether the engine is data-dependent shape (dds)
self.engine_is_dds: bool = False

def validate_conversion(self) -> Set[str]:
missing_converters: Set[str] = set()

Expand Down Expand Up @@ -575,6 +581,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
self.input_specs,
self.compilation_settings,
self.weight_name_map,
self.engine_is_dds,
),
)

Expand All @@ -589,6 +596,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
cached_engine_input_specs,
engine_compilation_settings,
self.weight_name_map,
self.engine_is_dds,
) = cached_data

setting_compatiblity, incompattible_settings = settings_are_compatible(
Expand Down Expand Up @@ -650,9 +658,20 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
self._input_names,
self._output_names,
self.weight_name_map,
self.engine_is_dds,
)
return None

def check_dds(self, serialized_engine: bytes, output_names: List[str]) -> bool:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

for output_name in output_names:
output_shape = engine.get_tensor_shape(output_name)
if -1 in output_shape:
return True
return False

def run(
self,
strict_type_constraints: bool = False,
Expand Down Expand Up @@ -709,6 +728,8 @@ def run(
)
assert serialized_engine

self.engine_is_dds = self.check_dds(serialized_engine, self._output_names)

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
Expand All @@ -735,6 +756,7 @@ def run(
self._input_names,
self._output_names,
self.weight_name_map,
self.engine_is_dds,
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def infer_module_output_dtypes(
"""
outputs = [node for node in module.graph.nodes if node.op == "output"]
outputs = outputs[0].args
return get_output_dtypes(outputs, truncate_double)
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]


def interpret_module_to_result(
Expand Down Expand Up @@ -112,4 +112,5 @@ def convert_module(
name=name,
settings=settings,
weight_name_map=interpreter_result.weight_name_map,
engine_is_dds=interpreter_result.engine_is_dds,
)
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3582,3 +3582,20 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
def aten_ops_nonzero(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.nonzero(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,18 @@ def native_dropout(
mask = np.ones(input_val.shape, dtype=bool)
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
return identity_layer.get_output(0), mask


def nonzero(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
non_zero_layer = ctx.net.add_non_zero(input_val)
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
shuffle_layer.first_transpose = trt.Permutation([1, 0])
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
return shuffle_layer.get_output(0)
Loading
Loading