-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_utils.py
executable file
·116 lines (91 loc) · 2.99 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
from random import sample
import pickle
EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz '
'''
get metadata - lookup
idx2w and w2idx
'''
def get_metadata():
with open('datasets/twitter/metadata.pkl', 'rb') as f:
metadata = pickle.load(f)
return metadata.get('idx2w'), metadata.get('w2idx'), metadata.get('limit')
'''
split data into train (70%), test (15%) and valid(15%)
return tuple( (trainX, trainY), (testX,testY), (validX,validY) )
'''
def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ):
# number of examples
data_len = len(x)
lens = [ int(data_len*item) for item in ratio ]
trainX, trainY = x[:lens[0]], y[:lens[0]]
testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]]
validX, validY = x[-lens[-1]:], y[-lens[-1]:]
return (trainX,trainY), (testX,testY), (validX,validY)
'''
generate batches from dataset
yield (x_gen, y_gen)
TODO : fix needed
'''
def batch_gen(x, y, batch_size):
# infinite while
while True:
for i in range(0, len(x), batch_size):
if (i+1)*batch_size < len(x):
yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T
'''
generate batches, by random sampling a bunch of items
yield (x_gen, y_gen)
'''
def rand_batch_gen(x, y, batch_size):
while True:
sample_idx = sample(list(np.arange(len(x))), batch_size)
yield x[sample_idx].T, y[sample_idx].T
#'''
# convert indices of alphabets into a string (word)
# return str(word)
#
#'''
#def decode_word(alpha_seq, idx2alpha):
# return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ])
#
#
#'''
# convert indices of phonemes into list of phonemes (as string)
# return str(phoneme_list)
#
#'''
#def decode_phonemes(pho_seq, idx2pho):
# return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ])
'''
a generic decode function
inputs : sequence, lookup
'''
def decode(sequence, lookup, separator=' '): # 0 used for padding, is ignored
return separator.join([ lookup[element] for element in sequence if element ])
def encode(sentence, lookup, maxlen, whitelist=EN_WHITELIST, separator=''):
# to lower case
sentence = sentence.lower()
# allow only characters that are on whitelist
sentence = ''.join( [ ch for ch in sentence if ch in whitelist ] )
# words to indices
indices_x = [ token for token in sentence.strip().split(' ') ]
# clip the sentence to fit model (#words)
indices_x = indices_x[-maxlen:] if len(indices_x) > maxlen else indices_x
# zero pad
idx_x = np.array(pad_seq(indices_x, lookup, maxlen))
# reshape
return idx_x.reshape([maxlen, 1])
'''
replace words with indices in a sequence
replace with unknown if word not in lookup
return [list of indices]
'''
def pad_seq(seq, lookup, maxlen):
indices = []
for word in seq:
if word in lookup:
indices.append(lookup[word])
else:
indices.append(lookup['unk'])
return indices + [0]*(maxlen - len(seq))