-
Notifications
You must be signed in to change notification settings - Fork 10.6k
/
Copy pathgguf.py
749 lines (599 loc) · 27.9 KB
/
gguf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
#!/usr/bin/env python3
import shutil
import sys
import struct
import tempfile
import numpy as np
from enum import IntEnum, auto
from typing import Any, BinaryIO, IO, Dict, List, Optional, Tuple
#
# constants
#
GGUF_MAGIC = 0x46554747
GGUF_VERSION = 2
GGUF_DEFAULT_ALIGNMENT = 32
# general
KEY_GENERAL_ARCHITECTURE = "general.architecture"
KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version"
KEY_GENERAL_ALIGNMENT = "general.alignment"
KEY_GENERAL_NAME = "general.name"
KEY_GENERAL_AUTHOR = "general.author"
KEY_GENERAL_URL = "general.url"
KEY_GENERAL_DESCRIPTION = "general.description"
KEY_GENERAL_LICENSE = "general.license"
KEY_GENERAL_SOURCE_URL = "general.source.url"
KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
KEY_GENERAL_FILE_TYPE = "general.file_type"
# LLM
KEY_CONTEXT_LENGTH = "{arch}.context_length"
KEY_EMBEDDING_LENGTH = "{arch}.embedding_length"
KEY_BLOCK_COUNT = "{arch}.block_count"
KEY_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
KEY_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
# attention
KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count"
KEY_ATTENTION_HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
KEY_ATTENTION_MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
KEY_ATTENTION_CLAMP_KQV = "{arch}.attention.clamp_kqv"
KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
# RoPE
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear"
# tokenization
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens"
KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type"
KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores"
KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges"
KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id"
KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id"
KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id"
KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id"
KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id"
KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json"
KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
#
# recommended mapping of model tensor names for storage in gguf
#
class MODEL_ARCH(IntEnum):
LLAMA : int = auto()
FALCON : int = auto()
GPT2 : int = auto()
GPTJ : int = auto()
GPTNEOX: int = auto()
MPT : int = auto()
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD : int = auto()
POS_EMBD : int = auto()
OUTPUT : int = auto()
OUTPUT_NORM : int = auto()
ROPE_FREQS : int = auto()
ATTN_Q : int = auto()
ATTN_K : int = auto()
ATTN_V : int = auto()
ATTN_QKV : int = auto()
ATTN_OUT : int = auto()
ATTN_NORM : int = auto()
ATTN_NORM_2 : int = auto()
ATTN_ROT_EMBD: int = auto()
FFN_GATE : int = auto()
FFN_DOWN : int = auto()
FFN_UP : int = auto()
FFN_NORM : int = auto()
MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
}
MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
MODEL_ARCH.LLAMA: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.GPTNEOX: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.FALCON: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.GPT2: {
# TODO
},
# TODO
}
# tensors that will not be serialized
MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}
# TODO: the following helper functions should be removed
# instead, get_tensor_name_map should return tuples of (name, MODEL_TENSOR)
# however, my Python is very bad, and I couldn't figure out how to do this, hence these functions
# REMOVE
def should_skip_tensor_TMP(arch: MODEL_ARCH, n_blocks: int, name: str) -> bool:
for skip in MODEL_TENSOR_SKIP.get(arch, []):
for i in range(n_blocks):
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i):
return True
return False
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict:
tensor_map = {}
# Token embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None)
tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox
tensor_map["transformer.wte"] = mapped_to # gpt2 mpt
tensor_map["transformer.word_embeddings"] = mapped_to # falcon
tensor_map["model.embed_tokens"] = mapped_to # llama-hf
tensor_map["tok_embeddings"] = mapped_to # llama-pth
# Position embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None)
tensor_map["transformer.wpe"] = mapped_to # gpt2
# Output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None)
tensor_map["embed_out"] = mapped_to # gptneox
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth
# Output norm
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None)
tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox
tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon
tensor_map["transformer.norm_f"] = mapped_to # mpt
tensor_map["model.norm"] = mapped_to # llama-hf
tensor_map["norm"] = mapped_to # llama-pth
# Rope frequencies
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None)
tensor_map["rope.freqs"] = mapped_to # llama-pth
# Attention and feed-forward blocks
for i in range(0, n_blocks):
# Attention norm
# TODO: is there are simpler way to write these 2 lines in Python?
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to else None
tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b
tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b
tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth
# Attention norm 2
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b
# Attention query-key-value
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon
# Attention query
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth
# Attention key
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth
# Attention value
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth
# Attention output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth
# Rotary embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth
# Feed-forward norm
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
# Feed-forward up
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth
# Feed-forward gate
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth
# Feed-forward down
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth
return tensor_map
class TokenType(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
#
# implementation
#
class GGMLQuantizationType(IntEnum):
F32 = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q5_0 = 6
Q5_1 = 7
Q8_0 = 8
Q8_1 = 9
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
class GGUFValueType(IntEnum):
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
FLOAT32 = 6
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12
@staticmethod
def get_type(val):
if isinstance(val, str) or isinstance(val, bytes) or isinstance(val, bytearray):
return GGUFValueType.STRING
elif isinstance(val, list):
return GGUFValueType.ARRAY
elif isinstance(val, float):
return GGUFValueType.FLOAT32
elif isinstance(val, bool):
return GGUFValueType.BOOL
elif isinstance(val, int):
return GGUFValueType.INT32
# TODO: need help with 64-bit types in Python
else:
print("Unknown type: "+str(type(val)))
sys.exit()
class GGUFWriter:
fout: BinaryIO
arch: str
offset_tensor = 0
data_alignment = GGUF_DEFAULT_ALIGNMENT
kv_data = b""
kv_data_count = 0
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
tensors: List[Tuple[np.ndarray, int]]
def __init__(self, path: str, arch: str, use_temp_file = True):
self.fout = open(path, "wb")
self.arch = arch
self.add_architecture()
self.use_temp_file = use_temp_file
self.tensors = []
def write_header_to_file(self):
self.fout.write(struct.pack("<I", GGUF_MAGIC))
self.fout.write(struct.pack("<I", GGUF_VERSION))
self.fout.write(struct.pack("<Q", self.ti_data_count))
self.fout.write(struct.pack("<Q", self.kv_data_count))
self.flush()
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
def write_kv_data_to_file(self):
self.fout.write(self.kv_data)
self.flush()
def write_ti_data_to_file(self):
self.fout.write(self.ti_data)
self.flush()
def add_key(self, key: str):
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
def add_uint8(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.UINT8)
def add_int8(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.INT8)
def add_uint16(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.UINT16)
def add_int16(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.INT16)
def add_uint32(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.UINT32)
def add_int32(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.INT32)
def add_float32(self, key: str, val: float):
self.add_key(key)
self.add_val(val, GGUFValueType.FLOAT32)
def add_uint64(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.UINT64)
def add_int64(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.INT64)
def add_float64(self, key: str, val: float):
self.add_key(key)
self.add_val(val, GGUFValueType.FLOAT64)
def add_bool(self, key: str, val: bool):
self.add_key(key)
self.add_val(val, GGUFValueType.BOOL)
def add_string(self, key: str, val: str):
if len(val) == 0:
return
self.add_key(key)
self.add_val(val, GGUFValueType.STRING)
def add_array(self, key: str, val: list):
if not isinstance(val, list):
raise ValueError("Value must be a list for array type")
self.add_key(key)
self.add_val(val, GGUFValueType.ARRAY)
def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
if vtype is None:
vtype = GGUFValueType.get_type(val)
if add_vtype:
self.kv_data += struct.pack("<I", vtype)
self.kv_data_count += 1
if vtype == GGUFValueType.UINT8:
self.kv_data += struct.pack("<B", val)
elif vtype == GGUFValueType.INT8:
self.kv_data += struct.pack("<b", val)
elif vtype == GGUFValueType.UINT16:
self.kv_data += struct.pack("<H", val)
elif vtype == GGUFValueType.INT16:
self.kv_data += struct.pack("<h", val)
elif vtype == GGUFValueType.UINT32:
self.kv_data += struct.pack("<I", val)
elif vtype == GGUFValueType.INT32:
self.kv_data += struct.pack("<i", val)
elif vtype == GGUFValueType.FLOAT32:
self.kv_data += struct.pack("<f", val)
elif vtype == GGUFValueType.UINT64:
self.kv_data += struct.pack("<Q", val)
elif vtype == GGUFValueType.INT64:
self.kv_data += struct.pack("<q", val)
elif vtype == GGUFValueType.FLOAT64:
self.kv_data += struct.pack("<d", val)
elif vtype == GGUFValueType.BOOL:
self.kv_data += struct.pack("?", val)
elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val
self.kv_data += struct.pack("<Q", len(encoded_val))
self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY:
ltype = set([GGUFValueType.get_type(item) for item in val])
assert len(ltype) == 1, "All items in a GGUF array should be of the same type"
self.kv_data += struct.pack("<I", list(ltype)[0])
self.kv_data += struct.pack("<Q", len(val))
for item in val:
self.add_val(item, add_vtype=False)
else:
raise ValueError("Invalid GGUF metadata value type")
@staticmethod
def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
encoded_name = name.encode("utf8")
self.ti_data += struct.pack("<Q", len(encoded_name))
self.ti_data += encoded_name
n_dims = len(tensor_shape)
self.ti_data += struct.pack("<I", n_dims)
for i in range(n_dims):
self.ti_data += struct.pack("<Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None:
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
else:
dtype = raw_dtype
self.ti_data += struct.pack("<I", dtype)
self.ti_data += struct.pack("<Q", self.offset_tensor)
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
fp.seek(0)
self.temp_file = fp
self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if self.temp_file is None:
self.tensors.append((tensor, pad))
return
tensor.tofile(self.temp_file)
if pad != 0:
self.temp_file.write(bytes([0] * pad))
def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None):
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray):
self.write_padding(self.fout, self.fout.tell())
tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
def write_tensors_to_file(self):
self.write_ti_data_to_file()
self.write_padding(self.fout, self.fout.tell())
if not self.use_temp_file:
for (currtensor, currpad) in self.tensors:
currtensor.tofile(self.fout)
if currpad != 0:
self.fout.write(bytes([0] * currpad))
return
self.temp_file.seek(0)
shutil.copyfileobj(self.temp_file, self.fout)
self.flush()
self.temp_file.close()
def flush(self):
self.fout.flush()
def close(self):
self.fout.close()
def add_architecture(self):
self.add_string(KEY_GENERAL_ARCHITECTURE, self.arch)
def add_author(self, author: str):
self.add_string(KEY_GENERAL_AUTHOR, author)
def add_tensor_data_layout(self, layout: str):
self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_url(self, url: str):
self.add_string(KEY_GENERAL_URL, url)
def add_description(self, description: str):
self.add_string(KEY_GENERAL_DESCRIPTION, description)
def add_source_url(self, url: str):
self.add_string(KEY_GENERAL_SOURCE_URL, url)
def add_source_hf_repo(self, repo: str):
self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)
def add_file_type(self, ftype: int):
self.add_uint32(KEY_GENERAL_FILE_TYPE, ftype)
def add_name(self, name: str):
self.add_string(KEY_GENERAL_NAME, name)
def add_quantization_version(self, quantization_version: GGMLQuantizationType):
self.add_uint32(
KEY_GENERAL_QUANTIZATION_VERSION, quantization_version)
def add_custom_alignment(self, alignment: int):
self.data_alignment = alignment
self.add_uint32(KEY_GENERAL_ALIGNMENT, alignment)
def add_context_length(self, length: int):
self.add_uint32(
KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
def add_embedding_length(self, length: int):
self.add_uint32(
KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
def add_block_count(self, length: int):
self.add_uint32(
KEY_BLOCK_COUNT.format(arch=self.arch), length)
def add_feed_forward_length(self, length: int):
self.add_uint32(
KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
def add_parallel_residual(self, use: bool):
self.add_bool(
KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
def add_head_count(self, count: int):
self.add_uint32(
KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)
def add_head_count_kv(self, count: int):
self.add_uint32(
KEY_ATTENTION_HEAD_COUNT_KV.format(arch=self.arch), count)
def add_max_alibi_bias(self, bias: float):
self.add_float32(
KEY_ATTENTION_MAX_ALIBI_BIAS.format(arch=self.arch), bias)
def add_clamp_kqv(self, value: float):
self.add_float32(
KEY_ATTENTION_CLAMP_KQV.format(arch=self.arch), value)
def add_layer_norm_eps(self, value: float):
self.add_float32(
KEY_ATTENTION_LAYERNORM_EPS.format(arch=self.arch), value)
def add_layer_norm_rms_eps(self, value: float):
self.add_float32(
KEY_ATTENTION_LAYERNORM_RMS_EPS.format(arch=self.arch), value)
def add_rope_dimension_count(self, count: int):
self.add_uint32(
KEY_ROPE_DIMENSION_COUNT.format(arch=self.arch), count)
def add_rope_freq_base(self, value: float):
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_rope_scale_linear(self, value: float):
self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model)
def add_token_list(self, tokens: List):
self.add_array(KEY_TOKENIZER_LIST, tokens)
def add_token_merges(self, merges: List):
self.add_array(KEY_TOKENIZER_MERGES, merges)
def add_token_types(self, types: List[int]):
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
def add_token_scores(self, scores: List[float]):
self.add_array(KEY_TOKENIZER_SCORES, scores)
def add_bos_token_id(self, id: int):
self.add_uint32(KEY_TOKENIZER_BOS_ID, id)
def add_eos_token_id(self, id: int):
self.add_uint32(KEY_TOKENIZER_EOS_ID, id)
def add_unk_token_id(self, id: int):
self.add_uint32(KEY_TOKENIZER_UNK_ID, id)
def add_sep_token_id(self, id: int):
self.add_uint32(KEY_TOKENIZER_SEP_ID, id)
def add_pad_token_id(self, id: int):
self.add_uint32(KEY_TOKENIZER_PAD_ID, id)
# Example usage:
if __name__ == "__main__":
# Example usage with a file
gguf_writer = GGUFWriter("example.gguf", "llama")
gguf_writer.add_architecture()
gguf_writer.add_block_count(12)
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
gguf_writer.add_custom_alignment(64)
tensor1 = np.ones((32,), dtype=np.float32) * 100.0
tensor2 = np.ones((64,), dtype=np.float32) * 101.0
tensor3 = np.ones((96,), dtype=np.float32) * 102.0
gguf_writer.add_tensor("tensor1", tensor1)
gguf_writer.add_tensor("tensor2", tensor2)
gguf_writer.add_tensor("tensor3", tensor3)
gguf_writer.write_header_to_file()
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file()
gguf_writer.close()