From 512c414874313f64fbdb725ece96c78e75e0c8cb Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 19 Feb 2024 11:39:17 +0100 Subject: [PATCH] Enable HPU support in vLLM (#1) * Porting vllm to HPU * add hpu cache allocate * move slot_mapping to cpu and add is_prompt in cache_ops.reshape_and_cache * add bucket to input metadata * 1. limit max block number for lazy mode (TODO) 2. set some inpu metadata from cuda to cpu * remove bucket for block tables * add run bash script and change benchmark config * 1. modify kv cache structure to tensors 2. update hpu paged attention API (for hpu graph compatibility) * add attention mask for generation * add multi_query_kv_attention attn_bias * Temp commit * Integrate fused kernels for RMSNorm and RoPE * Resolve merge conflicts * Minor Gaudi workarounds, add debugging to stock vLLM API server * Fix post-merge pinned memory segfaults * Re-enable sequence decode * Maintain GPU compatibility in cache_engine * Adjust HPU RoPE for non-query runs * Integrate HPU primitive implementations * Add xops bindings * Cast paged attention inputs to bfloat16 * Remove leftover debug calls * Update comments on HPU ops * Restoring NVIDIA compatibility in setup.py * vllm.hpu cleanup * Added HPU-specific requirements * Restored full functionality on NVIDIA * vllm.core cleanup * vllm init cleanup * vllm.hpu cleanup * vllm.benchmarks cleanup * vllm.entrypoint cleanup * Changed is_hpu logic * vllm.benchmark cleanup * Fixed importing condition * tests cleanup * removed dummy printings * Update test_api_server.py * restored attention and logprobs tests functionality on Nvidia * throughput benchmark cleanup * Changed Habana copyright header * Restored alibi in bloom * Added BSD license header --------- Co-authored-by: Xiaotong Chen Co-authored-by: Jinyan Chen Co-authored-by: Mikhail Dvoretckii Co-authored-by: Sebastian Urwan --- benchmarks/benchmark_latency.py | 3 + benchmarks/benchmark_serving.py | 4 + benchmarks/benchmark_throughput.py | 10 +- .../kernels/benchmark_paged_attention.py | 6 +- requirements-hpu.txt | 14 + setup.py | 128 +-- tests/conftest.py | 35 +- tests/kernels/conftest.py | 6 +- tests/kernels/test_attention.py | 145 ++-- tests/kernels/test_cache.py | 6 +- vllm/engine/llm_engine.py | 2 +- vllm/entrypoints/api_server.py | 6 +- vllm/entrypoints/llm.py | 22 +- vllm/entrypoints/openai/api_server.py | 6 +- vllm/hpu/__init__.py | 6 + vllm/hpu/attn_bias.py | 764 ++++++++++++++++++ vllm/hpu/cache_ops.py | 41 + vllm/hpu/cuda_utils.py | 9 + vllm/hpu/ops.py | 137 ++++ vllm/hpu/rotary_embed.py | 117 +++ vllm/hpu/xops.py | 67 ++ vllm/model_executor/layers/activation.py | 6 +- vllm/model_executor/layers/attention.py | 141 ++-- vllm/model_executor/layers/layernorm.py | 23 +- .../model_executor/layers/quantization/awq.py | 6 +- .../layers/quantization/gptq.py | 6 +- .../layers/quantization/squeezellm.py | 7 +- .../model_executor/layers/rotary_embedding.py | 15 +- vllm/model_executor/sampling_metadata.py | 2 +- vllm/utils.py | 13 +- vllm/worker/cache_engine.py | 38 +- vllm/worker/model_runner.py | 8 +- vllm/worker/worker.py | 201 +++++ 33 files changed, 1796 insertions(+), 204 deletions(-) create mode 100644 requirements-hpu.txt create mode 100644 vllm/hpu/__init__.py create mode 100644 vllm/hpu/attn_bias.py create mode 100644 vllm/hpu/cache_ops.py create mode 100644 vllm/hpu/cuda_utils.py create mode 100644 vllm/hpu/ops.py create mode 100644 vllm/hpu/rotary_embed.py create mode 100644 vllm/hpu/xops.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e33d5fb2dc247..f550aba060e38 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -6,6 +6,9 @@ import numpy as np import torch +from vllm.utils import is_hpu +if is_hpu(): + import habana_frameworks.torch as htorch from tqdm import tqdm from vllm import LLM, SamplingParams diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3a80e679191e3..bb28d700fc321 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -25,6 +25,10 @@ import aiohttp import numpy as np from transformers import PreTrainedTokenizerBase +import torch +from vllm.utils import is_hpu +if is_hpu(): + import habana_frameworks.torch as htorch from vllm.transformers_utils.tokenizer import get_tokenizer # (prompt len, output len, latency) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 3aac479c01bd2..9afb4721dd01c 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,6 +6,9 @@ from typing import List, Optional, Tuple import torch +from vllm.utils import is_hpu +if is_hpu(): + import habana_frameworks.torch as htorch from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) from tqdm import tqdm @@ -71,6 +74,7 @@ def run_vllm( dtype: str, max_model_len: Optional[int], enforce_eager: bool, + profiling: bool = False, # For Gaudi2 ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -104,7 +108,7 @@ def run_vllm( start = time.perf_counter() # FIXME(woosuk): Do not use internal method. - llm._run_engine(use_tqdm=True) + llm._run_engine(use_tqdm=True, profiling=profiling) end = time.perf_counter() return end - start @@ -206,7 +210,8 @@ def main(args: argparse.Namespace): args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager) + args.max_model_len, args.enforce_eager, + args.profiling) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -284,6 +289,7 @@ def main(args: argparse.Namespace): parser.add_argument("--enforce-eager", action="store_true", help="enforce eager execution") + parser.add_argument("--profiling", action='store_true', help='Profiling first 4 steps') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 935393e9942ce..e47a5313c444c 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -4,7 +4,11 @@ import torch -from vllm._C import ops +from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu import ops +else: + from vllm._C import ops NUM_BLOCKS = 1024 PARTITION_SIZE = 512 diff --git a/requirements-hpu.txt b/requirements-hpu.txt new file mode 100644 index 0000000000000..73a64a94391f0 --- /dev/null +++ b/requirements-hpu.txt @@ -0,0 +1,14 @@ +ninja # For faster builds. +psutil +ray >= 2.5.1 +pandas # Required for Ray data. +pyarrow # Required for Ray data. +sentencepiece # Required for LLaMA tokenizer. +numpy +#torch == 2.1.2 +transformers >= 4.36.0 # Required for Mixtral. +#xformers == 0.0.23.post1 # Required for CUDA 12.1. +fastapi +uvicorn[standard] +pydantic == 1.10.13 # Required for OpenAI server. +aioprometheus[starlette] diff --git a/setup.py b/setup.py index 811d494e7a01f..f182a0084fae1 100644 --- a/setup.py +++ b/setup.py @@ -19,19 +19,26 @@ ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"} # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) - def _is_hip() -> bool: return torch.version.hip is not None def _is_cuda() -> bool: - return torch.version.cuda is not None + return torch.version.cuda is not None and torch.cuda.is_available() # Compiler flags. -CXX_FLAGS = ["-g", "-O2", "-std=c++17"] +CXX_FLAGS = [] # TODO(woosuk): Should we use -O3? -NVCC_FLAGS = ["-O2", "-std=c++17"] +NVCC_FLAGS = [] + +if _is_cuda() or _is_hip(): + CXX_FLAGS = ["-g", "-O2", "-std=c++17"] + NVCC_FLAGS = ["-O2", "-std=c++17"] + + ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 + CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] if _is_hip(): if ROCM_HOME is None: @@ -44,11 +51,6 @@ def _is_cuda() -> bool: raise RuntimeError( "Cannot find CUDA_HOME. CUDA must be available to build the package.") -ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 -CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] - - def get_amdgpu_offload_arch(): command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" try: @@ -101,43 +103,45 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_torch_arch_list() -> Set[str]: - # TORCH_CUDA_ARCH_LIST can have one or more architectures, - # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the - # compiler to additionally include PTX code that can be runtime-compiled - # and executed on the 8.6 or newer architectures. While the PTX code will - # not give the best performance on the newer architectures, it provides - # forward compatibility. - env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) - if env_arch_list is None: - return set() - - # List are separated by ; or space. - torch_arch_list = set(env_arch_list.replace(" ", ";").split(";")) - if not torch_arch_list: + if _is_cuda() or _is_hip(): + # TORCH_CUDA_ARCH_LIST can have one or more architectures, + # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the + # compiler to additionally include PTX code that can be runtime-compiled + # and executed on the 8.6 or newer architectures. While the PTX code will + # not give the best performance on the newer architectures, it provides + # forward compatibility. + env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + if env_arch_list is None: + return set() + + # List are separated by ; or space. + torch_arch_list = set(env_arch_list.replace(" ", ";").split(";")) + if not torch_arch_list: + return set() + + # Filter out the invalid architectures and print a warning. + valid_archs = NVIDIA_SUPPORTED_ARCHS.union( + {s + "+PTX" + for s in NVIDIA_SUPPORTED_ARCHS}) + arch_list = torch_arch_list.intersection(valid_archs) + # If none of the specified architectures are valid, raise an error. + if not arch_list: + raise RuntimeError( + "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env " + f"variable ({env_arch_list}) is supported. " + f"Supported CUDA/ROCM architectures are: {valid_archs}.") + invalid_arch_list = torch_arch_list - valid_archs + if invalid_arch_list: + warnings.warn( + f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are " + "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " + f"({env_arch_list}). Supported CUDA/ROCM architectures are: " + f"{valid_archs}.", + stacklevel=2) + return arch_list + else: return set() - - # Filter out the invalid architectures and print a warning. - valid_archs = NVIDIA_SUPPORTED_ARCHS.union( - {s + "+PTX" - for s in NVIDIA_SUPPORTED_ARCHS}) - arch_list = torch_arch_list.intersection(valid_archs) - # If none of the specified architectures are valid, raise an error. - if not arch_list: - raise RuntimeError( - "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env " - f"variable ({env_arch_list}) is supported. " - f"Supported CUDA/ROCM architectures are: {valid_archs}.") - invalid_arch_list = torch_arch_list - valid_archs - if invalid_arch_list: - warnings.warn( - f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are " - "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " - f"({env_arch_list}). Supported CUDA/ROCM architectures are: " - f"{valid_archs}.", - stacklevel=2) - return arch_list - - + # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() if _is_cuda() and not compute_capabilities: @@ -227,15 +231,15 @@ def get_torch_arch_list() -> Set[str]: if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") -vllm_extension = CUDAExtension( - name="vllm._C", - sources=vllm_extension_sources, - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(vllm_extension) + vllm_extension = CUDAExtension( + name="vllm._C", + sources=vllm_extension_sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) + ext_modules.append(vllm_extension) def get_path(*filepath) -> str: @@ -264,7 +268,7 @@ def get_vllm_version() -> str: if hipcc_version != MAIN_CUDA_VERSION: rocm_version_str = hipcc_version.replace(".", "")[:3] version += f"+rocm{rocm_version_str}" - else: + elif _is_cuda(): cuda_version = str(nvcc_cuda_version) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] @@ -274,12 +278,8 @@ def get_vllm_version() -> str: def read_readme() -> str: - """Read the README file if present.""" - p = get_path("README.md") - if os.path.isfile(p): - return io.open(get_path("README.md"), "r", encoding="utf-8").read() - else: - return "" + """Read the README file.""" + return io.open(get_path("README.md"), "r", encoding="utf-8").read() def get_requirements() -> List[str]: @@ -287,6 +287,9 @@ def get_requirements() -> List[str]: if _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") + elif not _is_cuda() and not _is_hip(): + with open(get_path("requirements-hpu.txt")) as f: + requirements = f.read().strip().split("\n") else: with open(get_path("requirements.txt")) as f: requirements = f.read().strip().split("\n") @@ -320,6 +323,5 @@ def get_requirements() -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension}, - package_data={"vllm": ["py.typed"]}, + cmdclass={"build_ext": BuildExtension} if _is_cuda() or _is_hip() else {}, ) diff --git a/tests/conftest.py b/tests/conftest.py index 16c04e01d703c..fa24c667f93d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,10 @@ import pytest import torch +from vllm.utils import is_hpu +if is_hpu(): + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.gpu_migration from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams @@ -53,11 +57,18 @@ def __init__( ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() + if is_hpu(): + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + else: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).cuda() if tokenizer_name is None: tokenizer_name = model_name self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) @@ -69,9 +80,12 @@ def generate( ) -> List[Tuple[List[int], str]]: outputs: List[Tuple[List[int], str]] = [] for prompt in prompts: - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + if is_hpu(): + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + else: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda output_ids = self.model.generate( - input_ids.cuda(), + input_ids, use_cache=True, **kwargs, ) @@ -125,9 +139,12 @@ def generate_greedy_logprobs( ) -> List[List[torch.Tensor]]: all_logprobs = [] for prompt in prompts: - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + if is_hpu(): + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + else: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda() output = self.model.generate( - input_ids.cuda(), + input_ids, use_cache=True, do_sample=False, max_new_tokens=max_tokens, diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index 97516bd3052cf..17af2f5c3868d 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -2,6 +2,7 @@ import pytest import torch +from vllm.utils import is_hpu def create_kv_caches( @@ -18,7 +19,10 @@ def create_kv_caches( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + if is_hpu(): + key_cache_shape = (num_blocks, num_heads, head_size, block_size) + else: + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 614b65f82ccbd..e7f2f5bb395ef 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,11 +3,18 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import ops -from vllm.utils import get_max_shared_memory_bytes +from vllm.utils import get_max_shared_memory_bytes, is_hpu +if is_hpu(): + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.gpu_migration + from vllm.hpu import ops + from vllm.hpu import xops + from vllm.hpu.attn_bias import BlockDiagonalCausalMask +else: + from vllm._C import ops + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. @@ -16,6 +23,9 @@ NUM_BLOCKS = 40000 # Arbitrary values for testing PARTITION_SIZE = 512 +VERSION = ["v1", "v2"] +if is_hpu(): + VERSION.pop() DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing @@ -71,8 +81,11 @@ def ref_single_query_cached_kv_attention( block_number = int(block_table[j // block_size]) block_offset = j % block_size - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_kv_heads, head_size) + if is_hpu(): + k = key_cache[block_number, :, :, block_offset] + else: + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) keys.append(k) v = value_cache[block_number, :, :, block_offset] @@ -97,7 +110,7 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize("version", VERSION) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -161,41 +174,8 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - if version == "v1": - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - ) - elif version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) - assert PARTITION_SIZE % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, + if is_hpu(): + output = ops.paged_attention_v1( query, key_cache, value_cache, @@ -208,7 +188,54 @@ def test_paged_attention( alibi_slopes, ) else: - raise AssertionError(f"Unknown version: {version}") + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + elif version == "v2": + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + else: + raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. ref_output = torch.empty_like(query) @@ -305,19 +332,31 @@ def test_multi_query_kv_attention( key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) - cu_seq_lens = [0] for seq_len in seq_lens: cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + + if is_hpu(): + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + cu_seq_lens, + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + else: + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 9b5d7687a3fec..bdef59b3b86b1 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -3,7 +3,11 @@ import pytest import torch -from vllm._C import cache_ops +from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu import cache_ops +else: + from vllm._C import cache_ops DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [83] # Arbitrary values for testing diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d6e388bf135b2..481ba1a17c808 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -217,7 +217,7 @@ def _init_cache(self) -> None: # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) + num_gpu_blocks = min(10500, min(b[0] for b in num_blocks)) num_cpu_blocks = min(b[1] for b in num_blocks) # FIXME(woosuk): Change to debug log. logger.info(f"# GPU blocks: {num_gpu_blocks}, " diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 6910b3265dfd2..629f329a568c4 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -1,7 +1,11 @@ import argparse import json from typing import AsyncGenerator - +import torch +from vllm.utils import is_hpu +if is_hpu(): + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.gpu_migration from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import uvicorn diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0700298b03a3d..8220ccd406f03 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -8,7 +8,9 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.utils import Counter +from vllm.utils import is_hpu +import torch class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -174,11 +176,23 @@ def _add_request( self.llm_engine.add_request(request_id, prompt, sampling_params, prompt_token_ids) - def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + def _run_engine(self, use_tqdm: bool, profiling: bool = False) -> List[RequestOutput]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() pbar = tqdm(total=num_requests, desc="Processed prompts") + + if profiling and is_hpu(): + prof = torch.profiler.profile( + schedule = torch.profiler.schedule(wait=6, warmup=0, active=2, repeat=1), + activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.HPU], + with_stack = True, + record_shapes = False, + on_trace_ready = torch.profiler.tensorboard_trace_handler("./", use_gzip = True) + ) + prof.start() + count = 0 + # Run the engine. outputs: List[RequestOutput] = [] while self.llm_engine.has_unfinished_requests(): @@ -188,6 +202,12 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: outputs.append(output) if use_tqdm: pbar.update(1) + if profiling and is_hpu(): + htorch.core.mark_step() + prof.step() + if profiling and is_hpu(): + htorch.hpu.synchronize() + prof.stop() if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index be5f4190e633f..d3062e9220dd8 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -8,7 +8,11 @@ import time from http import HTTPStatus from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union - +import torch +from vllm.utils import is_hpu +if is_hpu(): + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.gpu_migration from aioprometheus import MetricsMiddleware from aioprometheus.asgi.starlette import metrics import fastapi diff --git a/vllm/hpu/__init__.py b/vllm/hpu/__init__.py new file mode 100644 index 0000000000000..b8e4d3aac98a7 --- /dev/null +++ b/vllm/hpu/__init__.py @@ -0,0 +1,6 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +############################################################################### diff --git a/vllm/hpu/attn_bias.py b/vllm/hpu/attn_bias.py new file mode 100644 index 0000000000000..ff508a59cc56a --- /dev/null +++ b/vllm/hpu/attn_bias.py @@ -0,0 +1,764 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +import torch + + +class AttentionBias: + """Base class for a custom bias that can be applied \ + as the attn_bias argument in + :attr:`xformers.ops.memory_efficient_attention`. + + That function has the ability to add a tensor, the + attention bias, to the QK^T matrix before it is used + in the softmax part of the attention calculation. + The attention bias tensor with shape + (B or 1, n_queries, number of keys) + can be given as the attn_bias input. + The most common use case is for an attention bias is + to contain only zeros and negative infinities, which forms + a mask so that some queries only attend to some keys. + + Children of this class define alternative things which can + be used as the attn_bias input to define an attention bias which + forms such a mask, for some common cases. + + When using an :attr:`xformers.ops.AttentionBias` + instead of a :attr:`torch.Tensor`, the mask matrix does + not need to be materialized, and can be + hardcoded into some kernels for better performance. + + See: + + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask` + + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """ + Materializes the bias as a `torch.Tensor`. This is very slow + and we don't attempt to make it fast. Only use for debugging/testing. + + Shape should be like `[*, q_seqlen, k_seqlen]` + """ + raise NotImplementedError() + + +class LowerTriangularMask(AttentionBias): + """ + A lower-triangular (aka causal) mask + + A query Q cannot attend to a key which is farther from the + initial key than Q is from the initial query. + """ + + def __init__(self, *tensor_args, **tensor_kwargs) -> None: + # NOTE: Unused arguments, we keep them for backward compatibility + super().__init__() + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore + + def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias": + return LowerTriangularMaskWithTensorBias(bias) + + +class LowerTriangularMaskWithTensorBias(LowerTriangularMask): + """A lower-triangular (aka causal) mask with an additive bias""" + + def __init__(self, bias: torch.Tensor) -> None: + self._bias = bias + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return super().materialize(shape, dtype=dtype, device=device) + self._bias + + +@dataclass +class _SeqLenInfo: + """ + (Internal) Represents the division of a dimension into blocks. + + For example, to represents a dimension of length 7 divided into + three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 2, 5, 7] + seqstart: torch.IntTensor([0, 2, 5, 7]) + """ + + seqstart: torch.Tensor + max_seqlen: int + min_seqlen: int + seqstart_py: List[int] + + def to(self, device: torch.device) -> None: + self.seqstart = self.seqstart.to(device, non_blocking=True) + + def intervals(self) -> Iterable[Tuple[int, int]]: + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = [0] + max_seqlen = -1 + min_seqlen = -1 + for seqlen in seqlens: + min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + return cls( + max_seqlen=max_seqlen, + min_seqlen=min_seqlen, + seqstart=seqstart, + seqstart_py=seqstart_py, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1: + raise ValueError( + f"Invalid `torch.Tensor` of shape {x.shape}, expected format " + f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n" + f" seqstart: {self.seqstart_py}" + ) + if batch_sizes is None: + batch_sizes = [1] * (len(self.seqstart_py) - 1) + split_chunks = [] + it = 0 + for batch_size in batch_sizes: + split_chunks.append( + self.seqstart_py[it + batch_size] - self.seqstart_py[it] + ) + it += batch_size + return [ + tensor.reshape([bs, -1, *tensor.shape[2:]]) + for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1)) + ] + + +@dataclass +class _PaddedSeqLenInfo(_SeqLenInfo): + """ + (Internal) Represents the division of a dimension into blocks which are + padded out to the same total length. + + For example, to represent a dimension of length 12 with space for + three blocks of length 4, but where the occupied lengths are + 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`. + + The layout along the dimension is + + 0 ─► block 0 + block 0 + + + 4 ─► block 1 + block 1 + block 1 + + 8 ─► block 2 + block 2 + + + 12 ─► + + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 4, 8, 12] + seqstart: torch.IntTensor([0, 4, 8, 12]) + seqlen_py: [2, 3, 2] + seqlen: torch.IntTensor([2, 3, 2]) + padding: 4 + """ + + seqlen: torch.Tensor + seqlen_py: Sequence[int] + padding: int + # From parent: seqstart[i] contains the start position + # of the i-th sequence + # seqstart: torch.Tensor + + def __post_init__(self) -> None: + assert len(self.seqstart_py) == len(self.seqlen_py) + 1 + + def to(self, device: torch.device) -> None: + self.seqlen = self.seqlen.to(device, non_blocking=True) + super().to(device) + + def intervals(self) -> Iterable[Tuple[int, int]]: + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + raise RuntimeError( + "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`" + ) + + @classmethod + def from_seqlens_padded( + cls, seqlens: Sequence[int], padding: int + ) -> "_PaddedSeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + seqstart = padding * torch.arange(batch_size) + """ + assert not isinstance(seqlens, torch.Tensor) + assert all(seqlen <= padding for seqlen in seqlens) + seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + return cls( + seqlen=torch.tensor(seqlens, dtype=torch.int32), + seqlen_py=seqlens, + max_seqlen=max(seqlens), + min_seqlen=min(seqlens), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32), + seqstart_py=seqstart_py, + padding=padding, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + raise NotImplementedError("_PaddedSeqLenInfo.split") + + +@dataclass +class BlockDiagonalMask(AttentionBias): + """ + A block-diagonal mask that can be passed as ``attn_bias`` + argument to :attr:`xformers.ops.memory_efficient_attention`. + + Queries and Keys are each divided into the same number of blocks. + Queries in block i only attend to keys in block i. + + .. figure:: /_static/block_diag_bias.png + + This bias can be used to handle a batch of sequences of + different lengths, via :attr:`BlockDiagonalMask.from_tensor_list` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) + + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _SeqLenInfo + _batch_sizes: Optional[Sequence[int]] = None + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return torch.zeros( + shape, + dtype=dtype, + device=device, + ) + + def materialize( + self, + shape: Optional[Tuple[int, ...]] = None, + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape is None: + shape = (self.q_seqinfo.seqstart_py[-1], + self.k_seqinfo.seqstart_py[-1]) + assert shape[-1] == self.k_seqinfo.seqstart_py[-1], ( + shape[-1], + self.k_seqinfo.seqstart_py[-1], + ) + assert shape[-2] == self.q_seqinfo.seqstart_py[-1], ( + shape[-2], + self.q_seqinfo.seqstart_py[-1], + ) + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Optional[Sequence[int]] = None, + ) -> "BlockDiagonalMask": + """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. + + Args: + q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors + kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value. + (Defaults to ``q_seqlen``.) + Returns: + BlockDiagonalMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + if kv_seqlen is None or q_seqlen == kv_seqlen: + k_seqinfo = q_seqinfo + else: + k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + @classmethod + def from_tensor_list( + cls, + tensors: Sequence[torch.Tensor], + ) -> Tuple["BlockDiagonalMask", torch.Tensor]: + """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors + concatenated on the sequence length dimension + + .. figure:: /_static/block_diag_cat_split.png + + See also :attr:`BlockDiagonalMask.split` to split the returned + :attr:`torch.Tensor` back to a list of tensors of varying sequence length + + Args: + tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``. + All tensors should have the same dimension and the same batch size ``B``, but + they can have different sequence length ``M``. + + Returns: + Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention + along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]`` + """ + batch_sizes = [tensor.shape[0] for tensor in tensors] + seqlens = [] + for x in tensors: + for _ in range(x.shape[0]): + seqlens.append(x.shape[1]) + block_diag = cls.from_seqlens(seqlens) + block_diag._batch_sizes = batch_sizes + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors) + concat_tensors = torch.cat(tensors_bs1, dim=1) + return block_diag, concat_tensors + + @classmethod + def from_tensor_lists_qkv( + cls, + tensors_q: Sequence[torch.Tensor], + tensors_k: Sequence[torch.Tensor], + tensors_v: Optional[Sequence[torch.Tensor]] = None, + ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert len(tensors_q) == len(tensors_k) + assert tensors_v is None or len(tensors_v) == len(tensors_q) + batch_sizes = [tensor.shape[0] for tensor in tensors_q] + q_seqlens, kv_seqlens = [], [] + for i, (q, k) in enumerate(zip(tensors_q, tensors_k)): + assert q.shape[0] == k.shape[0] + q_seqlens += [q.shape[1]] * q.shape[0] + kv_seqlens += [k.shape[1]] * k.shape[0] + assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2] + block_diag = cls.from_seqlens(q_seqlens, kv_seqlens) + block_diag._batch_sizes = batch_sizes + return ( + block_diag, + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1) + if tensors_v is not None + else None, + ) + + def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.k_seqinfo.split(tensor, self._batch_sizes) + + def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list` + + Args: + tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]`` + + Returns: + Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths + """ + assert self.q_seqinfo is self.k_seqinfo + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def make_causal(self) -> "BlockDiagonalCausalMask": + """Makes each block causal""" + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask": + """Makes each block causal with a possible non-causal prefix""" + return BlockDiagonalCausalFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_local_attention( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionMask": + """Experimental: Makes each block causal with local attention""" + return BlockDiagonalCausalLocalAttentionMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + def make_local_attention_from_bottomright( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask": + """Experimental: Makes each block causal with local attention, start from bottom right""" + return BlockDiagonalCausalLocalAttentionFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + +@dataclass +class BlockDiagonalCausalMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + + Queries and Keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which is farther from the initial key in block i than Q + is from the initial query in block i. + """ + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularMask().materialize( + shape, + dtype=dtype, + device=device, + ) + + +@dataclass +class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + This mask allows for a non-causal prefix + NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not + defined (softmax of vector of `-inf` in the attention) + + Queries and keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which nearer the final key in block i than Q is to the + final query in block i. + """ + + def __post_init__(self) -> None: + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + num_queries = q_end - q_start + num_keys = k_end - k_start + if num_keys < num_queries: + raise ValueError( + f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}." + " Expected `num_keys >= num_queries`" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + num_queries, num_keys = shape[-2:] + return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore + + +@dataclass +class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`, + except an offset on causality is allowed for each block and we support padding for k/v + + The keys and values are divided into blocks which are padded out to + the same total length. + For example, if there is space for 12 keys, for three blocks of + max length 4, but we only want to use the first 2, 3 and 2 + of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`. + The queries are divided into blocks, without padding, of lengths given by + q_seqlen. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area), + nor one which is nearer to the final key in block i + than Q is to the final query in block i. + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + causal_diagonal: Any = None # unused. Exists for BC only. + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + num_queries, num_keys = shape[-2:] + return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong") + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong") + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + causal_diagonal: Any = None, + ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalCausalWithOffsetPaddedKeysMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + +@dataclass +class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q - self._window_size >= k: + # Each query only attends to keys no further than window_size back. + # When q > k + window_size, there will be a query for which the window doesn't reach any key. + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore + if self._window_size is not None and self._window_size > 0: + mask = torch.triu(mask, diagonal=-self._window_size + 1) + mask = torch.log(mask) + return mask.to(dtype) + + +@dataclass +class BlockDiagonalCausalLocalAttentionFromBottomRightMask( + BlockDiagonalCausalFromBottomRightMask +): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + super().__post_init__() + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + num_queries, num_keys = shape[-2:] + mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore + if self._window_size is not None: + mask = torch.triu( + mask, diagonal=num_keys - num_queries - self._window_size + 1 + ) + mask = torch.log(mask) + return mask.to(dtype) diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py new file mode 100644 index 0000000000000..fb08e4167a10a --- /dev/null +++ b/vllm/hpu/cache_ops.py @@ -0,0 +1,41 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +############################################################################### + +from typing import Tuple +import torch +import habana_frameworks.torch as htorch + + +def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, is_prompt=False): + """ + key: [num_tokens, num_heads, head_size] + value: [num_tokens, num_heads, head_size] + key_cache: [num_heads, head_size, block_size] * num_blocks + value_cache: [num_heads, head_size, block_size] * num_blocks + slot_mapping: [num_tokens] + """ + num_tokens = key.shape[0] + block_size = key_cache.shape[-1] + slot_mapping = slot_mapping.to(key.device) + block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + for i in range(0, num_tokens, block_size): + key_cache.index_put_([block_indices[i]], key[i:i+block_size].transpose(0,1).transpose(1,2)) + value_cache.index_put_([block_indices[i]], value[i:i+block_size].transpose(0,1).transpose(1,2)) + else: + key_cache = key_cache.permute(0, 3, 1, 2) + value_cache = value_cache.permute(0, 3, 1, 2) + block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_offsets = torch.fmod(slot_mapping, block_size) + slot_indices = torch.stack([block_indices, block_offsets], dim=-1) + index = torch.tensor(0, device=key.device) + for i in range(num_tokens): + key_cache[slot_indices[i][0], slot_indices[i][1], :, :] = key[i] + value_cache[slot_indices[i][0], slot_indices[i][1], :, :] = value[i] + index.add_(1) + key_cache = key_cache.permute(0, 2, 3, 1) + value_cache = value_cache.permute(0, 2, 3, 1) diff --git a/vllm/hpu/cuda_utils.py b/vllm/hpu/cuda_utils.py new file mode 100644 index 0000000000000..bec242cf985c2 --- /dev/null +++ b/vllm/hpu/cuda_utils.py @@ -0,0 +1,9 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +############################################################################### + +def get_device_attribute(attribute, device_id): + return 10240 # TODO: fake value now diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py new file mode 100644 index 0000000000000..79f8f186a2b21 --- /dev/null +++ b/vllm/hpu/ops.py @@ -0,0 +1,137 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +############################################################################### + +import torch +import torch.nn as nn +import torch.nn.functional as F +import habana_frameworks.torch as htorch +from typing import List, Optional, Tuple + +def silu_and_mul(output, input): + htorch.core.mark_step() + d = input.shape[-1] // 2 + silu = torch.nn.SiLU().to(input.device) + x, y = torch.split(input, d, dim=-1) + output.copy_(silu(x) * y) + htorch.core.mark_step() + +def gelu_new(output, input): + raise NotImplementedError + +def gelu_fast(output, input): + raise NotImplementedError + +def paged_attention_v1(query_in, key_cache_in, value_cache_in, head_mapping, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, attn_masks=None) -> None: + query = query_in.bfloat16() + key_cache = key_cache_in.bfloat16() + value_cache = value_cache_in.bfloat16() + num_kv_heads = value_cache[0].shape[0] + head_size = value_cache[0].shape[1] + block_size = value_cache[0].shape[2] + num_seqs = query.shape[0] + num_query_heads = query.shape[1] + max_num_blocks_per_seq = block_tables.shape[1] + + if alibi_slopes or num_query_heads != num_kv_heads: #or attn_masks is None: + raise NotImplementedError + + attn_weights_blocks = [] + value_blocks = [] + seq_index = torch.tensor([0], dtype=torch.int64, device="hpu") + + for i in range(0, max_num_blocks_per_seq): + # FIXME: dynamic hard override for filler. These blocks would contribute nothing to the output due to zero attention_probs and + # will clog up compute resources. The override itself makes the code unsuitable for graph precompilation + if (i - 2) * block_size > torch.max(context_lens): + break + attn_weights = torch.full((num_seqs, num_query_heads, 1, block_size), torch.finfo(query.dtype).min, dtype=query.dtype, device="hpu") + values = torch.zeros((num_seqs, num_query_heads, head_size, block_size), dtype=query.dtype, device="hpu") + for seq_id in range(num_seqs): + seq_index.fill_(seq_id) + if i * block_size < context_lens[seq_id]: + + q = torch.index_select(query, 0, seq_index).transpose(0, 1) + key = torch.index_select(key_cache, 0, block_tables[seq_id][i]).squeeze(0) + attn_weight = scale * torch.matmul(q, key) + + if attn_masks is not None: + attn_mask = torch.index_select(attn_masks[i], 0, seq_index) + attn_weight = torch.masked_fill(attn_weight, ~(attn_mask.unsqueeze(0).to(torch.bool)), torch.finfo(attn_weight.dtype).min) + + # FIXME: these dynamic checks serve to ensure the -inf default value is not overwritten with fillers that would cause errors + # in logsoftmax computation. A change to custom block multiplication code is required to avoid incurring extra costs here + if context_lens[seq_id] < (i + 1) * block_size: + if context_lens[seq_id] - i*block_size < 0: + attn_weight = torch.finfo(query.dtype).min + else: + attn_weight[:, :, context_lens[seq_id] - i*block_size:] = torch.finfo(query.dtype).min + attn_weights.index_copy_(0, seq_index, attn_weight.unsqueeze(0)) + value = torch.index_select(value_cache, 0, block_tables[seq_id][i]) + # FIXME: these checks concern filler values in the V cache and should be removed once the underlying issue is addressed + value = torch.nan_to_num(value) + value[value < -1.0e+30] = 0.0 + values.index_copy_(0, seq_index, value) + torch.hpu.synchronize() + + attn_weights_blocks.append(attn_weights.reshape(num_seqs * num_query_heads, 1, block_size)) + value_blocks.append(values.reshape(num_seqs * num_query_heads, head_size, block_size).transpose(1, 2)) + + exp_sum = torch.zeros((*attn_weights_blocks[0].shape[:2], 1), dtype=attn_weights_blocks[0].dtype, device="hpu") + for x in attn_weights_blocks: + exp_sum.add_(torch.exp(x).sum(dim=-1, keepdim=True)) + + output = torch.zeros_like(query) + for i in range(len(attn_weights_blocks)): + attention_probs = torch.exp(attn_weights_blocks[i]) / exp_sum + value = value_blocks[i] + out = torch.matmul(attention_probs.to(value.dtype), value).reshape(num_seqs, num_query_heads, head_size) + output.add_(out) + htorch.core.mark_step() + return output.to(dtype=query_in.dtype) + +def rms_norm(out, hidden_states, weight, eps): + htorch.core.mark_step() + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + out.copy_(weight * hidden_states.to(input_dtype)) + htorch.core.mark_step() + +def rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def apply_rope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + rotate_fn = rotate_neox if is_neox_style else rotate_gptj + q_embed = (q * cos) + (rotate_fn(q) * sin) + k_embed = (k * cos) + (rotate_fn(k) * sin) + return q_embed, k_embed + + +def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox_style): + # FIXME: the below code is unused legacy code not meant to be used. Use FusedRoPE + # on HPU and delete this once coverage is verified + raise NotImplementedError + +def awq_gemm(*args): + raise NotImplementedError diff --git a/vllm/hpu/rotary_embed.py b/vllm/hpu/rotary_embed.py new file mode 100644 index 0000000000000..3def58b11feb6 --- /dev/null +++ b/vllm/hpu/rotary_embed.py @@ -0,0 +1,117 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +############################################################################### + +import torch +import torch.nn as nn + +def get_device_name(): + """ + Returns the name of the current device: Gaudi or Gaudi2. + + Inspired by: https://github.com/HabanaAI/Model-References/blob/a87c21f14f13b70ffc77617b9e80d1ec989a3442/PyTorch/computer_vision/classification/torchvision/utils.py#L274 + """ + import habana_frameworks.torch.utils.experimental as htexp + + device_type = htexp._get_device_type() + + if device_type == htexp.synDeviceType.synDeviceGaudi: + return "gaudi" + elif device_type == htexp.synDeviceType.synDeviceGaudi2: + return "gaudi2" + else: + raise ValueError(f"Unsupported device: the device type is {device_type}.") + +# TODO: remove this workaround when FusedRoPE properly works on Gaudi +if get_device_name() == "gaudi2": + try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV1 as FusedRoPE + except ImportError: + print("Not using HPU fused kernel for apply_rotary_pos_emb") + FusedRoPE = None +else: + FusedRoPE = None + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids]#.unsqueeze(unsqueeze_dim) + sin = sin[position_ids]#.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class HpuRotaryEmbedding(nn.Module): + def __init__(self, head_size, rotary_dim, max_position_embeddings=2048, base=10000, is_neox_style=None, device='cuda'): + super().__init__() + + self.head_size = head_size + self.dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor): + seq_len = key.shape[-2] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype) + + cos, sin = self.cos_cached[:seq_len].to(dtype=query.dtype), self.sin_cached[:seq_len].to(dtype=query.dtype) + query = query.reshape((query.shape[0], query.shape[1], query.shape[2] // self.head_size, self.head_size)) + key = key.reshape((key.shape[0], key.shape[1], key.shape[2] // self.head_size, self.head_size)) + if query.device.type == "hpu" and FusedRoPE: + if len(positions[0]) == 1: + cos = self.cos_cached[positions].unsqueeze(2).to(dtype=query.dtype) + sin = self.sin_cached[positions].unsqueeze(2).to(dtype=query.dtype) + else: + cos = cos[positions].unsqueeze(2) + sin = sin[positions].unsqueeze(2) + query, key = FusedRoPE.apply(query, cos, sin, 0), FusedRoPE.apply(key, cos, sin, 0) + else: + query, key = apply_rotary_pos_emb(query, key, cos, sin, positions) + return query.reshape((query.shape[0], query.shape[1], query.shape[2] * query.shape[3])), key.reshape((key.shape[0], key.shape[1], key.shape[2] * key.shape[3])) diff --git a/vllm/hpu/xops.py b/vllm/hpu/xops.py new file mode 100644 index 0000000000000..6460cb6ac4f33 --- /dev/null +++ b/vllm/hpu/xops.py @@ -0,0 +1,67 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +############################################################################### + +import habana_frameworks.torch as htorch +import torch +import torch.nn.functional as F +from typing import List, Optional, Tuple, Union +from .attn_bias import AttentionBias + + +def block_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + query = query * scale + attn = query.transpose(0,1) @ key.transpose(0, 1).transpose(1, 2) + if attn_mask is not None: + attn = attn + attn_mask.to(attn.dtype) + attn = attn.softmax(-1) + out = attn @ value.transpose(0, 1) + out = out.transpose(0, 1) + return out + + +def memory_efficient_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_lens: List[int], + attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + dim = query.dim() + if dim == 4: + query, key, value = query.squeeze(0), key.squeeze(0), value.squeeze(0) + num_seqs = len(cu_seq_lens) - 1 + outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + mask_start_idx = i * seq_len + mask_end_idx = (i + 1) * seq_len + + # Create attention mask. + attn_mask = attn_bias.materialize(device=query.device) + output = block_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask[mask_start_idx:mask_end_idx, + mask_start_idx:mask_end_idx], + ) + outputs.append(output) + out = torch.cat(outputs, dim=0) + if dim == 4: + out = out.unsqueeze(0) + return out diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 1af120d13cd4b..2bdd3a62b8b39 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,7 +6,11 @@ import torch.nn as nn import torch.nn.functional as F -from vllm._C import ops +from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu import ops +else: + from vllm._C import ops from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 6482875d1c55b..a6f6cca70480d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -3,14 +3,22 @@ import torch import torch.nn as nn -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias) -from vllm._C import ops -from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata -from vllm.utils import is_hip +from vllm.utils import is_hip, is_hpu + +if is_hpu(): + from vllm.hpu import ops + from vllm.hpu import cache_ops + from vllm.hpu import xops + from vllm.hpu.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) +else: + from vllm._C import ops + from vllm._C import cache_ops + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -143,17 +151,35 @@ def forward( key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) + if is_hpu(): + cu_seq_lens = [0] + for i in range(len(input_metadata.prompt_lens)): + cu_seq_lens.append(cu_seq_lens[-1] + input_metadata.prompt_lens[i]) + input_metadata.cu_seq_lens = cu_seq_lens + out = xops.memory_efficient_attention_forward( + query, + key, + value, + cu_seq_lens, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + ) + output = torch.zeros_like(query) + output[:, :out.shape[1], :, :] = out + output = output.view_as(query) + else: + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) else: # Decoding run. if key_cache is not None and value_cache is not None: @@ -234,10 +260,9 @@ def _paged_attention( # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = input_metadata.max_context_len <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - # Run PagedAttention V1. - ops.paged_attention_v1( - output, + + if is_hpu(): + output = ops.paged_attention_v1( query, key_cache, value_cache, @@ -250,33 +275,49 @@ def _paged_attention( alibi_slopes, ) else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - ) + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + ) return output diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index cb3cee2bad5ad..57e65c04e4019 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,8 +4,17 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu import ops +else: + from vllm._C import ops +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm +except ImportError: + print("Not using HPU fused kernel for RMSNorm") + FusedRMSNorm = None class RMSNorm(nn.Module): """Root mean square normalization. @@ -49,6 +58,12 @@ def forward( residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: + if x.device.type == "hpu" and FusedRMSNorm: + orig_dtype = x.dtype + residual += x + x = FusedRMSNorm.apply(residual.float(), self.weight.float(), self.variance_epsilon) + return x.to(orig_dtype), residual + ops.fused_add_rms_norm( x, residual, @@ -56,6 +71,12 @@ def forward( self.variance_epsilon, ) return x, residual + + if x.device.type == "hpu" and FusedRMSNorm: + orig_dtype = x.dtype + x = FusedRMSNorm.apply(x.float(), self.weight.float(), self.variance_epsilon) + return x.to(orig_dtype) + out = torch.empty_like(x) ops.rms_norm( out, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 831576b1d7cd7..4e0a0ec51beb6 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,7 +3,11 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu import ops +else: + from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 8fe96e7ddb98d..716c6f88d9e62 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -5,7 +5,11 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu import ops +else: + from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 1932bd145076b..c3d71e4309dbb 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,11 +3,14 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.utils import is_hip +from vllm.utils import is_hip, is_hpu +if is_hpu(): + from vllm.hpu import ops +else: + from vllm._C import ops class SqueezeLLMConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 91c093e33e3c9..201a5142e6466 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,12 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu import ops + from vllm.hpu.rotary_embed import HpuRotaryEmbedding +else: + from vllm._C import ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -343,8 +348,12 @@ def get_rope( return _ROPE_DICT[key] if rope_scaling is None: - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style) + if is_hpu(): + rotary_emb = HpuRotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style) + else: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style) else: scaling_type = rope_scaling["type"] scaling_factor = rope_scaling["factor"] diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 49013ec273787..388e55ba92e67 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -146,7 +146,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. - pin_memory = not in_wsl() + pin_memory = not in_wsl() and not device.type == "hpu" prompt_max_len = max(len(tokens) for tokens in prompt_tokens) prompt_padded_tokens = [ tokens + [vocab_size] * (prompt_max_len - len(tokens)) diff --git a/vllm/utils.py b/vllm/utils.py index eff5d10fd4ee0..29bb24a5f8b56 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,13 +1,12 @@ import enum import socket import uuid +import importlib from platform import uname import psutil import torch -from vllm._C import cuda_utils - class Device(enum.Enum): GPU = enum.auto() @@ -32,6 +31,16 @@ def is_hip() -> bool: return torch.version.hip is not None +def is_hpu() -> bool: + return importlib.util.find_spec('habana_frameworks') is not None + + +if is_hpu(): + from vllm.hpu import cuda_utils +else: + from vllm._C import cuda_utils + + def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1dd0243f8f3a3..64680ac5f8fc5 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -3,10 +3,14 @@ import torch -from vllm._C import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import in_wsl +from vllm.utils import in_wsl, is_hpu + +if is_hpu(): + from vllm.hpu import cache_ops +else: + from vllm._C import cache_ops logger = init_logger(__name__) @@ -41,7 +45,10 @@ def __init__( self.num_cpu_blocks = cache_config.num_cpu_blocks # Initialize the cache. - self.gpu_cache = self.allocate_gpu_cache() + if is_hpu(): + self.gpu_cache = self.allocate_hpu_cache() + else: + self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() # Initialize the stream for caching operations. @@ -67,6 +74,28 @@ def get_value_block_shape(self) -> Tuple[int, int, int]: self.block_size, ) + def allocate_hpu_cache(self) -> List[KVCache]: + hpu_cache: List[KVCache] = [] + kv_block_shape = ( + self.num_heads, + self.head_size, + self.block_size) + for _ in range(self.num_layers): + key_blocks = [] + value_blocks = [] + key_blocks = torch.empty( + size=(self.num_gpu_blocks, *kv_block_shape), + dtype=self.dtype, + device="hpu", + ) + value_blocks = torch.empty( + size=(self.num_gpu_blocks, *kv_block_shape), + dtype=self.dtype, + device="hpu", + ) + hpu_cache.append((key_blocks, value_blocks)) + return hpu_cache + def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] key_block_shape = self.get_key_block_shape() @@ -89,7 +118,8 @@ def allocate_cpu_cache(self) -> List[KVCache]: cpu_cache: List[KVCache] = [] key_block_shape = self.get_key_block_shape() value_block_shape = self.get_value_block_shape() - pin_memory = not in_wsl() + # pin_memory = not in_wsl() + pin_memory = not in_wsl() and not is_hpu() if not pin_memory: # Pinning memory in WSL is not supported. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fb7a0c17d6f9f..89dd1f054a258 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import in_wsl +from vllm.utils import in_wsl, is_hpu logger = init_logger(__name__) @@ -307,9 +307,9 @@ def _prepare_sample( selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, - pin_memory=not self.in_wsl) + pin_memory=not is_hpu() and not self.in_wsl) categorized_sample_indices = { - t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) + t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not is_hpu() and not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() } @@ -414,7 +414,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + slot_mapping = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() # FIXME (kzawora): revert this to torch.empty after bridge bug is fixed slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8698b15721507..94cf44e5f6d6b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,6 +8,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed +from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -125,6 +126,206 @@ def warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def _prepare_inputs( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + selected_token_indices: List[int] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + + # Add prompt tokens. + prompt_lens: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + if not seq_group_metadata.is_prompt: + continue + + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + # Use any sequence in the group. + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += prompt_len - 1 + + categorized_sample_indices[sampling_params.sampling_type].append( + categorized_sample_indices_start_idx) + categorized_sample_indices_start_idx += 1 + + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(prompt_len))) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([0] * prompt_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + for i in range(prompt_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + # Add generation tokens. + max_context_len = 0 + max_num_blocks_per_seq = 0 + context_lens: List[int] = [] + generation_block_tables: List[List[int]] = [] + max_seq_len = max(prompt_lens) if prompt_lens else 1 + for seq_group_metadata in seq_group_metadata_list: + if seq_group_metadata.is_prompt: + # We need to do this in this loop as we need to know max_seq_len + assert len( + seq_ids) == 1, "Prompt input should have only one seq." + sampling_params = seq_group_metadata.sampling_params + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + prompt_len - 1)) + selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += max_seq_len + continue + + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[sampling_params.sampling_type].extend( + range(categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices_start_idx += num_seqs + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + context_len = seq_data.get_len() + position = context_len - 1 + if self.sliding_window is not None: + context_len = min(context_len, self.sliding_window) + input_positions.append([position]) + + block_table = seq_group_metadata.block_tables[seq_id] + + max_context_len = max(max_context_len, context_len) + max_num_blocks_per_seq = max(max_num_blocks_per_seq, + len(block_table)) + context_lens.append(context_len) + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + generation_block_tables.append(block_table) + + def round_up(n, multiple): + return (n + multiple - 1) // multiple * multiple + + if self.block_size is not None: + if max_seq_len != 1: + max_seq_len = round_up(max_seq_len, self.block_size) + + padded_input_tokens = [ + _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens + ] + padded_input_positions = [ + _pad_to_max(positions, max_seq_len, pad=0) + for positions in input_positions + ] + padded_slot_mapping = [ + _pad_to_max(mapping, max_seq_len, pad=-1) + for mapping in slot_mapping + ] + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq, pad=0) + for block_table in generation_block_tables + ] + + # Convert to tensors. + tokens_tensor = torch.tensor(padded_input_tokens, + dtype=torch.long, + device="cuda") + positions_tensor = torch.tensor(padded_input_positions, + dtype=torch.long, + device="cuda") + slot_mapping_tensor = torch.tensor(padded_slot_mapping, + dtype=torch.long, + device="cpu") + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cpu") + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long, + device="cuda") + categorized_sample_indices = { + t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") + for t, seq_ids in categorized_sample_indices.items() + } + block_tables_tensor = torch.tensor(padded_block_tables, + dtype=torch.int, + device="cpu") + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + slot_mapping=slot_mapping_tensor, + context_lens=context_lens_tensor, + max_context_len=max_context_len, + block_tables=block_tables_tensor, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + sliding_window=self.sliding_window, + ) + + # Create attention mask + if max_num_blocks_per_seq != 0: + attn_masks = torch.zeros((max_num_blocks_per_seq, len(input_tokens), self.block_size), dtype=torch.int64) + for i in range(0, max_num_blocks_per_seq): + for seq_id in range(len(input_tokens)): + if (i * self.block_size) < context_lens[seq_id] and (i + 1) * self.block_size > context_lens[seq_id]: + attn_masks[i][seq_id, :context_lens[seq_id] % self.block_size] = 1 + elif (i + 1) * self.block_size <= context_lens[seq_id]: + attn_masks[i][seq_id, :] = 1 + input_metadata.attention_masks = attn_masks.to(device="cuda") + return tokens_tensor, positions_tensor, input_metadata + @torch.inference_mode() def execute_model( self,