-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathllama.py
2418 lines (2167 loc) · 93.8 KB
/
llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import os
import sys
import uuid
import time
import json
import ctypes
import typing
import random
import fnmatch
import warnings
import contextlib
import multiprocessing
from typing import (
Any,
List,
Literal,
Optional,
Union,
Generator,
Sequence,
Iterator,
Deque,
Callable,
Dict,
)
from collections import deque
from pathlib import Path
from .llama_types import *
from .llama_grammar import LlamaGrammar
from .llama_cache import (
BaseLlamaCache,
LlamaCache, # type: ignore
LlamaDiskCache, # type: ignore
LlamaRAMCache, # type: ignore
)
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format
from llama_cpp.llama_speculative import LlamaDraftModel
import numpy as np
import numpy.typing as npt
import llama_cpp._internals as internals
from ._logger import set_verbose
from ._utils import suppress_stdout_stderr
class Llama:
"""High-level Python wrapper for a llama.cpp model."""
__backend_initialized = False
def __init__(
self,
model_path: str,
*,
# Model Params
n_gpu_layers: int = 0,
split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER,
main_gpu: int = 0,
tensor_split: Optional[List[float]] = None,
rpc_servers: Optional[str] = None,
vocab_only: bool = False,
use_mmap: bool = True,
use_mlock: bool = False,
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
# Context Params
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
n_ctx: int = 512,
n_batch: int = 512,
n_ubatch: int = 512,
n_threads: Optional[int] = None,
n_threads_batch: Optional[int] = None,
rope_scaling_type: Optional[
int
] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
yarn_ext_factor: float = -1.0,
yarn_attn_factor: float = 1.0,
yarn_beta_fast: float = 32.0,
yarn_beta_slow: float = 1.0,
yarn_orig_ctx: int = 0,
logits_all: bool = False,
embedding: bool = False,
offload_kqv: bool = True,
flash_attn: bool = False,
# Sampling Params
no_perf: bool = False,
last_n_tokens_size: int = 64,
# LoRA Params
lora_base: Optional[str] = None,
lora_scale: float = 1.0,
lora_path: Optional[str] = None,
# Backend Params
numa: Union[bool, int] = False,
# Chat Format Params
chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None,
# Tokenizer Override
tokenizer: Optional[BaseLlamaTokenizer] = None,
# KV cache quantization
type_k: Optional[int] = None,
type_v: Optional[int] = None,
# Misc
spm_infill: bool = False,
verbose: bool = True,
# Extra Params
**kwargs, # type: ignore
):
"""Load a llama.cpp model from `model_path`.
Examples:
Basic usage
>>> import llama_cpp
>>> model = llama_cpp.Llama(
... model_path="path/to/model",
... )
>>> print(model("The quick brown fox jumps ", stop=["."])["choices"][0]["text"])
the lazy dog
Loading a chat model
>>> import llama_cpp
>>> model = llama_cpp.Llama(
... model_path="path/to/model",
... chat_format="llama-2",
... )
>>> print(model.create_chat_completion(
... messages=[{
... "role": "user",
... "content": "what is the meaning of life?"
... }]
... ))
Args:
model_path: Path to the model.
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_MODE_LAYER: ignored
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
rpc_servers: Comma separated list of RPC servers to use for offloading
vocab_only: Only load the vocabulary no weights.
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
kv_overrides: Key-value overrides for the model.
seed: RNG seed, -1 for random
n_ctx: Text context, 0 = from model
n_batch: Prompt processing maximum batch size
n_ubatch: Physical batch size
n_threads: Number of threads to use for generation
n_threads_batch: Number of threads to use for batch processing
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
pooling_type: Pooling type, from `enum llama_pooling_type`.
rope_freq_base: RoPE base frequency, 0 = from model
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
yarn_attn_factor: YaRN magnitude scaling factor
yarn_beta_fast: YaRN low correction dim
yarn_beta_slow: YaRN high correction dim
yarn_orig_ctx: YaRN original context size
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
embedding: Embedding mode only.
offload_kqv: Offload K, Q, V to GPU.
flash_attn: Use flash attention.
no_perf: Measure performance timings.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
numa: numa policy
chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding.
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
verbose: Print verbose output to stderr.
type_k: KV cache data type for K (default: f16)
type_v: KV cache data type for V (default: f16)
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
Raises:
ValueError: If the model path does not exist.
Returns:
A Llama instance.
"""
self.verbose = verbose
self._stack = contextlib.ExitStack()
set_verbose(verbose)
if not Llama.__backend_initialized:
with suppress_stdout_stderr(disable=verbose):
llama_cpp.llama_backend_init()
Llama.__backend_initialized = True
if isinstance(numa, bool):
self.numa = (
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
if numa
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
)
else:
self.numa = numa
if self.numa != llama_cpp.GGML_NUMA_STRATEGY_DISABLED:
with suppress_stdout_stderr(disable=verbose):
llama_cpp.llama_numa_init(self.numa)
self.model_path = model_path
# Model Params
self.model_params = llama_cpp.llama_model_default_params()
self.model_params.n_gpu_layers = (
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
self.model_params.split_mode = split_mode
self.model_params.main_gpu = main_gpu
if rpc_servers is not None:
self.model_params.rpc_servers = rpc_servers.encode("utf-8")
self._rpc_servers = rpc_servers
else:
self._rpc_servers = None
self.tensor_split = tensor_split
self._c_tensor_split = None
if self.tensor_split is not None:
if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
raise ValueError(
f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}"
)
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
self._c_tensor_split = FloatArray(
*tensor_split # type: ignore
) # keep a reference to the array so it is not gc'd
self.model_params.tensor_split = self._c_tensor_split
self.model_params.vocab_only = vocab_only
self.model_params.use_mmap = use_mmap if lora_path is None else False
self.model_params.use_mlock = use_mlock
# kv_overrides is the original python dict
self.kv_overrides = kv_overrides
if kv_overrides is not None:
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
kvo_array_len = len(kv_overrides) + 1 # for sentinel element
self._kv_overrides_array = (
llama_cpp.llama_model_kv_override * kvo_array_len
)()
for i, (k, v) in enumerate(kv_overrides.items()):
self._kv_overrides_array[i].key = k.encode("utf-8")
if isinstance(v, bool):
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
self._kv_overrides_array[i].value.val_bool = v
elif isinstance(v, int):
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
self._kv_overrides_array[i].value.val_i64 = v
elif isinstance(v, float):
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
self._kv_overrides_array[i].value.val_f64 = v
elif isinstance(v, str): # type: ignore
v_bytes = v.encode("utf-8")
if len(v_bytes) > 128: # TODO: Make this a constant
raise ValueError(f"Value for {k} is too long: {v}")
v_bytes = v_bytes.ljust(128, b"\0")
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
# copy min(v_bytes, 128) to str_value
address = typing.cast(
int,
ctypes.addressof(self._kv_overrides_array[i].value)
+ llama_cpp.llama_model_kv_override_value.val_str.offset,
)
buffer_start = ctypes.cast(address, ctypes.POINTER(ctypes.c_char))
ctypes.memmove(
buffer_start,
v_bytes,
128,
)
else:
raise ValueError(f"Unknown value type for {k}: {v}")
self._kv_overrides_array[
-1
].key = b"\0" # ensure sentinel element is zeroed
self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ???
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
# Used by the sampler
self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED
# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.n_ctx = n_ctx
self.context_params.n_batch = self.n_batch
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
self.context_params.n_threads = self.n_threads
self.context_params.n_threads_batch = self.n_threads_batch
self.context_params.rope_scaling_type = (
rope_scaling_type
if rope_scaling_type is not None
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
self.context_params.pooling_type = pooling_type
self.context_params.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
self.context_params.rope_freq_scale = (
rope_freq_scale if rope_freq_scale != 0.0 else 0
)
self.context_params.yarn_ext_factor = (
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
)
self.context_params.yarn_attn_factor = (
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
)
self.context_params.yarn_beta_fast = (
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
)
self.context_params.yarn_beta_slow = (
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
self.context_params.embeddings = embedding # TODO: Rename to embeddings
self.context_params.offload_kqv = offload_kqv
self.context_params.flash_attn = flash_attn
# KV cache quantization
if type_k is not None:
self.context_params.type_k = type_k
if type_v is not None:
self.context_params.type_v = type_v
# Sampling Params
self.context_params.no_perf = no_perf
self.last_n_tokens_size = last_n_tokens_size
self.cache: Optional[BaseLlamaCache] = None
self.lora_base = lora_base
self.lora_scale = lora_scale
self.lora_path = lora_path
self.spm_infill = spm_infill
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
self._model = self._stack.enter_context(
contextlib.closing(
internals.LlamaModel(
path_model=self.model_path,
params=self.model_params,
verbose=self.verbose,
)
)
)
# Override tokenizer
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
# Set the default value for the context and correct the batch
if n_ctx == 0:
n_ctx = self._model.n_ctx_train()
self.n_batch = min(n_ctx, n_batch)
self.context_params.n_ctx = self._model.n_ctx_train()
self.context_params.n_batch = self.n_batch
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
self._ctx = self._stack.enter_context(
contextlib.closing(
internals.LlamaContext(
model=self._model,
params=self.context_params,
verbose=self.verbose,
)
)
)
self._batch = self._stack.enter_context(
contextlib.closing(
internals.LlamaBatch(
n_tokens=self.n_batch,
embd=0,
n_seq_max=self.context_params.n_ctx,
verbose=self.verbose,
)
)
)
self._lora_adapter: Optional[llama_cpp.llama_adapter_lora_p] = None
if self.lora_path:
self._lora_adapter = llama_cpp.llama_adapter_lora_init(
self._model.model,
self.lora_path.encode("utf-8"),
)
if self._lora_adapter is None:
raise RuntimeError(
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
)
def free_lora_adapter():
if self._lora_adapter is None:
return
llama_cpp.llama_adapter_lora_free(self._lora_adapter)
self._lora_adapter = None
self._stack.callback(free_lora_adapter)
if llama_cpp.llama_set_adapter_lora(
self._ctx.ctx, self._lora_adapter, self.lora_scale
):
raise RuntimeError(
f"Failed to set LoRA adapter from lora path: {self.lora_path}"
)
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
self.chat_format = chat_format
self.chat_handler = chat_handler
self._chat_handlers: Dict[
str, llama_chat_format.LlamaChatCompletionHandler
] = {}
self.draft_model = draft_model
self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()
self._token_nl = self.token_nl()
self._token_eos = self.token_eos()
self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)
self.n_tokens = 0
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
)
self._mirostat_mu = ctypes.c_float(
2.0 * 5.0
) # TODO: Move this to sampling context
try:
self.metadata = self._model.metadata()
except Exception as e:
self.metadata = {}
if self.verbose:
print(f"Failed to load metadata: {e}", file=sys.stderr)
if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)
eos_token_id = self.token_eos()
bos_token_id = self.token_bos()
eos_token = (
self._model.token_get_text(eos_token_id) if eos_token_id != -1 else ""
)
bos_token = (
self._model.token_get_text(bos_token_id) if bos_token_id != -1 else ""
)
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
template_choices = dict(
(name[10:], template)
for name, template in self.metadata.items()
if name.startswith("tokenizer.chat_template.")
)
if "tokenizer.chat_template" in self.metadata:
template_choices["chat_template.default"] = self.metadata[
"tokenizer.chat_template"
]
if self.verbose and template_choices:
print(
f"Available chat formats from metadata: {', '.join(template_choices.keys())}",
file=sys.stderr,
)
for name, template in template_choices.items():
self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token,
stop_token_ids=[eos_token_id],
).to_chat_handler()
if (
self.chat_format is None
and self.chat_handler is None
and "chat_template.default" in template_choices
):
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
self.metadata
)
if chat_format is not None:
self.chat_format = chat_format
if self.verbose:
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
else:
if self.verbose:
print(
f"Using gguf chat template: {template_choices['chat_template.default']}",
file=sys.stderr,
)
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_format = "chat_template.default"
if self.chat_format is None and self.chat_handler is None:
self.chat_format = "llama-2"
if self.verbose:
print(
f"Using fallback chat format: {self.chat_format}", file=sys.stderr
)
self._sampler = None
@property
def ctx(self) -> llama_cpp.llama_context_p:
return self._ctx.ctx
@property
def model(self) -> llama_cpp.llama_model_p:
return self._model.model
@property
def _input_ids(self) -> npt.NDArray[np.intc]:
return self.input_ids[: self.n_tokens]
@property
def _scores(self) -> npt.NDArray[np.single]:
return self.scores[: self.n_tokens, :]
@property
def eval_tokens(self) -> Deque[int]:
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
@property
def eval_logits(self) -> Deque[List[float]]:
return deque(
self.scores[: self.n_tokens, :].tolist(),
maxlen=self._n_ctx if self.context_params.logits_all else 1,
)
def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = False
) -> List[int]:
"""Tokenize a string.
Args:
text: The utf-8 encoded string to tokenize.
add_bos: Whether to add a beginning of sequence token.
special: Whether to tokenize special tokens.
Raises:
RuntimeError: If the tokenization failed.
Returns:
A list of tokens.
"""
return self.tokenizer_.tokenize(text, add_bos, special)
def detokenize(
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
"""Detokenize a list of tokens.
Args:
tokens: The list of tokens to detokenize.
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
special: Whether to detokenize special tokens.
Returns:
The detokenized string.
"""
return self.tokenizer_.detokenize(
tokens, prev_tokens=prev_tokens, special=special
)
def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self.cache = cache
def set_seed(self, seed: int):
"""Set the random seed.
Args:
seed: The random seed.
"""
self._seed = seed
def reset(self):
"""Reset the model state."""
self.n_tokens = 0
def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
Args:
tokens: The list of tokens to evaluate.
"""
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = self.n_tokens
n_tokens = len(batch)
self._batch.set_batch(
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
)
self._ctx.decode(self._batch)
# Save tokens
self.input_ids[n_past : n_past + n_tokens] = batch
# Save logits
if self.context_params.logits_all:
rows = n_tokens
cols = self._n_vocab
logits = np.ctypeslib.as_array(
self._ctx.get_logits(), shape=(rows * cols,)
)
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
else:
# rows = 1
# cols = self._n_vocab
# logits = np.ctypeslib.as_array(
# self._ctx.get_logits(), shape=(rows * cols,)
# )
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
pass
# Update n_tokens
self.n_tokens += n_tokens
def _init_sampler(
self,
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
):
sampler = internals.LlamaSampler()
if logits_processor is not None:
# Create and add a custom sampler
def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
size = token_data_array.contents.size
data_soa = token_data_array.contents.data
data_soa_address = ctypes.addressof(data_soa.contents)
# NOTE: This is probably broken
recarray = np.recarray(
shape=(size,),
dtype=np.dtype(
[("id", np.intc), ("logit", np.single), ("p", np.single)],
align=True,
),
buf=(llama_cpp.llama_token_data * size).from_address(
data_soa_address
),
)
for logit_processor in logits_processor:
recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)
sampler.add_custom(apply_func)
sampler.add_penalties(
n_vocab=self._n_vocab,
special_eos_id=self._token_eos,
linefeed_id=self._token_nl,
penalty_last_n=self.last_n_tokens_size,
penalty_repeat=repeat_penalty,
penalty_freq=frequency_penalty,
penalty_present=presence_penalty,
penalize_nl=penalize_nl,
ignore_eos=False,
)
if grammar is not None:
sampler.add_grammar(self._model, grammar)
if temp < 0.0:
sampler.add_softmax()
sampler.add_dist(self._seed)
elif temp == 0.0:
sampler.add_greedy()
else:
if mirostat_mode == 1:
mirostat_m = 100
sampler.add_mirostat(
self._n_vocab,
self._seed,
mirostat_tau,
mirostat_eta,
mirostat_m,
)
elif mirostat_mode == 2:
sampler.add_mirostat_v2(
self._seed,
mirostat_tau,
mirostat_eta,
)
else:
n_probs = 0
min_keep = max(1, n_probs)
sampler.add_top_k(top_k)
sampler.add_typical(typical_p, min_keep)
sampler.add_top_p(top_p, min_keep)
sampler.add_min_p(min_p, min_keep)
sampler.add_temp(temp)
sampler.add_dist(self._seed)
return sampler
def sample(
self,
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
idx: Optional[int] = None,
):
"""Sample a token from the model.
Args:
top_k: The top-k sampling parameter.
top_p: The top-p sampling parameter.
temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter.
Returns:
The sampled token.
"""
assert self.n_tokens > 0
tmp_sampler = False
if self._sampler is None:
tmp_sampler = True
self._sampler = self._init_sampler(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
)
ridx = idx - self.n_tokens if idx is not None else -1
assert self.ctx is not None
token = self._sampler.sample(self._ctx, ridx)
if tmp_sampler:
self._sampler = None
return token
def generate(
self,
tokens: Sequence[int],
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.0,
reset: bool = True,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
Examples:
>>> llama = Llama("models/ggml-7b.bin")
>>> tokens = llama.tokenize(b"Hello, world!")
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.0):
... print(llama.detokenize([token]))
Args:
tokens: The prompt tokens.
top_k: The top-k sampling parameter.
top_p: The top-p sampling parameter.
temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter.
reset: Whether to reset the model state.
Yields:
The generated tokens.
"""
# Reset mirostat sampling
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
self._sampler = self._init_sampler(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
)
# Check for kv cache prefix match
if reset and self.n_tokens > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
if a == b:
longest_prefix += 1
else:
break
if longest_prefix > 0:
reset = False
tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix
if self.verbose:
print(
f"Llama.generate: {longest_prefix} prefix-match hit, "
f"remaining {len(tokens)} prompt tokens to eval",
file=sys.stderr,
)
# Reset the model state
if reset:
self.reset()
# # Reset the grammar
# if grammar is not None:
# grammar.reset()
sample_idx = self.n_tokens + len(tokens) - 1
tokens = list(tokens)
# Eval and sample
while True:
self.eval(tokens)
while sample_idx < self.n_tokens:
token = self.sample(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processor=logits_processor,
grammar=grammar,
penalize_nl=penalize_nl,
idx=sample_idx,
)
sample_idx += 1
if stopping_criteria is not None and stopping_criteria(
self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
):
return
tokens_or_none = yield token
tokens.clear()
tokens.append(token)
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
self.n_tokens = sample_idx
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
break
if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(
self.input_ids[: self.n_tokens + len(tokens)]
)
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
]
)
def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
) -> CreateEmbeddingResponse:
"""Embed a string.
Args:
input: The utf-8 encoded string to embed.
Returns:
An embedding object.
"""
model_name: str = model if model is not None else self.model_path
input = input if isinstance(input, list) else [input]
# get numeric embeddings
embeds: Union[List[List[float]], List[List[List[float]]]]
total_tokens: int
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
# convert to CreateEmbeddingResponse
data: List[Embedding] = [
{
"object": "embedding",
"embedding": emb,
"index": idx,
}
for idx, emb in enumerate(embeds)
]
return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}
def embed(