Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speculative decoding #1120

Merged
merged 25 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cee8c0f
Add draft model param to llama class, implement basic prompt lookup d…
abetlen Jan 22, 2024
2ff7247
Merge branch 'main' into add-speculative-decoding
abetlen Jan 22, 2024
be688da
Use samplingcontext for sampling
abetlen Jan 22, 2024
8fe1c48
Use 1d array
abetlen Jan 23, 2024
92cf2c4
Use draft model for sampling
abetlen Jan 23, 2024
b4976da
Fix dumb mistake
abetlen Jan 23, 2024
fae83f2
Allow for later extensions to the LlamaDraftModel api
abetlen Jan 23, 2024
e4e029e
Merge branch 'main' into add-speculative-decoding
abetlen Jan 23, 2024
346a6c5
Cleanup
abetlen Jan 24, 2024
eae4286
Merge remote-tracking branch 'origin' into add-speculative-decoding
abetlen Jan 24, 2024
e2dccf2
Merge branch 'main' into add-speculative-decoding
abetlen Jan 24, 2024
9b46cb9
Merge branch 'main' of https://github.com/abetlen/llama-cpp-python in…
abetlen Jan 24, 2024
8415837
Adaptive candidate prediction
abetlen Jan 24, 2024
a9d1da2
Merge branch 'add-speculative-decoding' of github.com:abetlen/llama_c…
abetlen Jan 24, 2024
c363eee
Update implementation to match hf transformers
abetlen Jan 24, 2024
5ab5999
Tuning
abetlen Jan 24, 2024
6732261
Merge branch 'main' into add-speculative-decoding
abetlen Jan 26, 2024
f39690c
Fix bug where last token was not used for ngram prediction
abetlen Jan 26, 2024
c6013e2
Remove heuristic for num_pred_tokens (no benefit)
abetlen Jan 26, 2024
515483a
Merge branch 'main' into add-speculative-decoding
abetlen Jan 31, 2024
edc3390
Merge branch 'add-speculative-decoding' of https://github.com/abetlen…
abetlen Jan 31, 2024
4f946b0
fix: n_candidates bug.
abetlen Jan 31, 2024
df93d1d
Add draft_model_num_pred_tokens server setting
abetlen Jan 31, 2024
995d40c
Cleanup
abetlen Jan 31, 2024
291eadc
Update README
abetlen Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process
)
```

### Speculative Decoding

`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.

The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class.

Just pass this as a draft model to the `Llama` class during initialization.

```python
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding

llama = Llama(
model_path="path/to/model.gguf",
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines.
)
```

### Adjusting the Context Window

The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.
Expand Down
181 changes: 91 additions & 90 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format

from llama_cpp.llama_speculative import LlamaDraftModel

import numpy as np
import numpy.typing as npt

Expand All @@ -39,6 +41,8 @@
_LlamaContext, # type: ignore
_LlamaBatch, # type: ignore
_LlamaTokenDataArray, # type: ignore
_LlamaSamplingParams, # type: ignore
_LlamaSamplingContext, # type: ignore
)


Expand Down Expand Up @@ -89,6 +93,8 @@ def __init__(
# Chat Format Params
chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None,
# Misc
verbose: bool = True,
# Extra Params
Expand Down Expand Up @@ -152,6 +158,7 @@ def __init__(
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding.
verbose: Print verbose output to stderr.

Raises:
Expand Down Expand Up @@ -315,6 +322,8 @@ def __init__(
self.chat_format = chat_format
self.chat_handler = chat_handler

self.draft_model = draft_model

self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()

Expand Down Expand Up @@ -503,6 +512,7 @@ def sample(
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
idx: Optional[int] = None,
):
"""Sample a token from the model.

Expand All @@ -517,77 +527,46 @@ def sample(
"""
assert self._ctx is not None
assert self.n_tokens > 0
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - self.n_tokens
) + self._input_ids[-self.last_n_tokens_size :].tolist()
last_n_tokens_size = len(last_n_tokens_data)
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = n_vocab if top_k <= 0 else top_k
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
*last_n_tokens_data
)
logits: npt.NDArray[np.single] = self._scores[-1, :]

if idx is None:
logits: npt.NDArray[np.single] = self._scores[-1, :]
else:
logits = self._scores[idx, :]

if logits_processor is not None:
logits[:] = logits_processor(self._input_ids, logits)

nl_logit = logits[self._token_nl]
self._candidates.copy_logits(logits)
self._ctx.sample_repetition_penalties(
candidates=self._candidates,
last_tokens_data=last_n_tokens_data_c,
penalty_last_n=last_n_tokens_size,
logits[:] = (
logits_processor(self._input_ids, logits)
if idx is None
else logits_processor(self._input_ids[:idx], logits)
)

sampling_params = _LlamaSamplingParams(
top_k=top_k,
top_p=top_p,
min_p=min_p,
tfs_z=tfs_z,
typical_p=typical_p,
temp=temp,
penalty_last_n=self.last_n_tokens_size,
penalty_repeat=repeat_penalty,
penalty_freq=frequency_penalty,
penalty_present=presence_penalty,
mirostat=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
)
sampling_context = _LlamaSamplingContext(
params=sampling_params,
grammar=grammar,
)
sampling_context.prev = list(self.eval_tokens)
id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits)
sampling_context.accept(
ctx_main=self._ctx,
id=id,
apply_grammar=grammar is not None,
)
if not penalize_nl:
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
nl_logit
)

if grammar is not None:
self._ctx.sample_grammar(
candidates=self._candidates,
grammar=grammar,
)

if temp < 0.0:
self._ctx.sample_softmax(candidates=self._candidates)
id = self._candidates.candidates.data[0].id
elif temp == 0.0:
id = self._ctx.sample_token_greedy(candidates=self._candidates)
elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
m=100,
)
elif mirostat_mode == 2:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat_v2(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
)
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(
candidates=self._candidates, p=typical_p, min_keep=1
)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None:
self._ctx.grammar_accept_token(grammar=grammar, token=id)
return id

def generate(
Expand Down Expand Up @@ -656,34 +635,56 @@ def generate(
if grammar is not None:
grammar.reset()

sample_idx = self.n_tokens + len(tokens) - 1
tokens = list(tokens)

# Eval and sample
while True:
self.eval(tokens)
token = self.sample(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processor=logits_processor,
grammar=grammar,
penalize_nl=penalize_nl,
)
if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
return
tokens_or_none = yield token
tokens = [token]
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
while sample_idx < self.n_tokens:
token = self.sample(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processor=logits_processor,
grammar=grammar,
penalize_nl=penalize_nl,
idx=sample_idx,
)

sample_idx += 1
if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
return
tokens_or_none = yield token
tokens.clear()
tokens.append(token)
if tokens_or_none is not None:
tokens.extend(tokens_or_none)

if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
self.n_tokens = sample_idx
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
break

if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
]
)

def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
Expand Down
64 changes: 64 additions & 0 deletions llama_cpp/llama_speculative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import abc

from typing import Any

import numpy as np
import numpy.typing as npt


class LlamaDraftModel(abc.ABC):
@abc.abstractmethod
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
raise NotImplementedError()


class LlamaPromptLookupDecoding(LlamaDraftModel):
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""

def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
self.max_ngram_size = max_ngram_size
self.num_pred_tokens = num_pred_tokens

@staticmethod
def find_candidate_pred_tokens(
input_ids: npt.NDArray[np.intc],
max_ngram_size: int,
num_pred_tokens: int,
):
input_length = input_ids.shape[0]

for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))

# Convert ngram to an array for comparison
ngram_array = input_ids[-ngram_size:]

# Find where the windows match the ngram
matches = np.all(windows == ngram_array, axis=1)

# Get the indices of matches
match_indices = np.nonzero(matches)[0]

# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + num_pred_tokens
end_idx = min(end_idx, input_length)

if start_idx < end_idx:
return input_ids[start_idx:end_idx]

# If no match is found, return an empty array
return np.array([], dtype=np.intc)

def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
return self.find_candidate_pred_tokens(
input_ids=input_ids,
max_ngram_size=self.max_ngram_size,
num_pred_tokens=self.num_pred_tokens,
)
9 changes: 9 additions & 0 deletions llama_cpp/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, Optional, Union, List

import llama_cpp
import llama_cpp.llama_speculative as llama_speculative

from llama_cpp.server.settings import ModelSettings

Expand Down Expand Up @@ -92,6 +93,12 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
)
)

draft_model = None
if settings.draft_model is not None:
draft_model = llama_speculative.LlamaPromptLookupDecoding(
num_pred_tokens=settings.draft_model_num_pred_tokens
)

kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list)
Expand Down Expand Up @@ -147,6 +154,8 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
# Chat Format Params
chat_format=settings.chat_format,
chat_handler=chat_handler,
# Speculative Decoding
draft_model=draft_model,
# Misc
verbose=settings.verbose,
)
Expand Down
9 changes: 9 additions & 0 deletions llama_cpp/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ class ModelSettings(BaseSettings):
default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
)
# Speculative Decoding
draft_model: Optional[str] = Field(
default=None,
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
)
draft_model_num_pred_tokens: int = Field(
default=10,
description="Number of tokens to predict using the draft model.",
)
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."
Expand Down
16 changes: 16 additions & 0 deletions tests/test_llama_speculative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np

from llama_cpp.llama_speculative import LlamaPromptLookupDecoding

def test_find_candidate_pred_tokens():
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens

# Test Case 1: Matching ngram is found
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result1, np.array([1, 2]))

# Test Case 2: Matching ngram is not found
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result2, np.array([]))
Loading