Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add n_keep parameter to LLama constructor to enable Streaming-LLM #954

Open
twoletters opened this issue Nov 29, 2023 · 1 comment
Open
Labels
enhancement New feature or request

Comments

@twoletters
Copy link

A recent paper by Meta/MIT/CMU proposed StreamingLLM, a simple yet efficient solution to enable "infinite" context. Better yet, the implementation in llama.cpp is as trivial as changing the n_keep value with option --keep as discussed in this issue. Unfortunately, the high-level API of llama-cpp-python does not support the keep/n_keep parameter.

It should be simple to add the parameter to the high-level API, ideally in the constructor for class Llama and to pass it along to function llama_cpp.llama_load_model_from_file as part of parameter lparams here.

@abetlen abetlen added the enhancement New feature or request label Dec 21, 2023
@Limour-dev
Copy link
Contributor

Limour-dev commented Feb 4, 2024

Maybe like this?

    def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
        if n_past < 0:
            n_past = self.n_tokens
        if im_start is not None:  # [<|im_start|>, name, nl]
            lps = compute_lps_array(im_start)
            _idx = kmp_search(self.input_ids, im_start, n_keep + n_discard, n_past, lps)
            if _idx >= n_keep:  # 其实是大于等于 n_keep + n_discard
                n_discard = _idx - n_keep  # 截断到最近的 im_start 序列结构
            else:
                _idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
                if _idx >= n_keep:
                    n_keep = _idx + len(im_start)  # 至少保留一个 im_start 序列结构
        self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
        self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
        self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
        self.n_tokens = n_past - n_discard

    def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None):
        if self._n_ctx < self.n_tokens + len(tokens):
            tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx)
            self.kv_cache_seq_ltrim(n_keep, tmp_n_discard, im_start)
        for i in range(0, len(tokens), self.n_batch):
            pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants