Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable HPU support in vLLM #1

Merged
merged 44 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e528d06
Porting vllm to HPU
Nov 8, 2023
d8da01f
add hpu cache allocate
Nov 8, 2023
4d1538f
move slot_mapping to cpu and add is_prompt in cache_ops.reshape_and_c…
Nov 8, 2023
c336824
add bucket to input metadata
Nov 8, 2023
068c748
1. limit max block number for lazy mode (TODO)
Nov 10, 2023
9a042f7
remove bucket for block tables
Nov 10, 2023
1e7e16d
add run bash script and change benchmark config
Nov 11, 2023
153eb71
1. modify kv cache structure to tensors
Nov 14, 2023
9b7e0a7
add attention mask for generation
Nov 16, 2023
c99eefc
add multi_query_kv_attention attn_bias
Nov 19, 2023
1327be8
Temp commit
Dec 8, 2023
de7799f
Integrate fused kernels for RMSNorm and RoPE
Dec 18, 2023
b839181
Resolve merge conflicts
Dec 21, 2023
00df486
Minor Gaudi workarounds, add debugging to stock vLLM API server
kzawora-intel Dec 21, 2023
8b20664
Merge remote-tracking branch 'origin/main' into mdvoretc/prototype
kzawora-intel Dec 21, 2023
16b5557
Fix post-merge pinned memory segfaults
kzawora-intel Dec 21, 2023
2b6ec4e
Re-enable sequence decode
kzawora-intel Dec 21, 2023
9d4bd9f
Maintain GPU compatibility in cache_engine
kzawora-intel Dec 22, 2023
7a0337a
Adjust HPU RoPE for non-query runs
Jan 10, 2024
6351d41
Integrate HPU primitive implementations
Jan 23, 2024
c0d3c69
Add xops bindings
Jan 23, 2024
48b26d1
Cast paged attention inputs to bfloat16
Jan 23, 2024
aefa573
Remove leftover debug calls
Jan 26, 2024
c49b68e
Update comments on HPU ops
Jan 26, 2024
c5c2a99
Restoring NVIDIA compatibility in setup.py
Feb 2, 2024
1c66908
vllm.hpu cleanup
kzawora-intel Feb 5, 2024
5725b31
Added HPU-specific requirements
Feb 7, 2024
97d31b0
Restored full functionality on NVIDIA
Feb 7, 2024
07671d7
vllm.core cleanup
Feb 8, 2024
413fb60
vllm init cleanup
Feb 8, 2024
a38686e
vllm.hpu cleanup
Feb 9, 2024
bed7da6
vllm.benchmarks cleanup
Feb 9, 2024
0baa2ef
vllm.entrypoint cleanup
Feb 9, 2024
1f22aa1
Changed is_hpu logic
Feb 13, 2024
eb2c22a
vllm.benchmark cleanup
Feb 13, 2024
e69fca6
Fixed importing condition
Feb 13, 2024
38cc53b
tests cleanup
Feb 13, 2024
54d499a
removed dummy printings
Feb 13, 2024
c0ea99c
Update test_api_server.py
Feb 13, 2024
ea3ea44
restored attention and logprobs tests functionality on Nvidia
Feb 14, 2024
5543642
throughput benchmark cleanup
Feb 16, 2024
a2acb86
Changed Habana copyright header
Feb 16, 2024
956bab7
Restored alibi in bloom
Feb 16, 2024
702d8a7
Added BSD license header
Feb 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np
import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch as htorch
from tqdm import tqdm

from vllm import LLM, SamplingParams
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import aiohttp
import numpy as np
from transformers import PreTrainedTokenizerBase
import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch as htorch
from vllm.transformers_utils.tokenizer import get_tokenizer

# (prompt len, output len, latency)
Expand Down
10 changes: 8 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import List, Optional, Tuple

import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch as htorch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm
Expand Down Expand Up @@ -71,6 +74,7 @@ def run_vllm(
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
profiling: bool = False, # For Gaudi2
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand Down Expand Up @@ -104,7 +108,7 @@ def run_vllm(

start = time.perf_counter()
# FIXME(woosuk): Do not use internal method.
llm._run_engine(use_tqdm=True)
llm._run_engine(use_tqdm=True, profiling=profiling)
end = time.perf_counter()
return end - start

Expand Down Expand Up @@ -206,7 +210,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager)
args.max_model_len, args.enforce_eager,
args.profiling)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -284,6 +289,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument("--profiling", action='store_true', help='Profiling first 4 steps')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import torch

from vllm._C import ops
from vllm.utils import is_hpu
if is_hpu():
from vllm.hpu import ops
else:
from vllm._C import ops

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand Down
14 changes: 14 additions & 0 deletions requirements-hpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
ninja # For faster builds.
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
#torch == 2.1.2
transformers >= 4.36.0 # Required for Mixtral.
#xformers == 0.0.23.post1 # Required for CUDA 12.1.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: May want to rephrase the comment here to mention that required functionality is integrated for HPU.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's true. I have some changes local because I'm still testing compatibility in all possible places (tests, benchmarks)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the comment thread resolutions be withheld until the changes land on the PR? The current state makes it harder to track which issues are known, since comments on their instances may be closed without a visible change.

fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
128 changes: 65 additions & 63 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,26 @@
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


def _is_hip() -> bool:
return torch.version.hip is not None


def _is_cuda() -> bool:
return torch.version.cuda is not None
return torch.version.cuda is not None and torch.cuda.is_available()


# Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
CXX_FLAGS = []
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2", "-std=c++17"]
NVCC_FLAGS = []

if _is_cuda() or _is_hip():
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
NVCC_FLAGS = ["-O2", "-std=c++17"]

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]

if _is_hip():
if ROCM_HOME is None:
Expand All @@ -44,11 +51,6 @@ def _is_cuda() -> bool:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]


def get_amdgpu_offload_arch():
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
try:
Expand Down Expand Up @@ -101,43 +103,45 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:


def get_torch_arch_list() -> Set[str]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if env_arch_list is None:
return set()

# List are separated by ; or space.
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
if not torch_arch_list:
if _is_cuda() or _is_hip():
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if env_arch_list is None:
return set()

# List are separated by ; or space.
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
if not torch_arch_list:
return set()

# Filter out the invalid architectures and print a warning.
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list
else:
return set()

# Filter out the invalid architectures and print a warning.
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list



# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_cuda() and not compute_capabilities:
Expand Down Expand Up @@ -227,15 +231,15 @@ def get_torch_arch_list() -> Set[str]:
if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")

vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(vllm_extension)
vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(vllm_extension)


def get_path(*filepath) -> str:
Expand Down Expand Up @@ -264,7 +268,7 @@ def get_vllm_version() -> str:
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"+rocm{rocm_version_str}"
else:
elif _is_cuda():
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
Expand All @@ -274,19 +278,18 @@ def get_vllm_version() -> str:


def read_readme() -> str:
"""Read the README file if present."""
p = get_path("README.md")
if os.path.isfile(p):
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
else:
return ""
"""Read the README file."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read()


def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
if _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
elif not _is_cuda() and not _is_hip():
with open(get_path("requirements-hpu.txt")) as f:
requirements = f.read().strip().split("\n")
else:
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
Expand Down Expand Up @@ -320,6 +323,5 @@ def get_requirements() -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
package_data={"vllm": ["py.typed"]},
cmdclass={"build_ext": BuildExtension} if _is_cuda() or _is_hip() else {},
)
35 changes: 26 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

import pytest
import torch
from vllm.utils import is_hpu
if is_hpu():
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.gpu_migration
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -53,11 +57,18 @@ def __init__(
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
if is_hpu():
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
Expand All @@ -69,9 +80,12 @@ def generate(
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if is_hpu():
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
else:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda
output_ids = self.model.generate(
input_ids.cuda(),
input_ids,
use_cache=True,
**kwargs,
)
Expand Down Expand Up @@ -125,9 +139,12 @@ def generate_greedy_logprobs(
) -> List[List[torch.Tensor]]:
all_logprobs = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if is_hpu():
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
else:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
output = self.model.generate(
input_ids.cuda(),
input_ids,
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
Expand Down
6 changes: 5 additions & 1 deletion tests/kernels/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from vllm.utils import is_hpu


def create_kv_caches(
Expand All @@ -18,7 +19,10 @@ def create_kv_caches(

scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
if is_hpu():
key_cache_shape = (num_blocks, num_heads, head_size, block_size)
else:
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
Expand Down
Loading