You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Inherent dispatch overhead in single-controller paradigm
For distributed offline inference, vLLM leverages a centralized controller process (e.g., Ray Driver) to broadcast the scheduler output to the workers. After workers' execution, the output is gathered from the workers to the centralized controller process to perform the next iteration scheduling. While this single-controller paradigm offers better user experience, it introduces throughput limitations.
Therefore, to launch a generation call, vLLM obey the following procedure:
python3 offline_inference.py # launch the centralized controller process (i.e., LLMEngine)
# inside the LLMEngine
llm_engine.distributed_gpu_executor._run_workers('start_worker_execution_loop', ...) # execute the model
# inside the _run_workers
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
From the code above, the _run_workers functions will cause unneglible overhead.
Although recent proposals managed to eliminate these overhead by using Ray DAG in #6556 and multi-step scheduling #6854. As discussed in Pathways [2] and HybridFlow [1], the dispatch overhead from single-controller is hard to mitigate compared to the fully SPMD paradigm.
Our proposed Fully SPMD approach for vLLM entails:
Independent LLMEngine per GPU, each containing a vLLM scheduler and SPMDGPUExecutor with cache engine and model
There's no centralized control. Instead, each LLMEngine will schedule its own data while maintaining the same behavior between different LLMEngine on different GPU (discussed why this will hold in offline inference setting later).
This approach aligns with other high-performance frameworks like TensorRT-LLM that leverage SPMD-style execution for maximum system throughput.
Inflexible support for RL/RLHF post-training workload, especially when vLLM works with other LLM Training infrastructures (HybridEngine Design [1, 3]).
Mainstream LLM training infrastructures, such as Megatron-LM, DeepSpeed, PyTorch FSDP, utilize the SPMD programming model.
In RL/RLHF post-training, actor models (in PPO and GRPO) must perform both training and autoregressive generation. This requires deploying actor models across both vLLM and training frameworks (e.g., Megatron-LM) with different parallelization strategies, necessitating weight synchronization each iteration.
Current RL/RLHF post-training frameworks offer two main deployment strategies:
Co-located Actor/Rollout using HybridEngine (HybridFlow (veRL)[1], Nemo-Aligner[4], DeepSpeed-Chat [3])
Separated Actor/Rollout deployment (OpenRLHF)[5]
For actor/rollout placing on different devices, the vLLM can be simply deployed as a service and single-controller paradigm works fine in this scenario. However, such placement will cause some GPU idle. The rollout GPU will get idle when the actor performs training following the dependency in PPO and GRPO.
The HybridEngine design could eliminate this GPU idle problem [1]. However, implementing HybridEngine with vLLM and Megatron-LM reveals significant challenges in merging training processes with vLLM worker processes. Weight synchronization requires inter-process weight resharding, which is complex to implement between the single-controller paradigm and SPMD paradigm.
Therefore, adopting a fully SPMD-style LLMEngine would facilitate HybridEngine implementation and enable more efficient weight resharding in RL/RLHF post-training scenarios.
Multi-Node Offline Inference Complexity
Current vLLM implementation requires Ray cluster setup before conducting multi-node offline inference, adding deployment complexity.
A Fully SPMD paradigm would simplify distributed offline inference to a single command:
This streamlined approach not only simplifies deployment but also facilitates straightforward implementation of data parallelism on top of the SPMD architecture. The resulting system would be more maintainable and scalable across multiple nodes.
Major Benefits
Based on the discussion above, the Fully SPMD execution pattern can provide the following benefits in offline settings:
Higher offline inference throughput
Faster weight resharding in RL/RLHF training and easier implementation
Better support for multi-node offline inference and data parallelism
[Planning] Easier to implement pipeline parallelism in LLMEngine under SPMD paradigm.
Proposed Change.
We have already implemented the SPMD version of vLLM in HybridFlow (veRL) under v0.6.3, v0.5.4, v0.4.2 and v0.3.1 github.com. The architecture is shown as above in the background section.
We made the following major changes:
SPMD pattern:
Introduced SPMDGPUExecutor class within LLMEngine
Each GPU runs its own LLMEngine and SPMDGPUExecutor instance
Single Worker per SPMDGPUExecutor, with scheduling and execution occurring within the same GPU process
Deterministic behavior:
Modified determine_num_available_blocks() to ensure same GPU/CPU cache size and therefore ensure consistent scheduler behavior across workers. Implemented all_reduce(MIN) synchronization for block availability data between workers
In an offline inference setting, a batch of prompts is processed simultaneously. Consequently, different LLMEngines within the same DP group receive identical inputs at the same time. Due to the consistent GPU/CPU caching mechanisms (with same size) and the inherent FIFO scheduling method used in vLLM, the behavior of the schedulers remains deterministic and identical.
Model initialization/synchronization: In RL/RLHF post-training workload, the model weights in vLLM should be synchronized every iteration. Therefore, we add an API to sync model weights under the SPMD paradigm.
KVCache offload: In RL/RLHF post-training, the KVCache should be offloaded when the actor is performing a training stage to reduce memory overhead.
I think the first two features are highly-related to SPMD and would be beneficial for all offline settings using vLLM.
The last two features will be a good supplement for the vLLM to support RL/RLHF post-training workloads.
Roadmap
Fully SPMD functionality and optimizations:
Fully SPMD functionality transfer from HybridFlow(veRL) practice to vLLM LLMEngine
Support various offline inference offloads using SPMD functionality (e.g., multi-modal, speculative)
Implement and optimize the pipeline parallelism on top of SPMD-based LLMEngine.
Reference
[1] Guangming Sheng, Chi Zhang, Zilingfeng Ye, Xibin Wu, Wang Zhang, Ru Zhang, Yanghua Peng, Haibin Lin, and Chuan Wu. 2024. Hybridflow: A flexible and efficient rlhf framework. Eurosys 2025.
[2] Paul Barham, Aakanksha Chowdhery, Jeff Dean, Sanjay Ghemawat, Steven Hand, Daniel Hurt, Michael Isard, Hyeontaek Lim, Ruoming Pang, Sudip Roy, et al. 2022. Pathways: Asynchronous distributed dataflow for ml. Proceedings of Machine Learning and Systems 4 (2022), 430–449.
[3] Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajb- handari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, et al. 2023. DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales. arXiv preprint arXiv:2308.01320 (2023).
[4] Gerald Shen and Zhilin Wang and Olivier Delalleau and Jiaqi Zeng and Yi Dong and Daniel Egert, et al. 2024. NeMo-Aligner: Scalable Toolkit for Efficient Model Alignment. Arxiv preprint arXiv: 2405.01481 (2024).
[5] Jian Hu and Xibin Wu and Zilin Zhu and Xianyu and Weixun Wang and Dehao Zhang and Yu Cao. 2024. OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework. arXiv preprint arXiv:2405.11143
Before submitting a new issue...
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
Motivation.
TL;DR: Introducing a fully SPMD-style LLMEngine execution pattern to improve offline inference throughput.
The RFC draft is initiated by @PeterSH6
Background and Motivation
Inherent dispatch overhead in single-controller paradigm
For distributed offline inference, vLLM leverages a centralized controller process (e.g., Ray Driver) to broadcast the scheduler output to the workers. After workers' execution, the output is gathered from the workers to the centralized controller process to perform the next iteration scheduling. While this single-controller paradigm offers better user experience, it introduces throughput limitations.
Therefore, to launch a generation call, vLLM obey the following procedure:
From the code above, the
_run_workers
functions will cause unneglible overhead.Although recent proposals managed to eliminate these overhead by using
Ray DAG
in #6556 and multi-step scheduling #6854. As discussed in Pathways [2] and HybridFlow [1], the dispatch overhead from single-controller is hard to mitigate compared to the fully SPMD paradigm.Our proposed Fully SPMD approach for vLLM entails:
LLMEngine
per GPU, each containing a vLLM scheduler andSPMDGPUExecutor
with cache engine and modelLLMEngine
will schedule its own data while maintaining the same behavior between differentLLMEngine
on different GPU (discussed why this will hold in offline inference setting later).This approach aligns with other high-performance frameworks like TensorRT-LLM that leverage SPMD-style execution for maximum system throughput.
Inflexible support for RL/RLHF post-training workload, especially when vLLM works with other LLM Training infrastructures (HybridEngine Design [1, 3]).
Mainstream LLM training infrastructures, such as Megatron-LM, DeepSpeed, PyTorch FSDP, utilize the SPMD programming model.
In RL/RLHF post-training, actor models (in PPO and GRPO) must perform both training and autoregressive generation. This requires deploying actor models across both vLLM and training frameworks (e.g., Megatron-LM) with different parallelization strategies, necessitating weight synchronization each iteration.
Current RL/RLHF post-training frameworks offer two main deployment strategies:
For actor/rollout placing on different devices, the vLLM can be simply deployed as a service and single-controller paradigm works fine in this scenario. However, such placement will cause some GPU idle. The rollout GPU will get idle when the actor performs training following the dependency in PPO and GRPO.
The HybridEngine design could eliminate this GPU idle problem [1]. However, implementing HybridEngine with vLLM and Megatron-LM reveals significant challenges in merging training processes with vLLM worker processes. Weight synchronization requires inter-process weight resharding, which is complex to implement between the single-controller paradigm and SPMD paradigm.
Therefore, adopting a fully SPMD-style LLMEngine would facilitate HybridEngine implementation and enable more efficient weight resharding in RL/RLHF post-training scenarios.
Multi-Node Offline Inference Complexity
Current vLLM implementation requires Ray cluster setup before conducting multi-node offline inference, adding deployment complexity.
A Fully SPMD paradigm would simplify distributed offline inference to a single command:
This streamlined approach not only simplifies deployment but also facilitates straightforward implementation of data parallelism on top of the SPMD architecture. The resulting system would be more maintainable and scalable across multiple nodes.
Major Benefits
Based on the discussion above, the Fully SPMD execution pattern can provide the following benefits in offline settings:
Proposed Change.
We have already implemented the SPMD version of vLLM in HybridFlow (veRL) under v0.6.3, v0.5.4, v0.4.2 and v0.3.1 github.com. The architecture is shown as above in the background section.
We made the following major changes:
I think the first two features are highly-related to SPMD and would be beneficial for all offline settings using vLLM.
The last two features will be a good supplement for the vLLM to support RL/RLHF post-training workloads.
Roadmap
Fully SPMD functionality and optimizations:
Feedback Period.
No response
CC List.
@PeterSH6
@vermouth1992
@ZSL98
Any Other Things.
Reference
[1] Guangming Sheng, Chi Zhang, Zilingfeng Ye, Xibin Wu, Wang Zhang, Ru Zhang, Yanghua Peng, Haibin Lin, and Chuan Wu. 2024. Hybridflow: A flexible and efficient rlhf framework. Eurosys 2025.
[2] Paul Barham, Aakanksha Chowdhery, Jeff Dean, Sanjay Ghemawat, Steven Hand, Daniel Hurt, Michael Isard, Hyeontaek Lim, Ruoming Pang, Sudip Roy, et al. 2022. Pathways: Asynchronous distributed dataflow for ml. Proceedings of Machine Learning and Systems 4 (2022), 430–449.
[3] Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajb- handari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, et al. 2023. DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales. arXiv preprint arXiv:2308.01320 (2023).
[4] Gerald Shen and Zhilin Wang and Olivier Delalleau and Jiaqi Zeng and Yi Dong and Daniel Egert, et al. 2024. NeMo-Aligner: Scalable Toolkit for Efficient Model Alignment. Arxiv preprint arXiv: 2405.01481 (2024).
[5] Jian Hu and Xibin Wu and Zilin Zhu and Xianyu and Weixun Wang and Dehao Zhang and Yu Cao. 2024. OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework. arXiv preprint arXiv:2405.11143
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: