From def25b8629b08697680533000d925ee705458b38 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 22 Aug 2024 14:42:12 -0700 Subject: [PATCH] Reformatted tensor parallelism --- examples/distributed_inference/llama3_model.py | 10 +++++----- .../distributed_inference/tensor_parallel_llama3.py | 11 ++++------- .../tensor_parallel_simple_example.py | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 11 ----------- 4 files changed, 10 insertions(+), 24 deletions(-) diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py index 11bd51fe3b..9fa59b5c49 100644 --- a/examples/distributed_inference/llama3_model.py +++ b/examples/distributed_inference/llama3_model.py @@ -3,7 +3,7 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch import torch.nn.functional as F @@ -195,7 +195,7 @@ def __init__(self, model_args: ModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - def init_weights(self, init_std: float): + def init_weights(self, init_std: float) -> None: for linear in (self.wq, self.wk, self.wv): nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) @@ -204,7 +204,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, - ): + ) -> Any: """Forward pass of the attention module. Args: @@ -275,10 +275,10 @@ def __init__( self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) - def forward(self, x): + def forward(self, x) -> Any: return self.w2(F.silu(self.w1(x)) * self.w3(x)) - def init_weights(self, init_std: float): + def init_weights(self, init_std: float) -> None: nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) for linear in (self.w2, self.w3): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index 180d8c0724..fc03a64386 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -1,3 +1,6 @@ +# Taken and modified pytorch lightening +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning +import logging import os import time @@ -12,9 +15,6 @@ ) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -# Taken and modified pytorch lightening -# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning -import logging _rank = int(os.environ["RANK"]) _world_size = int(os.environ["WORLD_SIZE"]) tp_size = 2 @@ -25,9 +25,6 @@ fh.setLevel(logging.INFO) logger.addHandler(fh) -# understand world topology - - tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) model_args = ModelArgs( @@ -56,7 +53,7 @@ "use_python_runtime": True, "workspace_size": 1 << 33, "debug": False, - "timing_cache_path":"/opt/file/cache/timing_cache_llama.bin" + "timing_cache_path": "/opt/file/cache/timing_cache_llama.bin", }, dynamic=False, ) diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index a228a7f636..470487a751 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -77,7 +77,7 @@ def forward(self, x): "truncate_long_and_double": True, "enabled_precisions": {torch.float32, torch.float16}, "use_python_runtime": True, - "min_block_size": 1 + "min_block_size": 1, }, dynamic=False, ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 27d46390a2..0c29bd378e 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -362,21 +362,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - logger.info(f"-" * 100) - logger.info(f"There are {len(list(partitioned_module.named_children()))} submodules in total.") - i = 0 - import os for name, _ in partitioned_module.named_children(): - # Benchmark log utilities - i += 1 - logger.info(f"-" * 100) - logger.info(f"Start compiling {i}th submodule") - total = torch.cuda.get_device_properties(0).total_memory - submodule = getattr(partitioned_module, name) # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: - # if (settings.use_fast_partitioner and "_run_on_acc" not in name) or int(os.environ["RANK"]) == 1: dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) continue