Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support aten.dot dynamo converter #3043

Merged
merged 1 commit into from
Aug 8, 2024
Merged

feat: Support aten.dot dynamo converter #3043

merged 1 commit into from
Aug 8, 2024

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Jul 26, 2024

Description

The existing code already supports it. It's just not registered and get lowered.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jul 26, 2024
@github-actions github-actions bot requested a review from gs-olive July 26, 2024 04:16
@peri044
Copy link
Collaborator

peri044 commented Jul 31, 2024

@HolyWu Is there any issue with aten.dot being lowered ? Generally, we intend to write converters for aten core ops https://pytorch.org/docs/main/torch.compiler_ir.html and the other ops are expected to be decomposed.

@HolyWu
Copy link
Contributor Author

HolyWu commented Aug 1, 2024

No issue for me. I just happened to find that the matrix layer in TRT has support for this operation and the code in impl.matmul already uses that operation. Then why not?

@peri044
Copy link
Collaborator

peri044 commented Aug 2, 2024

No issue for me. I just happened to find that the matrix layer in TRT has support for this operation and the code in impl.matmul already uses that operation. Then why not?

I see. We prefer Pytorch lowering whenever it is feasible (especially for non aten core ops) since this would ensure our converter library to be light weighted. We can close this PR. If there are any performance concerns, let us know.

@HolyWu HolyWu closed this Aug 2, 2024
@HolyWu HolyWu deleted the dot branch August 2, 2024 15:45
@HolyWu HolyWu restored the dot branch August 2, 2024 16:03
@HolyWu HolyWu reopened this Aug 2, 2024
@HolyWu
Copy link
Contributor Author

HolyWu commented Aug 2, 2024

@peri044 One more question. If an operator was once in Core ATen IR but later was removed from Core ATen IR, for example aten.roll and aten.pixel_shuffle, then should the converters for those operators be removed so as to use PyTorch decomposition?

@HolyWu
Copy link
Contributor Author

HolyWu commented Aug 4, 2024

Regarding performance, the decomposed path is actually 2-4 times slower than the converter path.

import timeit

import numpy as np
import torch
import torch_tensorrt


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.ops.aten.dot.default(x, y)


@torch.inference_mode()
def benchmark(model, inputs):
    # Warm up
    for _ in range(3):
        model(*inputs)

    torch.cuda.synchronize()

    timings = []
    for _ in range(100):
        start_time = timeit.default_timer()
        model(*inputs)
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        timings.append(end_time - start_time)
    return np.array(timings)


torch.manual_seed(12345)
device = torch.device("cuda", 0)
model = MyModule().eval().to(device).half()

inputs = (
    torch.randn((10000000,), dtype=torch.half, device=device),
    torch.randn((10000000,), dtype=torch.half, device=device),
)

trt_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=inputs,
    enabled_precisions={torch.half},
    debug=True,
    min_block_size=1,
    device=device,
)

torch_timings = benchmark(model, inputs)
trt_timings = benchmark(trt_model, inputs)

print("")
print("Torch:")
print(f"\tMin={torch_timings.min()}, Mean={torch_timings.mean()}, Max={torch_timings.max()}")

print("")
print("TRT:")
print(f"\tMin={trt_timings.min()}, Mean={trt_timings.mean()}, Max={trt_timings.max()}")

torch.testing.assert_close(model(*inputs), trt_model(*inputs), rtol=1e-3, atol=1e-3)

Before patch

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %dot : [num_users=1] = call_function[target=torch.ops.aten.dot.default](args = (%x, %y), kwargs = {})
    return (dot,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%y,), kwargs = {dtype: torch.float32})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_to_copy, %_to_copy_1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {})
    %_to_copy_2 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float16})
    return (_to_copy_2,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%y,), kwargs = {dtype: torch.float32})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_to_copy, %_to_copy_1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {})
    %_to_copy_2 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float16})
    return (_to_copy_2,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%y,), kwargs = {dtype: torch.float32})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_to_copy, %_to_copy_1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {})
    %_to_copy_2 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float16})
    return (_to_copy_2,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\timing_cache.bin')

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 3
- torch.ops.aten.mul.Tensor + Operator Count: 1
- torch.ops.aten.sum.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 5 operators out of 5 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 3
- torch.ops.aten.mul.Tensor + Operator Count: 1
- torch.ops.aten.sum.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(10000000,), (10000000,)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
    %y : [num_users=1] = placeholder[target=y]
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%y,), kwargs = {dtype: torch.float32})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_to_copy, %_to_copy_1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {})
    %_to_copy_2 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float16})
    return _to_copy_2
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 13448, GPU 1229 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2445, GPU +288, now: CPU 16177, GPU 1517 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node x, a placeholder node with target x in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[10000000], dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node x with properties: Inputs: () | Outputs: (x: (10000000,)@torch.float16)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node __/_to_copy, a call_function node with target aten._to_copy.default in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/_to_copy (kind: aten._to_copy.default, args: ('x <tensorrt.ITensor [shape=(10000000,), dtype=DataType.HALF]>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node __/_to_copy with properties: Inputs: (x: (10000000,)@torch.float16) | Outputs: (_to_copy: (10000000,)@torch.float32)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node y, a placeholder node with target y in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: y [shape=[10000000], dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node y with properties: Inputs: () | Outputs: (y: (10000000,)@torch.float16)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node __/_to_copy_1, a call_function node with target aten._to_copy.default in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/_to_copy_1 (kind: aten._to_copy.default, args: ('y <tensorrt.ITensor [shape=(10000000,), dtype=DataType.HALF]>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node __/_to_copy_1 with properties: Inputs: (y: (10000000,)@torch.float16) | Outputs: (_to_copy_1: (10000000,)@torch.float32)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node __/mul, a call_function node with target aten.mul.Tensor in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/mul (kind: aten.mul.Tensor, args: ('Forced Cast ITensor x from DataType.HALF to DataType.FLOAT - [aten_ops.torch.ops.aten._to_copy.default]-[__/_to_copy]_output <tensorrt.ITensor [shape=(10000000,), dtype=DataType.FLOAT]>', 'Forced Cast ITensor y from DataType.HALF to DataType.FLOAT - [aten_ops.torch.ops.aten._to_copy.default]-[__/_to_copy_1]_output <tensorrt.ITensor [shape=(10000000,), dtype=DataType.FLOAT]>'))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node __/mul with properties: Inputs: (_to_copy: (10000000,)@torch.float32, _to_copy_1: (10000000,)@torch.float32) | Outputs: (mul: (10000000,)@torch.float32)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node __/sum_1, a call_function node with target aten.sum.default in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/sum_1 (kind: aten.sum.default, args: ('[ELEMENTWISE]-[aten_ops.mul.Tensor]-[__/mul]_output_mul.Tensor <tensorrt.ITensor [shape=(10000000,), dtype=DataType.FLOAT]>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node __/sum_1 with properties: Inputs: (mul: (10000000,)@torch.float32) | Outputs: (sum_1: ()@torch.float32)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node __/_to_copy_2, a call_function node with target aten._to_copy.default in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/_to_copy_2 (kind: aten._to_copy.default, args: ('[REDUCE]-[aten_ops.sum.default]-[__/sum_1]_output <tensorrt.ITensor [shape=(), dtype=DataType.FLOAT]>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node __/_to_copy_2 with properties: Inputs: (sum_1: ()@torch.float32) | Outputs: (_to_copy_2: ()@torch.float16)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node output, a output node with target output in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(), dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node output with properties: Inputs: (_to_copy_2: ()@torch.float16) | Outputs: (output: )
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.003905
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 2 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 352
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 4 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.1151ms to assign 3 blocks to 4 nodes requiring 20001280 bytes.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 20001280
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 0.144 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 0 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3472 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.152334
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 19748 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 117423 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 265 timing cache entries
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 5 Total Operators, of which 5 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\timing_cache.bin')

  Graph Structure:

   Inputs: Tuple(Tensor: (10000000,)@float16, Tensor: (10000000,)@float16)
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (10000000,)@float16, Tensor: (10000000,)@float16]
     Number of Operators in Engine: 5
     Engine Outputs: Tensor: ()@float16
    ...
   Outputs: List[Tensor: ()@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 5.0
   Most Operators in a TRT Engine: 5

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=5 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=5 which generates 1 TRT engine(s)

Torch:
        Min=0.0002153999994334299, Mean=0.00022290499997325243, Max=0.0004743000026792288

TRT:
        Min=0.0012330999998084735, Mean=0.0012919229999897653, Max=0.0017789999983506277

After patch

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %dot : [num_users=1] = call_function[target=torch.ops.aten.dot.default](args = (%x, %y), kwargs = {})
    return (dot,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %dot : [num_users=1] = call_function[target=torch.ops.aten.dot.default](args = (%x, %y), kwargs = {})
    return (dot,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %dot : [num_users=1] = call_function[target=torch.ops.aten.dot.default](args = (%x, %y), kwargs = {})
    return (dot,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %dot : [num_users=1] = call_function[target=torch.ops.aten.dot.default](args = (%x, %y), kwargs = {})
    return (dot,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\timing_cache.bin')

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.dot.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.dot.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(10000000,), (10000000,)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %dot : [num_users=1] = call_function[target=torch.ops.aten.dot.default](args = (%x, %y), kwargs = {})
    return dot
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 13463, GPU 1109 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2432, GPU +288, now: CPU 16180, GPU 1397 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node x, a placeholder node with target x in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[10000000], dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node x with properties: Inputs: () | Outputs: (x: (10000000,)@torch.float16)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node y, a placeholder node with target y in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: y [shape=[10000000], dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node y with properties: Inputs: () | Outputs: (y: (10000000,)@torch.float16)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node __/dot, a call_function node with target aten.dot.default in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/dot (kind: aten.dot.default, args: ('x <tensorrt.ITensor [shape=(10000000,), dtype=DataType.HALF]>', 'y <tensorrt.ITensor [shape=(10000000,), dtype=DataType.HALF]>'))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node __/dot with properties: Inputs: (x: (10000000,)@torch.float16, y: (10000000,)@torch.float16) | Outputs: (dot: ()@torch.float16)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Running node output, a output node with target output in the TensorRT Interpreter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(), dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Ran node output with properties: Inputs: (dot: ()@torch.float16) | Outputs: (output: )
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001952
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 2 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 7328
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 3584
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 1 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.0563ms to assign 1 blocks to 1 nodes requiring 3584 bytes.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 3584
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 0.0122673 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 0 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3407 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.018553
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 80372 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 117423 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 265 timing cache entries
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\timing_cache.bin')

  Graph Structure:

   Inputs: Tuple(Tensor: (10000000,)@float16, Tensor: (10000000,)@float16)
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (10000000,)@float16, Tensor: (10000000,)@float16]
     Number of Operators in Engine: 1
     Engine Outputs: Tensor: ()@float16
    ...
   Outputs: List[Tensor: ()@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)

Torch:
        Min=0.0002167999991797842, Mean=0.00022596699975110823, Max=0.0004327000024204608

TRT:
        Min=0.00031380000291392207, Mean=0.0003244600002653897, Max=0.0004501000003074296

@peri044
Copy link
Collaborator

peri044 commented Aug 7, 2024

Thanks for the performance analysis and proving the converter is better option here.

One more question. If an operator was once in Core ATen IR but later was removed from Core ATen IR, for example aten.roll and aten.pixel_shuffle, then should the converters for those operators be removed so as to use PyTorch decomposition?
I think it's fine to leave them as long as it doesn't hurt performance

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@peri044 peri044 merged commit fdaba9a into pytorch:main Aug 8, 2024
36 of 61 checks passed
@HolyWu HolyWu deleted the dot branch August 8, 2024 23:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants