diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 2204970080ac9..9096c8deecf5f 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -9,4 +9,3 @@ tabulate setuptools>=61 setuptools-scm>=8 vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@bb47de4 - diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index e55a4de11fd6c..57f4499eaa261 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -7,11 +7,42 @@ import torch from vllm_hpu_extension import cache_ops, ops +import habana_frameworks.torch as htorch # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 +def _graphed(fn): + class Graphed(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return fn(*args, **kwargs) + graph = htorch.hpu.wrap_in_hpu_graph(Graphed(), disable_tensor_cache=True) + + def wrapper(*args, **kwargs): + return graph.forward(*args, **kwargs) + return wrapper + + +@_graphed +def _copy_blocks(key_caches, value_caches, block_mapping): + if key_caches[0].device.type == 'hpu': + htorch.core.mark_step() + block_mapping = block_mapping.transpose(0, 1) + src = block_mapping[0] + dst = block_mapping[1] + + for key_cache, value_cache in zip(key_caches, value_caches): + key_cache.index_copy_(0, dst, key_cache.index_select(0, src)) + value_cache.index_copy_(0, dst, value_cache.index_select(0, src)) + + if key_caches[0].device.type == 'hpu': + htorch.core.mark_step() + + @dataclass class HPUPagedAttentionMetadata: """Metadata for PagedAttention.""" @@ -85,6 +116,8 @@ def copy_blocks( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], src_to_dsts: torch.Tensor, ) -> None: + if src_to_dsts.numel() == 0: + return key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) + _copy_blocks(key_caches, value_caches, src_to_dsts) diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 695870742da50..fb21946f9f683 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -108,6 +108,27 @@ def allocate_immutable_blocks( return blocks + def _can_defragment_all(self, block_ids): + block_ids = list(sorted(block_ids)) + num_blocks = len(block_ids) + if len(self._free_block_indices) < num_blocks: + return False + free_block_ids = heapq.nsmallest(num_blocks, self._free_block_indices) + return free_block_ids[-1] < block_ids[0] + + def _reassign_block_id(self, block): + prev_block_id = block.block_id + self._free_block_id(block) + block.block_id = self._allocate_block_id() + new_block_id = block.block_id + return (prev_block_id, new_block_id) + + def defragment_all(self, blocks): + if not self._can_defragment_all([b.block_id for b in blocks]): + return [] + return [self._reassign_block_id(b) for b in blocks] + + def allocate_mutable_block(self, prev_block: Optional[Block], extra_hash: Optional[int] = None, diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index b41e848221882..000474961916f 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -2,6 +2,8 @@ from typing import Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple +import heapq +import itertools from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator @@ -105,6 +107,33 @@ def __init__( self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) + def try_defragmenting(self, num_blocks): + if len(self.block_tables) == 0: + return [] + block_heap = [] + + for seq_id, block_table in self.block_tables.items(): + for block in block_table.blocks: + item = (block.block_id, block, seq_id) + if len(block_heap) < num_blocks: + heapq.heappush(block_heap, item) + else: + heapq.heappushpop(block_heap, item) + if len(block_heap) < num_blocks: + return [] + + block_ids, blocks, block_seq_ids = zip(*block_heap) + allocator = self.block_allocator._allocators[Device.GPU] + if not allocator.defragment_all(blocks): + return [] + + new_block_ids = [b.block_id for b in blocks] + for seq_id in set(block_seq_ids): + bt = self.block_tables[seq_id] + bt.update(bt.blocks) + changed_block_ids = list(zip(block_ids, new_block_ids)) + return changed_block_ids + def can_allocate(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> AllocStatus: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 200098e3828da..c3da8c1c43985 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -27,7 +27,7 @@ os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa ARTIFICIAL_PREEMPTION_PROB = 0.5 ARTIFICIAL_PREEMPTION_MAX_CNT = 500 - +VLLM_DEFRAGMENT_BLOCK_IDS = int(os.getenv("VLLM_DEFRAGMENT_BLOCK_IDS", "0")) class PreemptionMode(enum.Enum): """Preemption modes. @@ -1397,8 +1397,8 @@ def schedule( # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. scheduler_start_time = time.perf_counter() - scheduler_outputs: SchedulerOutputs = self._schedule() + now = time.time() if not self.cache_config.enable_prefix_caching: @@ -1543,6 +1543,10 @@ def schedule( # Move to next cache (if exists) self.cache_id = self.next_cache_id + if VLLM_DEFRAGMENT_BLOCK_IDS > 0: + selected_blocks = self.block_manager.try_defragmenting(VLLM_DEFRAGMENT_BLOCK_IDS) + scheduler_outputs.blocks_to_copy.extend(selected_blocks) + # Return results return (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 058d765b49aab..d6182f1dc755e 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -75,6 +75,8 @@ LORA_WARMUP_RANK = 8 VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true' +VLLM_BUCKETING_SCHEME = os.environ.get('VLLM_BUCKETING_SCHEME', 'legacy') +VLLM_STRICT_BUCKETS = os.environ.get('VLLM_STRICT_BUCKETS', 'false') == 'true' DUMMY_TOKEN_ID = -1 @@ -242,6 +244,156 @@ def generate_prompt_buckets(self): f"{list(sorted(self.global_state.prompt_buckets))}") print(msg) + +class BucketingScheme: + def find_bucket(self, real_bs, real_blocks): + for bs, blocks in self.buckets: + if bs < real_bs: + continue + for b in blocks: + if b < real_blocks: + continue + return (bs, b) + assert False, "Couldn't find bucket for {} {}".format(real_bs, real_blocks) + + def list_buckets(self): + buckets = [[(bs, b) for b in blocks] for bs, blocks in self.buckets] + buckets = list(itertools.chain(*buckets)) + return buckets + + +class FileBucketingScheme(BucketingScheme): + def __init__(self, filename, max_bs, max_blocks): + self.buckets = self._read_buckets(filename, max_bs, max_blocks) + logger.info('Decode buckets [file]: {}'.format(self.list_buckets())) + + def _read_buckets(self, filename, max_bs, max_blocks): + buckets = {} + with open(filename) as f: + for line in f.readlines(): + bs, blocks = line.strip().split(':') + bs = min(int(bs), max_bs) + blocks = set(min(int(b), max_blocks) for b in blocks.split()) + buckets.setdefault(bs, set()) + buckets[bs].update(blocks) + return [(bs, list(sorted(buckets[bs]))) for bs in sorted(buckets.keys())] + + +class PolynomialBucketingScheme(BucketingScheme): + def __init__(self, max_bs, max_blocks): + min_blocks_per_seq = 2 + max_blocks_per_seq = 8 + max_block_steps = 16 + block_beta = 0.1 + block_rounding = 8 + min_bs = 4 + bs_div = 4 + bs_range = self._gen_bs_range(min_bs, max_bs, bs_div) + self.buckets = self._gen_buckets(bs_range, max_block_steps, min_blocks_per_seq, max_blocks_per_seq, max_blocks, block_beta, block_rounding) + logger.info('Decode buckets [polynomial]: {}'.format(self.list_buckets())) + + def _poly_fn(self, min_val, max_val, alpha, beta): + z = (max_val - min_val) * beta + a = (max_val - min_val - z) / (3 * alpha * alpha - 3 * alpha + 1) + b = -3 * a * alpha + c = 3 * a * alpha * alpha + z + d = min_val + + def f(x): + return a * pow(x, 3) + b * pow(x, 2) + c * x + d + return f + + def _round_up(self, x, k, max): + return min((math.ceil(x / k) * k), round(max)) + + def _apply(self, fn, max_steps, rounding, unique=True): + if max_steps > 1: + uniform_range = [i / (max_steps-1) for i in range(max_steps)] + else: + uniform_range = [1.0] + values = [fn(x) for x in uniform_range] + max_val = values[-1] + values = [self._round_up(y, rounding, max_val) for y in values] + if unique: + values = list(sorted(set(values))) + return values + + def _gen_bs_range(self, min_bs, max_bs, bs_div): + bs_range = [] + while True: + bs_range.insert(0, max(max_bs, min_bs)) + max_bs = math.ceil(max_bs / bs_div) + if bs_range[0] <= min_bs: + break + return bs_range + + def _gen_buckets(self, bs_range, max_block_steps, min_blocks_per_seq, max_blocks_per_seq, max_blocks, block_beta, block_rounding): + max_bs = bs_range[-1] + buckets = [] + for bs in bs_range: + block_min = bs * min_blocks_per_seq + block_max = min(max_blocks, bs * max_blocks_per_seq) + gen_fn = self._poly_fn(block_min, block_max, bs / max_bs, block_beta) + block_steps = round(max_block_steps * bs / max_bs) + block_range = self._apply(gen_fn, block_steps, block_rounding) + buckets.append((bs, block_range)) + return buckets + + +class DummyBucketingScheme: + def list_buckets(self): + return [] + + def find_bucket(self, real_batch_size, max_seq_len_or_blocks): + return (real_batch_size, max_seq_len_or_blocks) + + +class LegacyPromptBucketingScheme: + def __init__(self, bucketing_ctx): + self.bucketing_ctx = bucketing_ctx + self.bucketing_ctx.generate_prompt_buckets() + + def list_buckets(self): + return self.bucketing_ctx.prompt_buckets + + def find_bucket(self, real_batch_size, max_seq_len): + bs = self.bucketing_ctx.get_padded_batch_size(real_batch_size, True) + seq_len = self.bucketing_ctx.get_padded_prompt_seq_len(max_seq_len) + return (bs, seq_len) + +class LegacyDecodeBucketingScheme: + def __init__(self, bucketing_ctx, max_blocks): + self.bucketing_ctx = bucketing_ctx + self.bucketing_ctx.num_hpu_blocks = max_blocks + self.bucketing_ctx.generate_decode_buckets(max_blocks) + + def list_buckets(self): + return self.bucketing_ctx.decode_buckets + + def find_bucket(self, real_batch_size, num_blocks): + bs = self.bucketing_ctx.get_padded_batch_size(real_batch_size, False) + num_blocks = self.bucketing_ctx.get_padded_decode_num_blocks(num_blocks) + return (bs, num_blocks) + + +class BucketingContext: + def __init__(self, prompt_scheme, decode_scheme): + self.prompt_scheme = prompt_scheme + self.decode_scheme = decode_scheme + + def list_prompt_buckets(self): + return self.prompt_scheme.list_buckets() + + def list_decode_buckets(self): + return self.decode_scheme.list_buckets() + + def find_prompt_bucket(self, real_batch_size, max_seq_len): + return self.prompt_scheme.find_bucket(real_batch_size, max_seq_len) + + def find_decode_bucket(self, real_batch_size, num_blocks): + return self.decode_scheme.find_bucket(real_batch_size, num_blocks) + + class HpuModelAdapter: def __init__(self, model, vllm_config, layer_names): @@ -730,14 +882,8 @@ def __init__( self._mem_margin: Optional[int] = None self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', 'false').lower() == 'true' - if self.enable_merged_prefill: - self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( - self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, - self.max_num_batched_tokens) - else: - self.bucketing_ctx = HPUBucketingContext( - self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, - self.max_num_batched_tokens) + self.bucketing_ctx = BucketingContext( + DummyBucketingScheme(), DummyBucketingScheme()) self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() @@ -753,6 +899,27 @@ def __init__( # For delayed sampling self.cached_step_inputs: List[ModelInputForHPUWithSamplingMetadata] = [] + def _init_buckets(self, max_blocks): + if self.enable_merged_prefill: + legacy_ctx = HPUBucketingContextWithMergedPrefill( + self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, + self.max_num_batched_tokens) + else: + legacy_ctx = HPUBucketingContext( + self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, + self.max_num_batched_tokens) + + if VLLM_BUCKETING_SCHEME == 'polynomial': + self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx) + self.bucketing_ctx.decode_scheme = PolynomialBucketingScheme(self.max_num_seqs, max_blocks) + elif VLLM_BUCKETING_SCHEME.startswith('file:'): + _, filename = VLLM_BUCKETING_SCHEME.split(':') + self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx) + self.bucketing_ctx.decode_scheme = FileBucketingScheme(filename, self.max_num_seqs, max_blocks) + else: + self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx) + self.bucketing_ctx.decode_scheme = LegacyDecodeBucketingScheme(legacy_ctx, max_blocks) + def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold # for comprehensive description of gc generations. @@ -874,21 +1041,40 @@ def load_model(self) -> None: msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) - def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): + def _get_bucketing_input(self, seq_group_metadata_list, is_prompt): real_batch_size = len(seq_group_metadata_list) - batch_size_padded = self.bucketing_ctx.get_padded_batch_size( - real_batch_size, is_prompt) + seq_data = [list(sgm.seq_data.values())[0] for sgm in seq_group_metadata_list] + if is_prompt: + if self.enable_merged_prefill: + max_seq_len = sum(len(sd.prompt_token_ids) + len(sd.output_token_ids) for sd in seq_data) + else: + max_seq_len = max(len(sd.prompt_token_ids) + len(sd.output_token_ids) for sd in seq_data) + return self.bucketing_ctx.find_prompt_bucket, real_batch_size, max_seq_len + block_tables = [list(sgm.block_tables.values())[0] for sgm in seq_group_metadata_list] + if self.use_contiguous_pa: + max_block_id = max(max(bt) if bt else 0 for bt in block_tables) + 1 + return self.bucketing_ctx.find_decode_bucket, real_batch_size, max_block_id + else: + total_blocks = sum(len(bt) for bt in block_tables) + return self.bucketing_ctx.find_decode_bucket, real_batch_size, total_blocks + + def find_bucket(self, seq_group_metadata_list, is_prompt): + bucketing_fn, real_batch_size, blocks_or_seq_len = self._get_bucketing_input(seq_group_metadata_list, is_prompt) + return bucketing_fn(real_batch_size, blocks_or_seq_len) + + def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): + bucketing_fn, real_batch_size, blocks_or_seq_len = self._get_bucketing_input(seq_group_metadata_list, is_prompt) + batch_size_padded, _ = bucketing_fn(real_batch_size, blocks_or_seq_len) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() - if batch_size_padding > 0: has_greedy_samples = any( seq_group_metadata.sampling_params.temperature == 0.0 for seq_group_metadata in seq_group_metadata_list) temperature = 0.0 if has_greedy_samples else 1.0 dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - -1, 0, is_prompt, temperature=temperature) + -1, 0, is_prompt, _PAD_BLOCK_ID, temperature=temperature) seq_group_metadata_list.extend(dummy_seq_group_metadata for _ in range(batch_size_padding)) return seq_group_metadata_list, real_batch_size, batch_size_padded @@ -1040,9 +1226,7 @@ def _prepare_prompt( real_num_seqs = len(query_lens) assert max_query_len > 0 - max_prompt_len = max( - self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), - self.block_size) + _, max_prompt_len = self.find_bucket(seq_group_metadata_list, True) lora_ids: List[int] = [] for seq_group_metadata, context_len in zip(seq_group_metadata_list, @@ -1282,17 +1466,13 @@ def _prepare_prompt_merged( input_positions_merged = list( itertools.chain.from_iterable(input_positions)) input_positions_merged = [input_positions_merged] - total_seq_lens = [sum(seq_lens)] total_query_lens = [sum(query_lens)] max_query_len = max(total_query_lens) real_num_seqs = len(total_query_lens) assert max_query_len > 0 - - merged_prompt_len = max( - self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), - self.block_size) + _, merged_prompt_len = self.find_bucket(seq_group_metadata_list, True) # get cumsum of seq_lens repeated_idx = list(itertools.accumulate(seq_lens)) repeated_idx = [[idx - 1] * seq_len for idx, seq_len in zip(repeated_idx, seq_lens)] @@ -1490,19 +1670,20 @@ def _prepare_decode( assert len(block_list) == len(block_usage) padding_fn = None + real_batch_size = len(seq_group_metadata_list) + padded_bs, block_bucket_size = self.find_bucket(seq_group_metadata_list, False) if self.use_contiguous_pa: - block_bucket_size = max(max(block_list) + 1, len(block_list)) - block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - block_bucket_size) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): indices[bid] = i padding_fn = lambda tensor, pad_value: gather_list( tensor, indices, pad_value) + _, _, max_block_id = self._get_bucketing_input(seq_group_metadata_list, False) + self.profiler.record_counter(self.profiler.get_timestamp_us(), { + "cache_max_block_id": max_block_id, + }) else: - block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - len(block_list)) padding_fn = lambda tensor, pad_value: pad_list( tensor, block_bucket_size, pad_value) @@ -1776,6 +1957,7 @@ def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt, + block_id, lora_request=None, temperature=0): sampling_params = SamplingParams(temperature=temperature) @@ -1788,7 +1970,7 @@ def create_dummy_seq_group_metadata(self, else: input_len = seq_len - 1 output_len = 1 - block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + block_tables = {group_id: [block_id] * num_blocks} prompt_token_ids = [0] * input_len output_token_ids = [1] * output_len prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 @@ -1807,8 +1989,8 @@ def profile_run(self) -> None: bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, [kv_caches]) - _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - max_batch_size = min(self.max_num_seqs, + max_seq_len = self.max_model_len + max_batch_size = min(self.max_num_prefill_seqs, self.max_num_batched_tokens // max_seq_len) origin_enable_merged_prefill = self.enable_merged_prefill @@ -1863,6 +2045,7 @@ def warmup_scenario(self, i, seq_len, is_prompt, + _PAD_BLOCK_ID, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, temperature=temperature) for i in range(batch_size) @@ -1876,6 +2059,7 @@ def warmup_scenario(self, i, b * self.block_size - 1, is_prompt, + seq_len - 1, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, temperature=temperature) for i, b in enumerate(blocks) @@ -2032,6 +2216,8 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + max_blocks = kv_caches[0][0].size(0) + self._init_buckets(max_blocks) if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' @@ -2041,9 +2227,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, True) raise AssertionError("Finished profiling") - max_blocks = kv_caches[0][0].size(0) - self.bucketing_ctx.generate_prompt_buckets() - self.bucketing_ctx.generate_decode_buckets(max_blocks) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION', 'true').lower() == 'true' else 1 @@ -2079,9 +2262,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, + self.warmup_all_buckets(self.bucketing_ctx.list_prompt_buckets(), True, kv_caches) - self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, + self.warmup_all_buckets(self.bucketing_ctx.list_decode_buckets(), False, kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): @@ -2098,6 +2281,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: graph_free_mem) decode_available_memory = (graph_free_mem - prompt_available_memory) + if int(os.environ.get('VLLM_NUM_HPU_BLOCKS', '0')) > 0: + prompt_available_memory = graph_free_mem + decode_available_memory = graph_free_mem msg = ( f"Using {format_bytes(graph_free_mem)}" f"/{format_bytes(free_mem)} " @@ -2112,11 +2298,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.bucketing_ctx.prompt_buckets, + prompt_strategy, self.bucketing_ctx.list_prompt_buckets(), True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.bucketing_ctx.decode_buckets, + decode_strategy, self.bucketing_ctx.list_decode_buckets(), False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets @@ -2144,9 +2330,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: mem_post_decode, decode_batch_seq) self.log_graph_warmup_summary( - self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) + self.bucketing_ctx.list_prompt_buckets(), True, mem_post_prompt) self.log_graph_warmup_summary( - self.bucketing_ctx.decode_buckets, False, mem_post_decode) + self.bucketing_ctx.list_decode_buckets(), False, mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() @@ -2330,6 +2516,7 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): phase = 'prompt' if is_prompt else 'decode' logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", phase, batch_size, seq_len) + assert not VLLM_STRICT_BUCKETS def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], is_prompt: bool): @@ -2797,4 +2984,4 @@ def _patch_prev_output(self): # This is a hack. Assigning output_token_ids triggers # a cache recomputation and we only need to update the last token seq_data.output_token_ids_array[-1] = real_out - seq_data._cached_all_token_ids[-1] = real_out \ No newline at end of file + seq_data._cached_all_token_ids[-1] = real_out diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 969971f2e25cd..59d8ea8faa37e 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -118,6 +118,8 @@ def __init__( on_trace_ready=fn(torch_profiler_trace_dir, use_gzip=True)) else: self.profiler = None + self.total_num_copied_blocks = 0 + self.total_block_copies = 0 def full_trace_handler(self, dir_name, use_gzip=False): @@ -315,28 +317,36 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # At this point we should've allocated the maximum workspace for all # recipes we will use the extra memory for graphs/blocks free_hpu_memory = torch.hpu.mem_get_info()[0] - + manual_num_blocks = int(os.environ.get('VLLM_NUM_HPU_BLOCKS', '0')) cache_block_size = self.get_cache_block_size_bytes() - graph_reserved_mem = (float( - os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1')) - if not self.model_config.enforce_eager else 0) - graph_headroom = 1 - graph_reserved_mem - available_hpu_memory = free_hpu_memory * \ - self.cache_config.gpu_memory_utilization - hpu_memory_margin = free_hpu_memory * ( - 1 - self.cache_config.gpu_memory_utilization) - self.model_runner.mem_margin = hpu_memory_margin - cache_size_bytes = available_hpu_memory * graph_headroom - graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) - msg = ( - f"Free device memory: {format_bytes(free_hpu_memory)}, " - f"{format_bytes(available_hpu_memory)} usable " - f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," - f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " - f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " - f"{format_bytes(cache_size_bytes)} reserved for KV cache") - logger.info(msg) - num_hpu_blocks = int(cache_size_bytes // cache_block_size) + if manual_num_blocks > 0: + num_hpu_blocks = manual_num_blocks + msg = ( + f"Free device memory: {format_bytes(free_hpu_memory)}. " + f"Using {num_hpu_blocks} HPU blocks. No OOM protection!") + logger.info(msg) + self.model_runner.mem_margin = 0.0 + else: + graph_reserved_mem = (float( + os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1')) + if not self.model_config.enforce_eager else 0) + graph_headroom = 1 - graph_reserved_mem + available_hpu_memory = free_hpu_memory * \ + self.cache_config.gpu_memory_utilization + hpu_memory_margin = free_hpu_memory * ( + 1 - self.cache_config.gpu_memory_utilization) + self.model_runner.mem_margin = hpu_memory_margin + cache_size_bytes = available_hpu_memory * graph_headroom + graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) + msg = ( + f"Free device memory: {format_bytes(free_hpu_memory)}, " + f"{format_bytes(available_hpu_memory)} usable " + f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," + f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " + f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " + f"{format_bytes(cache_size_bytes)} reserved for KV cache") + logger.info(msg) + num_hpu_blocks = int(cache_size_bytes // cache_block_size) num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) num_hpu_blocks = max(num_hpu_blocks, 0) @@ -444,6 +454,13 @@ def execute_worker(self, worker_input: WorkerInput) -> None: worker_input.blocks_to_swap_out) if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): + num_blocks = worker_input.blocks_to_copy.numel() / 2 + self.total_num_copied_blocks += num_blocks + self.total_block_copies += 1 + self.model_runner.profiler.record_counter(self.model_runner.profiler.get_timestamp_us(), { + "total_num_copied_blocks": self.total_num_copied_blocks, + "total_block_copies": self.total_block_copies, + }) self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: