forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathhpu_model_runner.py
executable file
·2494 lines (2245 loc) · 111 KB
/
hpu_model_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import collections
import contextlib
import dataclasses
import functools
import gc
import itertools
import math
import os
import time
from array import array
from enum import IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union)
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import torch.nn as nn
import vllm_hpu_extension.environment as environment
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes)
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import (bind_kv_cache, is_fake_hpu, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
_TYPE_CACHE = {}
# These values are assumed to be zero in several places.
# Use caution when updating them!
_PAD_SLOT_ID = 0
_PAD_BLOCK_ID = 0
LORA_WARMUP_RANK = 8
def subtuple(obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields))
return _TYPE_CACHE[typename](**values)
def align_workers(value, op):
group = get_world_group().cpu_group
world_size = torch.distributed.get_world_size()
if world_size <= 1:
return value
value_t = torch.tensor(value, device='cpu')
torch.distributed.all_reduce(value_t, op=op, group=group)
return value_t.item()
def setup_profiler():
schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.HPU
]
profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler
def round_up(value: int, k: int) -> int:
return (value + k - 1) // k * k
def pad_list(input, k, v):
input_len = len(input)
target_len = round_up(input_len, k)
padding = target_len - input_len
return input + [v] * padding
def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]
def flatten(in_list):
return list(itertools.chain(*in_list))
def get_target_layer_suffix_list(model_type) -> list[str]:
# This sets the suffix for the hidden layer name, which is controlled by
# VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is
# applicable for most language models such as LLaMA, Qwen, and BART. If the
# model's decoder layer name differs from the default, it will need to
# be specified here.
decoder_layer_table = {
"gpt_bigcode": "BigCodeBlock",
}
return [
decoder_layer_table.get(model_type, "DecoderLayer"), "EncoderLayer"
]
def modify_model_layers(module: torch.nn.Module,
suffix_list: list[str],
n=1,
counter=None):
"""Currently add mark_step at the end of specified layers.
"""
def forward_hook(module, args, output):
htorch.core.mark_step()
return output
if counter is None:
counter = [0]
for child_name, child_module in module.named_children():
if any(
child_module.__class__.__name__.endswith(layer)
for layer in suffix_list):
counter[0] += 1
if counter[0] % n == 0:
child_module.register_forward_hook(forward_hook)
else:
modify_model_layers(child_module, suffix_list, n, counter)
def get_path_to_rope(model: torch.nn.Module):
"""Dynamically get the path to the RotaryEmbedding layer in the model.
This function will recursively search through the module hierarchy to find
a RotaryEmbedding layer and return the full path to that layer as a list
of names.
If no such layer is found, it returns None.
"""
def find_rope_layer(parent, path):
# Base case: check if this parent is None
if parent is None:
return None
# Check if the current layer is a RotaryEmbedding
if hasattr(parent, 'named_children'):
for child_name, child_module in parent.named_children():
# If the current child is of type RotaryEmbedding,
# return the full path
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
return path + [child_name]
# Otherwise, recurse into this child to check its children
result = find_rope_layer(child_module, path + [child_name])
if result is not None:
return result
return None
# Start the search from the top level model
path_to_rope = find_rope_layer(model, [])
# Return the result if found, otherwise None
return path_to_rope
class HpuModelAdapter:
def __init__(self, model, vllm_config, layer_names):
self.model = model
self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags()
self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE',
'false').lower() in ['1', 'true']
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
self.layer_names = layer_names
enforce_eager = vllm_config.model_config.enforce_eager
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
if os.getenv('VLLM_REGIONAL_COMPILATION',
'true').lower() == 'true':
self.regional_compilation_layers_list = [
RMSNorm, VocabParallelEmbedding
]
self._regional_compilation(self.model)
else:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)
def _regional_compilation(self,
module,
parent_module=None,
module_name=None):
if isinstance(module, torch.nn.ModuleList):
for children_name, children_module in module.named_children():
self._compile_region(module, children_name, children_module)
elif any(
isinstance(module, layer)
for layer in self.regional_compilation_layers_list):
self._compile_region(parent_module, module_name, module)
else:
for children_name, children_module in module.named_children():
self._regional_compilation(children_module, module,
children_name)
def _compile_region(self, model, name, module):
module = torch.compile(module, backend='hpu_backend', dynamic=False)
setattr(model, name, module)
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
if (attn_metadata is None or self.prefill_use_fusedsdpa
or not attn_metadata.is_prompt):
return attn_metadata
prefill_metadata = attn_metadata
seq_lens_t = prefill_metadata.seq_lens_tensor
context_lens_t = prefill_metadata.context_lens_tensor
query_lens_t = seq_lens_t - context_lens_t
block_list = attn_metadata.block_list
max_context_len = (block_list.size(-1) //
batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
past_mask = torch.arange(0,
max_context_len,
dtype=torch.int32,
device=device)
past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(
context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand(
batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1))
len_mask = (torch.arange(0, seq_len, device=device,
dtype=torch.int32).view(1, seq_len).ge(
query_lens_t.unsqueeze(-1)).view(
batch_size, 1, 1, seq_len))
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
device=device,
dtype=torch.bool),
diagonal=1)
mask = causal_mask.logical_or(len_mask)
mask = torch.concat((past_mask, mask), dim=-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
return attn_metadata
def _set_block_mapping(self, metadata, batch_size, device, dtype):
mask = torch.arange(0,
self.block_size,
device=device,
dtype=torch.int32).unsqueeze(0)
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
if not is_fake_hpu():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
block_mapping = torch.nn.functional.relu(block_groups)
block_mapping = torch.nn.functional.one_hot(block_mapping,
num_classes=batch_size)
oob_values = block_groups.lt(0)
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
block_groups.masked_fill_(oob_values, batch_size)
metadata = metadata._replace(block_groups=block_groups)
block_mapping = block_mapping.to(dtype)
metadata = metadata._replace(block_mapping=block_mapping,
attn_bias=attn_bias)
return metadata
def _set_block_scales(self, metadata, device):
block_mapping = metadata.block_mapping
ones = torch.ones((block_mapping.size(0), ),
device=device,
dtype=block_mapping.dtype)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
metadata = metadata._replace(block_scales=block_scales)
return metadata
def _set_indices_and_offsets(self, metadata, block_size, is_prompt):
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
else:
offsets = torch.fmod(slot_mapping, block_size)
metadata = metadata._replace(block_offsets=offsets,
block_indices=indices)
return metadata
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
if attn_metadata.is_prompt:
attn_metadata = self._set_attn_bias(attn_metadata, batch_size,
seq_len, device, dtype)
else:
attn_metadata = self._set_block_mapping(attn_metadata, batch_size,
device, dtype)
attn_metadata = self._set_block_scales(attn_metadata, device)
attn_metadata = self._set_indices_and_offsets(attn_metadata,
self.block_size,
attn_metadata.is_prompt)
return attn_metadata
def _prepare_cos_sin(self, positions):
"""Navigate through the model using the provided path and call
the prepare_cos_sin method on the 'RotaryEmbedding' layer."""
current_module = self.model # Start from the top level of the model
for layer in self.layer_names:
if layer.isdigit(): # Check if the layer is an index
layer = int(layer)
# Check if the current layer is a name in a module
if isinstance(
layer,
str) and not isinstance(layer, int): # Name-based access
current_module = getattr(current_module, layer)
elif isinstance(layer,
int): # Indexed-based access (like ModuleList)
current_module = list(current_module._modules.values())[layer]
# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(
positions, recompute_cos_sin=self.recompute_cos_sin)
else:
raise AttributeError(
"The module at the end of the path does not have \
a 'prepare_cos_sin' method.")
def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
selected_token_indices = kwargs.pop('selected_token_indices')
if 'warmup_mode' in kwargs:
kwargs.pop('warmup_mode')
virtual_engine = 0
if 'virtual_engine' in kwargs:
virtual_engine = kwargs.pop('virtual_engine')
input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
if self.layer_names is not None:
self._prepare_cos_sin(kwargs['positions'])
with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
virtual_engine):
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0,
selected_token_indices)
return hidden_states
def compute_logits(self, *args, **kwargs):
return self.model.compute_logits(*args, **kwargs)
def sample(self, *args, **kwargs):
return self.model.sample(*args, **kwargs)
def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)
# sampler property will be used by spec_decode_worker
# don't rename
@property
def sampler(self):
return self.model.sampler
class PreparePromptMetadata(NamedTuple):
input_tokens: torch.Tensor
input_positions: List[List[int]]
attn_metadata: Optional[AttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
lora_index_mapping: List[List[int]]
lora_prompt_mapping: List[List[int]]
lora_requests: Set[LoRARequest]
multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]]
slot_mapping: List[List[int]]
lora_ids: List[int]
@classmethod
def empty(cls):
return PreparePromptMetadata(input_tokens=[],
input_positions=[],
attn_metadata=None,
seq_lens=[],
query_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_kwargs=None,
slot_mapping=[],
lora_ids=[])
class PrepareDecodeMetadata(NamedTuple):
input_tokens: torch.Tensor
input_positions: List[List[int]]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[List[int]]
lora_prompt_mapping: List[List[int]]
lora_requests: Set[LoRARequest]
slot_mapping: List[List[int]]
lora_ids: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
lora_ids=[])
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU")
@dataclasses.dataclass(frozen=True)
class ModelInputForHPU(ModelRunnerInputBase):
"""
This base class contains metadata needed for the base model forward pass
but not metadata for possible additional steps, e.g., sampling. Model
runners that run additional steps should subclass this method to add
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
real_batch_size: Optional[int] = None
batch_size_padded: Optional[int] = None
virtual_engine: int = 0
lora_ids: Optional[List[int]] = None
async_callback: Optional[Callable] = None
is_first_multi_step: bool = True
is_last_step: bool = True
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"real_batch_size": self.real_batch_size,
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"lora_ids": self.lora_ids,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForHPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForHPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclasses.dataclass(frozen=True)
class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
# Used for speculative decoding. We do not broadcast it because it is only
# used by the driver worker.
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"lora_ids": self.lora_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForHPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
# FIXME(kzawora): this fails for whatever reason - why?
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
"""
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForHPU]
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
environment.set_model_config(self.model_config)
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.sliding_window = (self.model_config.get_sliding_window()
if self.model_config is not None else None)
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
if is_fake_hpu():
self.device_config.device = torch.device('cpu')
self.device_config.device_type = 'cpu'
self.load_config.device = None
self.device = self.device_config.device
self.enforce_eager = self.model_config.enforce_eager
self.max_num_seqs = self.scheduler_config.max_num_seqs
self.max_num_prefill_seqs = self.scheduler_config.max_num_prefill_seqs \
if self.scheduler_config.max_num_prefill_seqs is not None \
else self.max_num_seqs
self.max_model_len = self.scheduler_config.max_model_len
self.max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
self.block_size = self.cache_config.block_size
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = self.cache_config.cache_dtype
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
) if needs_attn_backend else None
# Multi-modal data support
self.input_registry = input_registry
self.mm_registry = mm_registry
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization
self.lora_manager: LRUCacheWorkerLoRAManager = None
self.model: torch.nn.Module = None
self.inc_initialized_successfully = False
# Profiler stats
self.profiler = HabanaHighLevelProfiler()
self.profiler_counter_helper = HabanaProfilerCounterHelper()
self.seen_configs: set = set()
self._mem_margin: Optional[int] = None
self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
self.max_num_prefill_seqs,
self.block_size,
self.max_num_batched_tokens)
self.graphed_buckets: Set[Any] = set()
self._set_gc_threshold()
self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if vllm_config.speculative_config is not None \
and self.use_contiguous_pa:
raise ValueError(
"Speculative decoding is not supported with "
"contiguous PA, please set VLLM_CONTIGUOUS_PA=false")
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
# We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority)
# to set particular generation threshold or use simpler
# VLLM_GC_THR_MULTIPLIER to multiply default values.
default_gc_thrs = list(gc.get_threshold())
requested_gc_thrs = [0] * len(default_gc_thrs)
for i in range(len(default_gc_thrs)):
requested_gc_thrs[i] = int(
os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i]))
if requested_gc_thrs == default_gc_thrs:
gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER',
2))
requested_gc_thrs = [
t * gc_thr_multiplier for t in default_gc_thrs
]
gc.set_threshold(*requested_gc_thrs)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP',
'false').lower() == 'true'
def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc' or \
self.model_config.quantization == 'fp8':
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(vllm_config=self.vllm_config)
msg = ("Pre-loading model weights on "
f"{next(self.model.parameters()).device} "
f"took {m_getmodel.get_summary_string()}")
logger.info(msg)
if self.lora_config:
assert hasattr(self.model, "supported_lora_modules"
) and self.model.supported_lora_modules, (
"Model does not support LoRA")
assert hasattr(self.model, "embedding_modules"
), "Model does not have embedding_modules"
assert hasattr(
self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"
assert not self.lora_config.bias_enabled, \
"Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet."
if supports_multimodal(self.model):
logger.warning(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.model_config.quantization == 'inc':
logger.info("Preparing model with INC..")
with HabanaMemoryProfiler() as m_inc:
from neural_compressor.torch.quantization import (
FP8Config, convert, prepare)
config = FP8Config.from_json_file(
os.getenv("QUANT_CONFIG", ""))
if config.measure:
self.model = prepare(self.model, config)
elif config.quantize:
self.model = convert(self.model, config)
htcore.hpu_initialize(self.model,
mark_only_scales_as_const=True)
self.inc_initialized_successfully = True
logger.info("Preparing model with INC took %s",
m_inc.get_summary_string())
elif not is_fake_hpu():
self.model = self.model.to("hpu")
htcore.mark_step()
hidden_layer_markstep_interval = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
model_config = getattr(self.model, "config", None)
modify_model_layers(
self.model,
get_target_layer_suffix_list(
model_config.
model_type if model_config is not None else None),
hidden_layer_markstep_interval)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()
with HabanaMemoryProfiler() as m_wrap:
self.model = self._maybe_wrap_in_hpu_graph(
self.model,
vllm_config=self.vllm_config,
layer_names=path_to_rope)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)
self.model_memory_usage = m.consumed_device_memory
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)
def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
has_greedy_samples = any(
seq_group_metadata.sampling_params.temperature == 0.0
for seq_group_metadata in seq_group_metadata_list)
temperature = 0.0 if has_greedy_samples else 1.0
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt, temperature=temperature)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
*args, **kwargs)
def get_model(self) -> nn.Module:
if isinstance(self.model, HpuModelAdapter):
return self.model.model
return self.model
def _use_graphs(self, batch_size, seq_len, is_prompt):
if self.enforce_eager:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, is_prompt) in self.graphed_buckets
def _is_valid_bucket(self, bucket):
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> PreparePromptMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
lora_index_mapping: List[List[int]] = []
lora_prompt_mapping: List[List[int]] = []
lora_requests: Set[LoRARequest] = set()
seq_lens: List[int] = []
context_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = collections.defaultdict(
MultiModalPlaceholderMap)
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id]
context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
# NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert context_len == 0
# actual prompt lens
context_lens.append(context_len)
query_lens.append(seq_len - context_len)
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(context_len, seq_len)))
if seq_group_metadata.multi_modal_data:
positions = input_positions[0]
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata,
range(positions[0], positions[0] + len(positions)))
if self.mm_registry.has_processor(self.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, seq_len - self.sliding_window)
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
max_query_len = max(query_lens)
real_num_seqs = len(query_lens)
assert max_query_len > 0
max_prompt_len = max(
self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)),
self.block_size)
lora_ids: List[int] = []
for seq_group_metadata, context_len in zip(seq_group_metadata_list,
context_lens):
lora_id = seq_group_metadata.lora_int_id
lora_ids.append(lora_id)
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * max_prompt_len
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
# prefix caching
max_num_block = max(len(bt) for bt in prefix_block_tables)
prefix_block_list = list(
itertools.chain.from_iterable(
bt if len(bt) == max_num_block else bt +
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
for bt in prefix_block_tables))
# TODO: pad to proper len
pad_len = len(prefix_block_list)
prefix_block_list = pad_list(prefix_block_list, pad_len,
_PAD_BLOCK_ID)