Skip to content

Commit

Permalink
Dispersed dummy slots (HabanaAI#243)
Browse files Browse the repository at this point in the history
Use all possible slot values for dummy blocks to avoid caching issues.
  • Loading branch information
madamczykhabana authored and zhouyu5 committed Sep 20, 2024
1 parent a93d597 commit eeb764c
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@

logger = init_logger(__name__)

# These values are assumed to be zero in several places.
# Use caution when updating them!
_PAD_SLOT_ID = 0
_PAD_BLOCK_ID = 0

LORA_WARMUP_RANK = 8
_TYPE_CACHE = {}

Expand Down Expand Up @@ -937,6 +941,13 @@ def _prepare_decode(
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))
slot_mapping = [[
s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl
] for sl in slot_mapping]

slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
Expand Down Expand Up @@ -1193,7 +1204,7 @@ def create_dummy_seq_group_metadata(self,
else:
input_len = seq_len - 1
output_len = 1
block_tables = {group_id: [0] * num_blocks}
block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks}
prompt_token_ids = [0] * input_len
output_token_ids = [1] * output_len
seq_data = SequenceData(prompt_token_ids)
Expand Down

0 comments on commit eeb764c

Please sign in to comment.