Skip to content

Commit

Permalink
Resolve issue with TP=1
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
  • Loading branch information
tdoublep committed Jan 17, 2025
1 parent 320ac3c commit 6550bcb
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions vllm/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,6 @@ def __init__(
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""

init_distributed_environment(
world_size=self.parallel_config.world_size,
rank=self.rank,
distributed_init_method="env://",
backend="gloo",
)

torch._C._distributed_c10d._register_process_group(
"default", dist.group.WORLD)

Expand All @@ -77,10 +70,6 @@ def init_distributed_environment(self) -> None:
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())

ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
)

def init_device(self) -> None:

Expand All @@ -90,13 +79,26 @@ def init_device(self) -> None:
LoadEndianness.LITTLE)

if not self._env_initialized:

init_distributed_environment(
world_size=self.parallel_config.world_size,
rank=self.rank,
distributed_init_method="env://",
backend="gloo",
)

if self.parallel_config.world_size > 1:
self.init_distributed_environment()
elif envs.VLLM_SPYRE_DYNAMO_BACKEND in [
"sendnn", "sendnn_decoder"
]:
spyre_setup.spyre_setup(rank=0, world_size=1, verbose=True)

ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
)

self._env_initialized = True

# Set random seed.
Expand Down

0 comments on commit 6550bcb

Please sign in to comment.