From 6550bcb82fd96bdae8189cd47f4f71cfd79faa1c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 17 Jan 2025 13:27:00 +0000 Subject: [PATCH] Resolve issue with TP=1 Signed-off-by: Thomas Parnell --- vllm/worker/spyre_worker.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/worker/spyre_worker.py b/vllm/worker/spyre_worker.py index 9fd5f4b60..0a89871e4 100644 --- a/vllm/worker/spyre_worker.py +++ b/vllm/worker/spyre_worker.py @@ -58,13 +58,6 @@ def __init__( def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" - init_distributed_environment( - world_size=self.parallel_config.world_size, - rank=self.rank, - distributed_init_method="env://", - backend="gloo", - ) - torch._C._distributed_c10d._register_process_group( "default", dist.group.WORLD) @@ -77,10 +70,6 @@ def init_distributed_environment(self) -> None: # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cpu()) - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - ) def init_device(self) -> None: @@ -90,6 +79,14 @@ def init_device(self) -> None: LoadEndianness.LITTLE) if not self._env_initialized: + + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + distributed_init_method="env://", + backend="gloo", + ) + if self.parallel_config.world_size > 1: self.init_distributed_environment() elif envs.VLLM_SPYRE_DYNAMO_BACKEND in [ @@ -97,6 +94,11 @@ def init_device(self) -> None: ]: spyre_setup.spyre_setup(rank=0, world_size=1, verbose=True) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + ) + self._env_initialized = True # Set random seed.