Skip to content

Commit

Permalink
Fix dispersed slots
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana committed Sep 10, 2024
1 parent 5cf8441 commit f392046
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,9 @@ def _prepare_decode(
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))

for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
Expand Down Expand Up @@ -918,8 +921,11 @@ def _prepare_decode(

block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
if block_number == _PAD_BLOCK_ID:
slot = next(dummy_slots)
else:
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
lora_index_mapping.append(lora_id)
lora_prompt_mapping.append(lora_id)
Expand All @@ -940,12 +946,6 @@ def _prepare_decode(
dtype=torch.long,
device=self.device)

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))
slot_mapping = [[
s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl
] for sl in slot_mapping]

num_decode_tokens = sum(seq_lens)

blocks_used = [len(bt) for bt in block_tables]
Expand Down

0 comments on commit f392046

Please sign in to comment.