Skip to content

Commit

Permalink
Reformatted tensor parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Aug 22, 2024
1 parent dfc65d0 commit def25b8
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 24 deletions.
10 changes: 5 additions & 5 deletions examples/distributed_inference/llama3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def init_weights(self, init_std: float):
def init_weights(self, init_std: float) -> None:
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
Expand All @@ -204,7 +204,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
):
) -> Any:
"""Forward pass of the attention module.
Args:
Expand Down Expand Up @@ -275,10 +275,10 @@ def __init__(
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)

def forward(self, x):
def forward(self, x) -> Any:
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def init_weights(self, init_std: float):
def init_weights(self, init_std: float) -> None:
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
for linear in (self.w2, self.w3):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
Expand Down
11 changes: 4 additions & 7 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
import logging
import os
import time

Expand All @@ -12,9 +15,6 @@
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
import logging
_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])
tp_size = 2
Expand All @@ -25,9 +25,6 @@
fh.setLevel(logging.INFO)
logger.addHandler(fh)

# understand world topology


tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))

model_args = ModelArgs(
Expand Down Expand Up @@ -56,7 +53,7 @@
"use_python_runtime": True,
"workspace_size": 1 << 33,
"debug": False,
"timing_cache_path":"/opt/file/cache/timing_cache_llama.bin"
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
},
dynamic=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward(self, x):
"truncate_long_and_double": True,
"enabled_precisions": {torch.float32, torch.float16},
"use_python_runtime": True,
"min_block_size": 1
"min_block_size": 1,
},
dynamic=False,
)
Expand Down
11 changes: 0 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,21 +362,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
logger.info(f"-" * 100)
logger.info(f"There are {len(list(partitioned_module.named_children()))} submodules in total.")
i = 0
import os
for name, _ in partitioned_module.named_children():
# Benchmark log utilities
i += 1
logger.info(f"-" * 100)
logger.info(f"Start compiling {i}th submodule")
total = torch.cuda.get_device_properties(0).total_memory

submodule = getattr(partitioned_module, name)
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
# if (settings.use_fast_partitioner and "_run_on_acc" not in name) or int(os.environ["RANK"]) == 1:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
continue

Expand Down

0 comments on commit def25b8

Please sign in to comment.