Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cache ops for beam search #3

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,30 @@ def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, is_promp
index.add_(1)
key_cache = key_cache.permute(0, 2, 3, 1)
value_cache = value_cache.permute(0, 2, 3, 1)


def swap_blocks(src, dst, block_mapping):
index_src = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
index_dst = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
for src_idx, dst_idx in block_mapping.items():
index_src[0] = src_idx
index_dst[0] = dst_idx
dst.index_put_([index_dst], src.index_select(0, index_src))
if dst.device.type == 'hpu':
htorch.core.mark_step()
torch.hpu.synchronize()


def copy_blocks(key_caches, value_caches, block_mapping):
index_src = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
index_dst = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
for src, dsts in block_mapping.items():
index_src[0] = src
for dst in dsts:
index_dst[0] = dst
for key_cache in key_caches:
key_cache.index_copy_(0, index_dst, key_cache.index_select(0, index_src))
for value_cache in value_caches:
value_cache.index_copy_(0, index_dst, value_cache.index_select(0, index_src))
if key_caches[0].device.type == 'hpu':
htorch.core.mark_step()