forked from udacity/deep-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
23 lines (18 loc) · 834 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
def split_data(chars, batch_size, num_steps, split_frac=0.9):
slice_size = batch_size * num_steps
n_batches = int(len(chars) / slice_size)
x = chars[: n_batches*slice_size]
y = chars[1: n_batches*slice_size + 1]
split_idx = int(batch_size*split_frac)
x = np.stack(np.split(x, batch_size))
y = np.stack(np.split(y, batch_size))
split_idx = int(n_batches*split_frac)
train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
return train_x, train_y, val_x, val_y
def get_batch(arrs, num_steps):
batch_size, slice_size = arrs[0].shape
n_batches = int(slice_size/num_steps)
for b in range(n_batches):
yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]