-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathword_prediction.py
51 lines (47 loc) · 1.25 KB
/
word_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
wew = np.genfromtxt('word_embedding_weights.csv', delimiter=', ')
ethw = np.genfromtxt('embed_to_hid_weights.csv', delimiter = ', ')
htow = np.genfromtxt('hid_to_output_weights.csv', delimiter = ', ')
hb = np.genfromtxt('hidden_bias.csv', delimiter = ', ')
ob = np.genfromtxt('output_bias.csv', delimiter = ', ')
vocab = []
with open('vocab.csv', 'r') as f:
for line in f.readlines():
vocab = line.split(', ')[: -1]
words = []
for i in range(3):
word = ''
while word not in vocab:
word = input('What is word ' + str(i)+ ' ')
if word not in vocab:
print('that word is not in dictionary')
words.append(word)
#get the embedded_layer_state
def predict():
els = []
for i in range(3):
els.append(wew[vocab.index(words[i])])
els = np.array(els).ravel()
ith = els.dot(ethw) + hb
hls = 1.0/(1.0 + np.exp(-ith))
its = hls.dot(htow) + ob
its -= np.max(its)
ols = np.exp(its)
ols = ols/np.sum(ols)
maxValue = 0.0
maxPos = 0
for i in range(250):
if ols[i] > maxValue:
maxValue = ols[i]
maxPos = i
predictedWord = vocab[maxPos]
words.pop(0)
words.append(predictedWord)
return predictedWord
strToWrite = ''
for i in range(3):
strToWrite += words[i] + ' '
for i in range(10):
n = input()
strToWrite+= predict() + ' '
print(strToWrite)