Skip to content

Commit

Permalink
[Hardware] [Intel] Enable Multiprocessing and tensor parallel in CPU …
Browse files Browse the repository at this point in the history
…backend and update documentation (#6125)
  • Loading branch information
bigPYJ1151 authored Jul 26, 2024
1 parent aa48677 commit 3bbb493
Show file tree
Hide file tree
Showing 14 changed files with 404 additions and 90 deletions.
28 changes: 20 additions & 8 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,38 @@
set -ex

# Try building the docker image
docker build -t cpu-test -f Dockerfile.cpu .
docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu .
numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .

# Setup cleanup
remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; }
trap remove_docker_container EXIT
remove_docker_container

# Run the image
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2

# offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"

# Run basic model test
docker exec cpu-test bash -c "cd tests;
docker exec cpu-test bash -c "
pip install pytest Pillow protobuf
cd ../
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported

# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
9 changes: 5 additions & 4 deletions Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

FROM ubuntu:22.04 AS cpu-test-1

RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
RUN apt-get update -y \
&& apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12

# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
Expand All @@ -13,8 +13,9 @@ RUN pip install intel-openmp

ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"

RUN echo 'ulimit -c 0' >> ~/.bashrc

RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl

RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
Expand All @@ -25,7 +26,7 @@ COPY ./ /workspace/vllm

WORKDIR /workspace/vllm

RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/test/cpu

# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
ARG VLLM_CPU_DISABLE_AVX512
Expand Down
4 changes: 4 additions & 0 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ endif()

message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")

list(APPEND LIBS "numa")


#
# Define extension targets
Expand All @@ -95,6 +97,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp")
Expand All @@ -104,6 +107,7 @@ define_gpu_extension_target(
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI
Expand Down
7 changes: 7 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <torch/library.h>

void init_cpu_threads_env(const std::string& cpu_ids);

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops

Expand Down Expand Up @@ -107,4 +109,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
// CPU utils
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
65 changes: 65 additions & 0 deletions csrc/cpu/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include <numa.h>
#include <unistd.h>
#include <string>
#include <sched.h>

#include "cpu_types.hpp"

void init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size);

constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);

for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
int i = 0;
while (group_mask) {
if (group_mask & 1) {
omp_cpu_ids.emplace_back(offset + i);
}
++i;
group_mask >>= 1;
}
}

// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
bitmask* src_mask = numa_get_membind();

int pid = getpid();

// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_CHECK(false,
"numa_migrate_pages failed. errno: " + std::to_string(errno));
}

// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
}

// OMP threads binding
omp_set_num_threads((int)omp_cpu_ids.size());
torch::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
CPU_ZERO_S(size, mask);
CPU_SET_S(omp_cpu_ids[i], size, mask);
sched_setaffinity(0, sizeof(cpu_set_t), mask);
CPU_FREE(mask);
}

numa_free_nodemask(omp_cpu_mask);
}
55 changes: 47 additions & 8 deletions docs/source/getting_started/cpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Table of contents:
#. :ref:`Requirements <cpu_backend_requirements>`
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
#. :ref:`Build from source <build_cpu_backend_from_source>`
#. :ref:`Related runtime environment variables <env_intro>`
#. :ref:`Intel Extension for PyTorch <ipex_guidance>`
#. :ref:`Performance tips <cpu_backend_performance_tips>`

Expand Down Expand Up @@ -47,7 +48,7 @@ Build from source
.. code-block:: console
$ sudo apt-get update -y
$ sudo apt-get install -y gcc-12 g++-12
$ sudo apt-get install -y gcc-12 g++-12 libnuma-dev
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
- Second, install Python packages for vLLM CPU backend building:
Expand All @@ -71,22 +72,27 @@ Build from source

- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.

.. _env_intro:

Related runtime environment variables
-------------------------------------

- ``VLLM_CPU_KVCACHE_SPACE``: specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- ``VLLM_CPU_OMP_THREADS_BIND``: specify the CPU cores dedicated to the OpenMP threads. For example, ``VLLM_CPU_OMP_THREADS_BIND=0-31`` means there will be 32 OpenMP threads bound on 0-31 CPU cores. ``VLLM_CPU_OMP_THREADS_BIND=0-31|32-63`` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.

.. _ipex_guidance:

Intel Extension for PyTorch
---------------------------

- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.

- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.

.. _cpu_backend_performance_tips:

Performance tips
-----------------

- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:

.. code-block:: console
Expand All @@ -96,11 +102,44 @@ Performance tips
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
$ python examples/offline_inference.py # run vLLM
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:

- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
.. code-block:: console
$ export VLLM_CPU_KVCACHE_SPACE=40
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
$ vllm serve facebook/opt-125m
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using ``VLLM_CPU_OMP_THREADS_BIND``. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:

.. code-block:: console
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful.
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores
# The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
$ python examples/offline_inference.py
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access.



4 changes: 2 additions & 2 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
torch == 2.4.0; platform_machine != "ppc64le"
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
3 changes: 3 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
Expand Down
2 changes: 0 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,6 @@ def _get_executor_cls(
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
Expand Down
8 changes: 7 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
Expand Down Expand Up @@ -241,11 +242,16 @@ def get_default_config_root():
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

# CPU key-value cache space
# (CPU backend only) CPU key-value cache space.
# default is 4GB
"VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),

# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),

# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":
Expand Down
Loading

0 comments on commit 3bbb493

Please sign in to comment.