Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port flat PA from habana_next to habana_main #169

Merged
merged 15 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,33 +455,33 @@ Environment variables
- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment
variables configuring ranges of bucketing mechanism
- `{phase}` is either `PROMPT` or `DECODE`
- `{dim}` is either `BS` or `SEQ`
- `{dim}` is either `BS`, `SEQ` or `BLOCK`
- `{param}` is either `MIN`, `STEP` or `MAX`
- Default values:
- Prompt:
- batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1`
- batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32`
- batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)`
- batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`):
`min(max_num_seqs, 64)`
- sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`):
`block_size`
- sequence length step
(`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size`
- sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`):
`1024`
`max_model_len`

- Decode:
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1`
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)`
- batch size step (`VLLM_DECODE_BS_BUCKET_STEP`):
`128`
`min(max_num_seqs, 32)`
- batch size max (`VLLM_DECODE_BS_BUCKET_MAX`):
`max_num_seqs`
- sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`):
`block_size`
- sequence length step
(`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size`
- sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`):
`2048`
- block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`):
`128`
- block size step
(`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128`
- block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`):
`max(128, (max_num_seqs*max_model_len)/block_size)`

Additionally, there are HPU PyTorch Bridge environment variables
impacting vLLM execution:
Expand Down
14 changes: 7 additions & 7 deletions docs/source/getting_started/gaudi-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,19 +335,19 @@ Environment variables

- Prompt:
- batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)``
- sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len``

- Decode:
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128``
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs``
- sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048``
- sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128``
- sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128``
- sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)``


Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:
Expand Down
136 changes: 40 additions & 96 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,58 +58,14 @@ def copy_blocks(


@dataclass
class HabanaAttentionMetadata(AttentionMetadata, HabanaPagedAttentionMetadata):
"""Metadata for HabanaAttentionbackend.

NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HabanaAttentionbackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]

# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|

# Maximum query length in the batch.
max_query_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[torch.Tensor] = None


class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
"""
Expand Down Expand Up @@ -229,60 +185,48 @@ def forward(

if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
else:
attn_bias = None

query_shape = (batch_size, seq_len, self.num_heads,
self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
)
output = out.reshape(batch_size, seq_len, hidden_size)
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
else:
# prefix-enabled attention
output = HabanaPagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.seq_lens_tensor,
attn_metadata.context_lens_tensor,
attn_metadata.max_query_len,
self.alibi_slopes,
)
attn_bias = None

query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HabanaPagedAttention.forward_decode(
query, key_cache, value_cache, attn_metadata.block_tables,
attn_metadata.seq_lens_tensor, self.kv_cache_dtype,
self.num_kv_heads, self.scale, self.position_bias, k_scale,
v_scale, self.matmul_qk, self.softmax, self.matmul_av,
self.k_cache, self.v_cache)
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand Down
51 changes: 5 additions & 46 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,9 @@
@dataclass
class HabanaPagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]


class HabanaPagedAttention:
Expand Down Expand Up @@ -63,42 +56,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
slot_mapping, kv_cache_dtype, is_prompt)

@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
matmul_qk_op,
softmax_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
) -> torch.Tensor:
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
alibi_slopes,
kv_cache_dtype,
matmul_qk_op,
softmax_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
)
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)

@staticmethod
def forward_prefix(
Expand Down
Loading
Loading