Skip to content

Commit

Permalink
Fix dispersed slots (HabanaAI#261)
Browse files Browse the repository at this point in the history
On habana_main the slots are calculated by adding an offset to the block
which breaks the check for _PAD_SLOT_ID. Reworked it so that in case of
_PAD_BLOCK_ID we're automatically inserting the right value.
  • Loading branch information
madamczykhabana authored and zhouyu5 committed Sep 13, 2024
1 parent 089a105 commit c0c204d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,9 @@ def _prepare_decode(
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))

for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
Expand Down Expand Up @@ -916,8 +919,11 @@ def _prepare_decode(

block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
if block_number == _PAD_BLOCK_ID:
slot = next(dummy_slots)
else:
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
lora_index_mapping.append(lora_id)
lora_prompt_mapping.append(lora_id)
Expand All @@ -938,12 +944,6 @@ def _prepare_decode(
dtype=torch.long,
device=self.device)

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))
slot_mapping = [[
s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl
] for sl in slot_mapping]

num_decode_tokens = sum(seq_lens)

blocks_used = [len(bt) for bt in block_tables]
Expand Down

0 comments on commit c0c204d

Please sign in to comment.