-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathllama.py
1679 lines (1508 loc) · 61.5 KB
/
llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import sys
import uuid
import time
import math
import multiprocessing
from abc import ABC, abstractmethod
from typing import (
List,
Optional,
Union,
Generator,
Sequence,
Iterator,
Deque,
Tuple,
Callable,
)
from collections import deque, OrderedDict
import diskcache
import ctypes
from . import llama_cpp
from .llama_types import *
import numpy as np
import numpy.typing as npt
class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model."""
def __init__(self, capacity_bytes: int = (2 << 30)):
self.capacity_bytes = capacity_bytes
@property
@abstractmethod
def cache_size(self) -> int:
raise NotImplementedError
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
pass
@abstractmethod
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
raise NotImplementedError
@abstractmethod
def __contains__(self, key: Sequence[int]) -> bool:
raise NotImplementedError
@abstractmethod
def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None:
raise NotImplementedError
class LlamaRAMCache(BaseLlamaCache):
"""Cache for a llama.cpp model using RAM."""
def __init__(self, capacity_bytes: int = (2 << 30)):
super().__init__(capacity_bytes)
self.capacity_bytes = capacity_bytes
self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
@property
def cache_size(self):
return sum([state.llama_state_size for state in self.cache_state.values()])
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key = None
keys = (
(k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
)
for k, prefix_len in keys:
if prefix_len > min_len:
min_len = prefix_len
min_key = k
return min_key
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
key = tuple(key)
if key in self.cache_state:
del self.cache_state[key]
self.cache_state[key] = value
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
self.cache_state.popitem(last=False)
# Alias for backwards compatibility
LlamaCache = LlamaRAMCache
class LlamaDiskCache(BaseLlamaCache):
"""Cache for a llama.cpp model using disk."""
def __init__(
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
):
super().__init__(capacity_bytes)
self.cache = diskcache.Cache(cache_dir)
@property
def cache_size(self):
return int(self.cache.volume()) # type: ignore
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key: Optional[Tuple[int, ...]] = None
for k in self.cache.iterkeys(): # type: ignore
prefix_len = Llama.longest_token_prefix(k, key)
if prefix_len > min_len:
min_len = prefix_len
min_key = k # type: ignore
return min_key
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value: "LlamaState" = self.cache.pop(_key) # type: ignore
# NOTE: This puts an integer as key in cache, which breaks,
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
# self.cache.push(_key, side="front") # type: ignore
return value
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
key = tuple(key)
if key in self.cache:
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
del self.cache[key]
self.cache[key] = value
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
key_to_remove = next(iter(self.cache))
del self.cache[key_to_remove]
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
class LlamaState:
def __init__(
self,
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
n_tokens: int,
llama_state: bytes,
llama_state_size: int,
):
self.input_ids = input_ids
self.scores = scores
self.n_tokens = n_tokens
self.llama_state = llama_state
self.llama_state_size = llama_state_size
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
class LogitsProcessorList(List[LogitsProcessor]):
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
for processor in self:
scores = processor(input_ids, scores)
return scores
StoppingCriteria = Callable[[List[int], List[float]], bool]
class StoppingCriteriaList(List[StoppingCriteria]):
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
class Llama:
"""High-level Python wrapper for a llama.cpp model."""
def __init__(
self,
model_path: str,
# NOTE: These parameters are likely to change in the future.
n_ctx: int = 512,
n_parts: int = -1,
n_gpu_layers: int = 0,
seed: int = 1337,
f16_kv: bool = True,
logits_all: bool = False,
vocab_only: bool = False,
use_mmap: bool = True,
use_mlock: bool = False,
embedding: bool = False,
n_threads: Optional[int] = None,
n_batch: int = 512,
last_n_tokens_size: int = 64,
lora_base: Optional[str] = None,
lora_path: Optional[str] = None,
low_vram: bool = False,
tensor_split: Optional[List[float]] = None,
rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0,
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
verbose: bool = True,
):
"""Load a llama.cpp model from `model_path`.
Args:
model_path: Path to the model.
n_ctx: Maximum context size.
n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
seed: Random seed. -1 for random.
f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights.
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
embedding: Embedding mode only.
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
rope_freq_base: Base frequency for rope sampling.
rope_freq_scale: Scale factor for rope sampling.
verbose: Print verbose output to stderr.
Raises:
ValueError: If the model path does not exist.
Returns:
A Llama instance.
"""
self.verbose = verbose
self.model_path = model_path
self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx
self.params.n_gpu_layers = n_gpu_layers
self.params.seed = seed
self.params.f16_kv = f16_kv
self.params.logits_all = logits_all
self.params.vocab_only = vocab_only
self.params.use_mmap = use_mmap if lora_path is None else False
self.params.use_mlock = use_mlock
self.params.embedding = embedding
self.params.low_vram = low_vram
self.tensor_split = tensor_split
self._p_tensor_split = None
if self.tensor_split is not None:
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._p_tensor_split
self.params.rope_freq_base = rope_freq_base
self.params.rope_freq_scale = rope_freq_scale
if n_gqa is not None:
self.params.n_gqa = n_gqa
if rms_norm_eps is not None:
self.params.rms_norm_eps = rms_norm_eps
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
self.cache: Optional[BaseLlamaCache] = None
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.lora_base = lora_base
self.lora_path = lora_path
### DEPRECATED ###
self.n_parts = n_parts
### DEPRECATED ###
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
self.model = llama_cpp.llama_load_model_from_file(
self.model_path.encode("utf-8"), self.params
)
assert self.model is not None
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
assert self.ctx is not None
if self.lora_path:
if llama_cpp.llama_model_apply_lora_from_file(
self.model,
llama_cpp.c_char_p(self.lora_path.encode("utf-8")),
llama_cpp.c_char_p(self.lora_base.encode("utf-8"))
if self.lora_base is not None
else llama_cpp.c_char_p(0),
llama_cpp.c_int(self.n_threads),
):
raise RuntimeError(
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
)
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()
size = llama_cpp.c_size_t(self._n_vocab)
sorted = llama_cpp.c_bool(False)
self._candidates_data = np.array(
[],
dtype=np.dtype(
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
),
)
self._candidates_data.resize(3, self._n_vocab, refcheck=False)
candidates = llama_cpp.llama_token_data_array(
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
size=size,
sorted=sorted,
)
self._candidates = candidates
self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos()
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single)
self.n_tokens = 0
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx, self._n_vocab), dtype=np.single
)
@property
def _input_ids(self) -> npt.NDArray[np.intc]:
return self.input_ids[: self.n_tokens]
@property
def _scores(self) -> npt.NDArray[np.single]:
return self.scores[: self.n_tokens, :]
@property
def eval_tokens(self) -> Deque[int]:
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
@property
def eval_logits(self) -> Deque[List[float]]:
return deque(
self.scores[: self.n_tokens, :].tolist(),
maxlen=self._n_ctx if self.params.logits_all else 1,
)
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.
Args:
text: The utf-8 encoded string to tokenize.
Raises:
RuntimeError: If the tokenization failed.
Returns:
A list of tokens.
"""
assert self.ctx is not None
n_ctx = self._n_ctx
tokens = (llama_cpp.llama_token * n_ctx)()
n_tokens = llama_cpp.llama_tokenize(
self.ctx,
text,
tokens,
llama_cpp.c_int(n_ctx),
llama_cpp.c_bool(add_bos),
)
if n_tokens < 0:
n_tokens = abs(n_tokens)
tokens = (llama_cpp.llama_token * n_tokens)()
n_tokens = llama_cpp.llama_tokenize(
self.ctx,
text,
tokens,
llama_cpp.c_int(n_tokens),
llama_cpp.c_bool(add_bos),
)
if n_tokens < 0:
raise RuntimeError(
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
)
return list(tokens[:n_tokens])
def detokenize(self, tokens: List[int]) -> bytes:
"""Detokenize a list of tokens.
Args:
tokens: The list of tokens to detokenize.
Returns:
The detokenized string.
"""
assert self.ctx is not None
output = b""
for token in tokens:
output += llama_cpp.llama_token_to_str(
self.ctx, llama_cpp.llama_token(token)
)
return output
def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self.cache = cache
def reset(self):
"""Reset the model state."""
self.n_tokens = 0
def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
Args:
tokens: The list of tokens to evaluate.
"""
assert self.ctx is not None
n_ctx = self._n_ctx
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self._input_ids))
n_tokens = len(batch)
return_code = llama_cpp.llama_eval(
ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(n_tokens),
n_past=llama_cpp.c_int(n_past),
n_threads=llama_cpp.c_int(self.n_threads),
)
if return_code != 0:
raise RuntimeError(f"llama_eval returned {return_code}")
# Save tokens
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
# Save logits
rows = n_tokens if self.params.logits_all else 1
cols = self._n_vocab
offset = (
0 if self.params.logits_all else n_tokens - 1
) # NOTE: Only save the last token logits if logits_all is False
self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape(
-1
)[:] = llama_cpp.llama_get_logits(self.ctx)[: rows * cols]
# Update n_tokens
self.n_tokens += n_tokens
def _sample(
self,
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
last_n_tokens_size: llama_cpp.c_int,
top_k: llama_cpp.c_int,
top_p: llama_cpp.c_float,
temp: llama_cpp.c_float,
tfs_z: llama_cpp.c_float,
repeat_penalty: llama_cpp.c_float,
frequency_penalty: llama_cpp.c_float,
presence_penalty: llama_cpp.c_float,
mirostat_mode: llama_cpp.c_int,
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
):
assert self.ctx is not None
assert self.n_tokens > 0
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
last_n_tokens_size = (
llama_cpp.c_int(n_ctx)
if last_n_tokens_size.value < 0
else last_n_tokens_size
)
logits: npt.NDArray[np.single] = self._scores[-1, :]
if logits_processor is not None:
logits = np.array(
logits_processor(self._input_ids.tolist(), logits.tolist()),
dtype=np.single,
)
self._scores[-1, :] = logits
nl_logit = logits[self._token_nl]
candidates = self._candidates
candidates_data = self._candidates_data
candidates_data["id"][:] = self._candidates_data_id # type: ignore
candidates_data["logit"][:] = logits
candidates_data["p"][:] = self._candidates_data_p # type: ignore
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
candidates.sorted = llama_cpp.c_bool(False)
candidates.size = llama_cpp.c_size_t(n_vocab)
llama_cpp.llama_sample_repetition_penalty(
ctx=self.ctx,
last_tokens_data=last_n_tokens_data,
last_tokens_size=last_n_tokens_size,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
penalty=repeat_penalty,
)
llama_cpp.llama_sample_frequency_and_presence_penalties(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
last_tokens_data=last_n_tokens_data,
last_tokens_size=last_n_tokens_size,
alpha_frequency=frequency_penalty,
alpha_presence=presence_penalty,
)
if not penalize_nl:
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
)
elif mirostat_mode.value == 1:
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
mirostat_m = llama_cpp.c_int(100)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
m=mirostat_m,
)
elif mirostat_mode.value == 2:
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
)
else:
llama_cpp.llama_sample_top_k(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
k=top_k,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_tail_free(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
z=tfs_z,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_typical(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
p=llama_cpp.c_float(1.0),
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_top_p(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
p=top_p,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
)
def sample(
self,
top_k: int = 40,
top_p: float = 0.95,
temp: float = 0.80,
repeat_penalty: float = 1.1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
):
"""Sample a token from the model.
Args:
top_k: The top-k sampling parameter.
top_p: The top-p sampling parameter.
temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter.
Returns:
The sampled token.
"""
assert self.ctx is not None
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self._input_ids)
) + self._input_ids[-self.last_n_tokens_size :].tolist()
return self._sample(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data
),
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
top_k=llama_cpp.c_int(top_k),
top_p=llama_cpp.c_float(top_p),
temp=llama_cpp.c_float(temp),
tfs_z=llama_cpp.c_float(tfs_z),
repeat_penalty=llama_cpp.c_float(repeat_penalty),
frequency_penalty=llama_cpp.c_float(frequency_penalty),
presence_penalty=llama_cpp.c_float(presence_penalty),
mirostat_mode=llama_cpp.c_int(mirostat_mode),
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl,
logits_processor=logits_processor,
)
def generate(
self,
tokens: Sequence[int],
top_k: int = 40,
top_p: float = 0.95,
temp: float = 0.80,
repeat_penalty: float = 1.1,
reset: bool = True,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
Examples:
>>> llama = Llama("models/ggml-7b.bin")
>>> tokens = llama.tokenize(b"Hello, world!")
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.1):
... print(llama.detokenize([token]))
Args:
tokens: The prompt tokens.
top_k: The top-k sampling parameter.
top_p: The top-p sampling parameter.
temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter.
reset: Whether to reset the model state.
Yields:
The generated tokens.
"""
assert self.ctx is not None
if reset and len(self._input_ids) > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
if a == b:
longest_prefix += 1
else:
break
if longest_prefix > 0:
if self.verbose:
print("Llama.generate: prefix-match hit", file=sys.stderr)
reset = False
tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix
if reset:
self.reset()
while True:
self.eval(tokens)
token = self.sample(
top_k=top_k,
top_p=top_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processor=logits_processor,
)
if stopping_criteria is not None and stopping_criteria(
self._input_ids.tolist(), self._scores[-1, :].tolist()
):
return
tokens_or_none = yield token
tokens = [token]
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
) -> Embedding:
"""Embed a string.
Args:
input: The utf-8 encoded string to embed.
Returns:
An embedding object.
"""
assert self.ctx is not None
model_name: str = model if model is not None else self.model_path
if self.params.embedding == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
)
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
if isinstance(input, str):
inputs = [input]
else:
inputs = input
data: List[EmbeddingData] = []
total_tokens = 0
for index, input in enumerate(inputs):
tokens = self.tokenize(input.encode("utf-8"))
self.reset()
self.eval(tokens)
n_tokens = len(tokens)
total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
data.append(
{
"object": "embedding",
"embedding": embedding,
"index": index,
}
)
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}
def embed(self, input: str) -> List[float]:
"""Embed a string.
Args:
input: The utf-8 encoded string to embed.
Returns:
A list of embeddings
"""
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
def _create_completion(
self,
prompt: str,
suffix: Optional[str] = None,
max_tokens: int = 16,
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
completion_tokens: List[int] = []
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[int] = self.tokenize(b" " + prompt.encode("utf-8"))
text: bytes = b""
returned_tokens: int = 0
stop = (
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
)
model_name: str = model if model is not None else self.model_path
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
raise ValueError(
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
)
if max_tokens <= 0:
# Unlimited, depending on n_ctx.
max_tokens = llama_cpp.llama_n_ctx(self.ctx) - len(prompt_tokens)
# Truncate max_tokens if requested tokens would exceed the context window
max_tokens = (
max_tokens
if max_tokens + len(prompt_tokens) < self._n_ctx
else (self._n_ctx - len(prompt_tokens))
)
if stop != []:
stop_sequences = [s.encode("utf-8") for s in stop]
else:
stop_sequences = []
if logprobs is not None and self.params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)
if self.cache:
try:
cache_item = self.cache[prompt_tokens]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.input_ids.tolist(), prompt_tokens
)
eval_prefix_len = Llama.longest_token_prefix(
self._input_ids.tolist(), prompt_tokens
)
if cache_prefix_len > eval_prefix_len:
self.load_state(cache_item)
if self.verbose:
print("Llama._create_completion: cache hit", file=sys.stderr)
except KeyError:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)
finish_reason = "length"
multibyte_fix = 0
for token in self.generate(
prompt_tokens,
top_k=top_k,
top_p=top_p,
temp=temperature,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
):
if token == self._token_eos:
text = self.detokenize(completion_tokens)
finish_reason = "stop"
break
completion_tokens.append(token)
all_text = self.detokenize(completion_tokens)
# Contains multi-byte UTF8
for k, char in enumerate(all_text[-3:]):
k = 3 - k
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if num > k and pattern & char == pattern:
multibyte_fix = num - k
# Stop incomplete bytes from passing
if multibyte_fix > 0:
multibyte_fix -= 1
continue
any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0:
first_stop = any_stop[0]
text = all_text[: all_text.index(first_stop)]
finish_reason = "stop"
break
if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens)
remaining_length = len(remaining_text)
# We want to avoid yielding any characters from
# the generated text if they are part of a stop
# sequence.
first_stop_position = 0
for s in stop_sequences:
for i in range(min(len(s), remaining_length), 0, -1):
if remaining_text.endswith(s[:i]):
if i > first_stop_position:
first_stop_position = i
break
token_end_position = 0
for token in remaining_tokens:
token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token
if token_end_position >= (
remaining_length - first_stop_position
):
break
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self._scores[token_offset - 1, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
zip(current_logprobs, range(len(current_logprobs))),
reverse=True,
)
)
top_logprob = {
self.detokenize([i]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"text_offset": [text_offset],
"token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob],
}
returned_tokens += 1