Skip to content

Commit

Permalink
scatter adding test cases for scatter.value and scatter.src
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Mar 6, 2024
1 parent f48d9f7 commit cc07bb5
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 39 deletions.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def aten_ops_scatter_value(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter_value(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


Expand All @@ -719,7 +719,7 @@ def aten_ops_scatter_src(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter_src(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


Expand Down
49 changes: 35 additions & 14 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcastable,
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
to_numpy,
Expand All @@ -20,6 +21,7 @@
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -378,8 +380,8 @@ def scatter_value(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Shape,
index: Shape,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
value: float,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
Expand All @@ -389,26 +391,34 @@ def scatter_value(
)
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index.shape)
if not (isinstance(index, TRTTensor)):
index = get_trt_tensor(ctx, index, f"_index_tensor")
if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
ranks = len(input_shape)
dim = get_positive_dim(cast(int, dim), ranks)
dim = get_positive_dim(dim, len(input_shape))
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

input_dims = len(input.shape)
input_dims = len(input_shape)
for i in range(0, input_dims):
if index[i] >= input.shape[i]:
if i != dim and (index_shape[i] >= input.shape[i]):
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
f"cannot have index size greater than the input size along dimension {dim}"
)
value_tensor = value * torch.ones(index.shape)

value_tensor = get_trt_tensor(
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
)
value_tensor = cast_trt_tensor(
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
)
scatter_layer = ctx.net.add_scatter(
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
input, index, value_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.set_axis(dim)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out
Expand All @@ -432,6 +442,8 @@ def scatter_src(
input_shape = input.shape
index_shape = index.shape
src_shape = src.shape
if not (isinstance(index, TRTTensor)):
index = get_trt_tensor(ctx, index, f"_index_tensor")
if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
if len(index_shape) != len(src_shape):
Expand All @@ -445,14 +457,23 @@ def scatter_src(
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

for i in range(0, input_dims):
if index[i] >= input.shape[i]:
if i != dim and (index_shape[i] >= input.shape[i]):
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
f"cannot have index size greater than the input size along dimension {dim}"
)
input_dtype = input.dtype
# required for cases where src is a constant
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
if input_dtype != src_dtype:
raise RuntimeError(f"The type of input and src should be made")
src_tensor = src
if not (isinstance(src, TRTTensor)):
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")

scatter_layer = ctx.net.add_scatter(
input, index, src, trt.tensorrt.ScatterModekELEMENT
input, index, src_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.set_axis(dim)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out
34 changes: 30 additions & 4 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import time
import unittest
Expand All @@ -10,6 +11,9 @@

# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule

Expand Down Expand Up @@ -46,15 +50,17 @@ def setUp(self):
def run_test(
self,
mod,
inputs,
fx_inputs,
trt_interpreter_inputs,
interpreter,
rtol,
atol,
check_dtype=True,
):
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_fx_inputs = []
for i in trt_interpreter_inputs:
cuda_inputs.append(i.cuda())

mod.eval()
Expand All @@ -68,7 +74,7 @@ def run_test(
interpreter_result.output_names,
)

ref_outputs = mod(*inputs)
ref_outputs = mod(*fx_inputs)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -237,15 +243,35 @@ def run_test(
precision=precision, truncate_long_and_double=True
)

num_inputs = len(inputs)
trt_inputs = inputs
for num_input in range(num_inputs):
input = inputs[num_input]
if input.dtype in (torch.int64, torch.float64):
dtype_32bit = (
torch.int32 if (input.dtype == torch.int64) else torch.int64
)
# should we modify graph here to insert clone nodes?
# ideally not required
trt_inputs = (
list(trt_inputs[:num_input])
+ [
input.to(dtype_32bit),
]
+ list(trt_inputs[num_input + 1 :])
)

interp = TRTInterpreter(
mod,
Input.from_tensors(inputs),
Input.from_tensors(trt_inputs),
output_dtypes=output_dtypes,
compilation_settings=compilation_settings,
)

super().run_test(
mod,
inputs,
trt_inputs,
interp,
rtol,
atol,
Expand Down
Loading

0 comments on commit cc07bb5

Please sign in to comment.