From 03053b122fed1c60cb47cde4284d6ae9287bf740 Mon Sep 17 00:00:00 2001 From: Mikhail Dvoretckii Date: Thu, 8 Feb 2024 18:32:21 +0200 Subject: [PATCH] Enable cache ops for beam search --- vllm/hpu/cache_ops.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index fb08e4167a10a..b734ff62cbd21 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -39,3 +39,30 @@ def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, is_promp index.add_(1) key_cache = key_cache.permute(0, 2, 3, 1) value_cache = value_cache.permute(0, 2, 3, 1) + + +def swap_blocks(src, dst, block_mapping): + index_src = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device) + index_dst = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device) + for src_idx, dst_idx in block_mapping.items(): + index_src[0] = src_idx + index_dst[0] = dst_idx + dst.index_put_([index_dst], src.index_select(0, index_src)) + if dst.device.type == 'hpu': + htorch.core.mark_step() + torch.hpu.synchronize() + + +def copy_blocks(key_caches, value_caches, block_mapping): + index_src = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device) + index_dst = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device) + for src, dsts in block_mapping.items(): + index_src[0] = src + for dst in dsts: + index_dst[0] = dst + for key_cache in key_caches: + key_cache.index_copy_(0, index_dst, key_cache.index_select(0, index_src)) + for value_cache in value_caches: + value_cache.index_copy_(0, index_dst, value_cache.index_select(0, index_src)) + if key_caches[0].device.type == 'hpu': + htorch.core.mark_step()