diff --git a/1 install guide.docx b/1 install guide.docx new file mode 100644 index 0000000..fbc8052 Binary files /dev/null and b/1 install guide.docx differ diff --git a/deflicker.py b/deflicker.py new file mode 100644 index 0000000..705d280 --- /dev/null +++ b/deflicker.py @@ -0,0 +1,116 @@ +""" +deflicker.py +------------ +Remove flicker from a series of images. + +This scripts reads images from a specified directory to determine an RGB +"timeseries", smooths the RGB timeseries with a square filter of specified +width, and either outputs plots of the smoothed and unsmoothed RGB timeseries +or adjusts the RGB values of each image such that their RGB values match the +smoothed values. + +To use this script, run ``python deflicker.py +[options]``. ```` should specify a path to a folder than contains +the image files that are to the deflickered. The image names must contain +numbers somewhere, and the images will included in the timeseries in ascending +numerical order. specified the width (in images) of the square filter +used the smooth the image values. Other options include + ``--plot ``: + do not output images with adjusted means; instead, print a plot + of the RGB timeseries before and after smoothing to a PNG image in + ````. If ```` already exists, it may be overwritten. + ``--outdir ``: + output images with adjusted means in the directory specified by + ````. If the directory is the same as ````, the + smoothing is done in-place and the input files are overwritten. + +.. moduleauthor Tristan Abbott +""" + +from libdeflicker import meanRGB, squareFilter, relaxToMean, toIntColor +import os +import re +import sys +from PIL import Image +from matplotlib import pyplot as plt +import numpy as np + +if __name__ == "__main__": + + # Process input arguments + if len(sys.argv) < 3: + print ('Usage: python deflicker.py [..]') + exit(0) + loc = sys.argv[1] + w = int(sys.argv[2]) + __plot = False + __outdir = False + + for ii in range(3, len(sys.argv)): + a = sys.argv[ii] + if a == '--plot': + __plot = True + __file = sys.argv[ii+1] + elif a == '--outdir': + __outdir = True + __output = sys.argv[ii+1] + + # Just stop if not told to do anything + if not (__plot or __outdir): + print ('Exiting without doing anything') + exit(0) + + # Get list of image names in order + loc = sys.argv[1] + f = os.listdir(loc) + n = [] + ii = 0 + while ii < len(f): + match = re.search('\d+', f[ii]) + if match is not None: + n.append(int(match.group(0))) + ii += 1 + else: + f.pop(ii) + n = np.array(n) + i = np.argsort(n) + f = [f[ii] for ii in i] + + # Load images and calculate smoothed RGB curves + print ('Calculating smoothed sequence') + n = len(f) + rgb = np.zeros((n, 3)) + ii = 0 + for ff in f: + img = np.asarray(Image.open('%s/%s' % (loc, ff))) / 255. + rgb[ii,:] = meanRGB(img) + ii += 1 + + # Filter series + rgbi = np.zeros(rgb.shape) + for ii in range(0,3): + rgbi[:,ii] = squareFilter(rgb[:,ii], w) + + # Print initial and filtered series + if __plot: + print ('Plotting smoothed and unsmoothed sequences in %s') % __file + plt.subplot(1, 2, 1) + plt.plot(rgb[:,0], 'r', rgb[:,1], 'g', rgb[:,2], 'b') + plt.title('Unfiltered RGB sequence') + plt.subplot(1, 2, 2) + plt.plot(rgbi[:,0], 'r', rgbi[:,1], 'g', rgbi[:,2], 'b') + plt.title('Filtered RGB sequence (w = %d)' % w) + plt.savefig(__file) + + # Process images sequentially + if __outdir: + print ('Processing images') + ii = 0 + for ff in f: + img = np.asarray(Image.open('%s/%s' % (loc, ff))) / 255. + relaxToMean(img, rgbi[ii,:]) + jpg = Image.fromarray(toIntColor(img)) + jpg.save('%s/%s' % (__output, ff)) + ii += 1 + + print ('Finished') diff --git a/generate.py b/generate.py index 6762856..ce8e1e4 100644 --- a/generate.py +++ b/generate.py @@ -13,17 +13,25 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", help="checkpoint location", required=True) - parser.add_argument("--data_root", help="data root", required=True) - parser.add_argument("--dir_input", help="dir input", required=True) + #parser.add_argument("--data_root", help="data root", required=False) + #parser.add_argument("--dir_input", help="dir input", required=False) parser.add_argument("--dir_x1", help="dir extra 1", required=False) parser.add_argument("--dir_x2", help="dir extra 2", required=False) parser.add_argument("--dir_x3", help="dir extra 3", required=False) parser.add_argument("--outdir", help="output directory", required=True) parser.add_argument("--device", help="device", required=True) + parser.add_argument("--channels", help="if you didn't use tools_all.py u can just use --channels 1, if you did use it use --channels 2", required=True) + parser.add_argument('--projectname', type=str, help='name of the project_', required=True) args = parser.parse_args() - - generator = (torch.load(args.checkpoint, map_location=lambda storage, loc: storage)) + + data_path = os.path.expanduser('~\Documents\\visionsofchaos\\fewshot\\data') + data_root = data_path + "\\" + args.projectname+"_gen" + dir_input = "input_filtered" + checkpoint = data_path + "\\" + "\\"+ args.projectname+"_train"+"\\"+"logs_reference_P"+"\\"+args.checkpoint + + generator = (torch.load(checkpoint, map_location=lambda storage, loc: storage)) generator.eval() + if not os.path.exists(args.outdir): os.mkdir(args.outdir) @@ -35,10 +43,10 @@ if device.lower() != "cpu": generator = generator.type(torch.half) transform = build_transform() - dataset = DatasetFullImages(args.data_root + "/" + args.dir_input, "ignore", "ignore", device, - dir_x1=args.data_root + "/" + args.dir_x1 if args.dir_x1 is not None else None, - dir_x2=args.data_root + "/" + args.dir_x2 if args.dir_x2 is not None else None, - dir_x3=args.data_root + "/" + args.dir_x3 if args.dir_x3 is not None else None, + dataset = DatasetFullImages(data_root + "/" + dir_input, "ignore", "ignore", device, + dir_x1=data_root + "/" + args.dir_x1 if args.dir_x1 is not None else None, + dir_x2=data_root + "/" + args.dir_x2 if args.dir_x2 is not None else None, + dir_x3=data_root + "/" + args.dir_x3 if args.dir_x3 is not None else None, dir_x4=None, dir_x5=None, dir_x6=None, dir_x7=None, dir_x8=None, dir_x9=None) imloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 @@ -56,7 +64,7 @@ #image_space_in = to_image_space(batch['image'].cpu().data.numpy()) #image_space = to_image_space(net_out.cpu().data.numpy()) - image_space = ((net_out.clamp(-1, 1) + 1) * 127.5).permute((0, 2, 3, 1)) + image_space = ((net_out.clamp(-1, 1) + 1) * 127.5).permute((0, int(args.channels), 3, 1)) image_space = image_space.cpu().data.numpy().astype(np.uint8) for k in range(0, len(image_space)): diff --git a/libdeflicker.py b/libdeflicker.py new file mode 100644 index 0000000..77563d5 --- /dev/null +++ b/libdeflicker.py @@ -0,0 +1,155 @@ +""" +libdeflicker.py +--------------- +Library routines for image deflickering. + +.. moduleauthor Tristan Abbott +""" + +import numpy as np +from scipy import signal + +def squareFilter(sig, w): + """ + squareFilter(sig, w) + -------------------- + Smooth a signal with a square filter. + + This function is just a wrapper for scipy.signal.convolve with a kernel + given by ``np.ones(w)/w``. + + Parameters: + sig: np.array + Unsmoothed signal + w: int + Width of the filter + + Returns: + np.array + Smoothed signal + """ + # Create filter + win = np.ones(w) + # Pad input + sigp = np.concatenate(([np.tile(sig[0], w//2), sig, + np.tile(sig[-1], w//2)])) + # Filter + return signal.convolve(sigp, win, mode = + 'same')[w//2:-w//2+1] / np.sum(win) + +# Compute image-mean RGB values +def meanRGB(img, ii = -1): + """ + meanRGB(img, ii = -1) + --------------------- + Compute image-mean RGB values. + + This function takes an np.array representation of an image (x and y in the + first two dimensions and RGB values along the third dimension) and computes + the image-average R,G, and B values. + + Parameters: + img: np.array + Array image representation. The first two dimensions should + represent pixel positions, and each position in the third dimension + can represent a particular pixel attribute, e.g. an R, G, or B + value; an H, S, or V value, etc. + ii: int, optional + Specify a slice of the third dimension to average over. If a + particular slice is specified, the function returns a scalar; + otherwise, it returns an average over each slice in the third + dimension of the input image. ``ii`` must be between ``0`` and + ``img.shape[2]``, inclusive. + + Returns: + np.array + Average over the specified slice, if ``ii`` is given, or a 1D array + of average over the first two dimensions for each slice in the + third dimension. + """ + if ii < 0: + return np.array([np.mean(img[:,:,i]) for i in range(0,img.shape[2])]) + else: + return np.mean(img[:,:,ii]) + +# Adjust pixel-by-pixel RGB values to converge to correct mean +# by multiplying them by a uniform value. +def relaxToMean(img, rgb): + """ + relaxToMean(img, rgb) + --------------------- + Uniformly adjust pixel-by-pixel attributes so their mean becomes a + specified value. + + The adjustment is done by multiplying pixel attributes by a scaling factor + that is unique to the attribute but uniform over all the pixels in the + image. This function assumes that each + attribute is described by a floating point number between 0 and 1, + inclusive, and it will stop individual pixels from moving outside this range + while others are being scaled. + + Parameters: + img: np.array + Array image representation. The first two dimensions should + represent pixel positions, and each position in the third dimension + can represent a particular pixel attribute, e.g. an R, G, or B + value; an H, S, or V value, etc. + rgb: np.array + Desired image-mean values for each attribute included in ``img``. + The linear indices of the values in this array should map in order + to the attributes in the third dimension of ``img``. + + Returns: + np.array + ``img`` with each attribute multiplied by a factor (unique to the + attribute but the same for that attribute in every pixel in the + image) such that the image mean of that attribute is as specified + in ``rgb``. + + """ + rgbi = meanRGB(img) + fac = np.array([2. if i else 0.5 for i in rgbi < rgb]) + + # Relax toward mean + for ii in range(0,3): + + # Repeat until converged to mean + while not np.isclose(rgbi[ii], rgb[ii]): + + # Compute ratio + r = rgb[ii] / rgbi[ii] + # Relax image + img[:,:,ii] = np.minimum(1., img[:,:,ii] * r) + # Update average + rgbi[ii] = meanRGB(img, ii) + +# Convert floating point colors to integer colors +def toIntColor(img, t = np.uint8): + """ + toIntColor(img, t = np.uint8) + ----------------------------- + Convert floating-point attributes to other types. + + This function takes an image with floating-point [0,1] representations of + attributes and returns an near-equivalent image with attributes represented + by a different type. It does so by scaling the floating point attributes by + the maximum value representable by the new type and then converting the + scaled floating point value to the new type (with rounding, if required). + + Parameters: + img: np.array + Array image representation. The first two dimensions should + represent pixel positions, and each position in the third dimension + can represent a particular pixel attribute, e.g. an R, G, or B + value; an H, S, or V value, etc. The attributes must be represented + as [0,1] floating point values. + t: type, optional + Type used to represent attributes in the new image. By default, the + type is an unsigned 8 bit integer (``np.uint8``). + Returns: + np.array(dtype = t) + Representation of the attributes of ``img`` using the type + specified by ``t``. + """ + scale = np.iinfo(t).max + return np.round(img * scale).astype(t) diff --git a/logger1.py b/logger1.py new file mode 100644 index 0000000..29833b4 --- /dev/null +++ b/logger1.py @@ -0,0 +1,36 @@ +import tensorflow as tf +import os +import shutil + + +class Logger(object): + def __init__(self, log_dir, suffix=None): + """Create a summary writer logging to log_dir.""" + writer = tf.summary.create_file_writer(log_dir, filename_suffix=suffix) + with writer.as_default(): + for step in range(100): + # other model code would go here + tf.summary.scalar("my_metric", 0.5, step=step) + writer.flush() + + def scalar_summary(self, tag, value, step): + """Log a scalar variable.""" + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + +class ModelLogger(object): + def __init__(self, log_dir, save_func): + self.log_dir = log_dir + self.save_func = save_func + + def save(self, model, epoch, isGenerator): + if isGenerator: + new_path = os.path.join(self.log_dir, "model_%05d.pth" % epoch) + else: + new_path = os.path.join(self.log_dir, "disc_%05d.pth" % epoch) + self.save_func(model, new_path) + + def copy_file(self, source): + shutil.copy(source, self.log_dir) + diff --git a/train.py b/train.py index dc62b7e..ac7176d 100644 --- a/train.py +++ b/train.py @@ -37,12 +37,18 @@ def worker_init_fn(worker_id): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', help='Yaml config with training parameters', required=True) - parser.add_argument('--log_folder', '-l', help='Log folder', required=True) - parser.add_argument('--data_root', '-r', help='Data root folder', required=True) + parser.add_argument('--log_folder', '-l', help='Log folder', default="logs_reference_P") + parser.add_argument('--data_root', '-r', help='Data root folder', required=False) parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True) + parser.add_argument('--resume', '-rs', type=str, help='resume', required=False) + parser.add_argument('--projectname', type=str, help='name of the project_', required=True) args = parser.parse_args() - args_log_folder = args.data_root + "/" + args.log_folder + + + data_path = os.path.expanduser('~\Documents\\visionsofchaos\\fewshot\\data') + data_root = data_path + "\\" + args.projectname+"_train" + args_log_folder = data_root + "/" + "logs_reference_P" with open(args.config, 'r') as f: job_description = yaml.load(f, Loader=yaml.FullLoader) @@ -60,24 +66,25 @@ def worker_init_fn(worker_id): raise RuntimeError("Got unexpected parameter in training_dataset: " + str(training_dataset_parameters)) d = dict(config['training_dataset']) - d['dir_pre'] = args.data_root + "/" + d['dir_pre'] - d['dir_post'] = args.data_root + "/" + d['dir_post'] + d['dir_pre'] = data_root + "/" + d['dir_pre'] + d['dir_post'] = data_root + "/" + d['dir_post'] d['device'] = config['device'] if 'dir_mask' in d: - d['dir_mask'] = args.data_root + "/" + d['dir_mask'] + d['dir_mask'] = data_root + "/" + d['dir_mask'] # complete dir_x paths and set a correct number of channels channels = 3 for dir_x_index in range(1, 10): dir_x_name = f"dir_x{dir_x_index}" - d[dir_x_name] = args.data_root + "/" + d[dir_x_name] if dir_x_name in d else None + d[dir_x_name] = data_root + "/" + d[dir_x_name] if dir_x_name in d else None channels = channels + 3 if d[dir_x_name] is not None else channels config['generator']['args']['input_channels'] = channels print(d) - + resumedata = str(args.resume) generator = build_model(config['generator']['type'], config['generator']['args'], device) - #generator = (torch.load(args.data_root + "/model_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + if args.resume: + generator = (torch.load(data_root + "/"+"/logs_reference_P"+"/"+resumedata+".pth", map_location=lambda storage, loc: storage)).to(device) opt_generator = build_optimizer(config['opt_generator']['type'], generator, config['opt_generator']['args']) discriminator, opt_discriminator = None, None @@ -127,6 +134,6 @@ def worker_init_fn(worker_id): args_config = args.config.replace('\\', '/') args_config = args_config[args_config.rfind('/') + 1:] - trainer.train(generator, discriminator, int(config['trainer']['epochs']), args.data_root, args_config, 0) + trainer.train(generator, discriminator, int(config['trainer']['epochs']), data_root, args_config, 0) print("Training finished", flush=True) sys.exit(0) diff --git a/train1.py b/train1.py new file mode 100644 index 0000000..dc62b7e --- /dev/null +++ b/train1.py @@ -0,0 +1,132 @@ +import argparse +import os +import data +import models as m +import torch +import torch.optim as optim +import yaml +from logger import Logger, ModelLogger +from trainers import Trainer +import sys +import numpy as np + + +def build_model(model_type, args, device): + model = getattr(m, model_type)(**args) + return model.to(device) + + +def build_optimizer(opt_type, model, args): + args['params'] = model.parameters() + opt_class = getattr(optim, opt_type) + return opt_class(**args) + + +def build_loggers(log_folder): + if not os.path.exists(log_folder): + os.makedirs(log_folder) + model_logger = ModelLogger(log_folder, torch.save) + scalar_logger = Logger(log_folder) + return scalar_logger, model_logger + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', help='Yaml config with training parameters', required=True) + parser.add_argument('--log_folder', '-l', help='Log folder', required=True) + parser.add_argument('--data_root', '-r', help='Data root folder', required=True) + parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True) + args = parser.parse_args() + + args_log_folder = args.data_root + "/" + args.log_folder + + with open(args.config, 'r') as f: + job_description = yaml.load(f, Loader=yaml.FullLoader) + + config = job_description['job'] + scalar_logger, model_logger = build_loggers(args_log_folder) + + model_logger.copy_file(args.config) + device = config.get('device') or 'cpu' + + # Check 'training_dataset' parameters + training_dataset_parameters = set(config['training_dataset'].keys()) - \ + {"type", "dir_pre", "dir_post", "dir_mask", "patch_size", "dir_x1", "dir_x2", "dir_x3", "dir_x4", "dir_x5", "dir_x6", "dir_x7", "dir_x8", "dir_x9", } + if len(training_dataset_parameters) > 0: + raise RuntimeError("Got unexpected parameter in training_dataset: " + str(training_dataset_parameters)) + + d = dict(config['training_dataset']) + d['dir_pre'] = args.data_root + "/" + d['dir_pre'] + d['dir_post'] = args.data_root + "/" + d['dir_post'] + d['device'] = config['device'] + if 'dir_mask' in d: + d['dir_mask'] = args.data_root + "/" + d['dir_mask'] + + # complete dir_x paths and set a correct number of channels + channels = 3 + for dir_x_index in range(1, 10): + dir_x_name = f"dir_x{dir_x_index}" + d[dir_x_name] = args.data_root + "/" + d[dir_x_name] if dir_x_name in d else None + channels = channels + 3 if d[dir_x_name] is not None else channels + config['generator']['args']['input_channels'] = channels + + print(d) + + generator = build_model(config['generator']['type'], config['generator']['args'], device) + #generator = (torch.load(args.data_root + "/model_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_generator = build_optimizer(config['opt_generator']['type'], generator, config['opt_generator']['args']) + + discriminator, opt_discriminator = None, None + if 'discriminator' in config: + discriminator = build_model(config['discriminator']['type'], config['discriminator']['args'], device) + #discriminator = (torch.load(args.data_root + "/disc_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_discriminator = build_optimizer(config['opt_discriminator']['type'], discriminator, config['opt_discriminator']['args']) + + if 'type' not in d: + raise RuntimeError("Type of training_dataset must be specified!") + + dataset_type = getattr(data, d.pop('type')) + training_dataset = dataset_type(**d) + + train_loader = torch.utils.data.DataLoader(training_dataset, config['trainer']['batch_size'], shuffle=False, + num_workers=config['num_workers'], drop_last=True)#, worker_init_fn=worker_init_fn) + + reconstruction_criterion = getattr(torch.nn, config['trainer']['reconstruction_criterion'])() + adversarial_criterion = getattr(torch.nn, config['trainer']['adversarial_criterion'])() + + perception_loss_model = None + perception_loss_weight = 1 + if 'perception_loss' in config: + if 'perception_model' in config['perception_loss']: + perception_loss_model = build_model(config['perception_loss']['perception_model']['type'], + config['perception_loss']['perception_model']['args'], + device) + else: + perception_loss_model = discriminator + + perception_loss_weight = config['perception_loss']['weight'] + + trainer = Trainer( + train_loader=train_loader, + data_for_dataloader=d, # data for later dataloader creation, if needed + opt_generator=opt_generator, opt_discriminator=opt_discriminator, + adversarial_criterion=adversarial_criterion, reconstruction_criterion=reconstruction_criterion, + reconstruction_weight=config['trainer']['reconstruction_weight'], + adversarial_weight=config['trainer']['adversarial_weight'], + log_interval=args.log_interval, + model_logger=model_logger, scalar_logger=scalar_logger, + perception_loss_model=perception_loss_model, + perception_loss_weight=perception_loss_weight, + use_image_loss=config['trainer']['use_image_loss'], + device=device + ) + + args_config = args.config.replace('\\', '/') + args_config = args_config[args_config.rfind('/') + 1:] + trainer.train(generator, discriminator, int(config['trainer']['epochs']), args.data_root, args_config, 0) + print("Training finished", flush=True) + sys.exit(0) diff --git a/train2.py b/train2.py new file mode 100644 index 0000000..dc62b7e --- /dev/null +++ b/train2.py @@ -0,0 +1,132 @@ +import argparse +import os +import data +import models as m +import torch +import torch.optim as optim +import yaml +from logger import Logger, ModelLogger +from trainers import Trainer +import sys +import numpy as np + + +def build_model(model_type, args, device): + model = getattr(m, model_type)(**args) + return model.to(device) + + +def build_optimizer(opt_type, model, args): + args['params'] = model.parameters() + opt_class = getattr(optim, opt_type) + return opt_class(**args) + + +def build_loggers(log_folder): + if not os.path.exists(log_folder): + os.makedirs(log_folder) + model_logger = ModelLogger(log_folder, torch.save) + scalar_logger = Logger(log_folder) + return scalar_logger, model_logger + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', help='Yaml config with training parameters', required=True) + parser.add_argument('--log_folder', '-l', help='Log folder', required=True) + parser.add_argument('--data_root', '-r', help='Data root folder', required=True) + parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True) + args = parser.parse_args() + + args_log_folder = args.data_root + "/" + args.log_folder + + with open(args.config, 'r') as f: + job_description = yaml.load(f, Loader=yaml.FullLoader) + + config = job_description['job'] + scalar_logger, model_logger = build_loggers(args_log_folder) + + model_logger.copy_file(args.config) + device = config.get('device') or 'cpu' + + # Check 'training_dataset' parameters + training_dataset_parameters = set(config['training_dataset'].keys()) - \ + {"type", "dir_pre", "dir_post", "dir_mask", "patch_size", "dir_x1", "dir_x2", "dir_x3", "dir_x4", "dir_x5", "dir_x6", "dir_x7", "dir_x8", "dir_x9", } + if len(training_dataset_parameters) > 0: + raise RuntimeError("Got unexpected parameter in training_dataset: " + str(training_dataset_parameters)) + + d = dict(config['training_dataset']) + d['dir_pre'] = args.data_root + "/" + d['dir_pre'] + d['dir_post'] = args.data_root + "/" + d['dir_post'] + d['device'] = config['device'] + if 'dir_mask' in d: + d['dir_mask'] = args.data_root + "/" + d['dir_mask'] + + # complete dir_x paths and set a correct number of channels + channels = 3 + for dir_x_index in range(1, 10): + dir_x_name = f"dir_x{dir_x_index}" + d[dir_x_name] = args.data_root + "/" + d[dir_x_name] if dir_x_name in d else None + channels = channels + 3 if d[dir_x_name] is not None else channels + config['generator']['args']['input_channels'] = channels + + print(d) + + generator = build_model(config['generator']['type'], config['generator']['args'], device) + #generator = (torch.load(args.data_root + "/model_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_generator = build_optimizer(config['opt_generator']['type'], generator, config['opt_generator']['args']) + + discriminator, opt_discriminator = None, None + if 'discriminator' in config: + discriminator = build_model(config['discriminator']['type'], config['discriminator']['args'], device) + #discriminator = (torch.load(args.data_root + "/disc_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_discriminator = build_optimizer(config['opt_discriminator']['type'], discriminator, config['opt_discriminator']['args']) + + if 'type' not in d: + raise RuntimeError("Type of training_dataset must be specified!") + + dataset_type = getattr(data, d.pop('type')) + training_dataset = dataset_type(**d) + + train_loader = torch.utils.data.DataLoader(training_dataset, config['trainer']['batch_size'], shuffle=False, + num_workers=config['num_workers'], drop_last=True)#, worker_init_fn=worker_init_fn) + + reconstruction_criterion = getattr(torch.nn, config['trainer']['reconstruction_criterion'])() + adversarial_criterion = getattr(torch.nn, config['trainer']['adversarial_criterion'])() + + perception_loss_model = None + perception_loss_weight = 1 + if 'perception_loss' in config: + if 'perception_model' in config['perception_loss']: + perception_loss_model = build_model(config['perception_loss']['perception_model']['type'], + config['perception_loss']['perception_model']['args'], + device) + else: + perception_loss_model = discriminator + + perception_loss_weight = config['perception_loss']['weight'] + + trainer = Trainer( + train_loader=train_loader, + data_for_dataloader=d, # data for later dataloader creation, if needed + opt_generator=opt_generator, opt_discriminator=opt_discriminator, + adversarial_criterion=adversarial_criterion, reconstruction_criterion=reconstruction_criterion, + reconstruction_weight=config['trainer']['reconstruction_weight'], + adversarial_weight=config['trainer']['adversarial_weight'], + log_interval=args.log_interval, + model_logger=model_logger, scalar_logger=scalar_logger, + perception_loss_model=perception_loss_model, + perception_loss_weight=perception_loss_weight, + use_image_loss=config['trainer']['use_image_loss'], + device=device + ) + + args_config = args.config.replace('\\', '/') + args_config = args_config[args_config.rfind('/') + 1:] + trainer.train(generator, discriminator, int(config['trainer']['epochs']), args.data_root, args_config, 0) + print("Training finished", flush=True) + sys.exit(0) diff --git a/trainers.py b/trainers.py index 9be770b..319231d 100644 --- a/trainers.py +++ b/trainers.py @@ -1,260 +1,267 @@ -import time -import models -import numpy as np -import six -import torch -import torch.nn as nn -from torch.autograd import Variable -from PIL import Image -from custom_transforms import * -from data import DatasetFullImages -import os - - -class Trainer(object): - def __init__(self, - train_loader, data_for_dataloader, opt_discriminator, opt_generator, - reconstruction_criterion, adversarial_criterion, reconstruction_weight, - adversarial_weight, log_interval, scalar_logger, model_logger, - perception_loss_model, perception_loss_weight, use_image_loss, device - ): - - self.train_loader = train_loader - self.data_for_dataloader = data_for_dataloader - - self.opt_discriminator = opt_discriminator - self.opt_generator = opt_generator - - self.reconstruction_criterion = reconstruction_criterion - self.adversarial_criterion = adversarial_criterion - - self.reconstruction_weight = reconstruction_weight - self.adversarial_weight = adversarial_weight - - self.scalar_logger = scalar_logger - self.model_logger = model_logger - - self.training_log = {} - self.log_interval = log_interval - - self.perception_loss_weight = perception_loss_weight - self.perception_loss_model = perception_loss_model - - self.use_adversarial_loss = False - self.use_image_loss = use_image_loss - self.device = device - - self.dataset = None - self.imloader = None - - - def run_discriminator(self, discriminator, images): - return discriminator(images) - - def compute_discriminator_loss(self, generator, discriminator, batch): - generated = generator(batch['pre']) - fake = self.apply_mask(generated, batch, 'pre_mask') - fake_labels, _ = self.run_discriminator(discriminator, fake.detach()) - - true = self.apply_mask(batch['already'], batch, 'already_mask') - true_labels, _ = self.run_discriminator(discriminator, true) - - discriminator_loss = self.adversarial_criterion(fake_labels, self.zeros_like(fake_labels)) + \ - self.adversarial_criterion(true_labels, self.ones_like(true_labels)) - - return discriminator_loss - - def compute_generator_loss(self, generator, discriminator, batch, use_gan, use_mask): - image_loss = 0 - perception_loss = 0 - adversarial_loss = 0 - - generated = generator(batch['pre']) - - if use_mask: - generated = generated * batch['mask'] - batch['post'] = batch['post'] * batch['mask'] - - if self.use_image_loss: - if generated[0][0].shape != batch['post'][0][0].shape: - if ((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) % 2) != 0: - raise RuntimeError("batch['post'][0][0].shape[0] - generated[0][0].shape[0] must be even number") - if generated[0][0].shape[0] != generated[0][0].shape[1] or batch['post'][0][0].shape[0] != batch['post'][0][0].shape[1]: - raise RuntimeError("And also it is expected to be exact square ... fix it if you want") - boundary_size = int((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) / 2) - cropped_batch_post = batch['post'][:, :, boundary_size: -1*boundary_size, boundary_size: -1*boundary_size] - image_loss = self.reconstruction_criterion(generated, cropped_batch_post) - else: - image_loss = self.reconstruction_criterion(generated, batch['post']) - - if self.perception_loss_model is not None: - _, fake_features = self.perception_loss_model(generated) - _, target_features = self.perception_loss_model(Variable(batch['post'], requires_grad=False)) - perception_loss = ((fake_features - target_features) ** 2).mean() - - - if self.use_adversarial_loss and use_gan: - fake = self.apply_mask(generated, batch, 'pre_mask') - fake_smiling_labels, _ = self.run_discriminator(discriminator, fake) - adversarial_loss = self.adversarial_criterion(fake_smiling_labels, self.ones_like(fake_smiling_labels)) - - return image_loss, perception_loss, adversarial_loss, generated - - - def train(self, generator, discriminator, epochs, data_root, config_yaml_name, starting_batch_num): - self.use_adversarial_loss = discriminator is not None - batch_num = starting_batch_num - save_num = 0 - - start = time.time() - for epoch in range(epochs): - np.random.seed() - for i, batch in enumerate(self.train_loader): - # just sets the models into training mode (enable BN and DO) - [m.train() for m in [generator, discriminator] if m is not None] - batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] - for k in batch.keys()} - - # train discriminator - if self.use_adversarial_loss: - self.opt_discriminator.zero_grad() - discriminator_loss = self.compute_discriminator_loss(generator, discriminator, batch) - discriminator_loss.backward() - self.opt_discriminator.step() - - # train generator - self.opt_generator.zero_grad() - - g_image_loss, g_perc_loss, g_adv_loss, _ = self.compute_generator_loss(generator, discriminator, batch, use_gan=True, use_mask=False) - - generator_loss = self.reconstruction_weight * g_image_loss + \ - self.perception_loss_weight * g_perc_loss + \ - self.adversarial_weight * g_adv_loss - - generator_loss.backward() - - self.opt_generator.step() - - # log losses - current_log = {key: value.item() for key, value in six.iteritems(locals()) if - 'loss' in key and isinstance(value, Variable)} - - self.add_log(current_log) - - batch_num += 1 - - if batch_num % 100 == 0: - print(f"Batch num: {batch_num}, totally elapsed {(time.time() - start)}", flush=True) - - #if batch_num % self.log_interval == 0 or batch_num == 1: - if batch_num % self.log_interval == 0 or batch_num == 1: # (time.time() - start) > 16: - eval_start = time.time() - generator.eval() - self.test_on_full_image(generator, batch_num, data_root, config_yaml_name) - self.flush_scalar_log(batch_num, time.time() - start) - self.model_logger.save(generator, save_num, True) - #self.model_logger.save(discriminator, save_num, False) - save_num += 1 - print(f"Eval of batch: {batch_num} took {(time.time() - eval_start)}", flush=True) - - #if batch_num > 5000: - # sys.exit(0) - - self.model_logger.save(generator, 99999) - - # Accumulates the losses - def add_log(self, log): - for k, v in log.items(): - if k in self.training_log: - self.training_log[k] += v - else: - self.training_log[k] = v - - # Divide the losses by log_interval and print'em - def flush_scalar_log(self, batch_num, took): - for key in self.training_log.keys(): - self.scalar_logger.scalar_summary(key, self.training_log[key] / self.log_interval, batch_num) - - log = "[%d]" % batch_num - for key in sorted(self.training_log.keys()): - log += " [%s] % 7.4f" % (key, self.training_log[key] / self.log_interval) - - log += ". Took {}".format(took) - print(log, flush=True) - self.training_log = {} - - # Test the intermediate model on data from _gen folder - def test_on_full_image(self, generator, batch_num, data_root, config_yaml_name): - config_yaml_name = config_yaml_name.replace("reference", "").replace(".yaml", "") - - data_root = data_root.replace("_train", "_gen") - if self.dataset is None: - self.dataset = DatasetFullImages(data_root + "/" + self.data_for_dataloader['dir_pre'].split("/")[-1], - "ignore", # data_root + "/" + "ebsynth", - "ignore", # data_root + "/" + "mask", - self.device, - dir_x1=data_root + "/" + self.data_for_dataloader['dir_x1'].split("/")[-1] if self.data_for_dataloader['dir_x1'] is not None else None, - dir_x2=data_root + "/" + self.data_for_dataloader['dir_x2'].split("/")[-1] if self.data_for_dataloader['dir_x2'] is not None else None, - dir_x3=data_root + "/" + self.data_for_dataloader['dir_x3'].split("/")[-1] if self.data_for_dataloader['dir_x3'] is not None else None, - dir_x4=data_root + "/" + self.data_for_dataloader['dir_x4'].split("/")[-1] if self.data_for_dataloader['dir_x4'] is not None else None, - dir_x5=data_root + "/" + self.data_for_dataloader['dir_x5'].split("/")[-1] if self.data_for_dataloader['dir_x5'] is not None else None, - dir_x6=data_root + "/" + self.data_for_dataloader['dir_x6'].split("/")[-1] if self.data_for_dataloader['dir_x6'] is not None else None, - dir_x7=data_root + "/" + self.data_for_dataloader['dir_x7'].split("/")[-1] if self.data_for_dataloader['dir_x7'] is not None else None, - dir_x8=data_root + "/" + self.data_for_dataloader['dir_x8'].split("/")[-1] if self.data_for_dataloader['dir_x8'] is not None else None, - dir_x9=data_root + "/" + self.data_for_dataloader['dir_x9'].split("/")[-1] if self.data_for_dataloader['dir_x9'] is not None else None) - self.imloader = torch.utils.data.DataLoader(self.dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 - - with torch.no_grad(): - log = "### \n" - log = log + "[%d]" % batch_num + " " - generator_loss_on_ebsynth = 0 - for i, batch in enumerate(self.imloader): - batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] - for k in batch.keys()} - g_image_loss, g_perc_loss, g_adv_loss, e_cls_loss, e_smiling_loss, gan_output =\ - 0, 0, 0, 0, 0, generator(batch['pre']) - - generator_loss = self.reconstruction_weight * g_image_loss + \ - self.perception_loss_weight * g_perc_loss + \ - self.adversarial_weight * g_adv_loss - - if True or batch['file_name'][0] != "111.png": # do not accumulate loss in train frame - generator_loss_on_ebsynth = generator_loss_on_ebsynth + generator_loss - - if True or batch['file_name'][0] in ["111.png", "101.png", "106.png", "116.png", "121.png"]: - #log = log + batch['file_name'][0] - #log = log + ": %7.4f" % generator_loss + ", " - - image_space = to_image_space(gan_output.cpu().data.numpy()) - - gt_test_ganoutput_path = data_root + "/" + "res_" + config_yaml_name - if not os.path.exists(gt_test_ganoutput_path): - os.mkdir(gt_test_ganoutput_path) - gt_test_ganoutput_path_batch_num = gt_test_ganoutput_path + "/" + str("%07d" % batch_num) - if not os.path.exists(gt_test_ganoutput_path_batch_num): - os.mkdir(gt_test_ganoutput_path_batch_num) - for k in range(0, len(image_space)): - im = image_space[k].transpose(1, 2, 0) - Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path_batch_num, batch['file_name'][k])) - if i == 0: - Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path, str("%07d" % batch_num) + ".png")) - - log = log + " totalLossOnEbsynth: %7.4f" % (generator_loss_on_ebsynth/(len(self.imloader))) - print(log, flush=True) - - - def apply_mask(self, x, batch, mask_key): - if mask_key in batch: - mask = Variable(batch[mask_key].expand(x.size()), requires_grad=False) - return x * (mask / 2 + 0.5) - return x - - def ones_like(self, x): - return torch.ones_like(x).to(self.device) - - def zeros_like(self, x): - return torch.zeros_like(x).to(self.device) - - @staticmethod - def to_image_space(x): - return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8) +import time +import models +import numpy as np +import six +import torch +import torch.nn as nn +from torch.autograd import Variable +from PIL import Image +from custom_transforms import * +from data import DatasetFullImages +import os +import gc + +import tensorflow as tf +config = tf.compat.v1.ConfigProto() +config.gpu_options.allow_growth = True +sess = tf.compat.v1.Session(config=config) + +torch.backends.cudnn.benchmark = False + +class Trainer(object): + def __init__(self, + train_loader, data_for_dataloader, opt_discriminator, opt_generator, + reconstruction_criterion, adversarial_criterion, reconstruction_weight, + adversarial_weight, log_interval, scalar_logger, model_logger, + perception_loss_model, perception_loss_weight, use_image_loss, device + ): + + self.train_loader = train_loader + self.data_for_dataloader = data_for_dataloader + + self.opt_discriminator = opt_discriminator + self.opt_generator = opt_generator + + self.reconstruction_criterion = reconstruction_criterion + self.adversarial_criterion = adversarial_criterion + + self.reconstruction_weight = reconstruction_weight + self.adversarial_weight = adversarial_weight + + self.scalar_logger = scalar_logger + self.model_logger = model_logger + + self.training_log = {} + self.log_interval = log_interval + + self.perception_loss_weight = perception_loss_weight + self.perception_loss_model = perception_loss_model + + self.use_adversarial_loss = False + self.use_image_loss = use_image_loss + self.device = device + + self.dataset = None + self.imloader = None + + + def run_discriminator(self, discriminator, images): + return discriminator(images) + + def compute_discriminator_loss(self, generator, discriminator, batch): + generated = generator(batch['pre']) + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_labels, _ = self.run_discriminator(discriminator, fake.detach()) + + true = self.apply_mask(batch['already'], batch, 'already_mask') + true_labels, _ = self.run_discriminator(discriminator, true) + + discriminator_loss = self.adversarial_criterion(fake_labels, self.zeros_like(fake_labels)) + \ + self.adversarial_criterion(true_labels, self.ones_like(true_labels)) + + return discriminator_loss + + def compute_generator_loss(self, generator, discriminator, batch, use_gan, use_mask): + image_loss = 0 + perception_loss = 0 + adversarial_loss = 0 + + generated = generator(batch['pre']) + + if use_mask: + generated = generated * batch['mask'] + batch['post'] = batch['post'] * batch['mask'] + + if self.use_image_loss: + if generated[0][0].shape != batch['post'][0][0].shape: + if ((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) % 2) != 0: + raise RuntimeError("batch['post'][0][0].shape[0] - generated[0][0].shape[0] must be even number") + if generated[0][0].shape[0] != generated[0][0].shape[1] or batch['post'][0][0].shape[0] != batch['post'][0][0].shape[1]: + raise RuntimeError("And also it is expected to be exact square ... fix it if you want") + boundary_size = int((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) / 2) + cropped_batch_post = batch['post'][:, :, boundary_size: -1*boundary_size, boundary_size: -1*boundary_size] + image_loss = self.reconstruction_criterion(generated, cropped_batch_post) + else: + image_loss = self.reconstruction_criterion(generated, batch['post']) + + if self.perception_loss_model is not None: + _, fake_features = self.perception_loss_model(generated) + _, target_features = self.perception_loss_model(Variable(batch['post'], requires_grad=False)) + perception_loss = ((fake_features - target_features) ** 2).mean() + + + if self.use_adversarial_loss and use_gan: + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_smiling_labels, _ = self.run_discriminator(discriminator, fake) + adversarial_loss = self.adversarial_criterion(fake_smiling_labels, self.ones_like(fake_smiling_labels)) + + return image_loss, perception_loss, adversarial_loss, generated + + + def train(self, generator, discriminator, epochs, data_root, config_yaml_name, starting_batch_num, step): + self.use_adversarial_loss = discriminator is not None + batch_num = starting_batch_num + save_num = 0 + + start = time.time() + for epoch in range(epochs): + np.random.seed() + for i, batch in enumerate(self.train_loader): + # just sets the models into training mode (enable BN and DO) + [m.train() for m in [generator, discriminator] if m is not None] + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + + # train discriminator + if self.use_adversarial_loss: + self.opt_discriminator.zero_grad() + discriminator_loss = self.compute_discriminator_loss(generator, discriminator, batch) + discriminator_loss.backward() + self.opt_discriminator.step() + + # train generator + self.opt_generator.zero_grad() + + g_image_loss, g_perc_loss, g_adv_loss, _ = self.compute_generator_loss(generator, discriminator, batch, use_gan=True, use_mask=False) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + generator_loss.backward() + + self.opt_generator.step() + + # log losses + current_log = {key: value.item() for key, value in six.iteritems(locals()) if + 'loss' in key and isinstance(value, Variable)} + + self.add_log(current_log) + + batch_num += 1 + + if batch_num % 100 == 0: + print(f"Batch num: {batch_num}, totally elapsed {(time.time() - start)}", flush=True) + + #if batch_num % self.log_interval == 0 or batch_num == 1: + if batch_num % self.log_interval == 0 or batch_num == 1: # (time.time() - start) > 16: + eval_start = time.time() + generator.eval() + self.test_on_full_image(generator, batch_num, data_root, config_yaml_name) + self.flush_scalar_log(batch_num, time.time() - start, step=step) + self.model_logger.save(generator, save_num, True) + #self.model_logger.save(discriminator, save_num, False) + save_num += 1 + print(f"Eval of batch: {batch_num} took {(time.time() - eval_start)}", flush=True) + + #if batch_num > 5000: + # sys.exit(0) + + self.model_logger.save(generator, 99999) + + # Accumulates the losses + def add_log(self, log): + for k, v in log.items(): + if k in self.training_log: + self.training_log[k] += v + else: + self.training_log[k] = v + + # Divide the losses by log_interval and print'em + def flush_scalar_log(self, batch_num, took, step): + for key in self.training_log.keys(): + self.scalar_logger.scalar_summary(key, self.training_log[key] / self.log_interval, batch_num, step=step) + + log = "[%d]" % batch_num + for key in sorted(self.training_log.keys()): + log += " [%s] % 7.4f" % (key, self.training_log[key] / self.log_interval) + + log += ". Took {}".format(took) + print(log, flush=True) + self.training_log = {} + + # Test the intermediate model on data from _gen folder + def test_on_full_image(self, generator, batch_num, data_root, config_yaml_name): + config_yaml_name = config_yaml_name.replace("reference", "").replace(".yaml", "") + + data_root = data_root.replace("_train", "_gen") + if self.dataset is None: + self.dataset = DatasetFullImages(data_root + "/" + self.data_for_dataloader['dir_pre'].split("/")[-1], + "ignore", # data_root + "/" + "ebsynth", + "ignore", # data_root + "/" + "mask", + self.device, + dir_x1=data_root + "/" + self.data_for_dataloader['dir_x1'].split("/")[-1] if self.data_for_dataloader['dir_x1'] is not None else None, + dir_x2=data_root + "/" + self.data_for_dataloader['dir_x2'].split("/")[-1] if self.data_for_dataloader['dir_x2'] is not None else None, + dir_x3=data_root + "/" + self.data_for_dataloader['dir_x3'].split("/")[-1] if self.data_for_dataloader['dir_x3'] is not None else None, + dir_x4=data_root + "/" + self.data_for_dataloader['dir_x4'].split("/")[-1] if self.data_for_dataloader['dir_x4'] is not None else None, + dir_x5=data_root + "/" + self.data_for_dataloader['dir_x5'].split("/")[-1] if self.data_for_dataloader['dir_x5'] is not None else None, + dir_x6=data_root + "/" + self.data_for_dataloader['dir_x6'].split("/")[-1] if self.data_for_dataloader['dir_x6'] is not None else None, + dir_x7=data_root + "/" + self.data_for_dataloader['dir_x7'].split("/")[-1] if self.data_for_dataloader['dir_x7'] is not None else None, + dir_x8=data_root + "/" + self.data_for_dataloader['dir_x8'].split("/")[-1] if self.data_for_dataloader['dir_x8'] is not None else None, + dir_x9=data_root + "/" + self.data_for_dataloader['dir_x9'].split("/")[-1] if self.data_for_dataloader['dir_x9'] is not None else None) + self.imloader = torch.utils.data.DataLoader(self.dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 + + with torch.no_grad(): + log = "### \n" + log = log + "[%d]" % batch_num + " " + generator_loss_on_ebsynth = 0 + for i, batch in enumerate(self.imloader): + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + g_image_loss, g_perc_loss, g_adv_loss, e_cls_loss, e_smiling_loss, gan_output =\ + 0, 0, 0, 0, 0, generator(batch['pre']) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + if True or batch['file_name'][0] != "111.png": # do not accumulate loss in train frame + generator_loss_on_ebsynth = generator_loss_on_ebsynth + generator_loss + + if True or batch['file_name'][0] in ["111.png", "101.png", "106.png", "116.png", "121.png"]: + #log = log + batch['file_name'][0] + #log = log + ": %7.4f" % generator_loss + ", " + + image_space = to_image_space(gan_output.cpu().data.numpy()) + + gt_test_ganoutput_path = data_root + "/" + "res_" + config_yaml_name + if not os.path.exists(gt_test_ganoutput_path): + os.mkdir(gt_test_ganoutput_path) + gt_test_ganoutput_path_batch_num = gt_test_ganoutput_path + "/" + str("%07d" % batch_num) + if not os.path.exists(gt_test_ganoutput_path_batch_num): + os.mkdir(gt_test_ganoutput_path_batch_num) + for k in range(0, len(image_space)): + im = image_space[k].transpose(1, 2, 0) + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path_batch_num, batch['file_name'][k])) + if i == 0: + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path, str("%07d" % batch_num) + ".png")) + + log = log + " totalLossOnEbsynth: %7.4f" % (generator_loss_on_ebsynth/(len(self.imloader))) + print(log, flush=True) + + + def apply_mask(self, x, batch, mask_key): + if mask_key in batch: + mask = Variable(batch[mask_key].expand(x.size()), requires_grad=False) + return x * (mask / 2 + 0.5) + return x + + def ones_like(self, x): + return torch.ones_like(x).to(self.device) + + def zeros_like(self, x): + return torch.zeros_like(x).to(self.device) + + @staticmethod + def to_image_space(x): + return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8) diff --git a/trainers1.py b/trainers1.py new file mode 100644 index 0000000..b99fd1f --- /dev/null +++ b/trainers1.py @@ -0,0 +1,267 @@ +import time +import models +import numpy as np +import six +import torch +import torch.nn as nn +from torch.autograd import Variable +from PIL import Image +from custom_transforms import * +from data import DatasetFullImages +import os +import gc + +import tensorflow as tf +config = tf.compat.v1.ConfigProto() +config.gpu_options.allow_growth = True +sess = tf.Session(config=config) + +torch.backends.cudnn.benchmark = False + +class Trainer(object): + def __init__(self, + train_loader, data_for_dataloader, opt_discriminator, opt_generator, + reconstruction_criterion, adversarial_criterion, reconstruction_weight, + adversarial_weight, log_interval, scalar_logger, model_logger, + perception_loss_model, perception_loss_weight, use_image_loss, device + ): + + self.train_loader = train_loader + self.data_for_dataloader = data_for_dataloader + + self.opt_discriminator = opt_discriminator + self.opt_generator = opt_generator + + self.reconstruction_criterion = reconstruction_criterion + self.adversarial_criterion = adversarial_criterion + + self.reconstruction_weight = reconstruction_weight + self.adversarial_weight = adversarial_weight + + self.scalar_logger = scalar_logger + self.model_logger = model_logger + + self.training_log = {} + self.log_interval = log_interval + + self.perception_loss_weight = perception_loss_weight + self.perception_loss_model = perception_loss_model + + self.use_adversarial_loss = False + self.use_image_loss = use_image_loss + self.device = device + + self.dataset = None + self.imloader = None + + + def run_discriminator(self, discriminator, images): + return discriminator(images) + + def compute_discriminator_loss(self, generator, discriminator, batch): + generated = generator(batch['pre']) + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_labels, _ = self.run_discriminator(discriminator, fake.detach()) + + true = self.apply_mask(batch['already'], batch, 'already_mask') + true_labels, _ = self.run_discriminator(discriminator, true) + + discriminator_loss = self.adversarial_criterion(fake_labels, self.zeros_like(fake_labels)) + \ + self.adversarial_criterion(true_labels, self.ones_like(true_labels)) + + return discriminator_loss + + def compute_generator_loss(self, generator, discriminator, batch, use_gan, use_mask): + image_loss = 0 + perception_loss = 0 + adversarial_loss = 0 + + generated = generator(batch['pre']) + + if use_mask: + generated = generated * batch['mask'] + batch['post'] = batch['post'] * batch['mask'] + + if self.use_image_loss: + if generated[0][0].shape != batch['post'][0][0].shape: + if ((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) % 2) != 0: + raise RuntimeError("batch['post'][0][0].shape[0] - generated[0][0].shape[0] must be even number") + if generated[0][0].shape[0] != generated[0][0].shape[1] or batch['post'][0][0].shape[0] != batch['post'][0][0].shape[1]: + raise RuntimeError("And also it is expected to be exact square ... fix it if you want") + boundary_size = int((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) / 2) + cropped_batch_post = batch['post'][:, :, boundary_size: -1*boundary_size, boundary_size: -1*boundary_size] + image_loss = self.reconstruction_criterion(generated, cropped_batch_post) + else: + image_loss = self.reconstruction_criterion(generated, batch['post']) + + if self.perception_loss_model is not None: + _, fake_features = self.perception_loss_model(generated) + _, target_features = self.perception_loss_model(Variable(batch['post'], requires_grad=False)) + perception_loss = ((fake_features - target_features) ** 2).mean() + + + if self.use_adversarial_loss and use_gan: + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_smiling_labels, _ = self.run_discriminator(discriminator, fake) + adversarial_loss = self.adversarial_criterion(fake_smiling_labels, self.ones_like(fake_smiling_labels)) + + return image_loss, perception_loss, adversarial_loss, generated + + + def train(self, generator, discriminator, epochs, data_root, config_yaml_name, starting_batch_num): + self.use_adversarial_loss = discriminator is not None + batch_num = starting_batch_num + save_num = 0 + + start = time.time() + for epoch in range(epochs): + np.random.seed() + for i, batch in enumerate(self.train_loader): + # just sets the models into training mode (enable BN and DO) + [m.train() for m in [generator, discriminator] if m is not None] + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + + # train discriminator + if self.use_adversarial_loss: + self.opt_discriminator.zero_grad() + discriminator_loss = self.compute_discriminator_loss(generator, discriminator, batch) + discriminator_loss.backward() + self.opt_discriminator.step() + + # train generator + self.opt_generator.zero_grad() + + g_image_loss, g_perc_loss, g_adv_loss, _ = self.compute_generator_loss(generator, discriminator, batch, use_gan=True, use_mask=False) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + generator_loss.backward() + + self.opt_generator.step() + + # log losses + current_log = {key: value.item() for key, value in six.iteritems(locals()) if + 'loss' in key and isinstance(value, Variable)} + + self.add_log(current_log) + + batch_num += 1 + + if batch_num % 100 == 0: + print(f"Batch num: {batch_num}, totally elapsed {(time.time() - start)}", flush=True) + + #if batch_num % self.log_interval == 0 or batch_num == 1: + if batch_num % self.log_interval == 0 or batch_num == 1: # (time.time() - start) > 16: + eval_start = time.time() + generator.eval() + self.test_on_full_image(generator, batch_num, data_root, config_yaml_name) + self.flush_scalar_log(batch_num, time.time() - start) + self.model_logger.save(generator, save_num, True) + #self.model_logger.save(discriminator, save_num, False) + save_num += 1 + print(f"Eval of batch: {batch_num} took {(time.time() - eval_start)}", flush=True) + + #if batch_num > 5000: + # sys.exit(0) + + self.model_logger.save(generator, 99999) + + # Accumulates the losses + def add_log(self, log): + for k, v in log.items(): + if k in self.training_log: + self.training_log[k] += v + else: + self.training_log[k] = v + + # Divide the losses by log_interval and print'em + def flush_scalar_log(self, batch_num, took): + for key in self.training_log.keys(): + self.scalar_logger.scalar_summary(key, self.training_log[key] / self.log_interval, batch_num) + + log = "[%d]" % batch_num + for key in sorted(self.training_log.keys()): + log += " [%s] % 7.4f" % (key, self.training_log[key] / self.log_interval) + + log += ". Took {}".format(took) + print(log, flush=True) + self.training_log = {} + + # Test the intermediate model on data from _gen folder + def test_on_full_image(self, generator, batch_num, data_root, config_yaml_name): + config_yaml_name = config_yaml_name.replace("reference", "").replace(".yaml", "") + + data_root = data_root.replace("_train", "_gen") + if self.dataset is None: + self.dataset = DatasetFullImages(data_root + "/" + self.data_for_dataloader['dir_pre'].split("/")[-1], + "ignore", # data_root + "/" + "ebsynth", + "ignore", # data_root + "/" + "mask", + self.device, + dir_x1=data_root + "/" + self.data_for_dataloader['dir_x1'].split("/")[-1] if self.data_for_dataloader['dir_x1'] is not None else None, + dir_x2=data_root + "/" + self.data_for_dataloader['dir_x2'].split("/")[-1] if self.data_for_dataloader['dir_x2'] is not None else None, + dir_x3=data_root + "/" + self.data_for_dataloader['dir_x3'].split("/")[-1] if self.data_for_dataloader['dir_x3'] is not None else None, + dir_x4=data_root + "/" + self.data_for_dataloader['dir_x4'].split("/")[-1] if self.data_for_dataloader['dir_x4'] is not None else None, + dir_x5=data_root + "/" + self.data_for_dataloader['dir_x5'].split("/")[-1] if self.data_for_dataloader['dir_x5'] is not None else None, + dir_x6=data_root + "/" + self.data_for_dataloader['dir_x6'].split("/")[-1] if self.data_for_dataloader['dir_x6'] is not None else None, + dir_x7=data_root + "/" + self.data_for_dataloader['dir_x7'].split("/")[-1] if self.data_for_dataloader['dir_x7'] is not None else None, + dir_x8=data_root + "/" + self.data_for_dataloader['dir_x8'].split("/")[-1] if self.data_for_dataloader['dir_x8'] is not None else None, + dir_x9=data_root + "/" + self.data_for_dataloader['dir_x9'].split("/")[-1] if self.data_for_dataloader['dir_x9'] is not None else None) + self.imloader = torch.utils.data.DataLoader(self.dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 + + with torch.no_grad(): + log = "### \n" + log = log + "[%d]" % batch_num + " " + generator_loss_on_ebsynth = 0 + for i, batch in enumerate(self.imloader): + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + g_image_loss, g_perc_loss, g_adv_loss, e_cls_loss, e_smiling_loss, gan_output =\ + 0, 0, 0, 0, 0, generator(batch['pre']) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + if True or batch['file_name'][0] != "111.png": # do not accumulate loss in train frame + generator_loss_on_ebsynth = generator_loss_on_ebsynth + generator_loss + + if True or batch['file_name'][0] in ["111.png", "101.png", "106.png", "116.png", "121.png"]: + #log = log + batch['file_name'][0] + #log = log + ": %7.4f" % generator_loss + ", " + + image_space = to_image_space(gan_output.cpu().data.numpy()) + + gt_test_ganoutput_path = data_root + "/" + "res_" + config_yaml_name + if not os.path.exists(gt_test_ganoutput_path): + os.mkdir(gt_test_ganoutput_path) + gt_test_ganoutput_path_batch_num = gt_test_ganoutput_path + "/" + str("%07d" % batch_num) + if not os.path.exists(gt_test_ganoutput_path_batch_num): + os.mkdir(gt_test_ganoutput_path_batch_num) + for k in range(0, len(image_space)): + im = image_space[k].transpose(1, 2, 0) + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path_batch_num, batch['file_name'][k])) + if i == 0: + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path, str("%07d" % batch_num) + ".png")) + + log = log + " totalLossOnEbsynth: %7.4f" % (generator_loss_on_ebsynth/(len(self.imloader))) + print(log, flush=True) + + + def apply_mask(self, x, batch, mask_key): + if mask_key in batch: + mask = Variable(batch[mask_key].expand(x.size()), requires_grad=False) + return x * (mask / 2 + 0.5) + return x + + def ones_like(self, x): + return torch.ones_like(x).to(self.device) + + def zeros_like(self, x): + return torch.zeros_like(x).to(self.device) + + @staticmethod + def to_image_space(x): + return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8)