-
Notifications
You must be signed in to change notification settings - Fork 362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aten::Index converter #2277
Aten::Index converter #2277
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-08-29 21:00:52.894681+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-08-29 21:03:31.334736+00:00
@@ -69,28 +69,28 @@
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- index: Union[TRTTensor, Sequence[TRTTensor]]
+ index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
for i in len(index):
ind = index[i]
- #FIXME: check if the datatype for the indices needs to be casted to INT32
- #TRTInterpretor should take care
+ # FIXME: check if the datatype for the indices needs to be casted to INT32
+ # TRTInterpretor should take care
adv_indx_indices.append(i)
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
identity_layer.set_output_type(0, trt.int32)
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
- elif (len(tensor_indices) == 1):
+ elif len(tensor_indices) == 1:
indices_tensor = tensor_indices[0]
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
else:
@@ -99,7 +99,5 @@
adv_indx_count = len(adv_indx_indices)
input_shape_layer = network.add_shape(input)
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
input_shape_tensor = input_shape_layer.get_output(0)
return input_shape_tensor.get_output(0)
-
-
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-08-29 21:00:52.894681+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-08-29 21:03:31.387359+00:00
@@ -169,11 +169,11 @@
)
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
def aten_ops_index(
-network: TRTNetwork,
+ network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
71ca151
to
302b962
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-09-01 17:57:45.620889+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-09-01 18:00:04.480860+00:00
@@ -169,11 +169,11 @@
)
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
def aten_ops_index(
-network: TRTNetwork,
+ network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-01 17:57:45.620889+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-01 18:00:04.530477+00:00
@@ -71,28 +71,28 @@
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- index: Union[TRTTensor, Sequence[TRTTensor]]
+ index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
for i in len(index):
ind = index[i]
- #FIXME: check if the datatype for the indices needs to be casted to INT32
- #TRTInterpretor should take care
+ # FIXME: check if the datatype for the indices needs to be casted to INT32
+ # TRTInterpretor should take care
adv_indx_indices.append(i)
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
identity_layer.set_output_type(0, trt.int32)
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
- elif (len(tensor_indices) == 1):
+ elif len(tensor_indices) == 1:
indices_tensor = tensor_indices[0]
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
else:
@@ -102,24 +102,26 @@
input_shape_layer = network.add_shape(input)
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
input_shape_tensor = input_shape_layer.get_output(0)
dim_tensor_list = []
for i in range(rank):
- #check this
- dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
- set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
+ # check this
+ dim_tensor_layer = network.add_gather(input_shape_tensor, i, 0)
+ set_layer_name(
+ input_shape_layer, target, name + "_index_gather_rank", source_ir
+ )
dim_tensor = dim_tensor_layer.get_output(0)
dim_tensor_list.append(dim_tensor)
- #for cases like
- #t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
- #where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
- #for ":"
- #Examples: x.shape = (10,20,30,40,50)
- #ind_1, ind_2 broadcasted to (2,3,4)
- #x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
- #x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
+ # for cases like
+ # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
+ # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
+ # for ":"
+ # Examples: x.shape = (10,20,30,40,50)
+ # ind_1, ind_2 broadcasted to (2,3,4)
+ # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
+ # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
transpose_layer = network.add_shuffle(input)
new_order = []
for i in range(adv_indx_count):
new_order.append(adv_indx_indices[i])
for i in range(rank):
@@ -130,166 +132,194 @@
permute_order(new_order)
transpose_layer.set_second_transpose(permute_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)
- #Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
+ # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
transpose_tensor_shape = network.add_shape(transpose_tensor)
d0 = 1
d0 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir
+ )
d0_gather = gather_layer.get_output(0)
mult_d0 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatOne_shape",
- trt.ElementWisePROD,
- mult_d0,
- d0_gather,
- )
-
+ network,
+ target,
+ source_ir,
+ name + "index_concatOne_shape",
+ trt.ElementWisePROD,
+ mult_d0,
+ d0_gather,
+ )
+
d1 = 1
d1 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count, rank):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir
+ )
d1_gather = gather_layer.get_output(0)
mult_d1 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatTwo_shape",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_concatTwo_shape",
+ trt.ElementWisePROD,
mult_d1,
d1_gather,
)
concat_tensor_layer = network.add_concatenation([mult_d0, mult_d1])
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
concat_tensor = concat_tensor_layer.get_output(0)
reshape_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)
- #tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
- #// j dimension of input x.
- multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
+ # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
+ # // j dimension of input x.
+ multiplier = get_trt_tensor(
+ network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
+ )
cum_adv_index = tensor_indices[adv_indx_count - 1]
- for i in range(adv_indx_count-2, 0):
+ for i in range(adv_indx_count - 2, 0):
adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_sum_intermediate",
- trt.ElementWiseSUM,
+ network,
+ target,
+ source_ir,
+ name + "index_sum_intermediate",
+ trt.ElementWiseSUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
dim_tensor_list[adv_indx_count[i]],
)
gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
- set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
+ set_layer_name(
+ gather_layer_element, target, name + "_index_gather_element", source_ir
+ )
gather_out = gather_layer.get_output(0)
cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
- #check if all advanced indices are consecutive
+ # check if all advanced indices are consecutive
concat_tensor_reshape = []
- if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
- #concat_tensor_reshape_initial = -1
- #concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
+ if (
+ adv_indx_count
+ == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
+ ):
+ # concat_tensor_reshape_initial = -1
+ # concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
concat_tensor_reshape.append(-1)
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor_reshape.append(curr_dim)
-
+
concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
- set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
+ set_layer_name(
+ concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
+ )
concat_tensor = concat_tensor_layer.get_output(0)
regular_index_shuffle_layer = network.add_shuffle(gather_out)
- set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
+ set_layer_name(
+ regular_index_shuffle_layer,
+ target,
+ name + "_index_regular_index",
+ source_ir,
+ )
unfold_tensor = regular_index_shuffle_layer.get_output(0)
transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
new_order = []
- for i in range(1, adv_indx_count[0]+1):
+ for i in range(1, adv_indx_count[0] + 1):
new_order.append(i)
new_order.append(0)
- for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
+ for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count):
new_order.append(i)
permute_order = trt.Permutation()
permute_order(new_order)
transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
- set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
+ set_layer_name(
+ transpose_advanced_shuffle_layer,
+ target,
+ name + "_index_advanced_shuffle_transpose",
+ source_ir,
+ )
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
- #unfold advanced layer
+ # unfold advanced layer
concat_final_tensor = []
for i in range(0, adv_indx_indices[0]):
current_dim = dim_tensor_list[i]
concat_final_tensor.push_back(curr_dim)
concat_final_tensor.push_back(cum_adv_index_shape_tensor)
for i in range(adv_indx_indices[0], rank):
- if(i not in (adv_indx_indices)):
+ if i not in (adv_indx_indices):
current_dim = dim_tensor_list[i]
concat_final_tensor.append(current_dim)
-
+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
- set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_final_shape_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_final_tensor)
reshape_output = reshape_layer.get_output(0)
-
+
else:
- concat_tensor= []
+ concat_tensor = []
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor.append(curr_dim)
-
+
concat_layer = network.add_concatenation(concat_tensor)
- set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
reshape_layer = network.add_shuffle(gather_out)
reshape_layer.setInput(1, concat_final_tensor)
- set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
+ set_layer_name(
+ reshape_layer,
+ target,
+ name + "_index_shuffle_final_shape_layer",
+ source_ir,
+ )
reshape_output = reshape_layer.get_output(0)
return reshape_output
-
-
-
-
-
-
-
-
-
-
4f2a738
to
42798cc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-07 03:07:48.113307+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-07 03:10:32.769440+00:00
@@ -71,28 +71,28 @@
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- index: Union[TRTTensor, Sequence[TRTTensor]]
+ index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
for i in len(index):
ind = index[i]
- #FIXME: check if the datatype for the indices needs to be casted to INT32
- #TRTInterpretor should take care
+ # FIXME: check if the datatype for the indices needs to be casted to INT32
+ # TRTInterpretor should take care
adv_indx_indices.append(i)
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
identity_layer.set_output_type(0, trt.int32)
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
- elif (len(tensor_indices) == 1):
+ elif len(tensor_indices) == 1:
indices_tensor = tensor_indices[0]
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
else:
@@ -102,24 +102,26 @@
input_shape_layer = network.add_shape(input)
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
input_shape_tensor = input_shape_layer.get_output(0)
dim_tensor_list = []
for i in range(rank):
- #check this
- dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
- set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
+ # check this
+ dim_tensor_layer = network.add_gather(input_shape_tensor, i, 0)
+ set_layer_name(
+ input_shape_layer, target, name + "_index_gather_rank", source_ir
+ )
dim_tensor = dim_tensor_layer.get_output(0)
dim_tensor_list.append(dim_tensor)
- #for cases like
- #t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
- #where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
- #for ":"
- #Examples: x.shape = (10,20,30,40,50)
- #ind_1, ind_2 broadcasted to (2,3,4)
- #x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
- #x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
+ # for cases like
+ # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
+ # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
+ # for ":"
+ # Examples: x.shape = (10,20,30,40,50)
+ # ind_1, ind_2 broadcasted to (2,3,4)
+ # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
+ # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
transpose_layer = network.add_shuffle(input)
new_order = []
for i in range(adv_indx_count):
new_order.append(adv_indx_indices[i])
for i in range(rank):
@@ -130,166 +132,194 @@
permute_order(new_order)
transpose_layer.set_second_transpose(permute_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)
- #Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
+ # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
transpose_tensor_shape = network.add_shape(transpose_tensor)
d0 = 1
d0 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir
+ )
d0_gather = gather_layer.get_output(0)
mult_d0 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatOne_shape",
- trt.ElementWisePROD,
- mult_d0,
- d0_gather,
- )
-
+ network,
+ target,
+ source_ir,
+ name + "index_concatOne_shape",
+ trt.ElementWisePROD,
+ mult_d0,
+ d0_gather,
+ )
+
d1 = 1
d1 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count, rank):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir
+ )
d1_gather = gather_layer.get_output(0)
mult_d1 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatTwo_shape",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_concatTwo_shape",
+ trt.ElementWisePROD,
mult_d1,
d1_gather,
)
concat_tensor_layer = network.add_concatenation([mult_d0, mult_d1])
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
concat_tensor = concat_tensor_layer.get_output(0)
reshape_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)
- #tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
- #// j dimension of input x.
- multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
+ # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
+ # // j dimension of input x.
+ multiplier = get_trt_tensor(
+ network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
+ )
cum_adv_index = tensor_indices[adv_indx_count - 1]
- for i in range(adv_indx_count-2, 0):
+ for i in range(adv_indx_count - 2, 0):
adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_sum_intermediate",
- trt.ElementWiseSUM,
+ network,
+ target,
+ source_ir,
+ name + "index_sum_intermediate",
+ trt.ElementWiseSUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
dim_tensor_list[adv_indx_count[i]],
)
gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
- set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
+ set_layer_name(
+ gather_layer_element, target, name + "_index_gather_element", source_ir
+ )
gather_out = gather_layer.get_output(0)
cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
- #check if all advanced indices are consecutive
+ # check if all advanced indices are consecutive
concat_tensor_reshape = []
- if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
- #concat_tensor_reshape_initial = -1
- #concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
+ if (
+ adv_indx_count
+ == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
+ ):
+ # concat_tensor_reshape_initial = -1
+ # concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
concat_tensor_reshape.append(-1)
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor_reshape.append(curr_dim)
-
+
concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
- set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
+ set_layer_name(
+ concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
+ )
concat_tensor = concat_tensor_layer.get_output(0)
regular_index_shuffle_layer = network.add_shuffle(gather_out)
- set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
+ set_layer_name(
+ regular_index_shuffle_layer,
+ target,
+ name + "_index_regular_index",
+ source_ir,
+ )
unfold_tensor = regular_index_shuffle_layer.get_output(0)
transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
new_order = []
- for i in range(1, adv_indx_count[0]+1):
+ for i in range(1, adv_indx_count[0] + 1):
new_order.append(i)
new_order.append(0)
- for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
+ for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count):
new_order.append(i)
permute_order = trt.Permutation()
permute_order(new_order)
transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
- set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
+ set_layer_name(
+ transpose_advanced_shuffle_layer,
+ target,
+ name + "_index_advanced_shuffle_transpose",
+ source_ir,
+ )
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
- #unfold advanced layer
+ # unfold advanced layer
concat_final_tensor = []
for i in range(0, adv_indx_indices[0]):
current_dim = dim_tensor_list[i]
concat_final_tensor.push_back(curr_dim)
concat_final_tensor.push_back(cum_adv_index_shape_tensor)
for i in range(adv_indx_indices[0], rank):
- if(i not in (adv_indx_indices)):
+ if i not in (adv_indx_indices):
current_dim = dim_tensor_list[i]
concat_final_tensor.append(current_dim)
-
+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
- set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_final_shape_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_final_tensor)
reshape_output = reshape_layer.get_output(0)
-
+
else:
- concat_tensor= []
+ concat_tensor = []
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor.append(curr_dim)
-
+
concat_layer = network.add_concatenation(concat_tensor)
- set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
reshape_layer = network.add_shuffle(gather_out)
reshape_layer.setInput(1, concat_final_tensor)
- set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
+ set_layer_name(
+ reshape_layer,
+ target,
+ name + "_index_shuffle_final_shape_layer",
+ source_ir,
+ )
reshape_output = reshape_layer.get_output(0)
return reshape_output
-
-
-
-
-
-
-
-
-
-
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-09-07 03:07:48.137309+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-09-07 03:10:36.917639+00:00
@@ -2,10 +2,11 @@
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from .harness import DispatchTestCase
+
class TestIndexConverter(DispatchTestCase):
def test_index(self):
class TestModule(nn.Module):
def forward(self, x):
@@ -13,6 +14,6 @@
index0 = torch.randint(0, 16, (1, 16))
index1 = torch.randint(0, 16, (1, 16))
out = torch.ops.aten.index(None, None, index0, index1)
inputs = [torch.randn(1, 10)]
- self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})
\ No newline at end of file
+ self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})
42798cc
to
6b186b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Added some style/naming comments, and will run tests on a model which uses this layer, to verify
for i in range(0, len(index)): | ||
ind = index[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider rewriting as: for i, ind in enumerate(index)
permute_order = trt.Permutation(new_order) | ||
transpose_layer.second_transpose = permute_order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can shorten to transpose_layer.second_transpose = tuple(new_order)
target, | ||
source_ir, | ||
name + "index_intermediate", | ||
trt.ElementWisePROD, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be trt.ElementWiseOperation.PROD
target, | ||
source_ir, | ||
name + "index_sum_intermediate", | ||
trt.ElementWiseSUM, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trt.ElementWiseOperation.SUM
target, | ||
source_ir, | ||
name + "index_intermediate", | ||
trt.ElementWisePROD, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trt.ElementWiseOperation.PROD
permute_order = trt.Permutation(new_order) | ||
transpose_advanced_shuffle_layer.second_transpose = permute_order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
permute_order
could be replaced with tuple(new_order)
) -> TRTTensor: | ||
adv_indx_indices = [] | ||
tensor_indices = [] | ||
_LOGGER.debug(f"The index shape is", index.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index.shape
is not valid, since index
could be a python list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-25 20:21:10.951934+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-25 20:24:07.893237+00:00
@@ -94,11 +94,13 @@
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
adv_indx_indices.append(i)
# torch.nn.parameter.Parameter=> torch.Tensor
ind = get_trt_tensor(network, ind, name + f"_parameter_to_fp32_tensor_{i}")
if last_index is not None:
- assert broadcastable(ind, last_index), "The indices should be broadcastable!"
+ assert broadcastable(
+ ind, last_index
+ ), "The indices should be broadcastable!"
last_index = ind
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
@@ -177,11 +179,13 @@
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
# // j dimension of input x.
multiplier = get_trt_tensor(
- network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], name + "dim_last"
+ network,
+ dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
+ name + "dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
network,
@@ -231,11 +235,13 @@
if (
adv_indx_count
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
):
_LOGGER.debug(f"The indices are continuous in this case")
- concat_tensor_reshape.append(get_trt_tensor(network, -1, name + "dynamic_concat"))
+ concat_tensor_reshape.append(
+ get_trt_tensor(network, -1, name + "dynamic_concat")
+ )
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor_reshape.append(curr_dim)
7d215af
to
fcbd767
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
|
||
import torch | ||
import torch.nn as nn | ||
from harness import DispatchTestCase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to .harness
…adcast and broadcasting cases
… for non continuous indices
742c7c2
to
709d626
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few suggestions for input data types and imports
from torch_tensorrt.fx.converters.converter_utils import ( | ||
get_positive_dim, | ||
has_dynamic_shape, | ||
set_layer_name, | ||
to_numpy, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch get_positive_dim
and to_numpy
to the torch_tensorrt.dynamo.conversion.converter_utils
version
@@ -137,6 +137,24 @@ def aten_ops_sigmoid( | |||
) | |||
|
|||
|
|||
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding @enforce_tensor_types( {0: (TRTTensor,)} )
, to ensure the input is a TRTTensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
… to_numpy to dynamo converter_utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Approved, pending CI
Aten::index converter
#2231