-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathtest_self_and_cross_attn.py
1699 lines (1433 loc) · 61.6 KB
/
test_self_and_cross_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import itertools
import random
from typing import List, Optional, Union
import pytest
import torch
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.utils import (
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
from vllm.attention.backends.xformers import XFormersBackend
from vllm.logger import init_logger
from vllm.utils import is_hip, make_tensor_with_pad
logger = init_logger(__name__)
# If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256]
#
# TODO: FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d0
# 37782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16]
BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16]
BACKEND_NAMES = ["xformers"]
CUDA_DEVICE = "cuda:0"
MAX_Q_SEQ_LENS = [128]
MAX_K_SEQ_LENS = [128]
def build_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \
-> torch.Tensor:
'''
Create a q_max_seq_len x kv_max_seq_len causal mask
Arguments:
* q_max_seq_len: query max seq len
* kv_max_seq_len: key/value max seq len
Returns:
* 2D tensor, q_max_seq_len x kv_max_seq_len
'''
# Create a matrix where entry (i, j) is True if i >= j
mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
# Replace True with float('-inf') and False with 0
mask = mask.masked_fill(mask == 1,
float('-inf')).masked_fill(mask == 0, 0.0)
return mask
def ref_masked_attention(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[List] = None,
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
'''
"Golden" masked attention reference. Supports two types of masking:
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
padding elements
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
causal
Arguments:
* query: batch_size x q_padded_seq_len x num_heads x head_size
* key: batch_size x kv_padded_seq_len x num_heads x head_size
* value: batch_size x kv_padded_seq_len x num_heads x head_size
* scale: Attention scale factor
* Custom mask: custom attention mask; good place to inject a causal
attention mask
* q_seq_lens: list of unpadded query seq_lens for each batch index
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
Returns:
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
'''
batch_size = query.shape[0]
assert (len(q_seq_lens) == batch_size)
assert (len(kv_seq_lens) == batch_size)
attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
# Basic attention mask, derived from seq lens
if (q_seq_lens is not None) or (kv_seq_lens is not None):
attn_mask = torch.zeros_like(attn_weights)
if q_seq_lens is not None:
for bdx, plen in enumerate(q_seq_lens):
attn_mask[bdx, :, plen:, :] = -torch.inf
if kv_seq_lens is not None:
for bdx, plen in enumerate(kv_seq_lens):
attn_mask[bdx, :, :, plen:] = -torch.inf
attn_weights = attn_weights + attn_mask.float()
# Custom attention mask
if custom_mask is not None:
attn_weights = attn_weights + custom_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
return out
def make_qkv(batch_size: int,
max_q_seq_len: int,
max_kv_seq_len: int,
num_heads: int,
head_size: int,
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False,
device: Union[torch.device, str] = CUDA_DEVICE) -> tuple:
'''
Construct QKV test tensors for self- and cross-attention.
Generates three query/key/value triplets:
* "Baseline" query/key/value (for input to reference attention function)
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
input to prefill kernel)
* "Decode" query/key/value (only the last sequence offset from baseline,
for use as input to decode kernel)
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
seqlens
Arguments:
* batch_size
* max_q_seq_len: max query seq len
* max_kv_seq_len: max key/value seq len
* num_heads
* head_size
* is_encoder_decoder_attn: if True, query seqlen may differ from
key/value seqlen (as is often the case for cross-attention);
o/w, query/key/value seqlens match at each batch index
(max_kv_seq_len is unused)
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
* device: CPU or CUDA device
Returns:
* query: "baseline" query; batch_size x max_q_seq_len x num_heads x
head_size
* key: "baseline" key; batch_size x max_kv_seq_len x num_heads x
head_size
* value: "baseline" value; batch_size x max_kv_seq_len x num_heads x
head_size
* prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size
* prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size
* prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size
* decode_query: batch_size x 1 x num_heads x head_size
* decode_key: batch_size x 1 x num_heads x head_size
* decode_value: batch_size x 1 x num_heads x head_size
* q_seq_lens: "baseline" query seqlen list
* kv_seq_lens: "baseline" key/value seqlen list
* actual_max_q_seq_len: actual "baseline" query max seq len (may be <=
max_q_seq_len due to randomness)
* actual_max_kv_seq_len: actual "baseline" key/value max seq len (may
be <= max_kv_seq_len due to randomness)
* prefill_q_seq_lens: "prefill" query seqlen list
* prefill_kv_seq_lens: "prefill" key/value seqlen list
* decode_q_seq_lens: "decode" query seqlen list (all ones)
* decode_kv_seq_lens: "decode" key/value seqlen list
'''
if force_max_len:
q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
else:
q_seq_lens = [
random.randint(2, max_q_seq_len) for _ in range(batch_size)
]
kv_seq_lens = None
if attn_type != AttentionType.ENCODER_DECODER:
# K,V seq lens match Q for self-attention
kv_seq_lens = q_seq_lens
else:
# K,V seq lens are distinct from Q seq lens & random
if force_max_len:
kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)]
else:
kv_seq_lens = [
random.randint(2, max_kv_seq_len) for _ in range(batch_size)
]
actual_max_q_seq_len = max(q_seq_lens)
actual_max_kv_seq_len = max(kv_seq_lens)
query = torch.rand(
(batch_size, max_q_seq_len, num_heads * head_size)).to(device)
key = torch.rand(
(batch_size, max_kv_seq_len, num_heads * head_size)).to(device)
value = torch.rand(
(batch_size, max_kv_seq_len, num_heads * head_size)).to(device)
prefill_query = torch.zeros(
(batch_size, max_q_seq_len, num_heads * head_size)).to(device)
prefill_key = torch.zeros(
(batch_size, max_kv_seq_len, num_heads * head_size)).to(device)
prefill_value = torch.zeros(
(batch_size, max_kv_seq_len, num_heads * head_size)).to(device)
decode_query = torch.zeros(
(batch_size, 1, num_heads * head_size)).to(device)
decode_key = torch.zeros((batch_size, 1, num_heads * head_size)).to(device)
decode_value = torch.zeros(
(batch_size, 1, num_heads * head_size)).to(device)
for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
kv_seq_lens)):
query[bdx, q_seq_len:, :] = 0
key[bdx, kv_seq_len:, :] = 0
value[bdx, kv_seq_len:, :] = 0
prefill_query[bdx, 0:(q_seq_len - 1), :] = query[bdx,
0:(q_seq_len - 1), :]
prefill_key[bdx, 0:(kv_seq_len - 1), :] = key[bdx,
0:(kv_seq_len - 1), :]
prefill_value[bdx,
0:(kv_seq_len - 1), :] = value[bdx,
0:(kv_seq_len - 1), :]
decode_query[bdx, :, :] = query[bdx, (q_seq_len - 1):q_seq_len, :]
decode_key[bdx, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :]
decode_value[bdx, :, :] = value[bdx, (kv_seq_len - 1):kv_seq_len, :]
prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
decode_q_seq_lens = [1 for _ in q_seq_lens]
decode_kv_seq_lens = [1 for _ in kv_seq_lens]
query = query.view(batch_size, query.shape[1], num_heads, head_size)
key = key.view(batch_size, key.shape[1], num_heads, head_size)
value = value.view(batch_size, value.shape[1], num_heads, head_size)
prefill_query = prefill_query.view(batch_size, prefill_query.shape[1],
num_heads, head_size)
prefill_key = prefill_key.view(batch_size, prefill_key.shape[1], num_heads,
head_size)
prefill_value = prefill_value.view(batch_size, prefill_value.shape[1],
num_heads, head_size)
decode_query = decode_query.view(batch_size, decode_query.shape[1],
num_heads, head_size)
decode_key = decode_key.view(batch_size, decode_key.shape[1], num_heads,
head_size)
decode_value = decode_value.view(batch_size, decode_value.shape[1],
num_heads, head_size)
return query, \
key, \
value, \
prefill_query, \
prefill_key, \
prefill_value, \
decode_query, \
decode_key, \
decode_value, \
q_seq_lens, \
kv_seq_lens, \
actual_max_q_seq_len, \
actual_max_kv_seq_len, \
prefill_q_seq_lens, \
prefill_kv_seq_lens, \
decode_q_seq_lens, \
decode_kv_seq_lens
def pack_tensor(unpacked_tensor: torch.Tensor,
seq_lens: List[int],
device: Union[torch.device, str] = CUDA_DEVICE) -> tuple:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
number_of_tokens = sum(seq_lens)
Arguments:
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
* seq_lens: list of token counts for each seq
* device: CPU or CUDA device
Returns
* packed_tensor: number_of_tokens x num_heads x head_size
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
list(itertools.accumulate(seq_lens))
'''
num_tok = sum(seq_lens)
num_heads = unpacked_tensor.shape[-2]
head_size = unpacked_tensor.shape[-1]
start_loc_list = [0] + list(itertools.accumulate(seq_lens))
packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
packed_tensor[start_loc:(
start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
return packed_tensor, start_loc_list
def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
q_seq_lens: List[int], kv_seq_lens: List[int]) -> tuple:
'''
Individually pack each of Q, K and V, each with dimensions batch_size x
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
num_heads x head_size tensors.
For Q, number_of_tokens = sum(q_seq_lens).
For K and V, number_of_tokens = sum(kv_seq_lens)
Arguments:
* query: batch_size x padded_seq_len x num_heads x head_size
* key: batch_size x padded_seq_len x num_heads x head_size
* value: batch_size x padded_seq_len x num_heads x head_size
* q_seq_lens: list of token counts for each query
* kv_seq_lens: list of token counts for each key/value
Returns
* packed_query: number_of_tokens x num_heads x head_size
* packed_key: number_of_tokens x num_heads x head_size
* packed_value: number_of_tokens x num_heads x head_size
* q_start_loc_list: start idx of each query in packed_query
* kv_start_loc_list: start idx of each {key,value} in packed_{key,value}
'''
if query is None:
packed_query = None
q_start_loc_list = None
else:
packed_query, q_start_loc_list = pack_tensor(query, q_seq_lens)
packed_key, kv_start_loc_list = pack_tensor(key, kv_seq_lens)
packed_value, _ = pack_tensor(value, kv_seq_lens)
if packed_query is not None:
packed_query = packed_query.view(
-1, packed_query.shape[-1] * packed_query.shape[-2])
packed_key = packed_key.view(-1,
packed_key.shape[-1] * packed_key.shape[-2])
packed_value = packed_value.view(
-1, packed_value.shape[-1] * packed_value.shape[-2])
return packed_query, \
packed_key, \
packed_value, \
q_start_loc_list, \
kv_start_loc_list
def make_backend(backend_name: str) -> AttentionBackend:
'''
Construct the backend instance determined by the backend_name string
argument.
"xformers" -> construct xformers backend
TODO: flash attention backend
Returns:
* Backend instance
'''
if backend_name == "xformers":
return XFormersBackend()
raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")
def make_metadata_tensors(is_prompt: bool,
seq_lens: List[int],
context_lens: List[int],
device: Union[torch.device, str] = \
CUDA_DEVICE) -> tuple:
'''
Build scalar & tensor values required to build attention metadata structure.
Arguments:
* is_prompt: True -> Prefill, False -> Decode
* seq_lens: list of token-counts for each seq
* context_lens: list of context length values for each seq
* device: CPU or CUDA device
Returns:
* seq_lens_tensor: seq_lens list, as tensor
* context_lens_tensor: context_lens list, as tensor
* max_query_len: max(seq_lens) if is_seq, o/w 1
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* query_start_loc: start idx of each query
'''
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device)
context_lens_tensor = None if context_lens is None else torch.tensor(
context_lens, dtype=torch.int, device=device)
max_context_len = None if context_lens is None else max(context_lens)
max_seq_len = None if seq_lens is None else max(seq_lens)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
if is_prompt:
# Prefill: query_start_loc matches seq_start_loc
query_start_loc = copy.deepcopy(seq_start_loc)
max_query_len = max_seq_len
else:
# Decode: one new query input token per batch element, thus
# query_start_loc is the cumsum of [1,1,1,...]
query_start_loc = list(range(len(seq_start_loc)))
max_query_len = 1
return seq_lens_tensor, \
context_lens_tensor, \
max_query_len, \
max_context_len, \
max_seq_len, \
seq_start_loc, \
query_start_loc
def make_kv_cache(num_blocks: int,
num_heads: int,
head_size: int,
block_size: int,
device: Union[torch.device, str] = \
CUDA_DEVICE,
default_val: float=0.0) -> torch.Tensor:
'''
Create a fake KV cache.
Arguments:
* num_blocks: number of blocks in the KV cache
* num_heads: number of attention heads
* head_size: head dimension
* block_size: number of offsets within a block
* device: CPU or CUDA device
* default_val: initialization value for KV cache elements
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
'''
kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
'''
Compute the minimum number of blocks required to hold num_tokens tokens,
given block_size
'''
return (num_tokens + block_size) // block_size
def make_block_tables_slot_mapping(block_size: int,
seq_lens: List,
block_base_addr: int=0,
device: Union[torch.device, str] = \
CUDA_DEVICE) -> tuple:
'''
Construct fake block tables & slot mappings.
For a sequence with num_tokens tokens the minimum number
of required KV cache blocks is
num_blocks = (num_tokens + block_size) // block_size
Then the minimum KV cache size in blocks is
total_cache_blocks = sum(num_blocks for all seqs)
Then, the blocktable mapping counts downward from
block_base_addr + total_cache_blocks
to
block_base_addr
Arguments:
* block_size: number of offsets per block
* seq_lens: list of token-counts for each sequence
* block_base_addr: the block table base address
* device: CPU or CUDA device
Return:
* decode_block_tables_tensor: fake the state of the block tables during
decode
* decode_slot_mapping_tensor: fake the state of the slot mapping during
decode
* prefill_slot_mapping_tensor: fake the state of the slot mapping during
prefill
* prefill_block_tables_tensor: fake the state of the block tables during
prefill
* slot_mapping_tensor: union of prefill and decode slot mappings
* empty_slot_mapping_tensor: empty slot mapping (useful for decode phase
cross attention)
* max_block_idx: the highest block address within this block table
'''
# Provision minimum number of KV cache blocks
num_blocks_list = [
num_tokens_to_min_blocks(num_tokens, block_size)
for num_tokens in seq_lens
]
max_block_table_len = max(num_blocks_list)
block_table_pad_tokens = 10
block_tables = []
prefill_slot_mapping = []
decode_slot_mapping = []
slot_mapping = []
# Compute uppermost address of block table
total_cache_blocks = sum(num_blocks_list)
block_base_idx = block_base_addr + total_cache_blocks
max_block_idx = block_base_idx
for sdx, num_tokens in enumerate(seq_lens):
num_blocks = num_blocks_list[sdx]
block_table = list(
range(block_base_idx, block_base_idx - num_blocks, -1))
for idx in range(num_tokens - 1):
prefill_slot_mapping.append((idx % block_size) +
block_table[idx // block_size] *
block_size)
slot_mapping.append((idx % block_size) +
block_table[idx // block_size] * block_size)
idx = num_tokens - 1
decode_slot_mapping.append((idx % block_size) +
block_table[idx // block_size] * block_size)
slot_mapping.append((idx % block_size) +
block_table[idx // block_size] * block_size)
block_base_idx -= num_blocks
block_tables.append(block_table)
prefill_block_tables_tensor = torch.tensor([], device=CUDA_DEVICE)
decode_block_tables_tensor = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len + block_table_pad_tokens,
pad=0,
dtype=torch.int,
device=device,
)
prefill_slot_mapping_tensor = torch.tensor(prefill_slot_mapping,
dtype=torch.long,
device=device)
decode_slot_mapping_tensor = torch.tensor(decode_slot_mapping,
dtype=torch.long,
device=device)
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.long,
device=device)
empty_slot_mapping_tensor = torch.tensor([],
dtype=torch.long,
device=device)
return decode_block_tables_tensor, \
decode_slot_mapping_tensor, \
prefill_slot_mapping_tensor, \
prefill_block_tables_tensor, \
slot_mapping_tensor, \
empty_slot_mapping_tensor, \
max_block_idx
def make_test_metadata(
attn_backend: AttentionBackend,
is_prompt: bool,
seq_lens: List[int],
context_lens: List[int],
block_tables: torch.Tensor,
slot_mapping: torch.Tensor,
is_encoder_only_test: bool,
device: Union[torch.device, str] = CUDA_DEVICE,
cross_seq_lens: Optional[List[int]] = None,
cross_block_tables: Optional[torch.Tensor] = None,
cross_slot_mapping: Optional[List[int]] = None,
) -> AttentionMetadata:
'''
Construct fake attention metadata for a combined self-/cross-attention
scenario i.e. an encoder/decoder model.
is_encoder_only_test=True causes the default attention metadata attention
type to be AttentionType.ENCODER. False causes the default to
be AttentionType.DECODER.
Assumptions:
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
Arguments:
* attn_backend: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* context_lens: list of context lengths for each sequence
* block_tables: self-attention block tables
* slot_mapping: self-attention slot_mapping
* is_encoder_only_test: True if testing encoder; False if testing
decoder self-attention or encoder/decoder cross-attention.
* device: CPU or CUDA device
* cross_seq_lens: list of token counts for each encoder sequence, if any
exist
* cross_block_tables: cross-attention block tables, if required
* cross_slot_mapping: cross-attention slot mapping, if required
Return:
* AttentionMetadata structure supporting self- and cross-attention
'''
default_attn_type = AttentionType.ENCODER if is_encoder_only_test \
else AttentionType.DECODER
if is_prompt:
num_prefills = len(seq_lens)
num_prefill_tokens = sum(seq_lens)
num_decode_tokens = 0
seq_lens_tensor, \
context_lens_tensor, \
max_query_len, \
_, \
_, \
seq_start_loc, \
query_start_loc = make_metadata_tensors(is_prompt,
seq_lens,
context_lens,
device=device)
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max(seq_lens),
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
_attn_type=default_attn_type,
cross_seq_lens=cross_seq_lens,
cross_slot_mapping=cross_slot_mapping,
cross_block_tables=cross_block_tables)
else: # not is_prompt
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = len(seq_lens)
seq_lens_tensor, \
context_lens_tensor, \
max_query_len, \
_, \
_, \
seq_start_loc, \
query_start_loc = make_metadata_tensors(is_prompt,
seq_lens,
context_lens,
device=device)
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=max(seq_lens),
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
_attn_type=default_attn_type,
cross_seq_lens=cross_seq_lens,
cross_slot_mapping=cross_slot_mapping,
cross_block_tables=cross_block_tables)
def basic_setup(num_heads: int, head_size: int, num_blocks: int,
block_size: int, backend_name: str) -> tuple:
'''
Compute & build entities required for the self-/cross-attention test.
Arguments:
* num_heads: Number of attention heads
* head_size: Head dimension
* num_blocks: Number of KV cache blocks (no KV cache if None)
* block_size: Number of offsets within a KV cache block
(no KV cache if None)
* backend_name: selection of backend
Returns:
* scale: 1/sqrt(head_size)
* attn_backend: backend instance
* attn: Attention wrapper instance
* kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads *
head_size)
* None if num_blocks or block_size is None
'''
scale = float(1.0 / (head_size**0.5))
attn_backend = make_backend(backend_name)
attn = Attention(
num_heads,
head_size,
scale=scale,
)
if num_blocks is None or num_heads is None:
# Caller does not require a KV cache
return scale, attn_backend, attn, None
# Construct KV cache
kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size)
return scale, attn_backend, attn, kv_cache
def encoder_attn_setup(batch_size: int,
num_heads: int,
head_size: int,
block_size: int,
scale: float,
max_q_seq_len: int,
block_base_addr: int = 0) -> tuple:
'''
Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries.
The query/key/value tensors are passed to an ideal reference
self-attention implementation to generate an ideal output tensor.
This function also constructs the self-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr
Arguments:
* batch_size
* num_heads: Number of attention heads
* head_size: Head dimension
* block_size: Number of offsets per KV cache block
* scale: attention scale parameter
* max_q_seq_len: upper limit on query length for synthetic test vectors
* block_base_addr: self-attention block table base address
Returns:
* packed_query: number_of_tokens x num_heads x head_size
* packed_key: number_of_tokens x num_heads x head_size
* packed_value: number_of_tokens x num_heads x head_size
* packed_ideal_output: number_of_tokens x num_heads x head_size
* block_tables: fake self-attn decode-phase block table
* slot_mapping: fake self-attn decode-phase slot mapping
* q_seq_lens: list of query sequence lengths
'''
max_kv_seq_len = max_q_seq_len
query, \
key, \
value, \
_, \
_, \
_, \
_, \
_, \
_, \
q_seq_lens, \
kv_seq_lens, \
_, \
_, \
_, \
_, \
_, \
_ = make_qkv(batch_size,
max_q_seq_len,
max_kv_seq_len,
num_heads,
head_size,
attn_type=AttentionType.ENCODER)
# No causal attention mask
ideal_output = ref_masked_attention(query,
key,
value,
scale=scale,
q_seq_lens=q_seq_lens,
kv_seq_lens=kv_seq_lens)
packed_ideal_output, _ = pack_tensor(ideal_output, q_seq_lens)
block_tables, \
_, \
_, \
_, \
slot_mapping, \
_, \
_ = make_block_tables_slot_mapping(
block_size, q_seq_lens, block_base_addr=block_base_addr)
packed_query, \
packed_key, \
packed_value, _, _ = pack_qkv(
query, key, value, q_seq_lens,
kv_seq_lens)
return packed_query, \
packed_key, \
packed_value, \
packed_ideal_output, \
block_tables, \
slot_mapping, \
q_seq_lens
def decoder_attn_setup(batch_size: int,
num_heads: int,
head_size: int,
block_size: int,
scale: float,
max_q_seq_len: int,
block_base_addr: int = 0) -> tuple:
'''
Set up test vectors & data structures for self-attention test.
A triplet of synthetic query/key/value tensors are constructed ("baseline"
query/key/value). Given this is a self-attention test, the key & value
sequences will have the same length as the corresponding queries.
"Prefill" query/key/value tensors are derived by masking out the last value
in each baseline query/key/value. These tensors are used to test prefill &
populate KV cache for a subsequent decode test.
"Decode" query/key/value tensors are derived by extracting *only* the last
value from each baseline query/key/value (i.e. complement of the prefill
tensors.) These tensors are used to test decode, conditional on the kv cache
being populated during the prefill test.
The baseline query/key/value tensors are passed to an ideal reference
self-attention implementation to generate a "Baseline" ideal output tensor.
This tensor is split into the "Prefill" ideal output tensor (all but the
last element of each output sequence) and the "Decode" ideal output tensor
(*only* the last element of each output sequence); the "Prefill" and
"Decode" ideal output tensors can be used to validate the prefill and decode
test results, respectively.
This function also constructs the self-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr
Arguments:
* batch_size
* num_heads: Number of attention heads
* head_size: Head dimension
* block_size: Number of offsets per KV cache block
* scale: attention scale parameter
* max_q_seq_len: upper limit on query length for synthetic test vectors
* block_base_addr: self-attention block table base address
Returns:
* query: "baseline" query; batch_size x padded_seq_len x num_heads x
head_size
* prefill_packed_query: "prefill" query; number_of_tokens x num_heads x
head_size
* prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads
x head_size
* prefill_packed_value: self-attn "prefill" value; number_of_tokens x
num_heads x head_size
* prefill_packed_ideal_output: self-attn "prefill" ideal output;
number_of_tokens x num_heads x head_size
* prefill_q_seq_lens: list of token counts for each *prefill query* (one
less than baseline query)
* prefill_kv_seq_lens: list of token counts for each self-attn *prefill
key/value* (should match prefill_q_seq_lens)
* decode_packed_query: "decode" query; number_of_tokens x num_heads x
head_size
* decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x
head_size
* decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads
x head_size
* decode_packed_ideal_output: self-attn "decode" ideal output;
number_of_tokens x num_heads x head_size
* decode_q_seq_lens: list of token counts for each *decode query* (should
be 1)
* decode_kv_seq_lens: list of token counts for each self-attn *decode
key/value* (should match decode_q_seq_lens)
* q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x
head_size
* kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens
x num_heads x head_size
* decode_block_tables: fake self-attn decode-phase block table
* decode_slot_mapping: fake self-attn decode-phase slot mapping
* prefill_slot_mapping: fake self-attn prefill-phase slot mapping
* prefill_block_tables: fake self-attn prefill-phase block table
* max_block_idx: highest block address in the self-attention block-table
'''
max_kv_seq_len = max_q_seq_len
query, \
key, \
value, \
prefill_query, \
prefill_key, \
prefill_value, \
decode_query, \
decode_key, \
decode_value, \
q_seq_lens, \
kv_seq_lens, \
_, \
_, \
prefill_q_seq_lens, \
prefill_kv_seq_lens, \
decode_q_seq_lens, \
decode_kv_seq_lens = make_qkv(batch_size,
max_q_seq_len,
max_kv_seq_len,
num_heads,
head_size,
attn_type=AttentionType.DECODER)
causal_mask = build_causal_mask(max_q_seq_len,
max_kv_seq_len).to(CUDA_DEVICE)
ideal_output = ref_masked_attention(query,
key,
value,
scale=scale,
custom_mask=causal_mask,
q_seq_lens=q_seq_lens,
kv_seq_lens=kv_seq_lens)
prefill_ideal_output = torch.zeros_like(ideal_output)
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
bdx, :prefill_q_seq_len]
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
prefill_q_seq_len + 1)]
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
prefill_q_seq_lens)
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
[1 for _ in range(batch_size)])
decode_block_tables, \
decode_slot_mapping, \
prefill_slot_mapping, \
prefill_block_tables, \