Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlperf features gdn #867

Merged
merged 6 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@bb47de4

35 changes: 34 additions & 1 deletion vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,42 @@

import torch
from vllm_hpu_extension import cache_ops, ops
import habana_frameworks.torch as htorch

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512


def _graphed(fn):
class Graphed(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return fn(*args, **kwargs)
graph = htorch.hpu.wrap_in_hpu_graph(Graphed(), disable_tensor_cache=True)

def wrapper(*args, **kwargs):
return graph.forward(*args, **kwargs)
return wrapper


@_graphed
def _copy_blocks(key_caches, value_caches, block_mapping):
if key_caches[0].device.type == 'hpu':
htorch.core.mark_step()
block_mapping = block_mapping.transpose(0, 1)
src = block_mapping[0]
dst = block_mapping[1]

for key_cache, value_cache in zip(key_caches, value_caches):
key_cache.index_copy_(0, dst, key_cache.index_select(0, src))
value_cache.index_copy_(0, dst, value_cache.index_select(0, src))

if key_caches[0].device.type == 'hpu':
htorch.core.mark_step()


@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
Expand Down Expand Up @@ -85,6 +116,8 @@ def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dsts: torch.Tensor,
) -> None:
if src_to_dsts.numel() == 0:
return
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
_copy_blocks(key_caches, value_caches, src_to_dsts)
21 changes: 21 additions & 0 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,27 @@ def allocate_immutable_blocks(

return blocks

def _can_defragment_all(self, block_ids):
block_ids = list(sorted(block_ids))
num_blocks = len(block_ids)
if len(self._free_block_indices) < num_blocks:
return False
free_block_ids = heapq.nsmallest(num_blocks, self._free_block_indices)
return free_block_ids[-1] < block_ids[0]

def _reassign_block_id(self, block):
prev_block_id = block.block_id
self._free_block_id(block)
block.block_id = self._allocate_block_id()
new_block_id = block.block_id
return (prev_block_id, new_block_id)

def defragment_all(self, blocks):
if not self._can_defragment_all([b.block_id for b in blocks]):
return []
return [self._reassign_block_id(b) for b in blocks]


def allocate_mutable_block(self,
prev_block: Optional[Block],
extra_hash: Optional[int] = None,
Expand Down
29 changes: 29 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
import heapq
import itertools

from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
Expand Down Expand Up @@ -105,6 +107,33 @@ def __init__(
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)

def try_defragmenting(self, num_blocks):
if len(self.block_tables) == 0:
return []
block_heap = []

for seq_id, block_table in self.block_tables.items():
for block in block_table.blocks:
item = (block.block_id, block, seq_id)
if len(block_heap) < num_blocks:
heapq.heappush(block_heap, item)
else:
heapq.heappushpop(block_heap, item)
if len(block_heap) < num_blocks:
return []

block_ids, blocks, block_seq_ids = zip(*block_heap)
allocator = self.block_allocator._allocators[Device.GPU]
if not allocator.defragment_all(blocks):
return []

new_block_ids = [b.block_id for b in blocks]
for seq_id in set(block_seq_ids):
bt = self.block_tables[seq_id]
bt.update(bt.blocks)
changed_block_ids = list(zip(block_ids, new_block_ids))
return changed_block_ids

def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
Expand Down
8 changes: 6 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

VLLM_DEFRAGMENT_BLOCK_IDS = int(os.getenv("VLLM_DEFRAGMENT_BLOCK_IDS", "0"))

class PreemptionMode(enum.Enum):
"""Preemption modes.
Expand Down Expand Up @@ -1397,8 +1397,8 @@ def schedule(
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter()

scheduler_outputs: SchedulerOutputs = self._schedule()

now = time.time()

if not self.cache_config.enable_prefix_caching:
Expand Down Expand Up @@ -1543,6 +1543,10 @@ def schedule(
# Move to next cache (if exists)
self.cache_id = self.next_cache_id

if VLLM_DEFRAGMENT_BLOCK_IDS > 0:
selected_blocks = self.block_manager.try_defragmenting(VLLM_DEFRAGMENT_BLOCK_IDS)
scheduler_outputs.blocks_to_copy.extend(selected_blocks)

# Return results
return (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
Expand Down
Loading
Loading