Skip to content

Commit

Permalink
Enable HPU support in vLLM (vllm-project#1)
Browse files Browse the repository at this point in the history
* Porting vllm to HPU

* add hpu cache allocate

* move slot_mapping to cpu and add is_prompt in cache_ops.reshape_and_cache

* add bucket to input metadata

* 1. limit max block number for lazy mode (TODO)
2. set some inpu metadata from cuda to cpu

* remove bucket for block tables

* add run bash script and change benchmark config

* 1. modify kv cache structure to tensors
2. update hpu paged attention API (for hpu graph compatibility)

* add attention mask for generation

* add multi_query_kv_attention attn_bias

* Temp commit

* Integrate fused kernels for RMSNorm and RoPE

* Resolve merge conflicts

* Minor Gaudi workarounds, add debugging to stock vLLM API server

* Fix post-merge pinned memory segfaults

* Re-enable sequence decode

* Maintain GPU compatibility in cache_engine

* Adjust HPU RoPE for non-query runs

* Integrate HPU primitive implementations

* Add xops bindings

* Cast paged attention inputs to bfloat16

* Remove leftover debug calls

* Update comments on HPU ops

* Restoring NVIDIA compatibility in setup.py

* vllm.hpu cleanup

* Added HPU-specific requirements

* Restored full functionality on NVIDIA

* vllm.core cleanup

* vllm init cleanup

* vllm.hpu cleanup

* vllm.benchmarks cleanup

* vllm.entrypoint cleanup

* Changed is_hpu logic

* vllm.benchmark cleanup

* Fixed importing condition

* tests cleanup

* removed dummy printings

* Update test_api_server.py

* restored attention and logprobs tests functionality on Nvidia

* throughput benchmark cleanup

* Changed Habana copyright header

* Restored alibi in bloom

* Added BSD license header

---------

Co-authored-by: Xiaotong Chen <xchen@habana.ai>
Co-authored-by: Jinyan Chen <jychen@habana.ai>
Co-authored-by: Mikhail Dvoretckii <mdvoretckii@habana.ai>
Co-authored-by: Sebastian Urwan <surwan@habana.ai>
  • Loading branch information
5 people authored Feb 19, 2024
1 parent bd29cf3 commit 512c414
Show file tree
Hide file tree
Showing 33 changed files with 1,796 additions and 204 deletions.
3 changes: 3 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np
import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch as htorch
from tqdm import tqdm

from vllm import LLM, SamplingParams
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import aiohttp
import numpy as np
from transformers import PreTrainedTokenizerBase
import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch as htorch
from vllm.transformers_utils.tokenizer import get_tokenizer

# (prompt len, output len, latency)
Expand Down
10 changes: 8 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import List, Optional, Tuple

import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch as htorch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm
Expand Down Expand Up @@ -71,6 +74,7 @@ def run_vllm(
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
profiling: bool = False, # For Gaudi2
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand Down Expand Up @@ -104,7 +108,7 @@ def run_vllm(

start = time.perf_counter()
# FIXME(woosuk): Do not use internal method.
llm._run_engine(use_tqdm=True)
llm._run_engine(use_tqdm=True, profiling=profiling)
end = time.perf_counter()
return end - start

Expand Down Expand Up @@ -206,7 +210,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager)
args.max_model_len, args.enforce_eager,
args.profiling)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -284,6 +289,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument("--profiling", action='store_true', help='Profiling first 4 steps')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import torch

from vllm._C import ops
from vllm.utils import is_hpu
if is_hpu():
from vllm.hpu import ops
else:
from vllm._C import ops

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand Down
14 changes: 14 additions & 0 deletions requirements-hpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
ninja # For faster builds.
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
#torch == 2.1.2
transformers >= 4.36.0 # Required for Mixtral.
#xformers == 0.0.23.post1 # Required for CUDA 12.1.
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
128 changes: 65 additions & 63 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,26 @@
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


def _is_hip() -> bool:
return torch.version.hip is not None


def _is_cuda() -> bool:
return torch.version.cuda is not None
return torch.version.cuda is not None and torch.cuda.is_available()


# Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
CXX_FLAGS = []
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2", "-std=c++17"]
NVCC_FLAGS = []

if _is_cuda() or _is_hip():
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
NVCC_FLAGS = ["-O2", "-std=c++17"]

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]

if _is_hip():
if ROCM_HOME is None:
Expand All @@ -44,11 +51,6 @@ def _is_cuda() -> bool:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]


def get_amdgpu_offload_arch():
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
try:
Expand Down Expand Up @@ -101,43 +103,45 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:


def get_torch_arch_list() -> Set[str]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if env_arch_list is None:
return set()

# List are separated by ; or space.
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
if not torch_arch_list:
if _is_cuda() or _is_hip():
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if env_arch_list is None:
return set()

# List are separated by ; or space.
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
if not torch_arch_list:
return set()

# Filter out the invalid architectures and print a warning.
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list
else:
return set()

# Filter out the invalid architectures and print a warning.
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list



# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_cuda() and not compute_capabilities:
Expand Down Expand Up @@ -227,15 +231,15 @@ def get_torch_arch_list() -> Set[str]:
if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")

vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(vllm_extension)
vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(vllm_extension)


def get_path(*filepath) -> str:
Expand Down Expand Up @@ -264,7 +268,7 @@ def get_vllm_version() -> str:
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"+rocm{rocm_version_str}"
else:
elif _is_cuda():
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
Expand All @@ -274,19 +278,18 @@ def get_vllm_version() -> str:


def read_readme() -> str:
"""Read the README file if present."""
p = get_path("README.md")
if os.path.isfile(p):
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
else:
return ""
"""Read the README file."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read()


def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
if _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
elif not _is_cuda() and not _is_hip():
with open(get_path("requirements-hpu.txt")) as f:
requirements = f.read().strip().split("\n")
else:
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
Expand Down Expand Up @@ -320,6 +323,5 @@ def get_requirements() -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
package_data={"vllm": ["py.typed"]},
cmdclass={"build_ext": BuildExtension} if _is_cuda() or _is_hip() else {},
)
35 changes: 26 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

import pytest
import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.gpu_migration
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -53,11 +57,18 @@ def __init__(
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
if is_hpu():
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
Expand All @@ -69,9 +80,12 @@ def generate(
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if is_hpu():
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
else:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda
output_ids = self.model.generate(
input_ids.cuda(),
input_ids,
use_cache=True,
**kwargs,
)
Expand Down Expand Up @@ -125,9 +139,12 @@ def generate_greedy_logprobs(
) -> List[List[torch.Tensor]]:
all_logprobs = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if is_hpu():
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
else:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
output = self.model.generate(
input_ids.cuda(),
input_ids,
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
Expand Down
6 changes: 5 additions & 1 deletion tests/kernels/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from vllm.utils import is_hpu


def create_kv_caches(
Expand All @@ -18,7 +19,10 @@ def create_kv_caches(

scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
if is_hpu():
key_cache_shape = (num_blocks, num_heads, head_size, block_size)
else:
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
Expand Down
Loading

0 comments on commit 512c414

Please sign in to comment.