-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
127 lines (100 loc) · 4.11 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.write("# Text Generator")
st.sidebar.title("Model Information")
st.sidebar.write("Made using a simple 2 hidden layered Neural Network, this text generator can predict next characters of the input text provided.")
st.sidebar.write("Here, the intension is not to generate meaningful sentences, we require a lot of compute for that. This app aims at showing how a vanilla neural network is also capable of capturing the format of English language, and generate words that are (very close to) valid words. Notice that the model uses capital letters (including capital I), punctuation marks and fullstops nearly correct. The text is generated paragraph wise, because the model learnt this from the text corpus.")
st.sidebar.write("This model was trained on a simple 600 KB text corpus titled: 'Gulliver's Travels'")
no_of_chars = st.slider("Number of characters to be generated", 100, 2000, 1000)
# Open the file in read mode
with open('gt.txt', 'r') as file:
# Read the entire content of the file
thefile = file.read()
content = thefile[:-2000]
test = thefile[-2000:]
# Create a dictionary to store unique characters and their indices
stoi = {}
stoi['@'] = 0
# Iterate through each character in the string
i = 1
for char in sorted(content):
# Check if the character is not already in the dictionary
if char not in stoi:
# Add the character to the dictionary with its index
stoi[char] = i
i+=1
itos = {value: key for key, value in stoi.items()}
def generate_text(model, inp, itos, stoi, block_size, max_len):
context = [0] * block_size
# inp = inp.lower()
if len(inp) <= block_size:
for i in range(len(inp)):
context[i] = stoi[inp[i]]
else:
j = 0
for i in range(len(inp)-block_size,len(inp)):
context[j] = stoi[inp[i]]
j+=1
name = ''
for i in range(max_len):
x = torch.tensor(context).view(1, -1).to(device)
y_pred = model(x)
ix = torch.distributions.categorical.Categorical(logits=y_pred).sample().item()
ch = itos[ix]
# if ch == '.':
# break
name += ch
context = context[1:] + [ix]
return name
# Function to simulate typing effect
def type_text(text):
# Create an empty text element
text_element = st.empty()
s = ""
for char in text:
# Update the text element with the next character
s += char
text_element.write(s+'$ꕯ$')
time.sleep(0.004) # Adjust the sleep duration for the typing speed
text_element.write(s)
class NextChar(nn.Module):
def __init__(self, block_size, vocab_size, emb_dim, hidden_size1, hidden_size2):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim)
self.lin1 = nn.Linear(block_size * emb_dim, hidden_size1)
self.lin2 = nn.Linear(hidden_size1, hidden_size2)
self.lin3 = nn.Linear(hidden_size2, vocab_size)
def forward(self, x):
x = self.emb(x)
x = x.view(x.shape[0], -1)
x = torch.sin(self.lin1(x)) # Activation function : change this
x = self.lin2(x)
return x
# Embedding layer for the context
# emb_dim = 10
emb_dim = st.selectbox(
'Select embedding size',
(1,2,5,10,15,30,50,100), index=4)
emb = torch.nn.Embedding(len(stoi), emb_dim)
# block_size = 15
block_size = st.selectbox(
'Select block size',
(15,50), index=0)
emb = torch.nn.Embedding(len(stoi), emb_dim)
model = NextChar(block_size, len(stoi), emb_dim, 500, 300).to(device)
model = torch.compile(model)
inp = st.text_input("Enter text", placeholder="Enter valid English text. You can also leave this blank.")
btn = st.button("Generate")
if btn:
st.subheader("Seed Text")
type_text(inp)
model.load_state_dict(torch.load("gt_eng_model_upper_two_hid_layer_emb"+str(emb_dim)+"_block_size_"+str(block_size)+".pth", map_location = device))
gen_txt = generate_text(model, inp, itos, stoi, block_size, no_of_chars)
st.subheader("Generated Text")
print(inp+gen_txt)
type_text(inp+gen_txt)