Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolai256 authored Oct 17, 2022
1 parent 3ab0fb6 commit de72a3c
Show file tree
Hide file tree
Showing 10 changed files with 1,139 additions and 279 deletions.
Binary file added 1 install guide.docx
Binary file not shown.
116 changes: 116 additions & 0 deletions deflicker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
deflicker.py
------------
Remove flicker from a series of images.
This scripts reads images from a specified directory to determine an RGB
"timeseries", smooths the RGB timeseries with a square filter of specified
width, and either outputs plots of the smoothed and unsmoothed RGB timeseries
or adjusts the RGB values of each image such that their RGB values match the
smoothed values.
To use this script, run ``python deflicker.py <directory> <width>
[options]``. ``<directory>`` should specify a path to a folder than contains
the image files that are to the deflickered. The image names must contain
numbers somewhere, and the images will included in the timeseries in ascending
numerical order. <width> specified the width (in images) of the square filter
used the smooth the image values. Other options include
``--plot <file>``:
do not output images with adjusted means; instead, print a plot
of the RGB timeseries before and after smoothing to a PNG image in
``<file>``. If ``<file>`` already exists, it may be overwritten.
``--outdir <output>``:
output images with adjusted means in the directory specified by
``<output>``. If the directory is the same as ``<directory>``, the
smoothing is done in-place and the input files are overwritten.
.. moduleauthor Tristan Abbott
"""

from libdeflicker import meanRGB, squareFilter, relaxToMean, toIntColor
import os
import re
import sys
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np

if __name__ == "__main__":

# Process input arguments
if len(sys.argv) < 3:
print ('Usage: python deflicker.py <directory> <width> [..]')
exit(0)
loc = sys.argv[1]
w = int(sys.argv[2])
__plot = False
__outdir = False

for ii in range(3, len(sys.argv)):
a = sys.argv[ii]
if a == '--plot':
__plot = True
__file = sys.argv[ii+1]
elif a == '--outdir':
__outdir = True
__output = sys.argv[ii+1]

# Just stop if not told to do anything
if not (__plot or __outdir):
print ('Exiting without doing anything')
exit(0)

# Get list of image names in order
loc = sys.argv[1]
f = os.listdir(loc)
n = []
ii = 0
while ii < len(f):
match = re.search('\d+', f[ii])
if match is not None:
n.append(int(match.group(0)))
ii += 1
else:
f.pop(ii)
n = np.array(n)
i = np.argsort(n)
f = [f[ii] for ii in i]

# Load images and calculate smoothed RGB curves
print ('Calculating smoothed sequence')
n = len(f)
rgb = np.zeros((n, 3))
ii = 0
for ff in f:
img = np.asarray(Image.open('%s/%s' % (loc, ff))) / 255.
rgb[ii,:] = meanRGB(img)
ii += 1

# Filter series
rgbi = np.zeros(rgb.shape)
for ii in range(0,3):
rgbi[:,ii] = squareFilter(rgb[:,ii], w)

# Print initial and filtered series
if __plot:
print ('Plotting smoothed and unsmoothed sequences in %s') % __file
plt.subplot(1, 2, 1)
plt.plot(rgb[:,0], 'r', rgb[:,1], 'g', rgb[:,2], 'b')
plt.title('Unfiltered RGB sequence')
plt.subplot(1, 2, 2)
plt.plot(rgbi[:,0], 'r', rgbi[:,1], 'g', rgbi[:,2], 'b')
plt.title('Filtered RGB sequence (w = %d)' % w)
plt.savefig(__file)

# Process images sequentially
if __outdir:
print ('Processing images')
ii = 0
for ff in f:
img = np.asarray(Image.open('%s/%s' % (loc, ff))) / 255.
relaxToMean(img, rgbi[ii,:])
jpg = Image.fromarray(toIntColor(img))
jpg.save('%s/%s' % (__output, ff))
ii += 1

print ('Finished')
26 changes: 17 additions & 9 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", help="checkpoint location", required=True)
parser.add_argument("--data_root", help="data root", required=True)
parser.add_argument("--dir_input", help="dir input", required=True)
#parser.add_argument("--data_root", help="data root", required=False)
#parser.add_argument("--dir_input", help="dir input", required=False)
parser.add_argument("--dir_x1", help="dir extra 1", required=False)
parser.add_argument("--dir_x2", help="dir extra 2", required=False)
parser.add_argument("--dir_x3", help="dir extra 3", required=False)
parser.add_argument("--outdir", help="output directory", required=True)
parser.add_argument("--device", help="device", required=True)
parser.add_argument("--channels", help="if you didn't use tools_all.py u can just use --channels 1, if you did use it use --channels 2", required=True)
parser.add_argument('--projectname', type=str, help='name of the project_', required=True)
args = parser.parse_args()

generator = (torch.load(args.checkpoint, map_location=lambda storage, loc: storage))

data_path = os.path.expanduser('~\Documents\\visionsofchaos\\fewshot\\data')
data_root = data_path + "\\" + args.projectname+"_gen"
dir_input = "input_filtered"
checkpoint = data_path + "\\" + "\\"+ args.projectname+"_train"+"\\"+"logs_reference_P"+"\\"+args.checkpoint

generator = (torch.load(checkpoint, map_location=lambda storage, loc: storage))
generator.eval()


if not os.path.exists(args.outdir):
os.mkdir(args.outdir)
Expand All @@ -35,10 +43,10 @@
if device.lower() != "cpu":
generator = generator.type(torch.half)
transform = build_transform()
dataset = DatasetFullImages(args.data_root + "/" + args.dir_input, "ignore", "ignore", device,
dir_x1=args.data_root + "/" + args.dir_x1 if args.dir_x1 is not None else None,
dir_x2=args.data_root + "/" + args.dir_x2 if args.dir_x2 is not None else None,
dir_x3=args.data_root + "/" + args.dir_x3 if args.dir_x3 is not None else None,
dataset = DatasetFullImages(data_root + "/" + dir_input, "ignore", "ignore", device,
dir_x1=data_root + "/" + args.dir_x1 if args.dir_x1 is not None else None,
dir_x2=data_root + "/" + args.dir_x2 if args.dir_x2 is not None else None,
dir_x3=data_root + "/" + args.dir_x3 if args.dir_x3 is not None else None,
dir_x4=None, dir_x5=None, dir_x6=None, dir_x7=None, dir_x8=None, dir_x9=None)

imloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4
Expand All @@ -56,7 +64,7 @@
#image_space_in = to_image_space(batch['image'].cpu().data.numpy())

#image_space = to_image_space(net_out.cpu().data.numpy())
image_space = ((net_out.clamp(-1, 1) + 1) * 127.5).permute((0, 2, 3, 1))
image_space = ((net_out.clamp(-1, 1) + 1) * 127.5).permute((0, int(args.channels), 3, 1))
image_space = image_space.cpu().data.numpy().astype(np.uint8)

for k in range(0, len(image_space)):
Expand Down
155 changes: 155 additions & 0 deletions libdeflicker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
libdeflicker.py
---------------
Library routines for image deflickering.
.. moduleauthor Tristan Abbott
"""

import numpy as np
from scipy import signal

def squareFilter(sig, w):
"""
squareFilter(sig, w)
--------------------
Smooth a signal with a square filter.
This function is just a wrapper for scipy.signal.convolve with a kernel
given by ``np.ones(w)/w``.
Parameters:
sig: np.array
Unsmoothed signal
w: int
Width of the filter
Returns:
np.array
Smoothed signal
"""
# Create filter
win = np.ones(w)
# Pad input
sigp = np.concatenate(([np.tile(sig[0], w//2), sig,
np.tile(sig[-1], w//2)]))
# Filter
return signal.convolve(sigp, win, mode =
'same')[w//2:-w//2+1] / np.sum(win)

# Compute image-mean RGB values
def meanRGB(img, ii = -1):
"""
meanRGB(img, ii = -1)
---------------------
Compute image-mean RGB values.
This function takes an np.array representation of an image (x and y in the
first two dimensions and RGB values along the third dimension) and computes
the image-average R,G, and B values.
Parameters:
img: np.array
Array image representation. The first two dimensions should
represent pixel positions, and each position in the third dimension
can represent a particular pixel attribute, e.g. an R, G, or B
value; an H, S, or V value, etc.
ii: int, optional
Specify a slice of the third dimension to average over. If a
particular slice is specified, the function returns a scalar;
otherwise, it returns an average over each slice in the third
dimension of the input image. ``ii`` must be between ``0`` and
``img.shape[2]``, inclusive.
Returns:
np.array
Average over the specified slice, if ``ii`` is given, or a 1D array
of average over the first two dimensions for each slice in the
third dimension.
"""
if ii < 0:
return np.array([np.mean(img[:,:,i]) for i in range(0,img.shape[2])])
else:
return np.mean(img[:,:,ii])

# Adjust pixel-by-pixel RGB values to converge to correct mean
# by multiplying them by a uniform value.
def relaxToMean(img, rgb):
"""
relaxToMean(img, rgb)
---------------------
Uniformly adjust pixel-by-pixel attributes so their mean becomes a
specified value.
The adjustment is done by multiplying pixel attributes by a scaling factor
that is unique to the attribute but uniform over all the pixels in the
image. This function assumes that each
attribute is described by a floating point number between 0 and 1,
inclusive, and it will stop individual pixels from moving outside this range
while others are being scaled.
Parameters:
img: np.array
Array image representation. The first two dimensions should
represent pixel positions, and each position in the third dimension
can represent a particular pixel attribute, e.g. an R, G, or B
value; an H, S, or V value, etc.
rgb: np.array
Desired image-mean values for each attribute included in ``img``.
The linear indices of the values in this array should map in order
to the attributes in the third dimension of ``img``.
Returns:
np.array
``img`` with each attribute multiplied by a factor (unique to the
attribute but the same for that attribute in every pixel in the
image) such that the image mean of that attribute is as specified
in ``rgb``.
"""
rgbi = meanRGB(img)
fac = np.array([2. if i else 0.5 for i in rgbi < rgb])

# Relax toward mean
for ii in range(0,3):

# Repeat until converged to mean
while not np.isclose(rgbi[ii], rgb[ii]):

# Compute ratio
r = rgb[ii] / rgbi[ii]
# Relax image
img[:,:,ii] = np.minimum(1., img[:,:,ii] * r)
# Update average
rgbi[ii] = meanRGB(img, ii)

# Convert floating point colors to integer colors
def toIntColor(img, t = np.uint8):
"""
toIntColor(img, t = np.uint8)
-----------------------------
Convert floating-point attributes to other types.
This function takes an image with floating-point [0,1] representations of
attributes and returns an near-equivalent image with attributes represented
by a different type. It does so by scaling the floating point attributes by
the maximum value representable by the new type and then converting the
scaled floating point value to the new type (with rounding, if required).
Parameters:
img: np.array
Array image representation. The first two dimensions should
represent pixel positions, and each position in the third dimension
can represent a particular pixel attribute, e.g. an R, G, or B
value; an H, S, or V value, etc. The attributes must be represented
as [0,1] floating point values.
t: type, optional
Type used to represent attributes in the new image. By default, the
type is an unsigned 8 bit integer (``np.uint8``).
Returns:
np.array(dtype = t)
Representation of the attributes of ``img`` using the type
specified by ``t``.
"""
scale = np.iinfo(t).max
return np.round(img * scale).astype(t)
36 changes: 36 additions & 0 deletions logger1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import tensorflow as tf
import os
import shutil


class Logger(object):
def __init__(self, log_dir, suffix=None):
"""Create a summary writer logging to log_dir."""
writer = tf.summary.create_file_writer(log_dir, filename_suffix=suffix)
with writer.as_default():
for step in range(100):
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
writer.flush()

def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)


class ModelLogger(object):
def __init__(self, log_dir, save_func):
self.log_dir = log_dir
self.save_func = save_func

def save(self, model, epoch, isGenerator):
if isGenerator:
new_path = os.path.join(self.log_dir, "model_%05d.pth" % epoch)
else:
new_path = os.path.join(self.log_dir, "disc_%05d.pth" % epoch)
self.save_func(model, new_path)

def copy_file(self, source):
shutil.copy(source, self.log_dir)

Loading

0 comments on commit de72a3c

Please sign in to comment.