forked from tairov/llama2.mojo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgradio_app.py
91 lines (86 loc) · 2.76 KB
/
gradio_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import subprocess
import sys
from pathlib import Path
async def generate(prompt, model_name, seed=0, temperature=0.5, num_tokens=256):
# stream stout
base = ""#"../model/"
tokenizer_name = "tokenizer.bin"
if model_name == "tl-chat.bin":
tokenizer_name = 'tok_tl-chat.bin'
process = subprocess.Popen(
[
"mojo",
"llama2.mojo",
Path(base + model_name),
"-s",
str(seed),
"-n",
str(num_tokens),
"-t",
str(temperature),
"-i",
prompt,
"-z",
Path(base + tokenizer_name)
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
text = ""
for char in iter(lambda: process.stdout.read(1), b""):
char_decoded = char.decode("utf-8", errors="ignore")
text += char_decoded
yield text
with gr.Blocks() as demo:
gr.Markdown(
"""
# llama2.🔥
## [Mojo](https://docs.modular.com/mojo/) implementation of [llama2.c](https://github.com/karpathy/llama2.c) by [@tairov](https://github.com/tairov)
Source: https://github.com/tairov/llama2.mojo
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Add your prompt here...")
seed = gr.Slider(
minimum=0,
maximum=2**53,
value=0,
step=1,
label="Seed",
randomize=True,
)
temperature = gr.Slider(
minimum=0.0, maximum=2.0, step=0.01, value=0.0, label="Temperature"
)
num_tokens = gr.Slider(
minimum=1, maximum=256, value=256, label="Number of tokens"
)
model_name = gr.Dropdown(
["stories15M.bin", "stories42M.bin", "stories110M.bin", "tl-chat.bin"],
value="stories15M.bin",
label="Model Size",
)
with gr.Row():
stop = gr.Button("Stop")
run = gr.Button("Run")
with gr.Column(scale=2):
output_text = gr.Textbox(label="Generated Text")
# update maximum number of tokens based on model size
model_name.change(
lambda x: gr.update(maximum=1024)
if x == "stories110M.bin" or x == "stories42M.bin" or x == "tl-chat.bin"
else gr.update(maximum=256),
model_name,
num_tokens,
queue=False,
)
click_event = run.click(
fn=generate,
inputs=[prompt, model_name, seed, temperature, num_tokens],
outputs=output_text,
)
stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
demo.queue()
demo.launch(server_name="0.0.0.0")