Skip to content

Commit

Permalink
llama.vim : wip [no ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 8, 2024
1 parent 7b92156 commit 25f3b4d
Showing 1 changed file with 70 additions and 124 deletions.
194 changes: 70 additions & 124 deletions examples/llama.vim
Original file line number Diff line number Diff line change
@@ -1,135 +1,81 @@
" Requires an already running llama.cpp server
" To install either copy or symlink to ~/.vim/autoload/llama.vim
" Then start with either :call llama#doLlamaGen(),
" or add a keybind to your vimrc such as
" nnoremap Z :call llama#doLlamaGen()<CR>
" Similarly, you could add an insert mode keybind with
" inoremap <C-B> <Cmd>call llama#doLlamaGen()<CR>
" sample config:
"
" g:llama_api_url, g:llama_api_key and g:llama_overrides can be configured in your .vimrc
" let g:llama_api_url = "192.168.1.10:8080"
" llama_overrides can also be set through buffer/window scopes. For instance
" autocmd filetype python let b:llama_overrides = {"temp": 0.2}
" Could be added to your .vimrc to automatically set a lower temperature when
" editing a python script
" Additionally, an override dict can be stored at the top of a file
" !*{"stop": ["User:"]}
" Could be added to the start of your chatlog.txt to set the stopping token
" These parameter dicts are merged together from lowest to highest priority:
" server default -> g:llama_overrides -> w:llama_overrides ->
" b:llama_overrides -> in file (!*) overrides
" - Ctrl+F - trigger FIM completion
"
" copy paste this in your .vimrc:
"
"augroup llama_cpp
" autocmd!
" autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <Esc>:call llama#fim()<CR>
"augroup END
"
" Sublists (like logit_bias and stop) are overridden, not merged
" Example override:
" !*{"logit_bias": [[13, -5], [2, false]], "temperature": 1, "top_k": 5, "top_p": 0.5, "n_predict": 256, "repeat_last_n": 256, "repeat_penalty": 1.17647}
if !exists("g:llama_api_url")
let g:llama_api_url= "127.0.0.1:8080"
endif
if !exists("g:llama_overrides")
let g:llama_overrides = {}
endif
const s:querydata = {"n_predict": 256, "stop": [ "\n" ], "stream": v:true }
const s:curlcommand = ['curl','--data-raw', "{\"prompt\":\"### System:\"}", '--silent', '--no-buffer', '--request', 'POST', '--url', g:llama_api_url .. '/completion', '--header', "Content-Type: application/json"]
let s:linedict = {}

func s:callbackHandler(bufn, channel, msg)
if len(a:msg) < 3
return
elseif a:msg[0] == "d"
let l:msg = a:msg[6:-1]
else
let l:msg = a:msg
endif
let l:decoded_msg = json_decode(l:msg)
let l:newtext = split(l:decoded_msg['content'], "\n", 1)
if len(l:newtext) > 0
call setbufline(a:bufn, s:linedict[a:bufn], getbufline(a:bufn, s:linedict[a:bufn])[0] .. newtext[0])
else
echo "nothing genned"
endif
if len(newtext) > 1
let l:failed = appendbufline(a:bufn, s:linedict[a:bufn], newtext[1:-1])
let s:linedict[a:bufn] = s:linedict[a:bufn] + len(newtext)-1
endif
if has_key(l:decoded_msg, "stop") && l:decoded_msg.stop
echo "Finished generation"
endif
endfunction
let s:default_config = {
\ 'prefix_lines': 32,
\ 'suffix_lines': 32,
\ 'endpoint': 'http://127.0.0.1:8012/infill',
\ 'stop': ["\n"],
\ 'n_predict': 64,
\ 'n_probs': 3,
\ 'temperature': 0.1
\}

func llama#doLlamaGen()
if exists("b:job")
if job_status(b:job) == "run"
call job_stop(b:job)
return
endif
endif
let g:llama_config = get(g:, 'llama_config', s:default_config)

let l:cbuffer = bufnr("%")
let s:linedict[l:cbuffer] = line('$')
let l:buflines = getbufline(l:cbuffer, 1, 1000)
let l:querydata = copy(s:querydata)
call extend(l:querydata, g:llama_overrides)
if exists("w:llama_overrides")
call extend(l:querydata, w:llama_overrides)
endif
if exists("b:llama_overrides")
call extend(l:querydata, b:llama_overrides)
endif
if l:buflines[0][0:1] == '!*'
let l:userdata = json_decode(l:buflines[0][2:-1])
call extend(l:querydata, l:userdata)
let l:buflines = l:buflines[1:-1]
endif
let l:querydata.prompt = join(l:buflines, "\n")
let l:curlcommand = copy(s:curlcommand)
if exists("g:llama_api_key")
call extend(l:curlcommand, ['--header', 'Authorization: Bearer ' .. g:llama_api_key])
endif
let l:curlcommand[2] = json_encode(l:querydata)
let b:job = job_start(l:curlcommand, {"callback": function("s:callbackHandler", [l:cbuffer])})
endfunction
function! llama#fim() abort
let l:lines_prefix = getline(max([1, line('.') - g:llama_config.suffix_lines]), line('.') - 1)
let l:lines_suffix = getline(line('.') + 1, min([line('$'), line('.') + g:llama_config.prefix_lines]))

" Echos the tokkenization of the provided string , or cursor to end of word
" Onus is placed on the user to include the preceding space
func llama#tokenizeWord(...)
if (a:0 > 0)
let l:input = a:1
else
exe "normal \"*ye"
let l:input = @*
endif
let l:querydata = {"content": l:input}
let l:curlcommand = copy(s:curlcommand)
let l:curlcommand[2] = json_encode(l:querydata)
let l:curlcommand[8] = g:llama_api_url .. "/tokenize"
let s:token_job = job_start(l:curlcommand, {"callback": function("s:tokenizeWordCallback", [l:input])})
endfunction
let l:cursor_col = col('.')

func s:tokenizeWordCallback(plaintext, channel, msg)
echo '"' .. a:plaintext ..'" - ' .. string(json_decode(a:msg).tokens)
endfunction
let l:line_cur = getline('.')
let l:line_cur_prefix = strpart(l:line_cur, 0, l:cursor_col)
let l:line_cur_suffix = strpart(l:line_cur, l:cursor_col)

let l:prefix = ""
\ . join(l:lines_prefix, "\n")
\ . "\n"
\ . l:line_cur_prefix

" Echos the token count of the entire buffer (or provided string)
" Example usage :echo llama#tokenCount()
func llama#tokenCount(...)
if (a:0 > 0)
let l:buflines = a:1
else
let l:buflines = getline(1,1000)
if l:buflines[0][0:1] == '!*'
let l:buflines = l:buflines[1:-1]
endif
let l:buflines = join(l:buflines, "\n")
endif
let l:querydata = {"content": l:buflines}
let l:curlcommand = copy(s:curlcommand)
let l:curlcommand[2] = json_encode(l:querydata)
let l:curlcommand[8] = g:llama_api_url .. "/tokenize"
let s:token_job = job_start(l:curlcommand, {"callback": "s:tokenCountCallback"})
endfunction
let l:suffix = ""
\ . l:line_cur_suffix
\ . join(l:lines_suffix, "\n")

let l:request = json_encode({
\ 'prompt': "",
\ 'input_prefix': l:prefix,
\ 'input_suffix': l:suffix,
"\ 'stop': g:llama_config.stop,
\ 'n_predict': g:llama_config.n_predict,
"\ 'n_probs': g:llama_config.n_probs,
\ 'penalty_last_n': 0,
\ 'temperature': g:llama_config.temperature,
\ 'top_k': 10,
\ 'stream': v:false,
\ 'samplers': ["top_k"]
\ })

" request completion from the server
let l:curl_command = printf(
\ "curl --silent --no-buffer --request POST --url %s --header \"Content-Type: application/json\" --data %s",
\ g:llama_config.endpoint, shellescape(l:request)
\ )

let l:response = json_decode(system(l:curl_command))

echom l:response

let l:content = []
for l:part in split(get(l:response, 'content', ''), "\n", 1)
call add(l:content, l:part)
endfor

echom l:content

" insert the 'content' at the current cursor location
let l:content[0] = l:line_cur_prefix . l:content[0]
let l:content[-1] .= l:line_cur_suffix

func s:tokenCountCallback(channel, msg)
let resp = json_decode(a:msg)
echo len(resp.tokens)
call setline('.', l:content[0])
call append (line('.'), l:content[1:-1])
endfunction

0 comments on commit 25f3b4d

Please sign in to comment.