Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Falcon style ROPE #35

Merged
merged 4 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ Then, just run the Mojo

```bash
mojo llama2.mojo tl-chat.bin \
-r falcon \
-z tok_tl-chat.bin \
-n 256 -t 0 -s 100 -i "<|im_start|>user\nGive me a python function to generate Fibonacci sequence<|im_end|>\n<|im_start|>assistant\n"
```
Expand Down
49 changes: 7 additions & 42 deletions llama2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -599,31 +599,6 @@ fn matmul(inout C: Matrix, A: Matrix, B: Matrix, rt: Runtime) -> None:
matmul_parallelized(C, A, B, rt)


# Apply RoPE rotation to the q and k vectors for each head
# roate the first and second half
fn rope_rotation_falcon(inout state: RunState, freq_cis_real_row: BufferPtrFloat32,
freq_cis_imag_row: BufferPtrFloat32, config: Config) -> None:
# tinyllama-1.1, llama model
let q = state.q.data
let k = state.k.data
let head_size = config.head_size
let off_rot = head_size // 2
for i in range(config.n_heads):
for j in range(config.head_size // 2):
let fcr = freq_cis_real_row.offset(j).load(0)
let fci = freq_cis_imag_row.offset(j).load(0)
let q0 = q.offset(i * head_size + j).load(0)
let q1 = q.offset(i * head_size + j + off_rot).load(0)
q.offset(i * head_size + j).store(0, q0 * fcr - q1 * fci)
q.offset(i * head_size + j + off_rot).store(0, q0 * fci + q1 * fcr)
if i < config.n_kv_heads:
let k0 = k.offset(i * head_size + j).load(0)
let k1 = k.offset(i * head_size + j + off_rot).load(0)
k.offset(i * head_size + j).store(0, k0 * fcr - k1 * fci)
k.offset(i * head_size + j + off_rot).store(
0, k0 * fci + k1 * fcr
)

# Apply RoPE rotation to the q and k vectors for each head
# rotate odd and even dim
fn rope_rotation_llama(inout state: RunState, freq_cis_real_row: BufferPtrFloat32,
Expand All @@ -632,28 +607,24 @@ fn rope_rotation_llama(inout state: RunState, freq_cis_real_row: BufferPtrFloat3
let q = state.q.data
let k = state.k.data
let head_size = config.head_size
let off_rot = 1
for i in range(config.n_heads):
for j in range(0, config.head_size, 2):
let fcr = freq_cis_real_row.offset(j // 2).load(0)
let fci = freq_cis_imag_row.offset(j // 2).load(0)
let q0 = q.offset(i * head_size + j).load(0)
let q1 = q.offset(i * head_size + j + off_rot).load(0)
let q1 = q.offset(i * head_size + j + 1).load(0)
q.offset(i * head_size + j).store(0, q0 * fcr - q1 * fci)
q.offset(i * head_size + j + off_rot).store(0, q0 * fci + q1 * fcr)
q.offset(i * head_size + j + 1).store(0, q0 * fci + q1 * fcr)
if i < config.n_kv_heads:
let k0 = k.offset(i * head_size + j).load(0)
let k1 = k.offset(i * head_size + j + off_rot).load(0)
let k1 = k.offset(i * head_size + j + 1).load(0)
k.offset(i * head_size + j).store(0, k0 * fcr - k1 * fci)
k.offset(i * head_size + j + off_rot).store(
k.offset(i * head_size + j + 1).store(
0, k0 * fci + k1 * fcr
)

@always_inline
fn transformer[
rope_rotation: fn (inout state: RunState, freq_cis_real_row: BufferPtrFloat32,
freq_cis_imag_row: BufferPtrFloat32, config: Config) -> None
](
fn transformer(
token: Int,
pos: Int,
config: Config,
Expand Down Expand Up @@ -700,7 +671,7 @@ fn transformer[
matmul(state.v, state.xb, tmpw, state.rt)

# Apply RoPE rotation to the q and k vectors for each head
rope_rotation(state, freq_cis_real_row, freq_cis_imag_row, config)
rope_rotation_llama(state, freq_cis_real_row, freq_cis_imag_row, config)

# Multihead attention. Iterate over all heads
for h in range(config.n_heads):
Expand Down Expand Up @@ -885,7 +856,6 @@ fn print_usage():
print(" -n <int> number of steps to run for, default 256. 0 = max_seq_len")
print(" -i <string> input prompt")
print(" -z tokenizer path")
print(" -r <string> rope architecture, default 'llama' for llama rope, 'falcon' for falcon rope")


fn main() raises:
Expand All @@ -897,7 +867,6 @@ fn main() raises:
var steps = 256
var prompt = String("")
var rng_seed: Int = time.now()
var rope_arch = String("llama") # llama | falcon

@parameter
fn argparse() raises -> Int:
Expand All @@ -916,8 +885,6 @@ fn main() raises:
rng_seed = atol(args[i + 1])
if args[i] == "-i":
prompt = args[i + 1]
if args[i] == "-r":
rope_arch = args[i + 1]
if args[i] == "-t":
let val = args[i + 1]
temperature = 0.0
Expand Down Expand Up @@ -979,14 +946,12 @@ fn main() raises:
var next_token = 0 # Will store the next token in the sequence
# Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
var token = 1
let _transformer = transformer[rope_rotation_llama] if rope_arch == 'llama'
else transformer[rope_rotation_falcon]

# Position in the sequence
var pos = 0
while pos < steps:
# Forward the transformer to get logits for the next token
_transformer(token, pos, config, state, weights)
transformer(token, pos, config, state, weights)

if pos < len(prompt_tokens):
next_token = prompt_tokens[pos]
Expand Down