Skip to content

Commit

Permalink
mnist
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzoValente3 committed Nov 29, 2021
1 parent edd5a1f commit d9fdee0
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 97 deletions.
186 changes: 114 additions & 72 deletions AE.ipynb

Large diffs are not rendered by default.

51 changes: 26 additions & 25 deletions MNIST_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
import sys
import tensorflow as tf
import scipy


# +------------------+
Expand All @@ -14,42 +15,46 @@ class MNISTData:
"""MNIST data class. You can adjust the data_fraction to use when creating
the data, according to your system capabilities."""

def __init__(self, gan=False, data_fraction=1.):
def __init__(self, data_fraction=1., zoom_factor = None):
data = mnist
(self.x_train, self.y_train), (self.x_test, self.y_test) = data.load_data()

self.get_subset_of_data(data_fraction)
if zoom_factor is not None:
self.interpolate(zoom_factor)

self.convert_label_to_categorical()

self.normalize_mnist_images(gan)
self.normalize_mnist_images()

self.reshape_to_color_channel(gan)
self.reshape_to_color_channel()



def interpolate(self, zoom_factor):
#self.x_train = scipy.ndimage.zoom(self.x_train,
# (1, zoom_factor, zoom_factor, 1))
#shape_train = self.x_train.shape[1]*self.x_train.shape[2]
#self.x_train.reshape([196,])


if gan is False:
self.flatten_pictures()
self.x_test = scipy.ndimage.zoom(self.x_test,
(1, zoom_factor, zoom_factor))

#shape_test = self.x_test.shape[1]*self.x_test.shape[2]
#self.x_test.reshape([196,])

def convert_label_to_categorical(self):
self.y_train = to_categorical(self.y_train)
self.y_test = to_categorical(self.y_test)

def normalize_mnist_images(self, gan):
if gan:
"""normalize the images to [-1, 1]"""
self.x_train = (self.x_train - 127.5) / 127.5
self.x_test = (self.x_test - 127.5) / 127.5
else:
self.x_train = self.x_train / 255.0
self.x_test = self.x_test / 255.0

def normalize_mnist_images(self):
self.x_train = self.x_train / 255.0
self.x_test = self.x_test / 255.0

def reshape_to_color_channel(self, gan):
if gan:
self.x_train = self.x_train.reshape(self.x_train.shape[0], 28, 28, 1).astype('float32')
self.x_test = self.x_test.reshape(self.x_test.shape[0], 28, 28, 1).astype('float32')
else:
self.x_train = self.x_train[:, :, :, np.newaxis]
self.x_test = self.x_test[:, :, :, np.newaxis]
def reshape_to_color_channel(self):
self.x_train = self.x_train[:, :, :, np.newaxis]
self.x_test = self.x_test[:, :, :, np.newaxis]

def get_subset_of_data(self, data_fraction):
"""Choosing a fraction of data according to the machine capabilities"""
Expand All @@ -59,7 +64,3 @@ def get_subset_of_data(self, data_fraction):
index = int(len(self.x_test) * data_fraction)
self.x_test = self.x_test[:index]
self.y_test = self.y_test[:index]

def flatten_pictures(self):
self.x_train = self.x_train.reshape(self.x_train.shape[0], -1)
self.x_test = self.x_test.reshape(self.x_test.shape[0], -1)
Binary file modified __pycache__/MNIST_dataset.cpython-38.pyc
Binary file not shown.
Binary file modified images/AE/Accuracy of Autoencoder without classifier.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/AE/reconstructed images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified model_1/KERAS_check_best_model.h5
Binary file not shown.

0 comments on commit d9fdee0

Please sign in to comment.