-
Notifications
You must be signed in to change notification settings - Fork 203
/
Copy pathllama_cpp_model.py
119 lines (89 loc) · 4.53 KB
/
llama_cpp_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Tuple
import sys
import numpy as np
import lmql.utils.nputil as nputil
from lmql.models.lmtp.backends.lmtp_model import (LMTPModel, LMTPModelResult,
TokenStreamer)
class LlamaCppModel(LMTPModel):
def __init__(self, model_identifier, **kwargs):
from llama_cpp import Llama
self.model_identifier = model_identifier
self.kwargs = kwargs
self.max_batch_size = 1
print("[Loading llama.cpp model from", self.model_identifier, " with ", kwargs, "]", flush=True)
if not "verbose" in kwargs.keys():
kwargs["verbose"] = False
self.llm = Llama(model_path=model_identifier[len("llama.cpp:"):], logits_all=True, **kwargs)
def model_info(self):
import llama_cpp
return {
"model_identifier": self.model_identifier[len("llama.cpp:"):],
"model_type": "llama.cpp",
"constructor": "Llama(model_path='{}'{})".format(self.model_identifier[len("llama.cpp:"):], ", " + ", ".join(["{}={}".format(k, v) for k,v in self.kwargs.items()]) if len(self.kwargs) > 0 else ""),
"llama-cpp-python": llama_cpp.__version__,
}
def eos_token_id(self):
return 2
def score(self, input_ids, attention_mask, **model_kwargs):
tokens = input_ids[0]
# single forward pass (use generate() in favor of eval() to handle kv cache automatically)
for _ in self.llm.generate(tokens, temp=0.0): break
logits = np.array(self.llm.scores[:self.llm.n_tokens])
logits = nputil.log_softmax(logits, axis=-1)
scores = np.array([0.0] + [logits[j][i] for j,i in enumerate(input_ids[0][1:])])
return scores.reshape(1, -1)
def generate(self, input_ids, attention_mask,
temperature: float, max_new_tokens: int,
bias_tensor, streamer: TokenStreamer, **kwargs) -> LMTPModelResult:
token_scores = []
sequence = []
input_ids = input_ids.reshape(-1).tolist()
def llama_streamer(tokens, scores):
nonlocal token_scores
scores = np.array(scores)
token_scores += [scores]
return False
logits_processor = self.logits_processors(bias_tensor) if bias_tensor is not None else None
for i, token in zip(range(max_new_tokens), self.llm.generate(input_ids,
temp=temperature,
stopping_criteria=llama_streamer,
logits_processor=logits_processor,
**kwargs)):
assert i + len(input_ids) < self.llm.n_ctx(), f"The requested number of tokens exceeds the llama.cpp model's specified context size {self.llm.n_ctx()}. Please specify a higher n_ctx value or use a shorter prompt."
sequence += [token]
sq_ar = np.array(sequence)
ts_ar = np.stack(token_scores, axis=0)
if i+1 >= max_new_tokens:
break
else:
streamer(sq_ar.reshape(1, *sq_ar.shape), ts_ar.reshape(-1, 1, *ts_ar.shape[1:]))
ts_ar = np.stack(token_scores, axis=0)
sq_ar = np.array(sequence)
return LMTPModelResult(
sequences=sq_ar.reshape(1, *sq_ar.shape),
scores=ts_ar.reshape(-1, 1, *ts_ar.shape[1:])
)
def logits_processors(self, logit_biases):
bias_tensors = None
make_bias_tensor = self.make_bias_tensor
if len(logit_biases) == 0:
return []
class BatchLogitsProcessor:
def __call__(self, input_ids, scores):
nonlocal bias_tensors
scores = np.array(scores)
if bias_tensors is None:
bias_tensors = np.array(make_bias_tensor(logit_biases, scores.shape[-1]))
return nputil.log_softmax(scores + bias_tensors, axis=-1).reshape(-1)
return BatchLogitsProcessor()
LMTPModel.registry["llama.cpp"] = LlamaCppModel
if __name__ == "__main__":
from transformers import AutoTokenizer
llm = Llama("/Users/luca/Developer/llama.cpp/models/7B/ggml-model-q4_0.bin")
s = "Say this is a test:"
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
ids = tokenizer(s)["input_ids"]
print(ids)
for token in llm.generate(ids, 120, temp=0.0):
ids += [token]
print(tokenizer.decode(ids))