Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aten::Index converter #2277

Merged
merged 9 commits into from
Oct 4, 2023
Merged

Aten::Index converter #2277

merged 9 commits into from
Oct 4, 2023

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Aug 29, 2023

Aten::index converter
#2231

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Aug 29, 2023
@apbose apbose marked this pull request as draft August 29, 2023 21:01
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-08-29 21:00:52.894681+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-08-29 21:03:31.334736+00:00
@@ -69,28 +69,28 @@
    network: TRTNetwork,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
    input: TRTTensor,
-    index: Union[TRTTensor, Sequence[TRTTensor]]
+    index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
    adv_indx_indices = []
    tensor_indices = []

    for i in len(index):
        ind = index[i]
-        #FIXME: check if the datatype for the indices needs to be casted to INT32
-        #TRTInterpretor should take care
+        # FIXME: check if the datatype for the indices needs to be casted to INT32
+        # TRTInterpretor should take care
        adv_indx_indices.append(i)
        tensor_indices.append(ind)

    if not tensor_indices:
        identity_layer = network.add_identity(input)
        identity_layer.set_output_type(0, trt.int32)
        set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
        return identity_layer.get_output(0)
-    elif (len(tensor_indices) == 1):
+    elif len(tensor_indices) == 1:
        indices_tensor = tensor_indices[0]
        gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
        set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
        return gather_layer.get_output(0)
    else:
@@ -99,7 +99,5 @@
        adv_indx_count = len(adv_indx_indices)
        input_shape_layer = network.add_shape(input)
        set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
        input_shape_tensor = input_shape_layer.get_output(0)
        return input_shape_tensor.get_output(0)
-    
-        
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-08-29 21:00:52.894681+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-08-29 21:03:31.387359+00:00
@@ -169,11 +169,11 @@
    )


@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
def aten_ops_index(
-network: TRTNetwork,
+    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

@apbose apbose force-pushed the dynamo_converter_index branch from 71ca151 to 302b962 Compare September 1, 2023 17:57
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-09-01 17:57:45.620889+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-09-01 18:00:04.480860+00:00
@@ -169,11 +169,11 @@
    )


@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
def aten_ops_index(
-network: TRTNetwork,
+    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-09-01 17:57:45.620889+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-09-01 18:00:04.530477+00:00
@@ -71,28 +71,28 @@
    network: TRTNetwork,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
    input: TRTTensor,
-    index: Union[TRTTensor, Sequence[TRTTensor]]
+    index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
    adv_indx_indices = []
    tensor_indices = []

    for i in len(index):
        ind = index[i]
-        #FIXME: check if the datatype for the indices needs to be casted to INT32
-        #TRTInterpretor should take care
+        # FIXME: check if the datatype for the indices needs to be casted to INT32
+        # TRTInterpretor should take care
        adv_indx_indices.append(i)
        tensor_indices.append(ind)

    if not tensor_indices:
        identity_layer = network.add_identity(input)
        identity_layer.set_output_type(0, trt.int32)
        set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
        return identity_layer.get_output(0)
-    elif (len(tensor_indices) == 1):
+    elif len(tensor_indices) == 1:
        indices_tensor = tensor_indices[0]
        gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
        set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
        return gather_layer.get_output(0)
    else:
@@ -102,24 +102,26 @@
        input_shape_layer = network.add_shape(input)
        set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
        input_shape_tensor = input_shape_layer.get_output(0)
        dim_tensor_list = []
        for i in range(rank):
-            #check this
-            dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
-            set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
+            # check this
+            dim_tensor_layer = network.add_gather(input_shape_tensor, i, 0)
+            set_layer_name(
+                input_shape_layer, target, name + "_index_gather_rank", source_ir
+            )
            dim_tensor = dim_tensor_layer.get_output(0)
            dim_tensor_list.append(dim_tensor)

-        #for cases like 
-        #t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
-        #where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
-        #for ":"
-        #Examples: x.shape = (10,20,30,40,50)
-        #ind_1, ind_2 broadcasted to (2,3,4)
-        #x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
-        #x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
+        # for cases like
+        # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
+        # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
+        # for ":"
+        # Examples: x.shape = (10,20,30,40,50)
+        # ind_1, ind_2 broadcasted to (2,3,4)
+        # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
+        # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
        transpose_layer = network.add_shuffle(input)
        new_order = []
        for i in range(adv_indx_count):
            new_order.append(adv_indx_indices[i])
        for i in range(rank):
@@ -130,166 +132,194 @@
        permute_order(new_order)
        transpose_layer.set_second_transpose(permute_order)
        set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
        transpose_tensor = transpose_layer.get_output(0)

-        #Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
+        # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
        transpose_tensor_shape = network.add_shape(transpose_tensor)
        d0 = 1
        d0 = get_trt_tensor(network, d0, "d0_initial")
        for i in range(adv_indx_count):
            dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
-            set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
+            set_layer_name(
+                dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir
+            )
            d0_gather = gather_layer.get_output(0)
            mult_d0 = convert_binary_elementwise(
-                            network, 
-                            target, 
-                            source_ir,
-                            name + "index_concatOne_shape", 
-                            trt.ElementWisePROD, 
-                            mult_d0,
-                            d0_gather,
-                    )
-            
+                network,
+                target,
+                source_ir,
+                name + "index_concatOne_shape",
+                trt.ElementWisePROD,
+                mult_d0,
+                d0_gather,
+            )
+
        d1 = 1
        d1 = get_trt_tensor(network, d0, "d0_initial")
        for i in range(adv_indx_count, rank):
            dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
-            set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
+            set_layer_name(
+                dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir
+            )
            d1_gather = gather_layer.get_output(0)
            mult_d1 = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_concatTwo_shape", 
-                trt.ElementWisePROD, 
+                network,
+                target,
+                source_ir,
+                name + "index_concatTwo_shape",
+                trt.ElementWisePROD,
                mult_d1,
                d1_gather,
            )
        concat_tensor_layer = network.add_concatenation([mult_d0, mult_d1])
        set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
        concat_tensor = concat_tensor_layer.get_output(0)

        reshape_layer = network.add_shuffle(transpose_tensor)
-        #check this
+        # check this
        reshape_layer.set_input(1, concat_tensor)
        flatten_tensor = reshape_layer.get_output(0)

-        #tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)),  ind_i is input indices[i], x_j is the
-        #// j dimension of input x.
-        multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
+        # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)),  ind_i is input indices[i], x_j is the
+        # // j dimension of input x.
+        multiplier = get_trt_tensor(
+            network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
+        )
        cum_adv_index = tensor_indices[adv_indx_count - 1]
-        for i in range(adv_indx_count-2, 0):
+        for i in range(adv_indx_count - 2, 0):
            adv_index = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_intermediate", 
-                trt.ElementWisePROD, 
+                network,
+                target,
+                source_ir,
+                name + "index_intermediate",
+                trt.ElementWisePROD,
                multiplier,
                tensor_indices[i],
            )
            cum_adv_index = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_sum_intermediate", 
-                trt.ElementWiseSUM, 
+                network,
+                target,
+                source_ir,
+                name + "index_sum_intermediate",
+                trt.ElementWiseSUM,
                cum_adv_index,
                adv_index,
            )
            multiplier = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_intermediate", 
-                trt.ElementWisePROD, 
+                network,
+                target,
+                source_ir,
+                name + "index_intermediate",
+                trt.ElementWisePROD,
                multiplier,
                dim_tensor_list[adv_indx_count[i]],
            )

        gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
-        set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
+        set_layer_name(
+            gather_layer_element, target, name + "_index_gather_element", source_ir
+        )
        gather_out = gather_layer.get_output(0)

        cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
-        #check if all advanced indices are consecutive
+        # check if all advanced indices are consecutive
        concat_tensor_reshape = []
-        if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
-            #concat_tensor_reshape_initial = -1
-            #concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
+        if (
+            adv_indx_count
+            == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
+        ):
+            # concat_tensor_reshape_initial = -1
+            # concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
            concat_tensor_reshape.append(-1)
            for i in range(0, rank):
                if i not in adv_indx_indices:
                    curr_dim = dim_tensor_list[i]
                    concat_tensor_reshape.append(curr_dim)
-            
+
            concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
-            set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
+            set_layer_name(
+                concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
+            )
            concat_tensor = concat_tensor_layer.get_output(0)

            regular_index_shuffle_layer = network.add_shuffle(gather_out)
-            set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
+            set_layer_name(
+                regular_index_shuffle_layer,
+                target,
+                name + "_index_regular_index",
+                source_ir,
+            )
            unfold_tensor = regular_index_shuffle_layer.get_output(0)

            transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
            new_order = []
-            for i in range(1, adv_indx_count[0]+1):
+            for i in range(1, adv_indx_count[0] + 1):
                new_order.append(i)
            new_order.append(0)
-            for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
+            for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count):
                new_order.append(i)

            permute_order = trt.Permutation()
            permute_order(new_order)
            transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
-            set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
+            set_layer_name(
+                transpose_advanced_shuffle_layer,
+                target,
+                name + "_index_advanced_shuffle_transpose",
+                source_ir,
+            )
            transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)

-            #unfold advanced layer
+            # unfold advanced layer
            concat_final_tensor = []
            for i in range(0, adv_indx_indices[0]):
                current_dim = dim_tensor_list[i]
                concat_final_tensor.push_back(curr_dim)

            concat_final_tensor.push_back(cum_adv_index_shape_tensor)
            for i in range(adv_indx_indices[0], rank):
-                if(i not in (adv_indx_indices)):
+                if i not in (adv_indx_indices):
                    current_dim = dim_tensor_list[i]
                    concat_final_tensor.append(current_dim)
-            
+
            concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
-            set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+            set_layer_name(
+                concat_final_shape_layer,
+                target,
+                name + "_index_concat_final_shape_layer",
+                source_ir,
+            )
            concat_final_tensor = concat_final_shape_layer.get_output(0)

            unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
-            #check this
+            # check this
            reshape_layer.set_input(1, concat_final_tensor)
            reshape_output = reshape_layer.get_output(0)
-        
+
        else:
-            concat_tensor= []
+            concat_tensor = []
            for i in range(0, rank):
                if i not in adv_indx_indices:
                    curr_dim = dim_tensor_list[i]
                    concat_tensor.append(curr_dim)
-            
+
            concat_layer = network.add_concatenation(concat_tensor)
-            set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+            set_layer_name(
+                concat_layer,
+                target,
+                name + "_index_concat_final_shape_layer",
+                source_ir,
+            )
            concat_final_tensor = concat_final_shape_layer.get_output(0)

            reshape_layer = network.add_shuffle(gather_out)
            reshape_layer.setInput(1, concat_final_tensor)
-            set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
+            set_layer_name(
+                reshape_layer,
+                target,
+                name + "_index_shuffle_final_shape_layer",
+                source_ir,
+            )
            reshape_output = reshape_layer.get_output(0)

    return reshape_output
-
-
-                                    
-
-
-            
-            
-
-        
-        

@github-actions github-actions bot added the component: tests Issues re: Tests label Sep 5, 2023
@apbose apbose force-pushed the dynamo_converter_index branch from 4f2a738 to 42798cc Compare September 7, 2023 03:07
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-09-07 03:07:48.113307+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-09-07 03:10:32.769440+00:00
@@ -71,28 +71,28 @@
    network: TRTNetwork,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
    input: TRTTensor,
-    index: Union[TRTTensor, Sequence[TRTTensor]]
+    index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
    adv_indx_indices = []
    tensor_indices = []

    for i in len(index):
        ind = index[i]
-        #FIXME: check if the datatype for the indices needs to be casted to INT32
-        #TRTInterpretor should take care
+        # FIXME: check if the datatype for the indices needs to be casted to INT32
+        # TRTInterpretor should take care
        adv_indx_indices.append(i)
        tensor_indices.append(ind)

    if not tensor_indices:
        identity_layer = network.add_identity(input)
        identity_layer.set_output_type(0, trt.int32)
        set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
        return identity_layer.get_output(0)
-    elif (len(tensor_indices) == 1):
+    elif len(tensor_indices) == 1:
        indices_tensor = tensor_indices[0]
        gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
        set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
        return gather_layer.get_output(0)
    else:
@@ -102,24 +102,26 @@
        input_shape_layer = network.add_shape(input)
        set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
        input_shape_tensor = input_shape_layer.get_output(0)
        dim_tensor_list = []
        for i in range(rank):
-            #check this
-            dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
-            set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
+            # check this
+            dim_tensor_layer = network.add_gather(input_shape_tensor, i, 0)
+            set_layer_name(
+                input_shape_layer, target, name + "_index_gather_rank", source_ir
+            )
            dim_tensor = dim_tensor_layer.get_output(0)
            dim_tensor_list.append(dim_tensor)

-        #for cases like 
-        #t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
-        #where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
-        #for ":"
-        #Examples: x.shape = (10,20,30,40,50)
-        #ind_1, ind_2 broadcasted to (2,3,4)
-        #x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
-        #x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
+        # for cases like
+        # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
+        # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
+        # for ":"
+        # Examples: x.shape = (10,20,30,40,50)
+        # ind_1, ind_2 broadcasted to (2,3,4)
+        # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
+        # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
        transpose_layer = network.add_shuffle(input)
        new_order = []
        for i in range(adv_indx_count):
            new_order.append(adv_indx_indices[i])
        for i in range(rank):
@@ -130,166 +132,194 @@
        permute_order(new_order)
        transpose_layer.set_second_transpose(permute_order)
        set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
        transpose_tensor = transpose_layer.get_output(0)

-        #Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
+        # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
        transpose_tensor_shape = network.add_shape(transpose_tensor)
        d0 = 1
        d0 = get_trt_tensor(network, d0, "d0_initial")
        for i in range(adv_indx_count):
            dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
-            set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
+            set_layer_name(
+                dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir
+            )
            d0_gather = gather_layer.get_output(0)
            mult_d0 = convert_binary_elementwise(
-                            network, 
-                            target, 
-                            source_ir,
-                            name + "index_concatOne_shape", 
-                            trt.ElementWisePROD, 
-                            mult_d0,
-                            d0_gather,
-                    )
-            
+                network,
+                target,
+                source_ir,
+                name + "index_concatOne_shape",
+                trt.ElementWisePROD,
+                mult_d0,
+                d0_gather,
+            )
+
        d1 = 1
        d1 = get_trt_tensor(network, d0, "d0_initial")
        for i in range(adv_indx_count, rank):
            dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
-            set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
+            set_layer_name(
+                dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir
+            )
            d1_gather = gather_layer.get_output(0)
            mult_d1 = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_concatTwo_shape", 
-                trt.ElementWisePROD, 
+                network,
+                target,
+                source_ir,
+                name + "index_concatTwo_shape",
+                trt.ElementWisePROD,
                mult_d1,
                d1_gather,
            )
        concat_tensor_layer = network.add_concatenation([mult_d0, mult_d1])
        set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
        concat_tensor = concat_tensor_layer.get_output(0)

        reshape_layer = network.add_shuffle(transpose_tensor)
-        #check this
+        # check this
        reshape_layer.set_input(1, concat_tensor)
        flatten_tensor = reshape_layer.get_output(0)

-        #tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)),  ind_i is input indices[i], x_j is the
-        #// j dimension of input x.
-        multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
+        # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)),  ind_i is input indices[i], x_j is the
+        # // j dimension of input x.
+        multiplier = get_trt_tensor(
+            network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
+        )
        cum_adv_index = tensor_indices[adv_indx_count - 1]
-        for i in range(adv_indx_count-2, 0):
+        for i in range(adv_indx_count - 2, 0):
            adv_index = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_intermediate", 
-                trt.ElementWisePROD, 
+                network,
+                target,
+                source_ir,
+                name + "index_intermediate",
+                trt.ElementWisePROD,
                multiplier,
                tensor_indices[i],
            )
            cum_adv_index = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_sum_intermediate", 
-                trt.ElementWiseSUM, 
+                network,
+                target,
+                source_ir,
+                name + "index_sum_intermediate",
+                trt.ElementWiseSUM,
                cum_adv_index,
                adv_index,
            )
            multiplier = convert_binary_elementwise(
-                network, 
-                target, 
-                source_ir,
-                name + "index_intermediate", 
-                trt.ElementWisePROD, 
+                network,
+                target,
+                source_ir,
+                name + "index_intermediate",
+                trt.ElementWisePROD,
                multiplier,
                dim_tensor_list[adv_indx_count[i]],
            )

        gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
-        set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
+        set_layer_name(
+            gather_layer_element, target, name + "_index_gather_element", source_ir
+        )
        gather_out = gather_layer.get_output(0)

        cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
-        #check if all advanced indices are consecutive
+        # check if all advanced indices are consecutive
        concat_tensor_reshape = []
-        if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
-            #concat_tensor_reshape_initial = -1
-            #concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
+        if (
+            adv_indx_count
+            == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
+        ):
+            # concat_tensor_reshape_initial = -1
+            # concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
            concat_tensor_reshape.append(-1)
            for i in range(0, rank):
                if i not in adv_indx_indices:
                    curr_dim = dim_tensor_list[i]
                    concat_tensor_reshape.append(curr_dim)
-            
+
            concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
-            set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
+            set_layer_name(
+                concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
+            )
            concat_tensor = concat_tensor_layer.get_output(0)

            regular_index_shuffle_layer = network.add_shuffle(gather_out)
-            set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
+            set_layer_name(
+                regular_index_shuffle_layer,
+                target,
+                name + "_index_regular_index",
+                source_ir,
+            )
            unfold_tensor = regular_index_shuffle_layer.get_output(0)

            transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
            new_order = []
-            for i in range(1, adv_indx_count[0]+1):
+            for i in range(1, adv_indx_count[0] + 1):
                new_order.append(i)
            new_order.append(0)
-            for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
+            for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count):
                new_order.append(i)

            permute_order = trt.Permutation()
            permute_order(new_order)
            transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
-            set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
+            set_layer_name(
+                transpose_advanced_shuffle_layer,
+                target,
+                name + "_index_advanced_shuffle_transpose",
+                source_ir,
+            )
            transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)

-            #unfold advanced layer
+            # unfold advanced layer
            concat_final_tensor = []
            for i in range(0, adv_indx_indices[0]):
                current_dim = dim_tensor_list[i]
                concat_final_tensor.push_back(curr_dim)

            concat_final_tensor.push_back(cum_adv_index_shape_tensor)
            for i in range(adv_indx_indices[0], rank):
-                if(i not in (adv_indx_indices)):
+                if i not in (adv_indx_indices):
                    current_dim = dim_tensor_list[i]
                    concat_final_tensor.append(current_dim)
-            
+
            concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
-            set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+            set_layer_name(
+                concat_final_shape_layer,
+                target,
+                name + "_index_concat_final_shape_layer",
+                source_ir,
+            )
            concat_final_tensor = concat_final_shape_layer.get_output(0)

            unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
-            #check this
+            # check this
            reshape_layer.set_input(1, concat_final_tensor)
            reshape_output = reshape_layer.get_output(0)
-        
+
        else:
-            concat_tensor= []
+            concat_tensor = []
            for i in range(0, rank):
                if i not in adv_indx_indices:
                    curr_dim = dim_tensor_list[i]
                    concat_tensor.append(curr_dim)
-            
+
            concat_layer = network.add_concatenation(concat_tensor)
-            set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+            set_layer_name(
+                concat_layer,
+                target,
+                name + "_index_concat_final_shape_layer",
+                source_ir,
+            )
            concat_final_tensor = concat_final_shape_layer.get_output(0)

            reshape_layer = network.add_shuffle(gather_out)
            reshape_layer.setInput(1, concat_final_tensor)
-            set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
+            set_layer_name(
+                reshape_layer,
+                target,
+                name + "_index_shuffle_final_shape_layer",
+                source_ir,
+            )
            reshape_output = reshape_layer.get_output(0)

    return reshape_output
-
-
-                                    
-
-
-            
-            
-
-        
-        
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py	2023-09-07 03:07:48.137309+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py	2023-09-07 03:10:36.917639+00:00
@@ -2,10 +2,11 @@
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase
+

class TestIndexConverter(DispatchTestCase):
    def test_index(self):
        class TestModule(nn.Module):
            def forward(self, x):
@@ -13,6 +14,6 @@
                index0 = torch.randint(0, 16, (1, 16))
                index1 = torch.randint(0, 16, (1, 16))
                out = torch.ops.aten.index(None, None, index0, index1)

        inputs = [torch.randn(1, 10)]
-        self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})
\ No newline at end of file
+        self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})

@apbose apbose force-pushed the dynamo_converter_index branch from 42798cc to 6b186b6 Compare September 7, 2023 03:15
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@apbose apbose requested a review from gs-olive September 8, 2023 16:28
@apbose apbose self-assigned this Sep 8, 2023
@apbose apbose marked this pull request as ready for review September 8, 2023 16:29
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Added some style/naming comments, and will run tests on a model which uses this layer, to verify

Comment on lines 87 to 88
for i in range(0, len(index)):
ind = index[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider rewriting as: for i, ind in enumerate(index)

Comment on lines 141 to 142
permute_order = trt.Permutation(new_order)
transpose_layer.second_transpose = permute_order
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can shorten to transpose_layer.second_transpose = tuple(new_order)

target,
source_ir,
name + "index_intermediate",
trt.ElementWisePROD,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be trt.ElementWiseOperation.PROD

target,
source_ir,
name + "index_sum_intermediate",
trt.ElementWiseSUM,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trt.ElementWiseOperation.SUM

target,
source_ir,
name + "index_intermediate",
trt.ElementWisePROD,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trt.ElementWiseOperation.PROD

Comment on lines 265 to 266
permute_order = trt.Permutation(new_order)
transpose_advanced_shuffle_layer.second_transpose = permute_order
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

permute_order could be replaced with tuple(new_order)

) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
_LOGGER.debug(f"The index shape is", index.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index.shape is not valid, since index could be a python list

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-09-25 20:21:10.951934+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-09-25 20:24:07.893237+00:00
@@ -94,11 +94,13 @@
            _LOGGER.debug(f"Shape of {i} index is {ind.shape}")
            adv_indx_indices.append(i)
            # torch.nn.parameter.Parameter=> torch.Tensor
            ind = get_trt_tensor(network, ind, name + f"_parameter_to_fp32_tensor_{i}")
            if last_index is not None:
-                assert broadcastable(ind, last_index), "The indices should be broadcastable!"
+                assert broadcastable(
+                    ind, last_index
+                ), "The indices should be broadcastable!"
            last_index = ind
            tensor_indices.append(ind)

    if not tensor_indices:
        identity_layer = network.add_identity(input)
@@ -177,11 +179,13 @@
        _LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")

        # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)),  ind_i is input indices[i], x_j is the
        # // j dimension of input x.
        multiplier = get_trt_tensor(
-            network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], name + "dim_last"
+            network,
+            dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
+            name + "dim_last",
        )
        cum_adv_index = tensor_indices[adv_indx_count - 1]
        for i in range(adv_indx_count - 2, -1, -1):
            adv_index = convert_binary_elementwise(
                network,
@@ -231,11 +235,13 @@
        if (
            adv_indx_count
            == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
        ):
            _LOGGER.debug(f"The indices are continuous in this case")
-            concat_tensor_reshape.append(get_trt_tensor(network, -1, name + "dynamic_concat"))
+            concat_tensor_reshape.append(
+                get_trt_tensor(network, -1, name + "dynamic_concat")
+            )
            for i in range(0, rank):
                if i not in adv_indx_indices:
                    curr_dim = dim_tensor_list[i]
                    concat_tensor_reshape.append(curr_dim)

@apbose apbose force-pushed the dynamo_converter_index branch from 7d215af to fcbd767 Compare September 25, 2023 20:24
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines


import torch
import torch.nn as nn
from harness import DispatchTestCase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to .harness

@apbose apbose force-pushed the dynamo_converter_index branch from 742c7c2 to 709d626 Compare October 3, 2023 23:20
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@apbose apbose requested a review from gs-olive October 3, 2023 23:23
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few suggestions for input data types and imports

Comment on lines 15 to 20
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
has_dynamic_shape,
set_layer_name,
to_numpy,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch get_positive_dim and to_numpy to the torch_tensorrt.dynamo.conversion.converter_utils version

@@ -137,6 +137,24 @@ def aten_ops_sigmoid(
)


@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
Copy link
Collaborator

@gs-olive gs-olive Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding @enforce_tensor_types( {0: (TRTTensor,)} ), to ensure the input is a TRTTensor

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Approved, pending CI

@gs-olive gs-olive merged commit e432bf2 into main Oct 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests priority: high
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants