Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Bug fix / doc updates for tensor parallel examples #3378

Open
wants to merge 3 commits into
base: release/2.6
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,46 @@

Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend.

1. Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference)
## Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference)

Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model
will be loaded onto each GPU and different chunks of batch input is processed on each device.

See the examples started with `data_parallel` for more details.
See the examples [data_parallel_gpt2.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_gpt2.py) and [data_parallel_stable_diffusion.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_stable_diffusion.py) for more details.

2. Tensor parallel distributed inference
## Tensor parallel distributed inference

Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

torchrun --nproc_per_node=2 tensor_parallel_llama2.py

3. Tensor parallel distributed inference using nccl ops plugin
## Tensor parallel distributed inference on a simple model using nccl ops plugin

apt install libmpich-dev

We use [torch.distributed](https://pytorch.org/docs/stable/distributed.html) package to add shard the model with Tensor parallelism. The distributed ops (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation. The [converters for these operators](https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py#L25-L55) are already available in Torch-TensorRT. The functional implementation of ops is imported from `tensorrt_llm` package (to be more specific, only `libnvinfer_plugin_tensorrt_llm.so` is required). So we have two options here

apt install libopenmpi-dev
### Option 1: Install TensorRT-LLM

#For python3.10
Follow the instructions to [install TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/installation/linux.html)

pip install tensorrt-llm
If the default installation fails due to issues like library version mismatches or Python compatibility, it is recommended to use Option 2. After a successful installation, ensure you test by running `import torch_tensorrt` to confirm it works without errors. The import might fail if the `tensorrt_llm` installation overrides `torch_tensorrt` dependencies. Option 2 is particularly advisable if you prefer not to install `tensorrt_llm` and its associated dependencies.

For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there
### Option 2: Link the TensorRT-LLM directly.

#then pip install the tensorrt and torch version compatible with installed torchTRT
Another alternative is to load the `libnvinfer_plugin_tensorrt_llm.so` directly. You can do this by
* Downloading the [tensorrt_llm-0.16.0](https://pypi.nvidia.com/tensorrt-llm/tensorrt_llm-0.16.0-cp310-cp310-linux_x86_64.whl#sha256=f86c6b89647802f49b26b4f6e40824701da14c0f053dbda3e1e7a8709d6939c7) wheel file from the NVIDIA python index.
* Extract the wheel file to a directory and you can find the `libnvinfer_plugin_tensorrt_llm.so` library under `tensorrt_llm/libs` directory.
* Please set the environment variable TRTLLM_PLUGINS_PATH to the above extracted path at the [initialize_distributed_env()](https://github.com/pytorch/TensorRT/blob/54e36dbafe567c75f36b3edb22d6f49d4278c12a/examples/distributed_inference/tensor_parallel_initialize_dist.py#L45) call.

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py

#For other python
After configuring the TensorRT-LLM or the TensorRT-LLM plugin library path, please run the following command which illustrates tensor parallelism of a simple model and compilation with Torch-TensorRT

4. Tensor parallel distributed llama3 inference using nccl ops plugin
```py
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
```

apt install libmpich-dev
We also provide a tensor paralellism compilation example on a more advanced model like `Llama-3`. Here's the command to run it

apt install libopenmpi-dev

#For python3.10

pip install tensorrt-llm

For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so

#then pip install the tensorrt and torch version compatible with installed torchTRT

mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
```py
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
```
4 changes: 2 additions & 2 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm

import torch_tensorrt

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)
import tensorrt_llm

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
Expand Down
15 changes: 8 additions & 7 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch._ops import OpOverload
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._compiler import compile_module
from torch_tensorrt.dynamo.lowering import (
Expand Down Expand Up @@ -59,17 +60,17 @@ def aot_torch_tensorrt_aten_backend(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompostions"] = get_decompositions(
settings_aot_autograd["decompositions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
# transpose key deleted since not desirable to lower it to permute
for key in settings_aot_autograd["decompositions"]:
if "transpose" in key._name:
to_delete = key
del settings_aot_autograd["decompositions"][to_delete]
return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
decompositions=settings_aot_autograd["decompositions"],
)(gm, sample_inputs)


Expand Down
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Dict, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -16,8 +17,6 @@
tensorrt_fused_nccl_reduce_scatter_op,
)

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():
Expand All @@ -30,7 +29,7 @@ def fused_nccl_gather(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_gather(
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
Expand All @@ -46,15 +45,14 @@ def fused_nccl_reduce_scatter(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_reduce_scatter(
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

breakpoint()
else:
_LOGGER.debug(
"Did not load torch.distributed converters since TensorRT-LLM is not available"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None:

if op_target in shape_inference_funcs:
new_shape = shape_inference_funcs[op_target](node)
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
new_node_dtype = None
if node.meta["val"].dtype == torch.complex64:
new_node_dtype = torch.float32
else:
new_node_dtype = torch.float64
real_tensor = torch.empty(new_shape, dtype=new_node_dtype)
node.meta["val"] = fake_mode.from_tensor(real_tensor)
else:
print("No shape for the inference function", {op_name})
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def fuse_distributed_ops(
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
Expand All @@ -58,11 +57,12 @@ def fuse_distributed_ops(
args=(node.args[0], node.args[1], node.args[2]),
)
else:
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

for i, contiguous_input in enumerate(contiguous_inputs):
if contiguous_input.dtype == torch.complex64:
contiguous_input_real = contiguous_input.real
contiguous_input_imag = contiguous_input.imag
contiguous_inputs[i] = torch.stack(
(contiguous_input_real, contiguous_input_imag), dim=-1
)

with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down
Loading