Skip to content

Commit

Permalink
Remove Falcon style ROPE (#35)
Browse files Browse the repository at this point in the history
* remove falcon style rope

* conver hf model to llama2.c format

* update readme

* .

---------

Co-authored-by: kirp <kirp2199.com>
  • Loading branch information
magician-blue authored Sep 29, 2023
1 parent e7c9344 commit e37bb87
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 43 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ Then, just run the Mojo

```bash
mojo llama2.mojo tl-chat.bin \
-r falcon \
-z tok_tl-chat.bin \
-n 256 -t 0 -s 100 -i "<|im_start|>user\nGive me a python function to generate Fibonacci sequence<|im_end|>\n<|im_start|>assistant\n"
```
Expand Down
49 changes: 7 additions & 42 deletions llama2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -599,31 +599,6 @@ fn matmul(inout C: Matrix, A: Matrix, B: Matrix, rt: Runtime) -> None:
matmul_parallelized(C, A, B, rt)


# Apply RoPE rotation to the q and k vectors for each head
# roate the first and second half
fn rope_rotation_falcon(inout state: RunState, freq_cis_real_row: BufferPtrFloat32,
freq_cis_imag_row: BufferPtrFloat32, config: Config) -> None:
# tinyllama-1.1, llama model
let q = state.q.data
let k = state.k.data
let head_size = config.head_size
let off_rot = head_size // 2
for i in range(config.n_heads):
for j in range(config.head_size // 2):
let fcr = freq_cis_real_row.offset(j).load(0)
let fci = freq_cis_imag_row.offset(j).load(0)
let q0 = q.offset(i * head_size + j).load(0)
let q1 = q.offset(i * head_size + j + off_rot).load(0)
q.offset(i * head_size + j).store(0, q0 * fcr - q1 * fci)
q.offset(i * head_size + j + off_rot).store(0, q0 * fci + q1 * fcr)
if i < config.n_kv_heads:
let k0 = k.offset(i * head_size + j).load(0)
let k1 = k.offset(i * head_size + j + off_rot).load(0)
k.offset(i * head_size + j).store(0, k0 * fcr - k1 * fci)
k.offset(i * head_size + j + off_rot).store(
0, k0 * fci + k1 * fcr
)

# Apply RoPE rotation to the q and k vectors for each head
# rotate odd and even dim
fn rope_rotation_llama(inout state: RunState, freq_cis_real_row: BufferPtrFloat32,
Expand All @@ -632,28 +607,24 @@ fn rope_rotation_llama(inout state: RunState, freq_cis_real_row: BufferPtrFloat3
let q = state.q.data
let k = state.k.data
let head_size = config.head_size
let off_rot = 1
for i in range(config.n_heads):
for j in range(0, config.head_size, 2):
let fcr = freq_cis_real_row.offset(j // 2).load(0)
let fci = freq_cis_imag_row.offset(j // 2).load(0)
let q0 = q.offset(i * head_size + j).load(0)
let q1 = q.offset(i * head_size + j + off_rot).load(0)
let q1 = q.offset(i * head_size + j + 1).load(0)
q.offset(i * head_size + j).store(0, q0 * fcr - q1 * fci)
q.offset(i * head_size + j + off_rot).store(0, q0 * fci + q1 * fcr)
q.offset(i * head_size + j + 1).store(0, q0 * fci + q1 * fcr)
if i < config.n_kv_heads:
let k0 = k.offset(i * head_size + j).load(0)
let k1 = k.offset(i * head_size + j + off_rot).load(0)
let k1 = k.offset(i * head_size + j + 1).load(0)
k.offset(i * head_size + j).store(0, k0 * fcr - k1 * fci)
k.offset(i * head_size + j + off_rot).store(
k.offset(i * head_size + j + 1).store(
0, k0 * fci + k1 * fcr
)

@always_inline
fn transformer[
rope_rotation: fn (inout state: RunState, freq_cis_real_row: BufferPtrFloat32,
freq_cis_imag_row: BufferPtrFloat32, config: Config) -> None
](
fn transformer(
token: Int,
pos: Int,
config: Config,
Expand Down Expand Up @@ -700,7 +671,7 @@ fn transformer[
matmul(state.v, state.xb, tmpw, state.rt)

# Apply RoPE rotation to the q and k vectors for each head
rope_rotation(state, freq_cis_real_row, freq_cis_imag_row, config)
rope_rotation_llama(state, freq_cis_real_row, freq_cis_imag_row, config)

# Multihead attention. Iterate over all heads
for h in range(config.n_heads):
Expand Down Expand Up @@ -885,7 +856,6 @@ fn print_usage():
print(" -n <int> number of steps to run for, default 256. 0 = max_seq_len")
print(" -i <string> input prompt")
print(" -z tokenizer path")
print(" -r <string> rope architecture, default 'llama' for llama rope, 'falcon' for falcon rope")


fn main() raises:
Expand All @@ -897,7 +867,6 @@ fn main() raises:
var steps = 256
var prompt = String("")
var rng_seed: Int = time.now()
var rope_arch = String("llama") # llama | falcon

@parameter
fn argparse() raises -> Int:
Expand All @@ -916,8 +885,6 @@ fn main() raises:
rng_seed = atol(args[i + 1])
if args[i] == "-i":
prompt = args[i + 1]
if args[i] == "-r":
rope_arch = args[i + 1]
if args[i] == "-t":
let val = args[i + 1]
temperature = 0.0
Expand Down Expand Up @@ -979,14 +946,12 @@ fn main() raises:
var next_token = 0 # Will store the next token in the sequence
# Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
var token = 1
let _transformer = transformer[rope_rotation_llama] if rope_arch == 'llama'
else transformer[rope_rotation_falcon]

# Position in the sequence
var pos = 0
while pos < steps:
# Forward the transformer to get logits for the next token
_transformer(token, pos, config, state, weights)
transformer(token, pos, config, state, weights)

if pos < len(prompt_tokens):
next_token = prompt_tokens[pos]
Expand Down

0 comments on commit e37bb87

Please sign in to comment.