diff --git a/README_GAUDI.md b/README_GAUDI.md index 91bcbe49405eb..5109f7ddf9927 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -455,12 +455,12 @@ Environment variables - `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - `{phase}` is either `PROMPT` or `DECODE` - - `{dim}` is either `BS` or `SEQ` + - `{dim}` is either `BS`, `SEQ` or `BLOCK` - `{param}` is either `MIN`, `STEP` or `MAX` - Default values: - Prompt: - batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1` - - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32` + - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)` - batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`): `min(max_num_seqs, 64)` - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): @@ -468,20 +468,20 @@ Environment variables - sequence length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): - `1024` + `max_model_len` - Decode: - - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` + - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)` - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): - `128` + `min(max_num_seqs, 32)` - batch size max (`VLLM_DECODE_BS_BUCKET_MAX`): `max_num_seqs` - - sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`): - `block_size` - - sequence length step - (`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size` - - sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`): - `2048` + - block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`): + `128` + - block size step + (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128` + - block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`): + `max(128, (max_num_seqs*max_model_len)/block_size)` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index b3234d10b3115..ed3beabb2c8aa 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -335,19 +335,19 @@ Environment variables - Prompt: - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len`` - Decode: - - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` - - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size`` - - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048`` + - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128`` + - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128`` + - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 2259630fa10b7..20b0f2bc7630b 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -58,58 +58,14 @@ def copy_blocks( @dataclass -class HabanaAttentionMetadata(AttentionMetadata, HabanaPagedAttentionMetadata): - """Metadata for HabanaAttentionbackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ +class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata): + """Metadata for HabanaAttentionbackend.""" # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. + attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. - max_query_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[torch.Tensor] = None - class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ @@ -229,60 +185,48 @@ def forward( if attn_metadata.is_prompt: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: - if not self.prefill_usefusedsdpa: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) - else: - attn_bias = None - - query_shape = (batch_size, seq_len, self.num_heads, - self.head_size) - kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, - self.head_size) - out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - softmax_op=self.softmax, - matmul_av_op=self.matmul_av, - valid_seq_lengths=attn_metadata.seq_lens_tensor, - ) - output = out.reshape(batch_size, seq_len, hidden_size) + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None and \ + self.position_bias is not None: + attn_bias.add_(self.position_bias[:, :, + -attn_bias.size(2):, + -attn_bias.size(3):]) else: - # prefix-enabled attention - output = HabanaPagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.seq_lens_tensor, - attn_metadata.context_lens_tensor, - attn_metadata.max_query_len, - self.alibi_slopes, - ) + attn_bias = None + + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + ) + output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HabanaPagedAttention.forward_decode( - query, key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, self.kv_cache_dtype, - self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale, self.matmul_qk, self.softmax, self.matmul_av, - self.k_cache, self.v_cache) + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 9602886299c47..cab8d7abe95fd 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -16,16 +16,9 @@ @dataclass class HabanaPagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of sequences (entire tokens seen so far) per - # sequence. - seq_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] class HabanaPagedAttention: @@ -63,42 +56,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, slot_mapping, kv_cache_dtype, is_prompt) @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) -> torch.Tensor: - block_size = value_cache.shape[1] - return ops.paged_attention_v1( - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - alibi_slopes, - kv_cache_dtype, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) @staticmethod def forward_prefix( diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index bacb755b39393..b2705429906c4 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. ############################################################################### -import os from typing import Optional import habana_frameworks.torch as htorch @@ -29,72 +28,57 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") -PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') - - -def fetch_from_cache(cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] - - -def paged_attention_v1(query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - alibi_slopes=None, - kv_cache_dtype=None, - matmul_qk_op=torch.matmul, - softmax_op=torch.softmax, - matmul_av_op=torch.matmul, - k_cache_cls=None, - v_cache_cls=None) -> None: - seq_len = block_tables.size(1) - batch_size, query_heads, _ = query.shape - _, _, kv_heads, _ = key_cache.shape - min_inf = torch.finfo(query.dtype).min - mask = (torch.arange(0, - seq_len * block_size, - dtype=torch.int32, - device=key_cache.device).view(1, -1).expand( - batch_size, -1).ge(context_lens.view(-1, 1)).view( - batch_size, 1, 1, -1)) - query.mul_(scale) - query = query.unsqueeze(-2) - fetch_keys = fetch_from_cache if k_cache_cls is None else \ - k_cache_cls.fetch_from_cache - keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) - if query_heads != kv_heads: + +def batch2block(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block2batch(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block_softmax(batch_size, attn, block_mapping): + attn.sub_(10.0) + attn = attn.exp_() + sums = attn.sum(dim=-1).unsqueeze(-1) + sums = block2batch(sums, block_mapping) + sums = batch2block(sums, block_mapping) + sums.add_(1.0e-12) + attn.div_(sums) + return attn + + +def flat_pa(query, key_cache, value_cache, block_list, block_mapping, + block_bias, scale, matmul_qk_op, matmul_av_op, keys_fetch_func, + values_fetch_func): + batch_size = query.size(0) + q_heads = query.size(1) + kv_heads = key_cache.size(2) + + query = batch2block(scale * query, block_mapping).unsqueeze(-2) + key = keys_fetch_func(key_cache, block_list).transpose(1, 2) + value = values_fetch_func(value_cache, block_list).transpose(1, 2) + block_bias = block_bias.view(key.size(0), 1, 1, -1) + + if kv_heads != q_heads: + block_bias = block_bias.unsqueeze(1) query = query.unflatten(1, (kv_heads, -1)) - keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] - mask = mask.unsqueeze(2) - - attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) - if alibi_slopes is not None: - attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, - -attn_weights.size(3):]) - attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - - fetch_values = fetch_from_cache if v_cache_cls is None else \ - v_cache_cls.fetch_from_cache - values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) - if PA_SPLIT_VALUE: - attn_weights = attn_weights.split(block_size, dim=-1) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + key = key.transpose(3, 4) else: - values = [torch.cat(values, dim=-2)] - attn_weights = [attn_weights] - if query_heads != kv_heads: - values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] - if query_heads != kv_heads: - attn_weights = [a.flatten(1, 2) for a in attn_weights] - attn_weights = sum(attn_weights) - return attn_weights.squeeze(-2) + key = key.transpose(2, 3) + + attn = matmul_qk_op(query, key) + block_bias + attn = block_softmax(batch_size, attn, block_mapping) + attn = matmul_av_op(attn, value) + attn = block2batch(attn, block_mapping) + attn = attn.squeeze(-2) + if kv_heads != q_heads: + attn = attn.flatten(1, 2) + return attn def silu_and_mul(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 3d9c7cb1c4c22..13204b83d5742 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -57,8 +57,5 @@ def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_offset) return cache - def fetch_from_cache(self, cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] + def fetch_from_cache(self, cache, blocks): + return cache.index_select(0, blocks) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a4ade587db089..a6bd5e5f68745 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -51,29 +51,47 @@ logger = init_logger(__name__) +_TYPE_CACHE = {} # These values are assumed to be zero in several places. # Use caution when updating them! _PAD_SLOT_ID = 0 _PAD_BLOCK_ID = 0 LORA_WARMUP_RANK = 8 -_TYPE_CACHE = {} + + +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) def read_bucket_settings(phase: str, dim: str, **defaults): """Read bucketing configuration from env variables. phase is either 'prompt' or 'decode' - dim is either 'bs' or 'block' + dim is either 'bs', 'seq' or 'block' param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ params = ['min', 'step', 'max'] + env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] + default_values = [defaults[p] for p in params] values = [ - int( - os.environ.get(f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper(), - defaults[p])) for p in params + int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) ] + for e, v, d in zip(env_vars, values, defaults): + logger.info('%s=%s (default:%s)', e, v, d) return values @@ -103,9 +121,9 @@ def warmup_range(config: Tuple[int, int, int]): return list(filter(lambda bucket: bucket >= bmin, buckets)) -def warmup_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): +def generate_prompt_buckets(bs_bucket_config, + seq_bucket_config, + max_num_batched_tokens=None): buckets = list( itertools.product(warmup_range(bs_bucket_config), warmup_range(seq_bucket_config))) @@ -150,6 +168,19 @@ def warmup_buckets(bs_bucket_config, return captured_buckets, omitted_buckets +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, + max_blocks): + buckets = [] + for bs in warmup_range(bs_bucket_config): + for blocks in warmup_range(blocks_bucket_config): + if blocks < bs: + continue + if blocks > max_blocks: + break + buckets.append((bs, blocks)) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + + def next_pow2(value: int, base: int): res = base while value > 1: @@ -169,22 +200,6 @@ def find_bucket(value: int, config: Tuple[int, int, int]): return max(bmin, min(next_step, next_pow)) -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): - if to_override is None: - to_override = {} - if obj is None: - return None - fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} - if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) - return _TYPE_CACHE[typename](**values) - - def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -195,13 +210,19 @@ def align_workers(value, op): return value_t.item() +def pad_list(list, k, v): + target_len = round_up(len(list), k) + padding = target_len - len(list) + return list + [v] * padding + + class HpuModelAdapter(): - def __init__(self, model, enforce_eager): + def __init__(self, model, block_size, enforce_eager): self.model = model self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] - + self.block_size = block_size if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', @@ -225,22 +246,45 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, mask = causal_mask.logical_or(len_mask) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - #FIXME: Restore sliding window support - #if self.sliding_window is not None: attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata + def _set_block_mapping(self, metadata, batch_size, device, dtype): + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) + mask = mask >= metadata.block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + block_mapping = torch.nn.functional.one_hot( + metadata.block_mapping.to(torch.long), + num_classes=batch_size).to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): + if attn_metadata.is_prompt: + meta = attn_metadata + attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, + device, dtype) + else: + meta = attn_metadata + attn_metadata = self._set_block_mapping(meta, batch_size, device, + dtype) + return attn_metadata + def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], - input_ids.size(0), - input_ids.size(1), - input_ids.device, - torch.bfloat16) + kwargs['attn_metadata'] = self._update_metadata( + kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), + input_ids.device, torch.bfloat16) LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -536,7 +580,9 @@ def load_model(self) -> None: # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph( - self.model, enforce_eager=self.enforce_eager) + self.model, + self.block_size, + enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) @@ -553,73 +599,48 @@ def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) max_bucket_cfg = 64 if self.lora_config and \ max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size - self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', - 'bs', - min=1, - step=32, - max=min( - self.max_num_seqs, - max_bucket_cfg)) + blocks_step = 128 + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 + self.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=align_bs(max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', - min=1, - step=128, + min=align_bs(32), + step=align_bs(32), max=self.max_num_seqs) self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', 'seq', min=self.block_size, step=self.block_size, - max=1024) - self.decode_seq_bucket_cfg = read_bucket_settings('decode', - 'seq', - min=self.block_size, - step=self.block_size, - max=2048) + max=max_prompt_seq) + self.decode_block_bucket_cfg = read_bucket_settings( + 'decode', + 'block', + min=blocks_step, + step=blocks_step, + max=max(blocks_step, + self.max_num_seqs * max_decode_seq // self.block_size)) self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - if self.lora_config: - self.prompt_buckets[:] = [ - bucket for bucket in self.prompt_buckets - if self._is_valid_bucket(bucket) - ] - - msg = (f"Generated {len(self.prompt_buckets)} " - f"prompt buckets: {list(sorted(self.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " - f"seq:{self.decode_seq_bucket_cfg}") - logger.info(msg) - self.decode_buckets, _ = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) - if self.lora_config: - self.decode_buckets[:] = [ - bucket for bucket in self.decode_buckets - if self._is_valid_bucket(bucket) - ] - msg = (f"Generated {len(self.decode_buckets)} decode buckets: " - f"{list(sorted(self.decode_buckets))}") + f"block:{self.decode_block_bucket_cfg}") logger.info(msg) def _prepare_prompt( @@ -735,10 +756,6 @@ def _prepare_prompt( real_num_seqs = len(query_lens) assert max_query_len > 0 - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - if multi_modal_input_list: assert self.multimodal_config, ( "Multi-modal inputs are only supported by " @@ -748,7 +765,6 @@ def _prepare_prompt( else: multi_modal_input = None - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) @@ -814,37 +830,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - block_tables = make_tensor_with_pad(prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - seq_lens=seq_lens, + block_list=None, + block_mapping=None, + block_usage=None, + attn_bias=None, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, @@ -950,32 +946,50 @@ def _prepare_decode( s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl ] for sl in slot_mapping] + num_decode_tokens = sum(seq_lens) + + blocks_used = [len(bt) for bt in block_tables] + block_list = list(itertools.chain(*block_tables)) + block_mapping_nested: List[List[int]] = [ + [i] * b_u for i, b_u in enumerate(blocks_used) + ] + block_mapping: List[int] = list( + itertools.chain.from_iterable(block_mapping_nested)) + + last_block = [ + sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) + ] + block_usage = [[self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block)] + block_usage = list(itertools.chain(*block_usage)) + + block_bucket_size = find_bucket(len(block_list), + self.decode_block_bucket_cfg) + block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) + block_mapping = pad_list(block_mapping, block_bucket_size, 0) + block_usage = pad_list(block_usage, block_bucket_size, 0) + + block_list = torch.tensor(block_list, + dtype=torch.int, + device=self.device) + block_mapping = torch.tensor(block_mapping, + dtype=torch.int, + device=self.device) + block_usage = torch.tensor(block_usage, + dtype=torch.bfloat16, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - num_decode_tokens = sum(seq_lens) - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=False, + block_list=block_list, + block_mapping=block_mapping, + block_usage=block_usage, + attn_bias=None, + seq_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, @@ -1163,7 +1177,7 @@ def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) else: - return attn_metadata.block_tables.size(1) * self.block_size + return attn_metadata.block_list.numel() def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: @@ -1187,8 +1201,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'block_tables', 'seq_lens_tensor', 'attn_bias', 'slot_mapping', - 'is_prompt' + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', + 'block_usage', 'slot_mapping', 'is_prompt' ]) return attention_metadata @@ -1222,9 +1236,8 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = self.prompt_seq_bucket_cfg[-1] - if self.lora_config: - max_seq_len = self.max_num_batched_tokens // max_batch_size + max_seq_len = min(self.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, @@ -1277,21 +1290,34 @@ def warmup_scenario(self, [0] * batch_size * seq_len, ) self.set_active_loras(set(), lora_mapping) - seqs = [ - self.create_dummy_seq_group_metadata( - i, - seq_len, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) - for i in range(batch_size) - ] + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size - 1, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i, b in enumerate(blocks) + ] torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + self.execute_model(inputs, kv_caches, warmup_mode=False) torch.hpu.synchronize() - self.profiler.end() gc.collect() def remove_all_loras(self): @@ -1328,9 +1354,12 @@ def list_loras(self) -> Set[int]: def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if phase == "Prompt": + dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " - f"seq_len:{seq_len} " + f"{dim}:{seq_len} " f"free_mem:{free_mem}") logger.info(msg) @@ -1390,6 +1419,8 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) + if num_candidates == 0: + num_candidates = 1 msg = (f'{phase} captured:{len(graphed)} ' f'({100 * len(graphed) / num_candidates:.1f}%) ' f'used_mem:{format_bytes(total_mem)} ' @@ -1402,6 +1433,42 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info("Skipping warmup...") return self.profiler.start('internal', 'warmup') + max_blocks = kv_caches[0][0].size(0) + + self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + if self.lora_config: + self.prompt_buckets[:] = [ + bucket for bucket in self.prompt_buckets + if self._is_valid_bucket(bucket) + ] + + msg = ( + f"Generated {len(self.prompt_buckets)} " + f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + self.decode_buckets = generate_decode_buckets( + self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, + max_blocks) + if self.lora_config: + self.decode_buckets[:] = [ + bucket for bucket in self.decode_buckets + if self._is_valid_bucket(bucket) + ] + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.decode_buckets), + list(sorted(self.decode_buckets))) + start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter()