forked from huggingface/optimum-habana
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodeling_llama.py
executable file
·1565 lines (1375 loc) · 68.4 KB
/
modeling_llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.distributed.distributed_c10d import ProcessGroup
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
logger,
)
from transformers.utils import is_torchdynamo_compiling
from .... import distributed, parallel_state
from ....distributed.strategy import DistributedStrategy, NoOpStrategy
from ....distributed.tensorparallel import (
reduce_from_tensor_model_parallel_region,
)
from ....distributed.tp import TPModule
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ..modeling_all_models import Matmul, apply_customized_rope_module
from .configuration_llama import LlamaConfig
try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa
has_fused_rope = True
except ImportError:
has_fused_rope = False
print("Not using HPU fused kernel for apply_rotary_pos_emb")
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm
has_fused_rms_norm = True
except ImportError:
has_fused_rms_norm = False
print("Not using HPU fused kernel for RMSNorm")
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None
import habana_frameworks.torch.core as htcore
def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- override RMSNorm with Habana fused RMSNorm
"""
if hidden_states.device.type == "hpu" and has_fused_rms_norm:
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
if hidden_states.dtype != self.weight.dtype:
orig_dtype = hidden_states.dtype
hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon)
return hidden_states.to(orig_dtype)
else:
hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon)
return hidden_states
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class GaudiLlamaRotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[LlamaConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)
def _dynamic_frequency_update(self, seq_len, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
# seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(seq_len, device=x.device)
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
if self.attention_scaling == 1.0:
return (
self._cos_cached[:seq_len].to(dtype=x.dtype),
self._sin_cached[:seq_len].to(dtype=x.dtype),
)
else:
return (
self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
)
class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding):
def __init__(self, *args, **kwargs):
logger.warning_once(
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
)
kwargs["rope_type"] = "linear"
super().__init__(*args, **kwargs)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)
class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding):
def __init__(self, *args, **kwargs):
logger.warning_once(
"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
"__init__)."
)
kwargs["rope_type"] = "dynamic"
super().__init__(*args, **kwargs)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)
class GaudiLlamaMLP(LlamaMLP):
def __init__(self, config):
super(LlamaMLP, self).__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def pre_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
output = sum(down_proj)
else:
input = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
output = self.down_proj(input)
return output
def mlp_all_reduce(self, x):
if hasattr(self.down_proj, "all_reduce"):
self.down_proj.all_reduce(x)
def post_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
return x
if hasattr(self.down_proj, "post_all_reduce"):
return self.down_proj.post_all_reduce(x)
return x
class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule):
def __init__(
self,
config,
group: Optional[ProcessGroup] = None,
):
assert torch.distributed.is_initialized()
rank, world_size = distributed.rank_and_world(group)
hidden_dim = int(config.hidden_grow_factor * config.hidden_size)
assert hidden_dim % world_size == 0, "Hidden dim must be divisible by world size"
self.config = copy.deepcopy(config)
self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size)
GaudiLlamaMLP.__init__(self, self.config)
self.setup_tp(rank, world_size)
def colwise_param_names(self) -> List[str]:
return ["up_proj", "gate_proj"]
def rowwise_param_names(self) -> List[str]:
return ["down_proj"]
@staticmethod
def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP":
config = copy.deepcopy(glu.config)
config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size
tp_glu = TPGaudiLlamaMLP(config=config, group=group)
return tp_glu
def pre_mlp_forward(self, x):
out_par = GaudiLlamaMLP.pre_mlp_forward(self, x)
return reduce_from_tensor_model_parallel_region(out_par)
def gaudi_llama_repeat_kv(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
n_rep: int,
):
"""
Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them.
- Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion.
The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
"""
batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask
new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
key_states = key_states.reshape(new_kv_shape)
value_states = value_states.reshape(new_kv_shape)
batch, _, q_len, head_dim = query_states.shape
new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
query_states = query_states.reshape(new_q_shape)
if attention_mask is not None:
# Add groups dim and set to 1
attention_mask = attention_mask.unsqueeze(1)
return query_states, key_states, value_states, attention_mask
# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA
self.scale = scale
self.attention_dropout = attention_dropout
self.enable_recompute = enable_recompute
self.flash_attention_fp8 = flash_attention_fp8
def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
return self._hpu_kernel_fsdpa.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)
class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1
def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)
@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)
def get_shape(self):
if self.cache is None:
return None
return self.cache.shape
def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)
def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return fused_scaled_dot_product_attention_distributed
else:
return fused_scaled_dot_product_attention
class GaudiLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
self.dim1 = self.num_heads * self.head_dim
self.dim2 = config.num_key_value_heads * self.head_dim
self.qkv_proj = torch.nn.Linear(
self.hidden_size,
self.dim1 + 2 * self.dim2,
bias=config.attention_bias,
)
self.q_proj = None
self.k_proj = None
self.v_proj = None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.fused_scaled_dot_product_attention = (
ModuleFusedSDPA(
FusedSDPA,
scale=self.norm_factor,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
)
if FusedSDPA
else None
)
# https://github.com/microsoft/DeepSpeed/issues/4359
# for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices
self.fused_scaled_dot_product_attention_distributed = None
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from deepspeed.sequence.layer import DistributedAttention
self.fused_scaled_dot_product_attention_distributed = DistributedAttention(
self.fused_scaled_dot_product_attention, parallel_state.get_sequence_parallel_group(), 1, 2
)
def get_k_proj_weight(self):
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
if hasattr(self.k_proj, "qweight"):
return self.k_proj.qweight
return self.k_proj.weight
def get_k_proj_weight_dtype(self):
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight.
Scales tensor gets the weight dtype."""
if hasattr(self.k_proj, "qweight"):
return self.k_proj.scales.dtype
return self.k_proj.weight.dtype
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.get_k_proj_weight().device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
def update_sincos_cache(self, seq_len):
# Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
# This helps in avoiding creation of these caches during actual model forward pass and
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len)
def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
tensor.copy_(updated)
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
if self.k_cache.cache is None:
return (None, None)
head_dim = self.k_cache.cache.size(-1)
seq_length = self.k_cache.cache.size(-2)
self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim)
self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
return (self.k_cache.cache.shape, self.v_cache.cache.shape)
def pre_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- optimize KV cache
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
- add new arg num_virtual_tokens
"""
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.get_k_proj_weight().split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
if hasattr(self.config, "fused_qkv") and self.config.fused_qkv:
qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(
qkv_states, [self.dim1, self.dim2, self.dim2], dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# TODO: update when auto mp params is enabled in DeepSpeed (cf. https://github.com/HabanaAI/DeepSpeed/blob/94309c7b5dfc1a69858f5c9f25737b2f81a332a5/deepspeed/module_inject/replace_module.py#L440)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if token_idx is None:
if hasattr(past_key_value, "get_usable_length"):
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value[0].shape[-2]
else:
if reuse_cache and not isinstance(past_key_value[0], torch.Tensor):
kv_seq_len = past_key_value[0][-2]
else:
if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]:
kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len
else:
kv_seq_len = past_key_value[0].shape[-2]
# TODO: the following section cause torch.compile performance issue with graph recompilation
# as we are not using position_embeddings, disable it for now
# if position_embeddings is None:
# logger.warning_once(
# "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
# "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
# "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
# "removed and `position_embeddings` will be mandatory."
# )
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# else:
# cos, sin = position_embeddings
seq_len = kv_seq_len
if parallel_state.sequence_parallel_is_initialized():
seq_len = kv_seq_len * parallel_state.get_sequence_parallel_world_size()
cos, sin = self.rotary_emb(value_states, seq_len=seq_len)
# If sequence parallel in enabled, position_ids should be based on which part of the sequence is present in the rank
# As we divide the inputs based on ranks, position_ids are generated to suit that part of the sequence
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_rank() > 0:
position_ids = torch.arange(
kv_seq_len * parallel_state.get_sequence_parallel_rank(),
kv_seq_len * (parallel_state.get_sequence_parallel_rank() + 1),
dtype=torch.long,
device=query_states.device,
)
position_ids = position_ids.unsqueeze(0)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids)
if use_cache:
# reuse k, v, self_attention
if reuse_cache:
if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor):
# prefix tuning case. attach past_key_value to generate first token.
key_states = torch.cat((past_key_value[0], key_states), -2)
value_states = torch.cat((past_key_value[1], value_states), -2)
key_states = self.k_cache(key_states, 2, token_idx)
value_states = self.v_cache(value_states, 2, token_idx)
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if past_key_value is None:
past_key = torch.zeros(
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
)
past_value = torch.zeros(
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
)
# Return list instead of tuple
past_key_value = [past_key, past_value]
if (
token_idx is not None
and num_virtual_tokens is not None
and num_virtual_tokens == past_key_value[0].shape[-2]
):
# prefix tuning case. attach past_key_value to generate first token.
key_states = torch.cat((past_key_value[0], key_states), -2)
value_states = torch.cat((past_key_value[1], value_states), -2)
past_key_value = (key_states, value_states)
else:
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
if token_idx is None:
past_key_value = (key_states, value_states)
if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]
else:
past_key_value = None
fused_scaled_dot_product_attention = GaudiDistributedAttention(
self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed
)
if use_flash_attention and FusedSDPA is not None:
if q_len == 1:
# next token
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
"None",
False,
None,
"None",
)
else:
# first token
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
None,
"None",
)
else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = self.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1:
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
past_key_value = (past_key_value[0].shape, past_key_value[1].shape)
return attn_output, attn_weights, past_key_value
def attention_all_reduce(self, attn_output):
if hasattr(self.o_proj, "all_reduce"):
self.o_proj.all_reduce(attn_output)
def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
return self.o_proj.post_all_reduce(attn_output)
return attn_output
class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule):
def __init__(
self,
config: LlamaConfig,
layer_idx: Optional[int] = None,
group: Optional[ProcessGroup] = None,
):
super().__init__(config, layer_idx)
assert torch.distributed.is_initialized()
rank, world_size = distributed.rank_and_world(group)
assert config.num_attention_heads % world_size == 0, "The number of heads must be divisible by world size"
self.config = copy.deepcopy(config)
self.pre_tp_kvheads = config.num_key_value_heads
GaudiLlamaAttention.__init__(self, self.config, layer_idx)
self.config.num_attention_heads = self.config.num_attention_heads // world_size
self.config.num_key_value_heads = (
(self.config.num_key_value_heads // world_size)
if self.config.num_key_value_heads > 1
else self.config.num_key_value_heads
)
self.head_dim = config.hidden_size // config.num_attention_heads
self.hidden_size = self.config.hidden_size // world_size
self.num_heads = self.config.num_attention_heads
self.q_proj = torch.nn.Linear(
config.hidden_size, self.config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = torch.nn.Linear(
config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = torch.nn.Linear(
config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = torch.nn.Linear(
self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.setup_tp(rank, world_size)
def colwise_param_names(self) -> List[str]:
colwise_weights = ["q_proj"]
if self.pre_tp_kvheads != 1:
colwise_weights.append("k_proj")
colwise_weights.append("v_proj")
return colwise_weights
def rowwise_param_names(self) -> List[str]:
return ["o_proj"]
@staticmethod
def import_module(mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup) -> "TPGaudiLlamaAttention":
tp_mha = TPGaudiLlamaAttention(config=mha.config, layer_idx=layer_idx, group=group)
return tp_mha
def pre_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
**kwargs,
)
hidden_states = reduce_from_tensor_model_parallel_region(hidden_states)
return hidden_states, attn_weights, present_key_value
class GaudiLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
super(LlamaDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
self.self_attn = GaudiLlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = GaudiLlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.self_attn.reorder_kv_cache(beam_idx)
def update_sincos_cache(self, seq_len):
self.self_attn.update_sincos_cache(seq_len)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
"""
residual = hidden_states
hidden_states, self_attn_weights, present_key_value = self.pre_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
)
self.self_attn.attention_all_reduce(hidden_states)
hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual)
self.mlp.mlp_all_reduce(hidden_states)
hidden_states = self.post_mlp(hidden_states, residual)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def pre_attn(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,