Skip to content

Commit

Permalink
Worker init
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 committed Jul 19, 2024
1 parent 7bdf67c commit 5f0b812
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 33 deletions.
33 changes: 17 additions & 16 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ def _configure_ray_workers_use_nsight(self,

return ray_remote_kwargs

def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"

return dict(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1
Expand All @@ -99,6 +113,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
Expand All @@ -108,23 +123,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
placement_group_bundle_index=bundle_id,
)

if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"

worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)

if self.use_ray_spmd_worker:
self.workers.append(worker)
Expand All @@ -135,10 +139,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
Expand Down
20 changes: 10 additions & 10 deletions vllm/executor/ray_xpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

return num_gpu_blocks, num_cpu_blocks

def _get_worker_wrapper_args(self) -> Dict[str, Any]:
return dict(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
Expand All @@ -126,6 +133,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
Expand All @@ -139,22 +147,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
Expand Down
26 changes: 19 additions & 7 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

import torch

Expand Down Expand Up @@ -315,14 +315,23 @@ class WorkerWrapperBase:
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
If worker_class_fn is specified, it will be executed to get the worker
class.
Otherwise, the worker class will be obtained by dynamically importing it
using worker_module_name and worker_class_name.
"""

def __init__(self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False) -> None:
def __init__(
self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False,
worker_class_fn: Optional[Callable[[],
Type[WorkerBase]]] = None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker_class_fn = worker_class_fn
self.worker: Optional[WorkerBase] = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
Expand All @@ -348,8 +357,11 @@ def init_worker(self, *args, **kwargs):
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'

mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
if self.worker_class_fn:
worker_class = self.worker_class_fn()
else:
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)

self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
Expand Down

0 comments on commit 5f0b812

Please sign in to comment.