-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_input_pipeline.py
45 lines (35 loc) · 1.38 KB
/
data_input_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
"""Data_Input_Pipeline.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1GdZf6XcAK7K3OyqE2AYFZAflWsDoExIx
"""
import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.image import decode_png
from tensorflow.io import read_file
import os
def image_generator(test_size = 0.2,testing = False):
path = '/content/drive/My Drive/Flickr_Faces'
no_of_images = len(os.listdir(path))
if testing:
list_of_images = os.listdir(path)[int(no_of_images*test_size):]
else:
list_of_images = os.listdir(path)[:int(no_of_images*test_size)]
for filename in list_of_images:
img = os.path.join(path,filename)
img = read_file(img)
img = decode_png(img,channels = 3)
yield img/255
def get_train(batch_size,shuffle = False,test_size = 0.2):
train = Dataset.from_generator(image_generator,output_types= (tf.float32),output_shapes=(tf.TensorShape((512,512,3))),args = [test_size])
if shuffle:
train = train.shuffle(10)
train = train.repeat().batch(batch_size)
return train
def get_test(batch_size,shuffle = False,test_size = 0.2):
test = Dataset.from_generator(image_generator,output_types= (tf.float32),output_shapes=(tf.TensorShape((512,512,3))),args = [test_size,True])
if shuffle:
test = test.shuffle(10)
test = test.repeat().batch(batch_size)
return test