diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5d093bef9..40fb5f891 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -392,6 +392,7 @@ def __init__( self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx, self._n_vocab), dtype=np.single ) + self.cancel_stream = False @property def _input_ids(self) -> npt.NDArray[np.intc]: @@ -975,6 +976,11 @@ def _create_completion( finish_reason = "stop" break + if self.cancel_stream: + finish_reason = "cancel" + self.cancel_stream = False + break + completion_tokens.append(token) all_text = self.detokenize(completion_tokens) @@ -1702,6 +1708,11 @@ def load_state(self, state: LlamaState) -> None: if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size: raise RuntimeError("Failed to set llama state data") + def cancel(self) -> None: + """Cancel a streaming completion.""" + if self.n_tokens > 0: + self.cancel_stream = True + def n_ctx(self) -> int: """Return the context window size.""" assert self.ctx is not None