From 52b14518b1b668b825a3edaec3edf5720046e73d Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Wed, 10 Jul 2024 13:46:59 +0200 Subject: [PATCH 01/14] Use flat block layout for PA (#92) * Cleanup AttentionMetadata on HPU * Flat PA - POC * Decode warmup overhaul * Fix input_hash calculation * Block bucket size 32 -> 16 * Improve host time * Skip UTs * Add GQA/MQA * Add mask instead of filling * 2d block mapping * Optional flipping in PA * Runner updated for 2d block mapping * Eliminate physical transposes * POC: build block_bias on device * Cleanup * Fix seq_len calculation * Experimental profiling * Add missing call to kv_matmul_op * Fix block_usage calculation * Change default block bucket step for decode to 128 * Fix max decode block bucket calculation * Fix block_usage calculations * Cleanup * Print values for bucketing vars * Pass block size do HpuModelAdapter --------- Co-authored-by: barak goldberg <149692267+bgoldberg-habana@users.noreply.github.com> --- vllm/attention/backends/habana_attn.py | 139 ++++------- vllm/attention/ops/habana_paged_attn.py | 53 +--- vllm/hpu/ops.py | 119 +++++---- vllm/hpu/utils.py | 7 +- vllm/worker/habana_model_runner.py | 310 +++++++++++++----------- 5 files changed, 269 insertions(+), 359 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 2259630fa10b7..16922bb034335 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -57,59 +57,15 @@ def copy_blocks( HabanaPagedAttention.copy_blocks(kv_caches, src_to_dists) -@dataclass -class HabanaAttentionMetadata(AttentionMetadata, HabanaPagedAttentionMetadata): - """Metadata for HabanaAttentionbackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ +@dataclass(frozen=True) +class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata): + """Metadata for HabanaAttentionbackend.""" # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. + attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. - max_query_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[torch.Tensor] = None - class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ @@ -229,60 +185,49 @@ def forward( if attn_metadata.is_prompt: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: - if not self.prefill_usefusedsdpa: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) - else: - attn_bias = None - - query_shape = (batch_size, seq_len, self.num_heads, - self.head_size) - kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, - self.head_size) - out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - softmax_op=self.softmax, - matmul_av_op=self.matmul_av, - valid_seq_lengths=attn_metadata.seq_lens_tensor, - ) - output = out.reshape(batch_size, seq_len, hidden_size) + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None and \ + self.position_bias is not None: + attn_bias.add_(self.position_bias[:, :, + -attn_bias.size(2):, + -attn_bias.size(3):]) else: - # prefix-enabled attention - output = HabanaPagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.seq_lens_tensor, - attn_metadata.context_lens_tensor, - attn_metadata.max_query_len, - self.alibi_slopes, - ) + attn_bias = None + + query_shape = (batch_size, seq_len, self.num_heads, + self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + ) + output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HabanaPagedAttention.forward_decode( - query, key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, self.kv_cache_dtype, - self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale, self.matmul_qk, self.softmax, self.matmul_av, - self.k_cache, self.v_cache) + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 9602886299c47..b5e74b74109a4 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -13,19 +13,12 @@ _PARTITION_SIZE = 512 -@dataclass +@dataclass(frozen=True) class HabanaPagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of sequences (entire tokens seen so far) per - # sequence. - seq_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] class HabanaPagedAttention: @@ -63,42 +56,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, slot_mapping, kv_cache_dtype, is_prompt) @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) -> torch.Tensor: - block_size = value_cache.shape[1] - return ops.paged_attention_v1( - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - alibi_slopes, - kv_cache_dtype, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) @staticmethod def forward_prefix( diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 746e87dad4aea..c5457e2b3d2dd 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -29,72 +29,63 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") -PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') - - -def fetch_from_cache(cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] - - -def paged_attention_v1(query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - alibi_slopes=None, - kv_cache_dtype=None, - matmul_qk_op=torch.matmul, - softmax_op=torch.softmax, - matmul_av_op=torch.matmul, - k_cache_cls=None, - v_cache_cls=None) -> None: - seq_len = block_tables.size(1) - batch_size, query_heads, _ = query.shape - _, _, kv_heads, _ = key_cache.shape - min_inf = torch.finfo(query.dtype).min - mask = (torch.arange(0, - seq_len * block_size, - dtype=torch.int32, - device=key_cache.device).view(1, -1).expand( - batch_size, -1).ge(context_lens.view(-1, 1)).view( - batch_size, 1, 1, -1)) - query.mul_(scale) - query = query.unsqueeze(-2) - fetch_keys = fetch_from_cache if k_cache_cls is None else \ - k_cache_cls.fetch_from_cache - keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) - if query_heads != kv_heads: + +def batch2block(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block2batch(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block_softmax(batch_size, attn, block_mapping): + attn = attn.exp_() + sums = attn.sum(dim=-1).unsqueeze(-1) + sums = block2batch(sums, block_mapping) + sums = batch2block(sums, block_mapping) + attn.div_(sums) + return attn + + +def flat_pa(query, + key_cache, + value_cache, + block_list, + block_mapping, + block_bias, + scale, + matmul_qk_op, + matmul_av_op, + keys_fetch_func, + values_fetch_func): + batch_size = query.size(0) + q_heads = query.size(1) + kv_heads = key_cache.size(2) + + query = batch2block(scale * query, block_mapping).unsqueeze(-2) + key = keys_fetch_func(key_cache, block_list).transpose(1, 2) + value = values_fetch_func(value_cache, block_list).transpose(1, 2) + block_bias = block_bias.view(key.size(0), 1, 1, -1) + + if kv_heads != q_heads: + block_bias = block_bias.unsqueeze(1) query = query.unflatten(1, (kv_heads, -1)) - keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] - mask = mask.unsqueeze(2) - - attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) - if alibi_slopes is not None: - attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, - -attn_weights.size(3):]) - attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - - fetch_values = fetch_from_cache if v_cache_cls is None else \ - v_cache_cls.fetch_from_cache - values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) - if PA_SPLIT_VALUE: - attn_weights = attn_weights.split(block_size, dim=-1) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + key = key.transpose(3, 4) else: - values = [torch.cat(values, dim=-2)] - attn_weights = [attn_weights] - if query_heads != kv_heads: - values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] - if query_heads != kv_heads: - attn_weights = [a.flatten(1, 2) for a in attn_weights] - attn_weights = sum(attn_weights) - return attn_weights.squeeze(-2) + key = key.transpose(2, 3) + + attn = matmul_qk_op(query, key) + block_bias + attn = block_softmax(batch_size, attn, block_mapping) + attn = matmul_av_op(attn, value) + attn = block2batch(attn, block_mapping) + attn = attn.squeeze(-2) + if kv_heads != q_heads: + attn = attn.flatten(1, 2) + return attn def silu_and_mul(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 3d9c7cb1c4c22..13204b83d5742 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -57,8 +57,5 @@ def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_offset) return cache - def fetch_from_cache(self, cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] + def fetch_from_cache(self, cache, blocks): + return cache.index_select(0, blocks) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a9a3f35d3934b..90f2ad1bb528d 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -10,6 +10,7 @@ import operator import os import time +import sys from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -48,9 +49,19 @@ logger = init_logger(__name__) +_TYPE_CACHE = {} _PAD_SLOT_ID = 0 LORA_WARMUP_RANK = 8 -_TYPE_CACHE = {} + + +def subtuple(obj: object, typename: str, to_copy: List[str], to_override: Dict[str, object] = {}): + if obj is None: + return None + fields = set(to_copy) | set(to_override.keys()) + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, ' '.join(fields)) + return _TYPE_CACHE[typename](**values) def read_bucket_settings(phase: str, dim: str, **defaults): @@ -62,11 +73,11 @@ def read_bucket_settings(phase: str, dim: str, **defaults): example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ params = ['min', 'step', 'max'] - values = [ - int( - os.environ.get(f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper(), - defaults[p])) for p in params - ] + env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] + defaults = [defaults[p] for p in params] + values = [int(os.environ.get(e, d)) for e, d in zip(env_vars, defaults)] + for e, v, d in zip(env_vars, values, defaults): + logger.info(f'{e}={v} (default:{d})') return values @@ -96,9 +107,7 @@ def warmup_range(config: Tuple[int, int, int]): return list(filter(lambda bucket: bucket >= bmin, buckets)) -def warmup_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): +def generate_prompt_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens=None): buckets = list( itertools.product(warmup_range(bs_bucket_config), warmup_range(seq_bucket_config))) @@ -108,6 +117,7 @@ def warmup_buckets(bs_bucket_config, f"bs:{bs_bucket_config}, " f"seq:{seq_bucket_config}") raise ValueError(msg) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) filtered_buckets = buckets if max_num_batched_tokens is not None: @@ -143,6 +153,18 @@ def warmup_buckets(bs_bucket_config, return captured_buckets, omitted_buckets +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, max_blocks): + buckets = [] + for bs in warmup_range(bs_bucket_config): + for blocks in warmup_range(blocks_bucket_config): + if blocks < bs: + continue + if blocks > max_blocks: + break + buckets.append((bs, blocks)) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + + def next_pow2(value: int, base: int): res = base while value > 1: @@ -162,22 +184,6 @@ def find_bucket(value: int, config: Tuple[int, int, int]): return max(bmin, min(next_step, next_pow)) -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): - if to_override is None: - to_override = {} - if obj is None: - return None - fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} - if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) - return _TYPE_CACHE[typename](**values) - - def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -188,13 +194,19 @@ def align_workers(value, op): return value_t.item() +def pad_list(l, k, v): + target_len = round_up(len(l), k) + padding = target_len - len(l) + return l + [v] * padding + + class HpuModelAdapter(): - def __init__(self, model, enforce_eager): + def __init__(self, model, block_size, enforce_eager): self.model = model self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] - + self.block_size = block_size if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', @@ -216,20 +228,36 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype=torch.bool), diagonal=1) mask = causal_mask.logical_or(len_mask) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf)) - #FIXME: Restore sliding window support - #if self.sliding_window is not None: + attn_bias = (torch.zeros_like(mask, dtype=dtype) + .masked_fill_(mask, -math.inf)) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata + def _set_block_mapping(self, metadata, batch_size, device, dtype): + mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= metadata.block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype) + .masked_fill_(mask, -math.inf)) + block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, num_classes=batch_size).to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) + return metadata + + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): + if attn_metadata.is_prompt: + meta=attn_metadata + attn_metadata=self._set_attn_bias(meta, batch_size, seq_len, device, dtype) + else: + meta=attn_metadata + attn_metadata=self._set_block_mapping(meta, batch_size, device, dtype) + return attn_metadata + def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], + kwargs['attn_metadata'] = self._update_metadata(kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), input_ids.device, @@ -529,7 +557,7 @@ def load_model(self) -> None: # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph( - self.model, enforce_eager=self.enforce_eager) + self.model, self.block_size, enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) @@ -546,75 +574,47 @@ def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) max_bucket_cfg = 64 if self.lora_config and \ max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size + blocks_step = 128 + max_prompt_seq = 1024 + max_decode_seq = 2048 self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, - step=32, - max=min( - self.max_num_seqs, - max_bucket_cfg)) + step=align_bs(32), + max=align_bs(max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', - min=1, - step=128, + min=align_bs(32), + step=align_bs(32), max=self.max_num_seqs) self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', 'seq', min=self.block_size, step=self.block_size, - max=1024) - self.decode_seq_bucket_cfg = read_bucket_settings('decode', - 'seq', - min=self.block_size, - step=self.block_size, - max=2048) + max=max_prompt_seq) + self.decode_block_bucket_cfg = read_bucket_settings('decode', + 'block', + min=blocks_step, + step=blocks_step, + max=max(blocks_step, self.max_num_seqs * max_decode_seq // self.block_size)) self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - if self.lora_config: - self.prompt_buckets[:] = [ - bucket for bucket in self.prompt_buckets - if self._is_valid_bucket(bucket) - ] - - msg = (f"Generated {len(self.prompt_buckets)} " - f"prompt buckets: {list(sorted(self.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " - f"seq:{self.decode_seq_bucket_cfg}") - logger.info(msg) - self.decode_buckets, _ = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) - if self.lora_config: - self.decode_buckets[:] = [ - bucket for bucket in self.decode_buckets - if self._is_valid_bucket(bucket) - ] - msg = (f"Generated {len(self.decode_buckets)} decode buckets: " - f"{list(sorted(self.decode_buckets))}") + f"seq:{self.decode_block_bucket_cfg}") logger.info(msg) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -728,10 +728,6 @@ def _prepare_prompt( real_num_seqs = len(query_lens) assert max_query_len > 0 - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - if multi_modal_input_list: assert self.multimodal_config, ( "Multi-modal inputs are only supported by " @@ -741,7 +737,6 @@ def _prepare_prompt( else: multi_modal_input = None - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) @@ -808,37 +803,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - block_tables = make_tensor_with_pad(prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - seq_lens=seq_lens, + block_list=None, + block_mapping=None, + block_usage=None, + attn_bias=None, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, @@ -892,7 +867,7 @@ def _prepare_decode( assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) - lora_id = seq_group_metadata.lora_int_id + #lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) @@ -904,7 +879,7 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 @@ -919,8 +894,8 @@ def _prepare_decode( block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) - lora_index_mapping.append(lora_id) - lora_prompt_mapping.append(lora_id) + #lora_index_mapping.append(lora_id) + #lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // @@ -933,36 +908,41 @@ def _prepare_decode( lora_logits_mask = lora_mask input_tokens = torch.tensor(input_tokens, dtype=torch.long, - device=self.device) + device=self.device).unsqueeze(-1) input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) num_decode_tokens = sum(seq_lens) - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + + blocks_used = [len(bt) for bt in block_tables] + block_list = list(itertools.chain(*block_tables)) + block_mapping = [[i] * bu for i, bu in enumerate(blocks_used)] + block_mapping = list(itertools.chain(*block_mapping)) + + last_block = [sl % self.block_size for sl in itertools.chain(*slot_mapping)] + block_usage = [[self.block_size] * (bu - 1) + [lb] for bu, lb in zip(blocks_used, last_block)] + block_usage = list(itertools.chain(*block_usage)) + + block_bucket_size = self.decode_block_bucket_cfg[1] + block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) + block_mapping = pad_list(block_mapping, block_bucket_size, 0) + block_usage = pad_list(block_usage, block_bucket_size, 0) + + block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) + block_mapping = torch.tensor(block_mapping, dtype=torch.int, device=self.device) + block_usage = torch.tensor(block_usage, dtype=torch.bfloat16, device=self.device) + + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=False, + block_list=block_list, + block_mapping=block_mapping, + block_usage=block_usage, + attn_bias=None, + seq_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, @@ -1150,7 +1130,7 @@ def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) else: - return attn_metadata.block_tables.size(1) * self.block_size + return attn_metadata.block_list.numel() def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: @@ -1174,7 +1154,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'block_tables', 'seq_lens_tensor', 'attn_bias', 'slot_mapping', + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt' ]) return attention_metadata @@ -1264,21 +1244,28 @@ def warmup_scenario(self, [0] * batch_size * seq_len, ) self.set_active_loras(set(), lora_mapping) - seqs = [ - self.create_dummy_seq_group_metadata( - i, - seq_len, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) - for i in range(batch_size) - ] + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [self.create_dummy_seq_group_metadata(i, b * self.block_size - 1, is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) for i, b in enumerate(blocks)] torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() - self.profiler.end() gc.collect() def remove_all_loras(self): @@ -1377,6 +1364,8 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) + if num_candidates == 0: + num_candidates = 1 msg = (f'{phase} captured:{len(graphed)} ' f'({100 * len(graphed) / num_candidates:.1f}%) ' f'used_mem:{format_bytes(total_mem)} ' @@ -1389,6 +1378,35 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info("Skipping warmup...") return self.profiler.start('internal', 'warmup') + max_blocks = kv_caches[0][0].size(0) + + self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets(self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, self.max_num_batched_tokens) + if self.lora_config: + self.prompt_buckets[:] = [ + bucket for bucket in self.prompt_buckets + if self._is_valid_bucket(bucket) + ] + + msg = (f"Generated {len(self.prompt_buckets)} " + f"prompt buckets: {list(sorted(self.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + self.decode_buckets = generate_decode_buckets(self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, max_blocks) + if self.lora_config: + self.decode_buckets[:] = [ + bucket for bucket in self.decode_buckets + if self._is_valid_bucket(bucket) + ] + logger.info(f"Generated {len(self.decode_buckets)} decode buckets: {list(sorted(self.decode_buckets))}") + start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) From 0112aa3e4b83c0b67bdcf519ab7919c9237e9e01 Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Thu, 11 Jul 2024 14:39:51 +0200 Subject: [PATCH 02/14] Fix block_usage calculation (#96) --- vllm/worker/habana_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 90f2ad1bb528d..278ad646f968a 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -919,7 +919,7 @@ def _prepare_decode( block_mapping = [[i] * bu for i, bu in enumerate(blocks_used)] block_mapping = list(itertools.chain(*block_mapping)) - last_block = [sl % self.block_size for sl in itertools.chain(*slot_mapping)] + last_block = [sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)] block_usage = [[self.block_size] * (bu - 1) + [lb] for bu, lb in zip(blocks_used, last_block)] block_usage = list(itertools.chain(*block_usage)) From 965f25eb98293cc16bc94900e2847790643462cd Mon Sep 17 00:00:00 2001 From: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com> Date: Wed, 17 Jul 2024 07:28:17 +0200 Subject: [PATCH 03/14] WA for numerically unstable block_softmax (#104) --- vllm/hpu/ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index c5457e2b3d2dd..6f29b31b4c910 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -41,10 +41,12 @@ def block2batch(tensor, block_mapping): def block_softmax(batch_size, attn, block_mapping): + attn.sub_(10.0) attn = attn.exp_() sums = attn.sum(dim=-1).unsqueeze(-1) sums = block2batch(sums, block_mapping) sums = batch2block(sums, block_mapping) + sums.add_(1.0e-12) attn.div_(sums) return attn From e66fc0b36d890a1ea29d8ad55f342d1d431e7885 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 23 Jul 2024 16:10:05 +0200 Subject: [PATCH 04/14] Fix finding proper block buckets (#119) --- vllm/worker/habana_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 278ad646f968a..816be395302ca 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -923,7 +923,7 @@ def _prepare_decode( block_usage = [[self.block_size] * (bu - 1) + [lb] for bu, lb in zip(blocks_used, last_block)] block_usage = list(itertools.chain(*block_usage)) - block_bucket_size = self.decode_block_bucket_cfg[1] + block_bucket_size = find_bucket(len(block_list), self.decode_block_bucket_cfg) block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) block_mapping = pad_list(block_mapping, block_bucket_size, 0) block_usage = pad_list(block_usage, block_bucket_size, 0) From 8525b223881f55206761b7038e0a2b55ac53328a Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Mon, 19 Aug 2024 23:41:10 +0300 Subject: [PATCH 05/14] Apply formatting --- vllm/attention/backends/habana_attn.py | 9 +- vllm/attention/ops/habana_paged_attn.py | 2 +- vllm/hpu/ops.py | 13 +- vllm/worker/habana_model_runner.py | 162 +++++++++++++++--------- 4 files changed, 110 insertions(+), 76 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 16922bb034335..20b0f2bc7630b 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -57,7 +57,7 @@ def copy_blocks( HabanaPagedAttention.copy_blocks(kv_caches, src_to_dists) -@dataclass(frozen=True) +@dataclass class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata): """Metadata for HabanaAttentionbackend.""" # Currently, input sequences can only contain all prompts @@ -193,13 +193,12 @@ def forward( if self.alibi_slopes is not None and \ self.position_bias is not None: attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) + -attn_bias.size(2):, + -attn_bias.size(3):]) else: attn_bias = None - query_shape = (batch_size, seq_len, self.num_heads, - self.head_size) + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) out = ops.prompt_attention( diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index b5e74b74109a4..cab8d7abe95fd 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -13,7 +13,7 @@ _PARTITION_SIZE = 512 -@dataclass(frozen=True) +@dataclass class HabanaPagedAttentionMetadata: """Metadata for PagedAttention.""" block_list: Optional[torch.Tensor] diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 6f29b31b4c910..6e47e1cd66a73 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. ############################################################################### -import os from typing import Optional import habana_frameworks.torch as htorch @@ -51,16 +50,8 @@ def block_softmax(batch_size, attn, block_mapping): return attn -def flat_pa(query, - key_cache, - value_cache, - block_list, - block_mapping, - block_bias, - scale, - matmul_qk_op, - matmul_av_op, - keys_fetch_func, +def flat_pa(query, key_cache, value_cache, block_list, block_mapping, + block_bias, scale, matmul_qk_op, matmul_av_op, keys_fetch_func, values_fetch_func): batch_size = query.size(0) q_heads = query.size(1) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 816be395302ca..aae2403a3ec8a 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -10,7 +10,6 @@ import operator import os import time -import sys from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -54,13 +53,19 @@ LORA_WARMUP_RANK = 8 -def subtuple(obj: object, typename: str, to_copy: List[str], to_override: Dict[str, object] = {}): +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): if obj is None: return None + if to_override is None: + to_override = {} fields = set(to_copy) | set(to_override.keys()) values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, ' '.join(fields)) + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) return _TYPE_CACHE[typename](**values) @@ -74,10 +79,12 @@ def read_bucket_settings(phase: str, dim: str, **defaults): """ params = ['min', 'step', 'max'] env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] - defaults = [defaults[p] for p in params] - values = [int(os.environ.get(e, d)) for e, d in zip(env_vars, defaults)] + default_values = [defaults[p] for p in params] + values = [ + int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) + ] for e, v, d in zip(env_vars, values, defaults): - logger.info(f'{e}={v} (default:{d})') + logger.info('%s=%s (default:%s)', e, v, d) return values @@ -107,7 +114,9 @@ def warmup_range(config: Tuple[int, int, int]): return list(filter(lambda bucket: bucket >= bmin, buckets)) -def generate_prompt_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens=None): +def generate_prompt_buckets(bs_bucket_config, + seq_bucket_config, + max_num_batched_tokens=None): buckets = list( itertools.product(warmup_range(bs_bucket_config), warmup_range(seq_bucket_config))) @@ -117,7 +126,6 @@ def generate_prompt_buckets(bs_bucket_config, seq_bucket_config, max_num_batched f"bs:{bs_bucket_config}, " f"seq:{seq_bucket_config}") raise ValueError(msg) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) filtered_buckets = buckets if max_num_batched_tokens is not None: @@ -153,7 +161,8 @@ def generate_prompt_buckets(bs_bucket_config, seq_bucket_config, max_num_batched return captured_buckets, omitted_buckets -def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, max_blocks): +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, + max_blocks): buckets = [] for bs in warmup_range(bs_bucket_config): for blocks in warmup_range(blocks_bucket_config): @@ -194,10 +203,10 @@ def align_workers(value, op): return value_t.item() -def pad_list(l, k, v): - target_len = round_up(len(l), k) - padding = target_len - len(l) - return l + [v] * padding +def pad_list(list, k, v): + target_len = round_up(len(list), k) + padding = target_len - len(list) + return list + [v] * padding class HpuModelAdapter(): @@ -228,27 +237,35 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype=torch.bool), diagonal=1) mask = causal_mask.logical_or(len_mask) - attn_bias = (torch.zeros_like(mask, dtype=dtype) - .masked_fill_(mask, -math.inf)) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata def _set_block_mapping(self, metadata, batch_size, device, dtype): - mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) mask = mask >= metadata.block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype) - .masked_fill_(mask, -math.inf)) - block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, num_classes=batch_size).to(dtype) - metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + block_mapping = torch.nn.functional.one_hot( + metadata.block_mapping, num_classes=batch_size).to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) return metadata - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): if attn_metadata.is_prompt: - meta=attn_metadata - attn_metadata=self._set_attn_bias(meta, batch_size, seq_len, device, dtype) + meta = attn_metadata + attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, + device, dtype) else: - meta=attn_metadata - attn_metadata=self._set_block_mapping(meta, batch_size, device, dtype) + meta = attn_metadata + attn_metadata = self._set_block_mapping(meta, batch_size, device, + dtype) return attn_metadata def forward(self, *args, **kwargs): @@ -257,11 +274,9 @@ def forward(self, *args, **kwargs): if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._update_metadata(kwargs['attn_metadata'], - input_ids.size(0), - input_ids.size(1), - input_ids.device, - torch.bfloat16) + kwargs['attn_metadata'] = self._update_metadata( + kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), + input_ids.device, torch.bfloat16) LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -557,7 +572,9 @@ def load_model(self) -> None: # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph( - self.model, self.block_size, enforce_eager=self.enforce_eager) + self.model, + self.block_size, + enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) @@ -582,11 +599,12 @@ def _setup_buckets(self) -> None: blocks_step = 128 max_prompt_seq = 1024 max_decode_seq = 2048 - self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', - 'bs', - min=1, - step=align_bs(32), - max=align_bs(max_bucket_cfg)) + self.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=align_bs(max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=align_bs(32), @@ -597,11 +615,13 @@ def _setup_buckets(self) -> None: min=self.block_size, step=self.block_size, max=max_prompt_seq) - self.decode_block_bucket_cfg = read_bucket_settings('decode', - 'block', - min=blocks_step, - step=blocks_step, - max=max(blocks_step, self.max_num_seqs * max_decode_seq // self.block_size)) + self.decode_block_bucket_cfg = read_bucket_settings( + 'decode', + 'block', + min=blocks_step, + step=blocks_step, + max=max(blocks_step, + self.max_num_seqs * max_decode_seq // self.block_size)) self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " @@ -614,7 +634,6 @@ def _setup_buckets(self) -> None: f"seq:{self.decode_block_bucket_cfg}") logger.info(msg) - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -910,27 +929,40 @@ def _prepare_decode( dtype=torch.long, device=self.device).unsqueeze(-1) input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) num_decode_tokens = sum(seq_lens) blocks_used = [len(bt) for bt in block_tables] block_list = list(itertools.chain(*block_tables)) - block_mapping = [[i] * bu for i, bu in enumerate(blocks_used)] - block_mapping = list(itertools.chain(*block_mapping)) + block_mapping_nested: List[List[int]] = [ + [i] * b_u for i, b_u in enumerate(blocks_used) + ] + block_mapping: List[int] = list( + itertools.chain.from_iterable(block_mapping_nested)) - last_block = [sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)] - block_usage = [[self.block_size] * (bu - 1) + [lb] for bu, lb in zip(blocks_used, last_block)] + last_block = [ + sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) + ] + block_usage = [[self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block)] block_usage = list(itertools.chain(*block_usage)) - block_bucket_size = find_bucket(len(block_list), self.decode_block_bucket_cfg) + block_bucket_size = find_bucket(len(block_list), + self.decode_block_bucket_cfg) block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) block_mapping = pad_list(block_mapping, block_bucket_size, 0) block_usage = pad_list(block_usage, block_bucket_size, 0) - block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) - block_mapping = torch.tensor(block_mapping, dtype=torch.int, device=self.device) - block_usage = torch.tensor(block_usage, dtype=torch.bfloat16, device=self.device) + block_list = torch.tensor(block_list, + dtype=torch.int, + device=self.device) + block_mapping = torch.tensor(block_mapping, + dtype=torch.int, + device=self.device) + block_usage = torch.tensor(block_usage, + dtype=torch.bfloat16, + device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, @@ -1154,8 +1186,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', - 'is_prompt' + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', + 'block_usage', 'slot_mapping', 'is_prompt' ]) return attention_metadata @@ -1258,9 +1290,15 @@ def warmup_scenario(self, # FIXME: seq_len is actually number of blocks blocks = [seq_len // batch_size for _ in range(batch_size)] blocks[0] += seq_len % batch_size - seqs = [self.create_dummy_seq_group_metadata(i, b * self.block_size - 1, is_prompt, + seqs = [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size - 1, + is_prompt, lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) for i, b in enumerate(blocks)] + if dummy_lora_requests_per_seq else None) + for i, b in enumerate(blocks) + ] torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) @@ -1380,7 +1418,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) - self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets(self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, self.max_num_batched_tokens) + self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ bucket for bucket in self.prompt_buckets @@ -1399,13 +1439,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" logger.debug(msg) - self.decode_buckets = generate_decode_buckets(self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, max_blocks) + self.decode_buckets = generate_decode_buckets( + self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, + max_blocks) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets if self._is_valid_bucket(bucket) ] - logger.info(f"Generated {len(self.decode_buckets)} decode buckets: {list(sorted(self.decode_buckets))}") + logger.info("Generated %d decode buckets: %s", + len(self.decode_buckets), + list(sorted(self.decode_buckets))) start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() From b97d8448b532ef05e8e81e2656f98e1f5e55a25c Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Thu, 22 Aug 2024 01:14:47 +0300 Subject: [PATCH 06/14] Uncomment LoRA lines --- vllm/worker/habana_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index aae2403a3ec8a..59c28cad1f6f7 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -886,7 +886,7 @@ def _prepare_decode( assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) - #lora_id = seq_group_metadata.lora_int_id + lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) @@ -913,8 +913,8 @@ def _prepare_decode( block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) - #lora_index_mapping.append(lora_id) - #lora_prompt_mapping.append(lora_id) + lora_index_mapping.append(lora_id) + lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // From f8d9048112dcbdbca5d98b7c747c681858e6dd8e Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Thu, 22 Aug 2024 01:15:59 +0300 Subject: [PATCH 07/14] Cast block_mapping to long for one_hot --- vllm/worker/habana_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 59c28cad1f6f7..313e00c0735e2 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -251,7 +251,8 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) block_mapping = torch.nn.functional.one_hot( - metadata.block_mapping, num_classes=batch_size).to(dtype) + metadata.block_mapping.to(torch.long), + num_classes=batch_size).to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata From 13979801090e220930d4ef0ab7511bbbac95cecd Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Fri, 6 Sep 2024 14:01:37 +0300 Subject: [PATCH 08/14] Adjust logs messages and README to new flat-PA --- README_GAUDI.md | 22 +++++++++---------- .../getting_started/gaudi-installation.rst | 14 ++++++------ vllm/worker/habana_model_runner.py | 15 +++++++------ 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/README_GAUDI.md b/README_GAUDI.md index 91bcbe49405eb..5109f7ddf9927 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -455,12 +455,12 @@ Environment variables - `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - `{phase}` is either `PROMPT` or `DECODE` - - `{dim}` is either `BS` or `SEQ` + - `{dim}` is either `BS`, `SEQ` or `BLOCK` - `{param}` is either `MIN`, `STEP` or `MAX` - Default values: - Prompt: - batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1` - - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32` + - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)` - batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`): `min(max_num_seqs, 64)` - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): @@ -468,20 +468,20 @@ Environment variables - sequence length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): - `1024` + `max_model_len` - Decode: - - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` + - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)` - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): - `128` + `min(max_num_seqs, 32)` - batch size max (`VLLM_DECODE_BS_BUCKET_MAX`): `max_num_seqs` - - sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`): - `block_size` - - sequence length step - (`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size` - - sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`): - `2048` + - block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`): + `128` + - block size step + (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128` + - block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`): + `max(128, (max_num_seqs*max_model_len)/block_size)` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index b3234d10b3115..ed3beabb2c8aa 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -335,19 +335,19 @@ Environment variables - Prompt: - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len`` - Decode: - - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` - - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size`` - - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048`` + - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128`` + - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128`` + - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 313e00c0735e2..1b2b3ed4ffe53 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -73,7 +73,7 @@ def read_bucket_settings(phase: str, dim: str, **defaults): """Read bucketing configuration from env variables. phase is either 'prompt' or 'decode' - dim is either 'bs' or 'block' + dim is either 'bs', 'seq' or 'block' param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ @@ -598,8 +598,8 @@ def _setup_buckets(self) -> None: max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size blocks_step = 128 - max_prompt_seq = 1024 - max_decode_seq = 2048 + max_prompt_seq = self.max_model_len + max_decode_seq = self.max_model_len self.prompt_bs_bucket_cfg = read_bucket_settings( 'prompt', 'bs', @@ -632,7 +632,7 @@ def _setup_buckets(self) -> None: msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " - f"seq:{self.decode_block_bucket_cfg}") + f"block:{self.decode_block_bucket_cfg}") logger.info(msg) def _prepare_prompt( @@ -1428,8 +1428,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: if self._is_valid_bucket(bucket) ] - msg = (f"Generated {len(self.prompt_buckets)} " - f"prompt buckets: {list(sorted(self.prompt_buckets))}") + msg = ( + f"Generated {len(self.prompt_buckets)} " + f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") logger.info(msg) msg = (f"Omitted {len(prompt_omitted_buckets)} " @@ -1448,7 +1449,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: bucket for bucket in self.decode_buckets if self._is_valid_bucket(bucket) ] - logger.info("Generated %d decode buckets: %s", + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", len(self.decode_buckets), list(sorted(self.decode_buckets))) From 0440fb2ef3daf32c6bf9a236c468bf45573bddd4 Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Fri, 6 Sep 2024 14:40:29 +0300 Subject: [PATCH 09/14] Set warmup_mode to False --- vllm/worker/habana_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 1b2b3ed4ffe53..c66ebeba9ffb1 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1303,7 +1303,7 @@ def warmup_scenario(self, torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + self.execute_model(inputs, kv_caches, warmup_mode=False) torch.hpu.synchronize() gc.collect() From 9916b6b8ff833652911a6d9cae15579dae0b16c2 Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Fri, 6 Sep 2024 17:09:01 +0300 Subject: [PATCH 10/14] Remove unsqueeze --- vllm/worker/habana_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c66ebeba9ffb1..7763eb95b7b52 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -899,7 +899,7 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) + input_tokens.append([generation_token]) seq_len = seq_data.get_len() position = seq_len - 1 @@ -928,7 +928,7 @@ def _prepare_decode( lora_logits_mask = lora_mask input_tokens = torch.tensor(input_tokens, dtype=torch.long, - device=self.device).unsqueeze(-1) + device=self.device) input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) From 063284670bb67cbe8dd94ae9f17fafb86922ff06 Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Mon, 9 Sep 2024 17:19:28 +0300 Subject: [PATCH 11/14] Fix formatting, re-add comment --- vllm/worker/habana_model_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a2848880fed89..c22366f9eb957 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -52,6 +52,8 @@ logger = init_logger(__name__) _TYPE_CACHE = {} +# These values are assumed to be zero in several places. +# Use caution when updating them! _PAD_SLOT_ID = 0 _PAD_BLOCK_ID = 0 @@ -937,13 +939,13 @@ def _prepare_decode( input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) - + dummy_slots = itertools.cycle( range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) slot_mapping = [[ s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl ] for sl in slot_mapping] - + num_decode_tokens = sum(seq_lens) blocks_used = [len(bt) for bt in block_tables] From 793f54b2c4fb07529f690155ea7052f53f5f9ff6 Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Tue, 10 Sep 2024 12:38:15 +0300 Subject: [PATCH 12/14] Use max_num_batched_tokens in profile_run --- vllm/worker/habana_model_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c22366f9eb957..29530a5313c9e 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1236,9 +1236,8 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = self.prompt_seq_bucket_cfg[-1] - if self.lora_config: - max_seq_len = self.max_num_batched_tokens // max_batch_size + max_seq_len = min(self.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, From 7468aab00444b8bff7fa76a6031bd48cac0c72ef Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Tue, 10 Sep 2024 12:47:24 +0300 Subject: [PATCH 13/14] Fix logging warmup --- vllm/worker/habana_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 29530a5313c9e..b1490abba32d5 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1354,9 +1354,12 @@ def list_loras(self) -> Set[int]: def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if phase == "Prompt": + dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " - f"seq_len:{seq_len} " + f"{dim}:{seq_len} " f"free_mem:{free_mem}") logger.info(msg) From 36fc84edd157d99b24ea6f21c53e2021cd5ffd66 Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Tue, 10 Sep 2024 12:55:54 +0300 Subject: [PATCH 14/14] Hardcode default values for max prompt and decode seq The default value for both max prompt and decode seq should be max model len, but it causes graph compilation error for longer seqs - to be fixed --- vllm/worker/habana_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index b1490abba32d5..d7fd9331e8e36 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -605,8 +605,9 @@ def _setup_buckets(self) -> None: max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size blocks_step = 128 - max_prompt_seq = self.max_model_len - max_decode_seq = self.max_model_len + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 self.prompt_bs_bucket_cfg = read_bucket_settings( 'prompt', 'bs',