-
Notifications
You must be signed in to change notification settings - Fork 10.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
examples : iOS example with swift ui (#4159)
* copy to llama.cpp as subdir * attempt enabling metal, fails * ggml metal compiles! * Update README.md * initial conversion to new format, utf8 errors? * bug fixes, but now has an invalid memory access :( * added O3, now has insufficient memory access * begin sync with master * update to match latest code, new errors * fixed it! * fix for loop conditionals, increase result size * fix current workflow errors * attempt a llama.swiftui workflow * Update .github/workflows/build.yml Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
- Loading branch information
Showing
16 changed files
with
829 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
xcuserdata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# llama.swiftui | ||
|
||
Local inference of llama.cpp on an iPhone. | ||
So far I only tested with starcoder 1B model, but it can most likely handle 7B models as well. | ||
|
||
https://github.com/bachittle/llama.cpp/assets/39804642/e290827a-4edb-4093-9642-2a5e399ec545 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import Foundation | ||
|
||
// import llama | ||
|
||
enum LlamaError: Error { | ||
case couldNotInitializeContext | ||
} | ||
|
||
actor LlamaContext { | ||
private var model: OpaquePointer | ||
private var context: OpaquePointer | ||
private var batch: llama_batch | ||
private var tokens_list: [llama_token] | ||
|
||
var n_len: Int32 = 512 | ||
var n_cur: Int32 = 0 | ||
var n_decode: Int32 = 0 | ||
|
||
init(model: OpaquePointer, context: OpaquePointer) { | ||
self.model = model | ||
self.context = context | ||
self.tokens_list = [] | ||
self.batch = llama_batch_init(512, 0, 1) | ||
} | ||
|
||
deinit { | ||
llama_free(context) | ||
llama_free_model(model) | ||
llama_backend_free() | ||
} | ||
|
||
static func createContext(path: String) throws -> LlamaContext { | ||
llama_backend_init(false) | ||
let model_params = llama_model_default_params() | ||
|
||
let model = llama_load_model_from_file(path, model_params) | ||
guard let model else { | ||
print("Could not load model at \(path)") | ||
throw LlamaError.couldNotInitializeContext | ||
} | ||
var ctx_params = llama_context_default_params() | ||
ctx_params.seed = 1234 | ||
ctx_params.n_ctx = 2048 | ||
ctx_params.n_threads = 8 | ||
ctx_params.n_threads_batch = 8 | ||
|
||
let context = llama_new_context_with_model(model, ctx_params) | ||
guard let context else { | ||
print("Could not load context!") | ||
throw LlamaError.couldNotInitializeContext | ||
} | ||
|
||
return LlamaContext(model: model, context: context) | ||
} | ||
|
||
func get_n_tokens() -> Int32 { | ||
return batch.n_tokens; | ||
} | ||
|
||
func completion_init(text: String) { | ||
print("attempting to complete \"\(text)\"") | ||
|
||
tokens_list = tokenize(text: text, add_bos: true) | ||
|
||
let n_ctx = llama_n_ctx(context) | ||
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count) | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)") | ||
|
||
if n_kv_req > n_ctx { | ||
print("error: n_kv_req > n_ctx, the required KV cache size is not big enough") | ||
} | ||
|
||
for id in tokens_list { | ||
print(token_to_piece(token: id)) | ||
} | ||
|
||
// batch = llama_batch_init(512, 0) // done in init() | ||
batch.n_tokens = Int32(tokens_list.count) | ||
|
||
for i1 in 0..<batch.n_tokens { | ||
let i = Int(i1) | ||
batch.token[i] = tokens_list[i] | ||
batch.pos[i] = i1 | ||
batch.n_seq_id[Int(i)] = 1 | ||
batch.seq_id[Int(i)]![0] = 0 | ||
batch.logits[i] = 0 | ||
} | ||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true | ||
|
||
if llama_decode(context, batch) != 0 { | ||
print("llama_decode() failed") | ||
} | ||
|
||
n_cur = batch.n_tokens | ||
} | ||
|
||
func completion_loop() -> String { | ||
var new_token_id: llama_token = 0 | ||
|
||
let n_vocab = llama_n_vocab(model) | ||
let logits = llama_get_logits_ith(context, batch.n_tokens - 1) | ||
|
||
var candidates = Array<llama_token_data>() | ||
candidates.reserveCapacity(Int(n_vocab)) | ||
|
||
for token_id in 0..<n_vocab { | ||
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0)) | ||
} | ||
candidates.withUnsafeMutableBufferPointer() { buffer in | ||
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false) | ||
|
||
new_token_id = llama_sample_token_greedy(context, &candidates_p) | ||
} | ||
|
||
if new_token_id == llama_token_eos(context) || n_cur == n_len { | ||
print("\n") | ||
return "" | ||
} | ||
|
||
let new_token_str = token_to_piece(token: new_token_id) | ||
print(new_token_str) | ||
// tokens_list.append(new_token_id) | ||
|
||
batch.n_tokens = 0 | ||
|
||
batch.token[Int(batch.n_tokens)] = new_token_id | ||
batch.pos[Int(batch.n_tokens)] = n_cur | ||
batch.n_seq_id[Int(batch.n_tokens)] = 1 | ||
batch.seq_id[Int(batch.n_tokens)]![0] = 0 | ||
batch.logits[Int(batch.n_tokens)] = 1 // true | ||
batch.n_tokens += 1 | ||
|
||
n_decode += 1 | ||
|
||
n_cur += 1 | ||
|
||
if llama_decode(context, batch) != 0 { | ||
print("failed to evaluate llama!") | ||
} | ||
|
||
return new_token_str | ||
} | ||
|
||
func clear() { | ||
tokens_list.removeAll() | ||
} | ||
|
||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] { | ||
let n_tokens = text.count + (add_bos ? 1 : 0) | ||
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens) | ||
let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos, false) | ||
|
||
var swiftTokens: [llama_token] = [] | ||
for i in 0..<tokenCount { | ||
swiftTokens.append(tokens[Int(i)]) | ||
} | ||
|
||
tokens.deallocate() | ||
|
||
return swiftTokens | ||
} | ||
|
||
private func token_to_piece(token: llama_token) -> String { | ||
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8) | ||
result.initialize(repeating: Int8(0), count: 8) | ||
|
||
let _ = llama_token_to_piece(model, token, result, 8) | ||
|
||
let resultStr = String(cString: result) | ||
|
||
result.deallocate() | ||
|
||
return resultStr | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
// | ||
// Use this file to import your target's public headers that you would like to expose to Swift. | ||
// | ||
|
||
#import "llama.h" |
Oops, something went wrong.
Isn't this:
the same as
and pretty much 1024?