Skip to content

Commit

Permalink
HPU: Change KV-cache layout (vllm-project#56)
Browse files Browse the repository at this point in the history
* HPU: Change KV-cache layout to (num_blocks, block_size, num_heads, head_size)

* Fix UTs

* Fix UTs - part 2
  • Loading branch information
madamczykhabana authored Jun 11, 2024
1 parent 45fb692 commit 2825dde
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 76 deletions.
18 changes: 13 additions & 5 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@ def ref_single_query_cached_kv_attention(
alibi_slopes: Optional[torch.Tensor],
) -> None:
num_query_heads = query.shape[1]
num_kv_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
if not is_hpu():
num_kv_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
else:
block_size = value_cache.shape[1]
num_kv_heads = value_cache.shape[2]
head_size = value_cache.shape[3]
num_seqs = query.shape[0]

block_tables = block_tables.cpu().tolist()
Expand All @@ -93,13 +98,16 @@ def ref_single_query_cached_kv_attention(
block_offset = j % block_size

if is_hpu():
k = key_cache[block_number, :, :, block_offset]
k = key_cache[block_number, block_offset, :, :]
else:
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size)
keys.append(k)

v = value_cache[block_number, :, :, block_offset]
if is_hpu():
v = value_cache[block_number, block_offset, :, :]
else:
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
Expand Down
20 changes: 8 additions & 12 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def test_copy_blocks(
if is_hpu():
tmp_block_mapping_dict = {}
for src, dst in block_mapping:
print(src, dst, tmp_block_mapping_dict)
if not tmp_block_mapping_dict.get(src):
tmp_block_mapping_dict[src] = [dst]
continue
Expand Down Expand Up @@ -191,17 +190,11 @@ def test_reshape_and_cache(
kv_scale = 1.0

# Call the reshape_and_cache kernel.
if is_hpu():
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping.view((1, -1)), "auto", False)
else:
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, "auto")
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, "auto")

# Run the reference implementation.
if is_hpu():
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0].shape)
else:
if not is_hpu():
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indices = block_indices.cpu().tolist()
Expand All @@ -211,10 +204,13 @@ def test_reshape_and_cache(
block_idx = block_indices[i]
block_offset = block_offsets[i]
if is_hpu():
cloned_key_cache[block_idx, :, :, block_offset] = reshaped_key[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
else:
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
if is_hpu():
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
else:
cloned_value_cache[block_idx, :, :, block_offset] = value[i]

if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, num_kv_heads, head_size, block_size)
return (num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def split_kv_cache(
Expand Down Expand Up @@ -86,7 +86,7 @@ def forward_decode(
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
) -> torch.Tensor:
block_size = value_cache.shape[3]
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
query,
key_cache,
Expand Down
50 changes: 7 additions & 43 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,13 @@
import habana_frameworks.torch as htorch


def pad_to_full_block(data, block_size, pad_value):
seq_dim = 1
pad_shape = list(data.shape)
remainder = pad_shape[seq_dim] % block_size
if remainder == 0:
return data
pad_shape[seq_dim] = block_size - remainder
pad = torch.full(pad_shape, pad_value, dtype=data.dtype, device=data.device)
return torch.cat([data, pad], dim=seq_dim)


def initialize_cache(data, indices, cache):
block_size = cache.size(-1)
data = data.unflatten(0, (-1, block_size)).permute(0, 2, 3, 1)
indices = indices.unflatten(0, (-1, block_size))[:,0]
cache.index_copy_(0, indices, data)


def update_cache(data, indices, offsets, cache):
prev = cache.index_select(0, indices)
idx = offsets.view(-1, 1, 1, 1).expand(-1, data.size(1), data.size(2), -1)
prev.scatter_(-1, idx, data.unsqueeze(-1))
cache.index_copy_(0, indices, prev)


def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, dtype, is_prompt):
block_size = key_cache.size(-1)
assert slot_mapping.dim() == 2, 'This implementation requires unflattened slot_mapping!'

if is_prompt:
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
batch_size, seq_length = block_indices.shape
key = pad_to_full_block(key.unflatten(0, (batch_size, seq_length)), block_size, 0).flatten(0, 1)
value = pad_to_full_block(value.unflatten(0, (batch_size, seq_length)), block_size, 0).flatten(0, 1)
block_indices = pad_to_full_block(block_indices, block_size, -1).flatten(0, 1)
initialize_cache(key, block_indices, key_cache)
initialize_cache(value, block_indices, value_cache)
else:
slot_mapping = slot_mapping.flatten()
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_offsets = torch.fmod(slot_mapping, block_size)
update_cache(key, block_indices, block_offsets, key_cache)
update_cache(value, block_indices, block_offsets, value_cache)
def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, dtype, is_prompt=False):
block_size = key_cache.size(1)
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)
key_cache.index_put_((indices, offsets), key)
value_cache.index_put_((indices, offsets), value)


def swap_blocks(src, dst, block_mapping):
Expand Down
19 changes: 9 additions & 10 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,48 +31,47 @@ def gelu_fast(output, input):
raise NotImplementedError


def fetch_from_cache(cache, blocks):
return [cache.index_select(0, blocks[:, i]) for i in range(blocks.size(1))]
def fetch_from_cache(cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]


@hpu_utils.with_mark_steps
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None:
seq_len = block_tables.size(1)
batch_size, query_heads, _ = query.shape
_, kv_heads, _, _ = key_cache.shape
_, _, kv_heads, _ = key_cache.shape
min_inf = torch.finfo(query.dtype).min
mask = (torch.arange(0, seq_len * block_size, dtype=torch.int32, device=key_cache.device)
.view(1, -1)
.expand(batch_size, -1)
.ge(context_lens.view(-1, 1))
.view(batch_size, 1, 1, -1))
query.mul_(scale)
query = query.unsqueeze(-2)
keys = fetch_from_cache(key_cache, block_tables)
keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1))
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = [torch.matmul(query, k) for k in keys]
attn_weights = (torch.cat(attn_weights, dim=-1)
.mul_(scale)
.masked_fill(mask, min_inf)
.softmax(dim=-1))

values = fetch_from_cache(value_cache, block_tables)
values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3))
if PA_SPLIT_VALUE:
attn_weights = attn_weights.split(block_size, dim=-1)
else:
values = [torch.cat(values, dim=-1)]
values = [torch.cat(values, dim=-2)]
attn_weights = [attn_weights]
if query_heads != kv_heads:
values = [v.unflatten(1, (kv_heads, 1)) for v in values]
attn_weights = [torch.matmul(a, v.transpose(-1, -2)).squeeze(-2) for a, v in zip(attn_weights, values)]
attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)]
if query_heads != kv_heads:
attn_weights = [a.flatten(1, 2) for a in attn_weights]
attn_weights = sum(attn_weights)

return attn_weights
return attn_weights.squeeze(-2)


def rms_norm(out, hidden_states, weight, eps):
Expand Down
7 changes: 5 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def create_kv_caches_with_random(

scale = head_size**-0.5
if is_hpu():
key_cache_shape = (num_blocks, num_heads, head_size, block_size)
key_cache_shape = (num_blocks, block_size, num_heads, head_size)
else:
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
Expand All @@ -429,7 +429,10 @@ def create_kv_caches_with_random(
f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_cache)

value_cache_shape = (num_blocks, num_heads, head_size, block_size)
if is_hpu():
value_cache_shape = (num_blocks, block_size, num_heads, head_size)
else:
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,8 +976,7 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem):
total_batch_seq += batch_seq
graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt)
logger.info(f'{phase} captured:{len(graphed)} ({100 * len(graphed) / num_candidates:.1f}%) used_mem:{format_bytes(total_mem)} buckets:{sorted(list(graphed))}')



@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true':
Expand Down

0 comments on commit 2825dde

Please sign in to comment.