Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Another attempt at v1 HPU integration #831

Draft
wants to merge 17 commits into
base: habana_main
Choose a base branch
from
Draft
1 change: 1 addition & 0 deletions .jenkins/lm-eval-harness/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def launch_lm_eval(eval_config):
f"dtype={dtype}," \
f"max_model_len=4096," \
f"max_num_seqs={max_num_seqs}," \
f"enable_prefix_caching=False," \
f"trust_remote_code={trust_remote_code}"
if eval_config.get("fp8"):
model_args += ",quantization=inc," \
Expand Down
30 changes: 24 additions & 6 deletions .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,44 @@
stages:
- name: test_gsm8k_small_models
steps:
- name: gsm8k_small_g3_tp1
- name: v0_gsm8k_small_g3_tp1
flavor: g3
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1
- name: gsm8k_small_g3_tp2
- name: v0_gsm8k_small_g3_tp2
flavor: g3.s
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2
- name: gsm8k_small_g2_tp1
- name: v0_gsm8k_small_g2_tp1
flavor: g2
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1
- name: gsm8k_small_g2_tp2
- name: v0_gsm8k_small_g2_tp2
flavor: g2.s
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2
- name: v1_gsm8k_small_g3_tp1
flavor: g3
command: export VLLM_USE_V1=1 && export VLLM_CONTIGUOUS_PA=false && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1
- name: v1_gsm8k_small_g3_tp2
flavor: g3.s
command: export VLLM_USE_V1=1 && export VLLM_CONTIGUOUS_PA=false && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2
- name: v1_gsm8k_small_g2_tp1
flavor: g2
command: export VLLM_USE_V1=1 && export VLLM_CONTIGUOUS_PA=false && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1
- name: v1_gsm8k_small_g2_tp2
flavor: g2.s
command: export VLLM_USE_V1=1 && export VLLM_CONTIGUOUS_PA=false && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2
- name: test_gsm8k_large_models
steps:
- name: gsm8k_large_g3_tp2
- name: v0_gsm8k_large_g3_tp2
flavor: g3.s
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 2
- name: gsm8k_large_g2_tp4
- name: v0_gsm8k_large_g2_tp4
flavor: g2.m
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 4
- name: v1_gsm8k_large_g3_tp2
flavor: g3.s
command: export VLLM_USE_V1=1 && export VLLM_CONTIGUOUS_PA=false && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 2
- name: v1_gsm8k_large_g2_tp4
flavor: g2.m
command: export VLLM_USE_V1=1 && export VLLM_CONTIGUOUS_PA=false && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 4
- name: test_gsm8k_fp8
steps:
- name: gsm8k_small_g3_tp1_fp8
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3231,7 +3231,7 @@
if self.compilation_config is None:
self.compilation_config = CompilationConfig()
if envs.VLLM_USE_V1 and self.model_config is not None and \
not self.model_config.enforce_eager:
not self.model_config.enforce_eager and not current_platform.is_hpu():

Check failure on line 3234 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:3234:81: E501 Line too long (82 > 80)
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class EngineArgs:
enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
use_padding_aware_scheduling: bool = current_platform.is_hpu()
use_padding_aware_scheduling: bool = current_platform.is_hpu(
) and not bool(envs.VLLM_USE_V1)
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,8 +1416,8 @@
if use_tqdm:
pbar.close()

# Make sure that all workers are finished.
self.llm_engine.stop_remote_worker_execution_loop()
# Make sure that all workers are finished - NOTE(kzawora): this crashes on v1, why?.

Check failure on line 1419 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/llm.py:1419:81: E501 Line too long (92 > 80)
#self.llm_engine.stop_remote_worker_execution_loop()

# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
Expand Down
28 changes: 20 additions & 8 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,42 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if use_v1:
logger.info("Using HPUAttentionV1 backend.")
return "vllm.v1.attention.backends.hpu_attn.HPUAttentionBackendV1"
logger.info("Using HPUAttention backend.")
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return cls.device_name

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

scheduler_config = vllm_config.scheduler_config

parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
elif vllm_config.speculative_config:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.hpu_worker.HPUWorker"
"vllm.v1.worker.hpu_worker.HPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.hpu_worker.HPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.hpu_worker.HPUWorker"

# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
Expand Down
103 changes: 103 additions & 0 deletions vllm/v1/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0

###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
from vllm.attention.backends.hpu_attn import HPUAttentionBackend, HPUAttentionMetadata

Check failure on line 11 in vllm/v1/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/hpu_attn.py:11:81: E501 Line too long (86 > 80)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
from vllm.logger import init_logger

logger = init_logger(__name__)


class HPUAttentionBackendV1(HPUAttentionBackend):

@staticmethod
def get_name() -> str:
return "HPU_ATTN_V1"

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return HPUAttentionMetadataV1



@dataclass
class HPUAttentionMetadataV1(HPUAttentionMetadata):
"""Metadata for HPUAttentionbackend."""
is_prompt: bool
attn_bias: Optional[torch.Tensor]

seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]

@classmethod
def make_prefill_metadata(cls, seq_lens_tensor, num_prefills,
num_prefill_tokens, slot_mapping):
return cls(is_prompt=True,
block_list=None,
block_mapping=None,
block_usage=None,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=None,
attn_bias=None,
num_decode_tokens=0,
context_lens_tensor=None,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
slot_mapping=slot_mapping,
enable_kv_scales_calculation=False)

@classmethod
def make_cached_prefill_metadata(cls, seq_lens_tensor, context_lens_tensor,
num_prefills, num_prefill_tokens,
slot_mapping, block_list):
return cls(is_prompt=True,
block_list=block_list,
block_mapping=None,
block_usage=None,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=None,
attn_bias=None,
num_decode_tokens=0,
context_lens_tensor=context_lens_tensor,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
slot_mapping=slot_mapping,
enable_kv_scales_calculation=False)

@classmethod
def make_decode_metadata(cls, block_list, block_usage, block_groups,
num_decode_tokens, slot_mapping):
return cls(is_prompt=False,
block_mapping=None,
block_indices=None,
block_offsets=None,
block_scales=None,
attn_bias=None,
seq_lens_tensor=None,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
multi_modal_placeholder_index_maps=None,
block_list=block_list,
block_usage=block_usage,
block_groups=block_groups,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
enable_kv_scales_calculation=False)
4 changes: 3 additions & 1 deletion vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple

from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
Expand Down Expand Up @@ -50,8 +51,9 @@ def __init__(
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)

# A Block pool of all kv-cache blocks.
start_block_id = 0 if not current_platform.is_hpu() else 1
self.block_pool: List[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
KVCacheBlock(idx) for idx in range(start_block_id, num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)

# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
sampled = sampled # .to(torch.int32)

# These are GPU tensors.
sampler_output = SamplerOutput(
Expand Down
64 changes: 64 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,70 @@ def make_sampling_metadata(
no_penalties=self.no_penalties,
)

def make_selective_sampling_metadata(
self,
req_id_output_token_ids: Tuple[str, List[int]],
skip_copy: bool = False,
) -> SamplingMetadata:
req_indices = [self.req_id_to_index[req_id[0]] for req_id in req_id_output_token_ids]
if not skip_copy:
self.temperature[req_indices].copy_(
self.temperature_cpu_tensor[req_indices], non_blocking=True)
self.top_p[req_indices].copy_(
self.top_p_cpu_tensor[req_indices], non_blocking=True)
self.top_k[req_indices].copy_(
self.top_k_cpu_tensor[req_indices], non_blocking=True)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
self.frequency_penalties[req_indices].copy_(
self.frequency_penalties_cpu_tensor[req_indices],
non_blocking=True)
self.presence_penalties[req_indices].copy_(
self.presence_penalties_cpu_tensor[req_indices],
non_blocking=True)
self.repetition_penalties[req_indices].copy_(
self.repetition_penalties_cpu_tensor[req_indices],
non_blocking=True)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
self.prompt_token_ids = self._make_prompt_token_ids_tensor()

output_token_ids: List[List[int]] = []

for req_id, output_tokens in req_id_output_token_ids:
assert req_id is not None
# Currently we create a tensor for output_token_ids from scratch
# at each step. However, for the penalties computation what we
# need is stats about the token ids present in the output. This
# stats can be maintained incrementally instead of computing it
# from scratch at each step.
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids.append(output_tokens)

return SamplingMetadata(
temperature=self.temperature[req_indices],
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=self.top_p[req_indices],
top_k=self.top_k[req_indices],
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=self.prompt_token_ids,
frequency_penalties=self.frequency_penalties[req_indices],
presence_penalties=self.presence_penalties[req_indices],
repetition_penalties=self.repetition_penalties[req_indices],
output_token_ids=output_token_ids,
min_tokens=[self.min_tokens[req_idx] for req_idx in req_indices],
stop_token_ids=[self.stop_token_ids[req_idx] for req_idx in req_indices],
no_penalties=self.no_penalties,
)

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
Expand Down
Loading
Loading