Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] LLM.collective_rpc interface and RLHF example #12084

Merged
merged 28 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ steps:
- tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
- examples/offline_inference/rlhf.py
commands:
- pytest -v -s ../examples/offline_inference/rlhf.py
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
Expand Down
152 changes: 152 additions & 0 deletions examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# a simple demonstration of RLHF with VLLM.
import os

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams, configure_as_vllm_process
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker


# recommended way to create data-plane communication
# between external (train processes) and VLLM workers.
def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


# inference code, inherit from Worker to provide custom functions
class MyWorker(Worker):

def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)

def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())

self.model_runner.model.load_weights(weights=[(name, weight)])

del weight

def get_weight_square_sum(self):
sum_value = 0.0
for name, p in self.model_runner.model.named_parameters():
sum_value += p.square().sum().item()
return sum_value


class MyLLM(LLM):

def __init__(self, *args, **kwargs):
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
del os.environ["CUDA_VISIBLE_DEVICES"]
super().__init__(*args, **kwargs)


# current process is a training process, and it takes 1 GPU.
# important: set some common environment variables the same as vLLM workers.
configure_as_vllm_process()

train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")

# start ray with 2 GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()

pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)

# inferencing engine, it takes 2 GPUs.
# for simplicity, we define the MyWorker class in this self-contained script.
# normally, we should define the MyWorker class in a separate file and pass
# the qualified name of the class to the worker_cls parameter.
# here we use `enforce_eager` to reduce test time.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
tensor_parallel_size=2,
distributed_executor_backend="ray",
)

# Generate texts from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

sampling_params = SamplingParams(temperature=0)

outputs_original = ray.get(llm.generate.remote(prompts, sampling_params))

master_address = get_ip()
master_port = get_open_port()

# set up the connection between the training process and the inference engine.
handle = llm.collective_rpc.remote("init_weight_update_group",
args=(master_address, master_port, 1, 3))
model_update_group = stateless_init_process_group(master_address, master_port,
0, 3, torch.device("cuda:0"))
ray.get(handle)

# simulate training, modify the weights of the model.
for name, p in train_model.named_parameters():
p.data.zero_()

# sync weight from the training process to the inference engine.
for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight",
args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)

# check if the weights are updated.
weight_square_sum_values = ray.get(
llm.collective_rpc.remote("get_weight_square_sum"))
for x in weight_square_sum_values:
assert x == 0.0

# use the updated model to generate texts.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))

# they should be different.
for output_original, output_updated in zip(outputs_original, outputs_updated):
generated_text_original = output_original.outputs[0].text
generated_text_updated = output_updated.outputs[0].text
assert generated_text_original != generated_text_updated
38 changes: 38 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,43 @@

from .version import __version__, __version_tuple__


def configure_as_vllm_process():

# some common environment variables
# for all processes created by vllm

import os

import torch

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1

from vllm.platforms import current_platform

if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
if current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'


__all__ = [
"__version__",
"__version_tuple__",
Expand All @@ -42,4 +79,5 @@
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"configure_as_vllm_process",
]
25 changes: 25 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)

import cloudpickle
from tqdm import tqdm
from typing_extensions import deprecated

Expand Down Expand Up @@ -186,6 +187,13 @@ def __init__(
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

if "worker_cls" in kwargs:
worker_cls = kwargs["worker_cls"]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
Expand Down Expand Up @@ -455,6 +463,23 @@ def generate(
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)

def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Run a method on all workers, with homogeneous arguments.
The main extension point for the LLM entrypoint.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)

def beam_search(
self,
prompts: List[Union[TokensPrompt, TextPrompt]],
Expand Down
31 changes: 0 additions & 31 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import logging
import os
from typing import Callable, Dict

import torch

import vllm.envs as envs

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,34 +47,6 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""

# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1

from vllm.platforms import current_platform

if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
if current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'

global plugins_loaded
if plugins_loaded:
return
Expand Down
15 changes: 11 additions & 4 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

import cloudpickle
import torch

from vllm.config import ObservabilityConfig, VllmConfig
Expand Down Expand Up @@ -510,14 +511,20 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
kwargs = all_kwargs[self.rank]
enable_trace_function_call_for_thread(self.vllm_config)

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
from vllm import configure_as_vllm_process
configure_as_vllm_process()

from vllm.plugins import load_general_plugins
load_general_plugins()

worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(**kwargs)
assert self.worker is not None

Expand Down
Loading