From eb2591db4f6dd56b37a0282b088bab656e3f19fe Mon Sep 17 00:00:00 2001 From: Ondrej Date: Thu, 13 Aug 2020 22:26:34 +0200 Subject: [PATCH] Init --- README.md | 80 +++++++ _config/reference_F.yaml | 80 +++++++ _config/reference_P.yaml | 81 +++++++ _config/reference_P_disco1010.yaml | 82 +++++++ _config/reference_P_disco1015.yaml | 82 +++++++ common_utils.py | 59 +++++ custom_transforms.py | 56 +++++ data.py | 341 +++++++++++++++++++++++++++++ generate.py | 64 ++++++ logger.py | 30 +++ models.py | 319 +++++++++++++++++++++++++++ train.py | 132 +++++++++++ trainers.py | 260 ++++++++++++++++++++++ 13 files changed, 1666 insertions(+) create mode 100644 README.md create mode 100644 _config/reference_F.yaml create mode 100644 _config/reference_P.yaml create mode 100644 _config/reference_P_disco1010.yaml create mode 100644 _config/reference_P_disco1015.yaml create mode 100644 common_utils.py create mode 100644 custom_transforms.py create mode 100644 data.py create mode 100644 generate.py create mode 100644 logger.py create mode 100644 models.py create mode 100644 train.py create mode 100644 trainers.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..18bdf14 --- /dev/null +++ b/README.md @@ -0,0 +1,80 @@ +# Interactive Video Stylization Using Few-Shot Patch-Based Training + +The official implementation of + +> **Interactive Video Stylization Using Few-Shot Patch-Based Training**
+_[O. Texler](https://ondrejtexler.github.io/), [D. Futschik](https://dcgi.fel.cvut.cz/people/futscdav), +[M. Kučera](https://www.linkedin.com/in/kuceram/), [O. Jamriška](https://dcgi.fel.cvut.cz/people/jamriond), +[Š. Sochorová](https://dcgi.fel.cvut.cz/people/sochosar), [M. Chai](http://www.mlchai.com), +[S. Tulyakov](http://www.stulyakov.com), and [D. Sýkora](https://dcgi.fel.cvut.cz/home/sykorad/)_
+[[`WebPage`](https://ondrejtexler.github.io/patch-based_training)], +[[`Paper`](https://ondrejtexler.github.io/res/Texler20-SIG_patch-based_training_main.pdf)], +[[`BiBTeX`](#CitingFewShotPatchBasedTraining)] + + +## Run + +Download the [TESTING DATA](https://drive.google.com/file/d/1EscSNFg4ILpB7dxr-zYw_UdOILLmDlRj/view?usp=sharing), and unzip. +The _train folder is expected to be next to the _gen folder. + +To train the network, run the `train.py`. +To generate the results, run `generate.py`. +See example commands below: + +``` +train.py --config "_config/reference_P.yaml" + --data_root "Maruska640_train" + --log_interval 1000 + --log_folder logs_reference_P +``` + +Every 1000 (log_interval) epochs, `train.py` saves the current generator to logs_reference_P (log_folder), and it validates/runs the generator on _gen data - the result is saved in Maruska640_gen/res__P + + +``` +generate.py --checkpoint "Maruska640_train/logs_reference_P/model_00010.pth" + --data_root "Maruska_gen" + --dir_input "input_filtered" + --outdir "Maruska_gen/res_00010" + --device "cuda:0" +``` + + +## Installation +Tested on Windows 10, `Python 3.7.8`, `CUDA 10.2`. +With the following python packages: +``` +numpy 1.19.1 +opencv-python 4.4.0.40 +Pillow 7.2.0 +PyYAML 5.3.1 +scikit-image 0.17.2 +scipy 1.5.2 +torch 1.6.0 +torchvision 0.7.0 +``` + + +## Credits +* This project started when [Ondrej Texler](https://ondrejtexler.github.io/) was an intern at [Snap Inc.](https://www.snap.com/), and it was funded by [Snap Inc.](https://www.snap.com/) and [Czech Technical University in Prague](https://www.cvut.cz/en) + + +## License +The code is released for research purposes only. + + +## Citing +If you find Interactive Video Stylization Using Few-Shot Patch-Based Training useful for your research or work, please use the following BibTeX entry. + +``` +@Article{Texler20-SIG, + author = "Ond\v{r}ej Texler and David Futschik and Michal Ku\v{c}era and Ond\v{r}ej Jamri\v{s}ka and \v{S}\'{a}rka Sochorov\'{a} and Menglei Chai and Sergey Tulyakov and Daniel S\'{y}kora", + title = "Interactive Video Stylization Using Few-Shot Patch-Based Training", + journal = "ACM Transactions on Graphics", + volume = "39", + number = "4", + pages = "73", + year = "2020", +} +``` + diff --git a/_config/reference_F.yaml b/_config/reference_F.yaml new file mode 100644 index 0000000..1adc94b --- /dev/null +++ b/_config/reference_F.yaml @@ -0,0 +1,80 @@ +# Generator +generator: &generator_j + type: GeneratorJ + args: + use_bias: True + tanh: True + append_smoothers: True + resnet_blocks: 4 + filters: [32, 64, 128, 128, 128, 64] + input_channels: 3 + + +# Optimizer of Generator +opt_generator: &opt_generator + type: Adam + args: + lr: 0.0002 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Discriminator +discriminator: &discriminatorn + type: DiscriminatorN_IN + args: + num_filters: 12 + n_layers: 2 + + +# Optimizer of Discriminator +opt_discriminator: &opt_discriminator + type: Adam + args: + lr: 0.0002 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Parameters of Perception Loss (VGG-Loss) +perception_loss: &perception_loss + weight: 6.0 + perception_model: + type: PerceptualVGG19 + args: + feature_layers: [0, 3, 5] + use_normalization: False + + +# Training Parameters +trainer: &trainer_1 + batch_size: 1 + epochs: 555555555 + reconstruction_weight: 4. + adversarial_weight: 0.5 + use_image_loss: True + reconstruction_criterion: L1Loss + adversarial_criterion: MSELoss + + +# Training Dataset Parameters +training_dataset: &training_dataset + type: DatasetFullImages + dir_pre: input + dir_post: output + dir_mask: mask + + +# "Main" of this YAML file +job: + training_dataset: *training_dataset + generator: *generator_j + opt_generator: *opt_generator + discriminator: *discriminatorn + opt_discriminator: *opt_discriminator + perception_loss: *perception_loss + trainer: *trainer_1 + + num_workers: 1 + device: "cuda:0" + diff --git a/_config/reference_P.yaml b/_config/reference_P.yaml new file mode 100644 index 0000000..862f3fc --- /dev/null +++ b/_config/reference_P.yaml @@ -0,0 +1,81 @@ +# Generator +generator: &generator_j + type: GeneratorJ + args: + use_bias: True + tanh: True + append_smoothers: True + resnet_blocks: 7 + filters: [32, 64, 128, 128, 128, 64] + input_channels: 3 + + +# Optimizer of Generator +opt_generator: &opt_generator + type: Adam + args: + lr: 0.0004 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Discriminator +discriminator: &discriminatorn + type: DiscriminatorN_IN + args: + num_filters: 12 + n_layers: 2 + + +# Optimizer of Discriminator +opt_discriminator: &opt_discriminator + type: Adam + args: + lr: 0.0004 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Parameters of Perception Loss (VGG-Loss) +perception_loss: &perception_loss + weight: 6.0 + perception_model: + type: PerceptualVGG19 + args: + feature_layers: [0, 3, 5] + use_normalization: False + + +# Training Parameters +trainer: &trainer_1 + batch_size: 40 + epochs: 50000000 + reconstruction_weight: 4. + adversarial_weight: 0.5 + use_image_loss: True + reconstruction_criterion: L1Loss + adversarial_criterion: MSELoss + + +# Training Dataset Parameters +training_dataset: &training_dataset + type: DatasetPatches_M + dir_pre: input_filtered + dir_post: output + dir_mask: mask + patch_size: 32 + + +# "Main" of this YAML file +job: + training_dataset: *training_dataset + generator: *generator_j + opt_generator: *opt_generator + discriminator: *discriminatorn + opt_discriminator: *opt_discriminator + perception_loss: *perception_loss + trainer: *trainer_1 + + num_workers: 1 + device: "cuda:0" + diff --git a/_config/reference_P_disco1010.yaml b/_config/reference_P_disco1010.yaml new file mode 100644 index 0000000..53daf9e --- /dev/null +++ b/_config/reference_P_disco1010.yaml @@ -0,0 +1,82 @@ +# Generator +generator: &generator_j + type: GeneratorJ + args: + use_bias: True + tanh: True + append_smoothers: True + resnet_blocks: 7 + filters: [32, 64, 128, 128, 128, 64] + input_channels: 6 + + +# Optimizer of Generator +opt_generator: &opt_generator + type: Adam + args: + lr: 0.0004 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Discriminator +discriminator: &discriminatorn + type: DiscriminatorN_IN + args: + num_filters: 12 + n_layers: 2 + + +# Optimizer of Discriminator +opt_discriminator: &opt_discriminator + type: Adam + args: + lr: 0.0004 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Parameters of Perception Loss (VGG-Loss) +perception_loss: &perception_loss + weight: 6.0 + perception_model: + type: PerceptualVGG19 + args: + feature_layers: [0, 3, 5] + use_normalization: False + + +# Training Parameters +trainer: &trainer_1 + batch_size: 40 + epochs: 50000000 + reconstruction_weight: 4. + adversarial_weight: 0.5 + use_image_loss: True + reconstruction_criterion: L1Loss + adversarial_criterion: MSELoss + + +# Training Dataset Parameters +training_dataset: &training_dataset + type: DatasetPatches_M + dir_pre: input_filtered + dir_post: output + dir_mask: mask + patch_size: 32 + dir_x1: input_gdisko_gauss_r10_s10 + + +# "Main" of this YAML file +job: + training_dataset: *training_dataset + generator: *generator_j + opt_generator: *opt_generator + discriminator: *discriminatorn + opt_discriminator: *opt_discriminator + perception_loss: *perception_loss + trainer: *trainer_1 + + num_workers: 1 + device: "cuda:0" + diff --git a/_config/reference_P_disco1015.yaml b/_config/reference_P_disco1015.yaml new file mode 100644 index 0000000..4ce0214 --- /dev/null +++ b/_config/reference_P_disco1015.yaml @@ -0,0 +1,82 @@ +# Generator +generator: &generator_j + type: GeneratorJ + args: + use_bias: True + tanh: True + append_smoothers: True + resnet_blocks: 7 + filters: [32, 64, 128, 128, 128, 64] + input_channels: 6 + + +# Optimizer of Generator +opt_generator: &opt_generator + type: Adam + args: + lr: 0.0004 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Discriminator +discriminator: &discriminatorn + type: DiscriminatorN_IN + args: + num_filters: 12 + n_layers: 2 + + +# Optimizer of Discriminator +opt_discriminator: &opt_discriminator + type: Adam + args: + lr: 0.0004 + betas: [0.9, 0.999] + weight_decay: 0.00001 + + +# Parameters of Perception Loss (VGG-Loss) +perception_loss: &perception_loss + weight: 6.0 + perception_model: + type: PerceptualVGG19 + args: + feature_layers: [0, 3, 5] + use_normalization: False + + +# Training Parameters +trainer: &trainer_1 + batch_size: 40 + epochs: 50000000 + reconstruction_weight: 4. + adversarial_weight: 0.5 + use_image_loss: True + reconstruction_criterion: L1Loss + adversarial_criterion: MSELoss + + +# Training Dataset Parameters +training_dataset: &training_dataset + type: DatasetPatches_M + dir_pre: input_filtered + dir_post: output + dir_mask: mask + patch_size: 32 + dir_x1: input_gdisko_gauss_r10_s15 + + +# "Main" of this YAML file +job: + training_dataset: *training_dataset + generator: *generator_j + opt_generator: *opt_generator + discriminator: *discriminatorn + opt_discriminator: *opt_discriminator + perception_loss: *perception_loss + trainer: *trainer_1 + + num_workers: 1 + device: "cuda:0" + diff --git a/common_utils.py b/common_utils.py new file mode 100644 index 0000000..193f594 --- /dev/null +++ b/common_utils.py @@ -0,0 +1,59 @@ +import numpy as np +import os +import cv2 + + +def make_image_noisy(image, noise_typ): + if noise_typ == "gauss": + row, col, ch = image.shape + mean = 0 + var = 40 + sigma = var**0.5 + gauss = np.random.normal(mean, sigma, (row, col, ch)) + gauss = gauss.reshape((row, col, ch)) + noisy_image = image + gauss + return noisy_image.clip(0, 255) + elif noise_typ == "zero": + amount = 0.05 # percentage of zero pixels + out = np.copy(image) + num_zeros = np.ceil(amount * image.shape[0]*image.shape[1]) + coords = [np.random.randint(0, i - 1, int(num_zeros)) + for i in image.shape[:2]] + out[:, :, 0][coords] = 0 + out[:, :, 1][coords] = 0 + out[:, :, 2][coords] = 0 + return out.astype(np.uint8) + elif noise_typ == "s&p": + raise RuntimeError("Test it properly before using!") + row, col, ch = image.shape + s_vs_p = 0.5 + amount = 0.004 + out = np.copy(image) + # Salt mode + num_salt = np.ceil(amount * image.size * s_vs_p) + coords = [np.random.randint(0, i - 1, int(num_salt)) + for i in image.shape] + out[coords] = 1 + + # Pepper mode + num_pepper = np.ceil(amount* image.size * (1. - s_vs_p)) + coords = [np.random.randint(0, i - 1, int(num_pepper)) + for i in image.shape] + out[coords] = 0 + return out + elif noise_typ == "poisson": + raise RuntimeError("Test it properly before using!") + vals = len(np.unique(image)) + vals = 2 ** np.ceil(np.log2(vals)) + noisy_image = np.random.poisson(image * vals) / float(vals) + return noisy_image + elif noise_typ == "speckle": + raise RuntimeError("Test it properly before using!") + row, col, ch = image.shape + gauss = np.random.randn(row, col, ch) + gauss = gauss.reshape((row, col, ch)) + noisy_image = image + image * gauss + return noisy_image + else: + raise RuntimeError(f"Unknown noisy_type: {noise_typ}") + diff --git a/custom_transforms.py b/custom_transforms.py new file mode 100644 index 0000000..9065c7a --- /dev/null +++ b/custom_transforms.py @@ -0,0 +1,56 @@ +import numpy as np +from torchvision import transforms +from scipy import ndimage +import torch + + +def to_image_space(x): + return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8) + + +def to_rgb(x): + return x if x.mode == 'RGB' else x.convert('RGB') + + +def to_l(x): + return x if x.mode == 'L' else x.convert('L') + + +def blur_mask(tensor): + np_tensor = tensor.numpy() + smoothed = ndimage.gaussian_filter(np_tensor, sigma=20) + return torch.FloatTensor(smoothed) + + +def build_transform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mask=False): + #if type(image_size) != tuple: + #image_size = (image_size, image_size) + t = [#transforms.Resize((image_size[0], image_size[1])), + to_rgb, + transforms.ToTensor(), + transforms.Normalize(mean, std)] + if mask: + t.append(blur_mask) + return transforms.Compose(t) + + +def build_mask_transform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): + t = [#transforms.Resize((image_size, image_size)), + to_l, + transforms.ToTensor()] + return transforms.Compose(t) + + +def to_pil(tensor): + t = transforms.ToPILImage() + return t(tensor) + + +def tensor_mb(tensor): + return (tensor.element_size() * tensor.nelement()) / 1024 / 1024 + + + + + + diff --git a/data.py b/data.py new file mode 100644 index 0000000..18c0779 --- /dev/null +++ b/data.py @@ -0,0 +1,341 @@ +import os +import torch.utils.data +import PIL +from PIL import Image +from torch.nn import functional as F +from custom_transforms import * + + +def get_geometric_blur_patch(tensor_small, midpoint, patchsize, coeff): + midpoint = midpoint // coeff + hs = patchsize // 2 + hn = max(0, midpoint[0] - hs) + hx = min(midpoint[0] + hs, tensor_small.size()[1] - 1) + xn = max(0, midpoint[1] - hs) + xx = min(midpoint[1] + hs, tensor_small.size()[2] - 1) + + p = tensor_small[:, hn:hx, xn:xx] + if p.size()[1] != patchsize or p.size()[2] != patchsize: + r = torch.zeros((3, patchsize, patchsize)) + r[:, 0:p.size()[1], 0:p.size()[2]] = p + p = r + return p + + +################################ +# Dataset full-images +################################ +class DatasetFullImages(torch.utils.data.Dataset): + def __init__(self, dir_pre, dir_post, dir_mask, device, dir_x1, dir_x2, dir_x3, dir_x4, dir_x5, dir_x6, dir_x7, dir_x8, dir_x9): + super(DatasetFullImages, self).__init__() + + self.dir_pre = dir_pre + self.fnames = sorted(os.listdir(self.dir_pre)) + self.dir_post = dir_post + self.dir_mask = dir_mask + self.transform = build_transform() + self.mask_transform = build_mask_transform() + #self.temporal_frames = 3 + + self.dir_pre_x1 = dir_x1 + self.dir_pre_x2 = dir_x2 + self.dir_pre_x3 = dir_x3 + self.dir_pre_x4 = dir_x4 + self.dir_pre_x5 = dir_x5 + self.dir_pre_x6 = dir_x6 + self.dir_pre_x7 = dir_x7 + self.dir_pre_x8 = dir_x8 + self.dir_pre_x9 = dir_x9 + #print('DatasetFullImages: number of training examples %d' % len(self.fnames)) + + #def getitem_inner(self, item): + def __getitem__(self, item): + # get an image that is NOT stylized and its stylized counterpart + fileName = self.fnames[item] + pre = PIL.Image.open(os.path.join(self.dir_pre, fileName)) + pre_tensor = self.transform(pre) + if self.dir_pre_x1 is not None and self.dir_pre_x1 != "": + pre_x1 = PIL.Image.open(os.path.join(self.dir_pre_x1, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x1)), dim=0) + if self.dir_pre_x2 is not None and self.dir_pre_x2 != "": + pre_x2 = PIL.Image.open(os.path.join(self.dir_pre_x2, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x2)), dim=0) + if self.dir_pre_x3 is not None and self.dir_pre_x3 != "": + pre_x3 = PIL.Image.open(os.path.join(self.dir_pre_x3, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x3)), dim=0) + if self.dir_pre_x4 is not None and self.dir_pre_x4 != "": + pre_x4 = PIL.Image.open(os.path.join(self.dir_pre_x4, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x4)), dim=0) + if self.dir_pre_x5 is not None and self.dir_pre_x5 != "": + pre_x5 = PIL.Image.open(os.path.join(self.dir_pre_x5, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x5)), dim=0) + if self.dir_pre_x6 is not None and self.dir_pre_x6 != "": + pre_x6 = PIL.Image.open(os.path.join(self.dir_pre_x6, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x6)), dim=0) + if self.dir_pre_x7 is not None and self.dir_pre_x7 != "": + pre_x7 = PIL.Image.open(os.path.join(self.dir_pre_x7, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x7)), dim=0) + if self.dir_pre_x8 is not None and self.dir_pre_x8 != "": + pre_x8 = PIL.Image.open(os.path.join(self.dir_pre_x8, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x8)), dim=0) + if self.dir_pre_x9 is not None and self.dir_pre_x9 != "": + pre_x9 = PIL.Image.open(os.path.join(self.dir_pre_x9, fileName)) + pre_tensor = torch.cat((pre_tensor, self.transform(pre_x9)), dim=0) + + result = {'pre': pre_tensor, + 'file_name': self.fnames[item]} + + if not self.dir_post.endswith("ignore"): + post = PIL.Image.open(os.path.join(self.dir_post, fileName)) + post_tensor = self.transform(post) + result['post'] = post_tensor + + # get a random already stylized image + already_path = os.path.join(self.dir_post, self.fnames[np.random.randint(0, len(self.fnames))]) + im_s = PIL.Image.open(already_path) + im_s_tensor = self.transform(im_s) + result['already'] = im_s_tensor + + if not self.dir_mask.endswith("ignore"): + mask = PIL.Image.open(os.path.join(self.dir_mask, fileName)) + mask = mask.point(lambda p: p > 128 and 255) # !!! thresholding the mask fixes possible float and int conversion errors + mask_tensor = self.mask_transform(mask).int().float() + result['mask'] = mask_tensor + + return result + + def XXX__getitem__(self, item): + result = {'pre': None, + 'file_name': self.fnames[item]} + + for i in range(item - self.temporal_frames, item + self.temporal_frames + 1): + is_curr_item = True if i == item else False + i = max(0, i) + i = min(len(self.fnames)-1, i) + result_i = self.getitem_inner(i) + + if result['pre'] is None: + result['pre'] = result_i['pre'] + else: + result['pre'] = torch.cat((result['pre'], result_i['pre']), dim=0) + + if is_curr_item and "post" in result_i: + result['post'] = result_i['post'] + if is_curr_item and "already" in result_i: + result['already'] = result_i['already'] + if is_curr_item and "mask" in result_i: + result['mask'] = result_i['mask'] + + return result + + + def __len__(self): + return int(len(self.fnames)) + + +##### +# Default "patch" dataset, used for training +##### +class DatasetPatches_M(torch.utils.data.Dataset): + def __init__(self, dir_pre, dir_post, dir_mask, patch_size, device, dir_x1, dir_x2, dir_x3, dir_x4, dir_x5, dir_x6, dir_x7, dir_x8, dir_x9): + super(DatasetPatches_M, self).__init__() + self.dir_pre = dir_pre + self.dir_post = dir_post + self.dir_mask = dir_mask + self.patch_size = patch_size + + self.geom_blur_coeff = 0.0 + self.device = "cpu" + self.real_device = device + #self.temporal_frames = 3 + + self.paths_pre = sorted(os.listdir(dir_pre)) + self.paths_post = sorted(os.listdir(dir_post)) + self.paths_masks = sorted(os.listdir(dir_mask)) + + self.transform = build_transform() + self.mask_transform = build_mask_transform() + + self.images_pre = [] + self.images_pre_geom = [] + self.images_post = [] + images_mask = [] + + # additional guides + self.images_x1 = [] + self.images_x2 = [] + self.images_x3 = [] + self.images_x4 = [] + self.images_x5 = [] + self.images_x6 = [] + self.images_x7 = [] + self.images_x8 = [] + self.images_x9 = [] + + i = 0 + for p in self.paths_pre: + if p == "Thumbs.db": + continue + + p_png = os.path.splitext(p)[0] + '.png' + preim = PIL.Image.open(os.path.join(self.dir_pre, p)) + postim = PIL.Image.open(os.path.join(self.dir_post, p_png)) + maskim = PIL.Image.open(os.path.join(self.dir_mask, p_png)) + + maskim = maskim.point(lambda p: p > 128 and 255) # !!! thresholding the mask fixes possible float and int conversion errors + + pre_tensor = self.transform(preim) + if self.geom_blur_coeff != 0.0: + self.images_pre_geom.append(torch.nn.functional.interpolate(pre_tensor.unsqueeze(0), scale_factor=1.0 / self.geom_blur_coeff).squeeze(0)) + self.images_pre.append(pre_tensor) # .to(self.device)) + + if dir_x1 is not None and dir_x1 != "": + x1_im = PIL.Image.open(os.path.join(dir_x1, p)) + self.images_x1.append(self.transform(x1_im)) + if dir_x2 is not None and dir_x2 != "": + x2_im = PIL.Image.open(os.path.join(dir_x2, p)) + self.images_x2.append(self.transform(x2_im)) + if dir_x3 is not None and dir_x3 != "": + x3_im = PIL.Image.open(os.path.join(dir_x3, p)) + self.images_x3.append(self.transform(x3_im)) + if dir_x4 is not None and dir_x4 != "": + x4_im = PIL.Image.open(os.path.join(dir_x4, p)) + self.images_x4.append(self.transform(x4_im)) + if dir_x5 is not None and dir_x5 != "": + x5_im = PIL.Image.open(os.path.join(dir_x5, p)) + self.images_x5.append(self.transform(x5_im)) + if dir_x6 is not None and dir_x6 != "": + x6_im = PIL.Image.open(os.path.join(dir_x6, p)) + self.images_x6.append(self.transform(x6_im)) + if dir_x7 is not None and dir_x7 != "": + x7_im = PIL.Image.open(os.path.join(dir_x7, p)) + self.images_x7.append(self.transform(x7_im)) + if dir_x8 is not None and dir_x8 != "": + x8_im = PIL.Image.open(os.path.join(dir_x8, p)) + self.images_x8.append(self.transform(x8_im)) + if dir_x9 is not None and dir_x9 != "": + x9_im = PIL.Image.open(os.path.join(dir_x9, p)) + self.images_x9.append(self.transform(x9_im)) + + self.images_post.append(self.transform(postim)) # .to(self.device)) + images_mask.append(self.mask_transform(maskim).int().float().to(device)) + i += 1 + + + self.valid_indices = [] + self.valid_indices_left = [] + i = 0 + erosion_weights = torch.ones((1, 1, 7, 7)).to(device) + for m in images_mask: + m[m < 0.4] = 0 + m = F.conv2d(m.unsqueeze(0), erosion_weights, stride=1, padding=3) + m[m < erosion_weights.numel()] = 0 + m /= erosion_weights.numel() + + self.valid_indices.append(m.squeeze().nonzero(as_tuple=False).to(self.device)) + self.valid_indices_left.append(list(range(0, len(self.valid_indices[i])))) + i += 1 + + + def cut_patch(self, im, midpoint, size): + hs = size // 2 + hn = max(0, midpoint[0] - hs) + hx = min(midpoint[0] + hs, im.size()[1] - 1) + xn = max(0, midpoint[1] - hs) + xx = min(midpoint[1] + hs, im.size()[2] - 1) + + p = im[:, hn:hx, xn:xx] + if p.size()[1] != size or p.size()[2] != size: + r = torch.zeros((3, size, size)) + r[:, 0:p.size()[1], 0:p.size()[2]] = p + p = r + + return p + + # CURRENTLY NOT IN USE + def patch_diff(self, im, patch1_mid, patch2_mid, size): + patch1 = self.cut_patch(im, patch1_mid, size) + patch2 = self.cut_patch(im, patch2_mid, size) + + patch = patch1 - patch2 + patch = patch ** 2 + + sum = patch.sum() + + return sum + + def cut_patches(self, im_index, midpoint, midpoint_r, size): + patch_pre = self.cut_patch(self.images_pre[im_index], midpoint, size) + if self.geom_blur_coeff != 0.0: + geom_blur_patch = get_geometric_blur_patch(self.images_pre_geom[im_index], midpoint, size, self.geom_blur_coeff) + patch_pre = torch.cat((patch_pre, geom_blur_patch), dim=0) + + if len(self.images_x1) > 0: + patch_x1 = self.cut_patch(self.images_x1[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x1), dim=0) + if len(self.images_x2) > 0: + patch_x2 = self.cut_patch(self.images_x2[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x2), dim=0) + if len(self.images_x3) > 0: + patch_x3 = self.cut_patch(self.images_x3[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x3), dim=0) + if len(self.images_x4) > 0: + patch_x4 = self.cut_patch(self.images_x4[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x4), dim=0) + if len(self.images_x5) > 0: + patch_x5 = self.cut_patch(self.images_x5[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x5), dim=0) + if len(self.images_x6) > 0: + patch_x6 = self.cut_patch(self.images_x6[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x6), dim=0) + if len(self.images_x7) > 0: + patch_x7 = self.cut_patch(self.images_x7[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x7), dim=0) + if len(self.images_x8) > 0: + patch_x8 = self.cut_patch(self.images_x8[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x8), dim=0) + if len(self.images_x9) > 0: + patch_x9 = self.cut_patch(self.images_x9[im_index], midpoint, size) + patch_pre = torch.cat((patch_pre, patch_x9), dim=0) + + patch_post = self.cut_patch(self.images_post[im_index], midpoint, size) + patch_random = self.cut_patch(self.images_post[im_index], midpoint_r, size) + + return patch_pre, patch_post, patch_random + + def __getitem__(self, item): + im_index = item % len(self.images_pre) + midpoint_id = np.random.randint(0, len(self.valid_indices_left[im_index])) + midpoint_r_id = np.random.randint(0, len(self.valid_indices[im_index])) + midpoint = self.valid_indices[im_index][self.valid_indices_left[im_index][midpoint_id], :].squeeze() + midpoint_r = self.valid_indices[im_index][midpoint_r_id, :].squeeze() + + del self.valid_indices_left[im_index][midpoint_id] + if len(self.valid_indices_left[im_index]) < 1: + self.valid_indices_left[im_index] = list(range(0, len(self.valid_indices[im_index]))) + + result = {} + + for i in range(0, 1): #range(im_index - self.temporal_frames, im_index + self.temporal_frames + 1): + is_curr_item = True # if i == im_index else False + #i = max(0, i) + #i = min(len(self.images_pre)-1, i) + + patch_pre, patch_post, patch_random = self.cut_patches(im_index, midpoint, midpoint_r, self.patch_size) + + if "pre" not in result: + result['pre'] = patch_pre + else: + result['pre'] = torch.cat((result['pre'], patch_pre), dim=0) + + if is_curr_item: + result['post'] = patch_post + if is_curr_item: + result['already'] = patch_random + + return result + + def __len__(self): + return sum([(n.numel() // 2) for n in self.valid_indices]) * 5 # dont need to restart + + diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..ea9dd05 --- /dev/null +++ b/generate.py @@ -0,0 +1,64 @@ +import argparse +import os +from PIL import Image +from custom_transforms import * +import numpy as np +import torch.utils.data +import time +from data import DatasetFullImages + + + +# Main to generate images +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", help="checkpoint location", required=True) + parser.add_argument("--data_root", help="data root", required=True) + parser.add_argument("--dir_input", help="dir input", required=True) + parser.add_argument("--dir_x1", help="dir extra 1", required=False) + parser.add_argument("--dir_x2", help="dir extra 2", required=False) + parser.add_argument("--dir_x3", help="dir extra 3", required=False) + parser.add_argument("--outdir", help="output directory", required=True) + parser.add_argument("--device", help="device", required=True) + args = parser.parse_args() + + generator = (torch.load(args.checkpoint, map_location=lambda storage, loc: storage)) + generator.eval() + + if not os.path.exists(args.outdir): + os.mkdir(args.outdir) + + device = args.device + print("device: " + device, flush=True) + + generator = generator.to(device).type(torch.half) + transform = build_transform() + dataset = DatasetFullImages(args.data_root + "/" + args.dir_input, "ignore", "ignore", device, + dir_x1=args.data_root + "/" + args.dir_x1 if args.dir_x1 is not None else None, + dir_x2=args.data_root + "/" + args.dir_x2 if args.dir_x2 is not None else None, + dir_x3=args.data_root + "/" + args.dir_x3 if args.dir_x3 is not None else None, + dir_x4=None, dir_x5=None, dir_x6=None, dir_x7=None, dir_x8=None, dir_x9=None) + + imloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 + + generate_start_time = time.time() + with torch.no_grad(): + for i, batch in enumerate(imloader): + print('Batch %d / %d' % (i, len(imloader))) + + net_in = batch['pre'].to(args.device).type(torch.half) + net_out = generator(net_in) + + #image_space_in = to_image_space(batch['image'].cpu().data.numpy()) + + #image_space = to_image_space(net_out.cpu().data.numpy()) + image_space = ((net_out.clamp(-1, 1) + 1) * 127.5).permute((0, 2, 3, 1)) + image_space = image_space.cpu().data.numpy().astype(np.uint8) + + for k in range(0, len(image_space)): + im = image_space[k] #image_space[k].transpose(1, 2, 0) + Image.fromarray(im).save(os.path.join(args.outdir, batch['file_name'][k])) + + + print(f"Generating took {(time.time() - generate_start_time)}", flush=True) + diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..aefa57d --- /dev/null +++ b/logger.py @@ -0,0 +1,30 @@ +import tensorflow as tf +import os +import shutil + + +class Logger(object): + def __init__(self, log_dir, suffix=None): + """Create a summary writer logging to log_dir.""" + self.writer = tf.compat.v1.summary.FileWriter(log_dir, filename_suffix=suffix) + + def scalar_summary(self, tag, value, step): + """Log a scalar variable.""" + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + +class ModelLogger(object): + def __init__(self, log_dir, save_func): + self.log_dir = log_dir + self.save_func = save_func + + def save(self, model, epoch, isGenerator): + if isGenerator: + new_path = os.path.join(self.log_dir, "model_%05d.pth" % epoch) + else: + new_path = os.path.join(self.log_dir, "disc_%05d.pth" % epoch) + self.save_func(model, new_path) + + def copy_file(self, source): + shutil.copy(source, self.log_dir) diff --git a/models.py b/models.py new file mode 100644 index 0000000..5e518ee --- /dev/null +++ b/models.py @@ -0,0 +1,319 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +from torchvision import models + + +class UpsamplingLayer(nn.Module): + def __init__(self, channels): + super(UpsamplingLayer, self).__init__() + self.layer = nn.Upsample(scale_factor=2) + + def forward(self, x): + return self.layer(x) + +##### +# Currently default generator we use +# conv0 -> conv1 -> conv2 -> resnet_blocks -> upconv2 -> upconv1 -> conv_11 -> (conv_11_a)* -> conv_12 -> (Tanh)* +# there are 2 conv layers inside conv_11_a +# * means is optional, model uses skip-connections +class GeneratorJ(nn.Module): + def __init__(self, input_size=256, norm_layer='batch_norm', + gpu_ids=None, use_bias=False, resnet_blocks=9, tanh=False, + filters=(64, 128, 128, 128, 128, 64), input_channels=3, append_smoothers=False): + super(GeneratorJ, self).__init__() + self.input_size = input_size + assert norm_layer in [None, 'batch_norm', 'instance_norm'], \ + "norm_layer should be None, 'batch_norm' or 'instance_norm', not {}".format(norm_layer) + self.norm_layer = None + if norm_layer == 'batch_norm': + self.norm_layer = nn.BatchNorm2d + elif norm_layer == 'instance_norm': + self.norm_layer = nn.InstanceNorm2d + self.gpu_ids = gpu_ids + self.use_bias = use_bias + self.resnet_blocks = resnet_blocks + self.append_smoothers = append_smoothers + + self.conv0 = self.relu_layer(in_filters=input_channels, out_filters=filters[0], + size=7, stride=1, padding=3, + bias=self.use_bias, + norm_layer=self.norm_layer, + nonlinearity=nn.LeakyReLU(.2)) + + self.conv1 = self.relu_layer(in_filters=filters[0], + out_filters=filters[1], + size=3, stride=2, padding=1, + bias=self.use_bias, + norm_layer=self.norm_layer, + nonlinearity=nn.LeakyReLU(.2)) + + self.conv2 = self.relu_layer(in_filters=filters[1], + out_filters=filters[2], + size=3, stride=2, padding=1, + bias=self.use_bias, + norm_layer=self.norm_layer, + nonlinearity=nn.LeakyReLU(.2)) + + self.resnets = nn.ModuleList() + for i in range(self.resnet_blocks): + self.resnets.append( + self.resnet_block(in_filters=filters[2], + out_filters=filters[2], + size=3, stride=1, padding=1, + bias=self.use_bias, + norm_layer=self.norm_layer, + nonlinearity=nn.ReLU())) + + self.upconv2 = self.upconv_layer_upsample_and_conv(in_filters=filters[3] + filters[2], + # in_filters=filters[3], # disable skip-connections + out_filters=filters[4], + size=4, stride=2, padding=1, + bias=self.use_bias, + norm_layer=self.norm_layer, + nonlinearity=nn.ReLU()) + + self.upconv1 = self.upconv_layer_upsample_and_conv(in_filters=filters[4] + filters[1], + # in_filters=filters[4], # disable skip-connections + out_filters=filters[4], + size=4, stride=2, padding=1, + bias=self.use_bias, + norm_layer=self.norm_layer, + nonlinearity=nn.ReLU()) + + self.conv_11 = nn.Sequential( + nn.Conv2d(in_channels=filters[0] + filters[4] + input_channels, + # in_channels=filters[4], # disable skip-connections + out_channels=filters[5], + kernel_size=7, stride=1, padding=3, bias=self.use_bias), + nn.ReLU() + ) + + if self.append_smoothers: + self.conv_11_a = nn.Sequential( + nn.Conv2d(filters[5], filters[5], kernel_size=3, bias=self.use_bias, padding=1), + nn.ReLU(), + nn.BatchNorm2d(num_features=filters[5]), # replace with variable + nn.Conv2d(filters[5], filters[5], kernel_size=3, bias=self.use_bias, padding=1), + nn.ReLU() + ) + + if tanh: + self.conv_12 = nn.Sequential(nn.Conv2d(filters[5], 3, + kernel_size=1, stride=1, + padding=0, bias=True), + nn.Tanh()) + else: + self.conv_12 = nn.Conv2d(filters[5], 3, kernel_size=1, stride=1, + padding=0, bias=True) + + def forward(self, x): + output_0 = self.conv0(x) + output_1 = self.conv1(output_0) + output = self.conv2(output_1) + output_2 = self.conv2(output_1) # comment to disable skip-connections + for layer in self.resnets: + output = layer(output) + output + + # output = self.upconv2(output) # disable skip-connections + # output = self.upconv1(output) # disable skip-connections + # output = self.conv_11(output) # disable skip-connections + output = self.upconv2(torch.cat((output, output_2), dim=1)) + output = self.upconv1(torch.cat((output, output_1), dim=1)) + output = self.conv_11(torch.cat((output, output_0, x), dim=1)) + + if self.append_smoothers: + output = self.conv_11_a(output) + output = self.conv_12(output) + return output + + def relu_layer(self, in_filters, out_filters, size, stride, padding, bias, + norm_layer, nonlinearity): + out = nn.Sequential() + out.add_module('conv', nn.Conv2d(in_channels=in_filters, + out_channels=out_filters, + kernel_size=size, stride=stride, + padding=padding, bias=bias)) + if norm_layer: + out.add_module('normalization', + norm_layer(num_features=out_filters)) + if nonlinearity: + out.add_module('nonlinearity', nonlinearity) + return out + + def resnet_block(self, in_filters, out_filters, size, stride, padding, bias, + norm_layer, nonlinearity): + out = nn.Sequential() + if nonlinearity: + out.add_module('nonlinearity_0', nonlinearity) + out.add_module('conv_0', nn.Conv2d(in_channels=in_filters, + out_channels=out_filters, + kernel_size=size, stride=stride, + padding=padding, bias=bias)) + if norm_layer: + out.add_module('normalization', + norm_layer(num_features=out_filters)) + if nonlinearity: + out.add_module('nonlinearity_1', nonlinearity) + out.add_module('conv_1', nn.Conv2d(in_channels=in_filters, + out_channels=out_filters, + kernel_size=size, stride=stride, + padding=padding, bias=bias)) + return out + + def upconv_layer(self, in_filters, out_filters, size, stride, padding, bias, + norm_layer, nonlinearity): + out = nn.Sequential() + out.add_module('upconv', nn.ConvTranspose2d(in_channels=in_filters, + out_channels=out_filters, + kernel_size=size, # 4 + stride=stride, # 2 + padding=padding, bias=bias)) + if norm_layer: + out.add_module('normalization', + norm_layer(num_features=out_filters)) + if nonlinearity: + out.add_module('nonlinearity', nonlinearity) + return out + + def upconv_layer_upsample_and_conv(self, in_filters, out_filters, size, stride, padding, bias, + norm_layer, nonlinearity): + + parts = [UpsamplingLayer(in_filters), + nn.Conv2d(in_filters, out_filters, 3, 1, 1, bias=False)] + + if norm_layer: + parts.append(norm_layer(num_features=out_filters)) + + if nonlinearity: + parts.append(nonlinearity) + + return nn.Sequential(*parts) + + +##### +# Default discriminator +##### +class DiscriminatorN_IN(nn.Module): + def __init__(self, num_filters=64, input_channels=3, n_layers=3, + use_noise=False, noise_sigma=0.2, norm_layer='instance_norm', use_bias=True): + super(DiscriminatorN_IN, self).__init__() + + self.num_filters = num_filters + self.use_noise = use_noise + self.noise_sigma = noise_sigma + self.input_channels = input_channels + self.use_bias = use_bias + + if norm_layer == 'batch_norm': + self.norm_layer = nn.BatchNorm2d + else: + self.norm_layer = nn.InstanceNorm2d + self.net = self.make_net(n_layers, self.input_channels, 1, 4, 2, self.use_bias) + + def make_net(self, n, flt_in, flt_out=1, k=4, stride=2, bias=True): + padding = 1 + model = nn.Sequential() + + model.add_module('conv0', self.make_block(flt_in, self.num_filters, k, stride, padding, bias, None, nn.LeakyReLU)) + + flt_mult, flt_mult_prev = 1, 1 + # n - 1 blocks + for l in range(1, n): + flt_mult_prev = flt_mult + flt_mult = min(2**(l), 8) + model.add_module('conv_%d'%(l), self.make_block(self.num_filters * flt_mult_prev, self.num_filters * flt_mult, + k, stride, padding, bias, self.norm_layer, nn.LeakyReLU)) + + flt_mult_prev = flt_mult + flt_mult = min(2**n, 8) + model.add_module('conv_%d'%(n), self.make_block(self.num_filters * flt_mult_prev, self.num_filters * flt_mult, + k, 1, padding, bias, self.norm_layer, nn.LeakyReLU)) + model.add_module('conv_out', self.make_block(self.num_filters * flt_mult, 1, k, 1, padding, bias, None, None)) + return model + + def make_block(self, flt_in, flt_out, k, stride, padding, bias, norm, relu): + m = nn.Sequential() + m.add_module('conv', nn.Conv2d(flt_in, flt_out, k, stride=stride, padding=padding, bias=bias)) + if norm is not None: + m.add_module('norm', norm(flt_out)) + if relu is not None: + m.add_module('relu', relu(0.2, True)) + return m + + def forward(self, x): + return self.net(x), None # 2nd is class? + + +##### +# Perception VGG19 loss +##### +class PerceptualVGG19(nn.Module): + def __init__(self, feature_layers, use_normalization=True, path=None): + super(PerceptualVGG19, self).__init__() + if path is not None: + print(f'Loading pre-trained VGG19 model from {path}') + model = models.vgg19(pretrained=False) + model.classifier = nn.Sequential( + nn.Linear(512 * 8 * 8, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 40), + ) + model.load_state_dict(torch.load(path)) + else: + model = models.vgg19(pretrained=True) + model.float() + model.eval() + + self.model = model + self.feature_layers = feature_layers + + self.mean = torch.FloatTensor([0.485, 0.456, 0.406]) + self.mean_tensor = None + + self.std = torch.FloatTensor([0.229, 0.224, 0.225]) + self.std_tensor = None + + self.use_normalization = use_normalization + + if torch.cuda.is_available(): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + for param in self.parameters(): + param.requires_grad = False + + def normalize(self, x): + if not self.use_normalization: + return x + + if self.mean_tensor is None: + self.mean_tensor = Variable( + self.mean.view(1, 3, 1, 1).expand(x.size()), + requires_grad=False) + self.std_tensor = Variable( + self.std.view(1, 3, 1, 1).expand(x.size()), requires_grad=False) + + x = (x + 1) / 2 + return (x - self.mean_tensor) / self.std_tensor + + def run(self, x): + features = [] + + h = x + + for f in range(max(self.feature_layers) + 1): + h = self.model.features[f](h) + if f in self.feature_layers: + not_normed_features = h.clone().view(h.size(0), -1) + features.append(not_normed_features) + + return None, torch.cat(features, dim=1) + + def forward(self, x): + h = self.normalize(x) + return self.run(h) diff --git a/train.py b/train.py new file mode 100644 index 0000000..dc62b7e --- /dev/null +++ b/train.py @@ -0,0 +1,132 @@ +import argparse +import os +import data +import models as m +import torch +import torch.optim as optim +import yaml +from logger import Logger, ModelLogger +from trainers import Trainer +import sys +import numpy as np + + +def build_model(model_type, args, device): + model = getattr(m, model_type)(**args) + return model.to(device) + + +def build_optimizer(opt_type, model, args): + args['params'] = model.parameters() + opt_class = getattr(optim, opt_type) + return opt_class(**args) + + +def build_loggers(log_folder): + if not os.path.exists(log_folder): + os.makedirs(log_folder) + model_logger = ModelLogger(log_folder, torch.save) + scalar_logger = Logger(log_folder) + return scalar_logger, model_logger + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', help='Yaml config with training parameters', required=True) + parser.add_argument('--log_folder', '-l', help='Log folder', required=True) + parser.add_argument('--data_root', '-r', help='Data root folder', required=True) + parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True) + args = parser.parse_args() + + args_log_folder = args.data_root + "/" + args.log_folder + + with open(args.config, 'r') as f: + job_description = yaml.load(f, Loader=yaml.FullLoader) + + config = job_description['job'] + scalar_logger, model_logger = build_loggers(args_log_folder) + + model_logger.copy_file(args.config) + device = config.get('device') or 'cpu' + + # Check 'training_dataset' parameters + training_dataset_parameters = set(config['training_dataset'].keys()) - \ + {"type", "dir_pre", "dir_post", "dir_mask", "patch_size", "dir_x1", "dir_x2", "dir_x3", "dir_x4", "dir_x5", "dir_x6", "dir_x7", "dir_x8", "dir_x9", } + if len(training_dataset_parameters) > 0: + raise RuntimeError("Got unexpected parameter in training_dataset: " + str(training_dataset_parameters)) + + d = dict(config['training_dataset']) + d['dir_pre'] = args.data_root + "/" + d['dir_pre'] + d['dir_post'] = args.data_root + "/" + d['dir_post'] + d['device'] = config['device'] + if 'dir_mask' in d: + d['dir_mask'] = args.data_root + "/" + d['dir_mask'] + + # complete dir_x paths and set a correct number of channels + channels = 3 + for dir_x_index in range(1, 10): + dir_x_name = f"dir_x{dir_x_index}" + d[dir_x_name] = args.data_root + "/" + d[dir_x_name] if dir_x_name in d else None + channels = channels + 3 if d[dir_x_name] is not None else channels + config['generator']['args']['input_channels'] = channels + + print(d) + + generator = build_model(config['generator']['type'], config['generator']['args'], device) + #generator = (torch.load(args.data_root + "/model_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_generator = build_optimizer(config['opt_generator']['type'], generator, config['opt_generator']['args']) + + discriminator, opt_discriminator = None, None + if 'discriminator' in config: + discriminator = build_model(config['discriminator']['type'], config['discriminator']['args'], device) + #discriminator = (torch.load(args.data_root + "/disc_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_discriminator = build_optimizer(config['opt_discriminator']['type'], discriminator, config['opt_discriminator']['args']) + + if 'type' not in d: + raise RuntimeError("Type of training_dataset must be specified!") + + dataset_type = getattr(data, d.pop('type')) + training_dataset = dataset_type(**d) + + train_loader = torch.utils.data.DataLoader(training_dataset, config['trainer']['batch_size'], shuffle=False, + num_workers=config['num_workers'], drop_last=True)#, worker_init_fn=worker_init_fn) + + reconstruction_criterion = getattr(torch.nn, config['trainer']['reconstruction_criterion'])() + adversarial_criterion = getattr(torch.nn, config['trainer']['adversarial_criterion'])() + + perception_loss_model = None + perception_loss_weight = 1 + if 'perception_loss' in config: + if 'perception_model' in config['perception_loss']: + perception_loss_model = build_model(config['perception_loss']['perception_model']['type'], + config['perception_loss']['perception_model']['args'], + device) + else: + perception_loss_model = discriminator + + perception_loss_weight = config['perception_loss']['weight'] + + trainer = Trainer( + train_loader=train_loader, + data_for_dataloader=d, # data for later dataloader creation, if needed + opt_generator=opt_generator, opt_discriminator=opt_discriminator, + adversarial_criterion=adversarial_criterion, reconstruction_criterion=reconstruction_criterion, + reconstruction_weight=config['trainer']['reconstruction_weight'], + adversarial_weight=config['trainer']['adversarial_weight'], + log_interval=args.log_interval, + model_logger=model_logger, scalar_logger=scalar_logger, + perception_loss_model=perception_loss_model, + perception_loss_weight=perception_loss_weight, + use_image_loss=config['trainer']['use_image_loss'], + device=device + ) + + args_config = args.config.replace('\\', '/') + args_config = args_config[args_config.rfind('/') + 1:] + trainer.train(generator, discriminator, int(config['trainer']['epochs']), args.data_root, args_config, 0) + print("Training finished", flush=True) + sys.exit(0) diff --git a/trainers.py b/trainers.py new file mode 100644 index 0000000..9be770b --- /dev/null +++ b/trainers.py @@ -0,0 +1,260 @@ +import time +import models +import numpy as np +import six +import torch +import torch.nn as nn +from torch.autograd import Variable +from PIL import Image +from custom_transforms import * +from data import DatasetFullImages +import os + + +class Trainer(object): + def __init__(self, + train_loader, data_for_dataloader, opt_discriminator, opt_generator, + reconstruction_criterion, adversarial_criterion, reconstruction_weight, + adversarial_weight, log_interval, scalar_logger, model_logger, + perception_loss_model, perception_loss_weight, use_image_loss, device + ): + + self.train_loader = train_loader + self.data_for_dataloader = data_for_dataloader + + self.opt_discriminator = opt_discriminator + self.opt_generator = opt_generator + + self.reconstruction_criterion = reconstruction_criterion + self.adversarial_criterion = adversarial_criterion + + self.reconstruction_weight = reconstruction_weight + self.adversarial_weight = adversarial_weight + + self.scalar_logger = scalar_logger + self.model_logger = model_logger + + self.training_log = {} + self.log_interval = log_interval + + self.perception_loss_weight = perception_loss_weight + self.perception_loss_model = perception_loss_model + + self.use_adversarial_loss = False + self.use_image_loss = use_image_loss + self.device = device + + self.dataset = None + self.imloader = None + + + def run_discriminator(self, discriminator, images): + return discriminator(images) + + def compute_discriminator_loss(self, generator, discriminator, batch): + generated = generator(batch['pre']) + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_labels, _ = self.run_discriminator(discriminator, fake.detach()) + + true = self.apply_mask(batch['already'], batch, 'already_mask') + true_labels, _ = self.run_discriminator(discriminator, true) + + discriminator_loss = self.adversarial_criterion(fake_labels, self.zeros_like(fake_labels)) + \ + self.adversarial_criterion(true_labels, self.ones_like(true_labels)) + + return discriminator_loss + + def compute_generator_loss(self, generator, discriminator, batch, use_gan, use_mask): + image_loss = 0 + perception_loss = 0 + adversarial_loss = 0 + + generated = generator(batch['pre']) + + if use_mask: + generated = generated * batch['mask'] + batch['post'] = batch['post'] * batch['mask'] + + if self.use_image_loss: + if generated[0][0].shape != batch['post'][0][0].shape: + if ((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) % 2) != 0: + raise RuntimeError("batch['post'][0][0].shape[0] - generated[0][0].shape[0] must be even number") + if generated[0][0].shape[0] != generated[0][0].shape[1] or batch['post'][0][0].shape[0] != batch['post'][0][0].shape[1]: + raise RuntimeError("And also it is expected to be exact square ... fix it if you want") + boundary_size = int((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) / 2) + cropped_batch_post = batch['post'][:, :, boundary_size: -1*boundary_size, boundary_size: -1*boundary_size] + image_loss = self.reconstruction_criterion(generated, cropped_batch_post) + else: + image_loss = self.reconstruction_criterion(generated, batch['post']) + + if self.perception_loss_model is not None: + _, fake_features = self.perception_loss_model(generated) + _, target_features = self.perception_loss_model(Variable(batch['post'], requires_grad=False)) + perception_loss = ((fake_features - target_features) ** 2).mean() + + + if self.use_adversarial_loss and use_gan: + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_smiling_labels, _ = self.run_discriminator(discriminator, fake) + adversarial_loss = self.adversarial_criterion(fake_smiling_labels, self.ones_like(fake_smiling_labels)) + + return image_loss, perception_loss, adversarial_loss, generated + + + def train(self, generator, discriminator, epochs, data_root, config_yaml_name, starting_batch_num): + self.use_adversarial_loss = discriminator is not None + batch_num = starting_batch_num + save_num = 0 + + start = time.time() + for epoch in range(epochs): + np.random.seed() + for i, batch in enumerate(self.train_loader): + # just sets the models into training mode (enable BN and DO) + [m.train() for m in [generator, discriminator] if m is not None] + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + + # train discriminator + if self.use_adversarial_loss: + self.opt_discriminator.zero_grad() + discriminator_loss = self.compute_discriminator_loss(generator, discriminator, batch) + discriminator_loss.backward() + self.opt_discriminator.step() + + # train generator + self.opt_generator.zero_grad() + + g_image_loss, g_perc_loss, g_adv_loss, _ = self.compute_generator_loss(generator, discriminator, batch, use_gan=True, use_mask=False) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + generator_loss.backward() + + self.opt_generator.step() + + # log losses + current_log = {key: value.item() for key, value in six.iteritems(locals()) if + 'loss' in key and isinstance(value, Variable)} + + self.add_log(current_log) + + batch_num += 1 + + if batch_num % 100 == 0: + print(f"Batch num: {batch_num}, totally elapsed {(time.time() - start)}", flush=True) + + #if batch_num % self.log_interval == 0 or batch_num == 1: + if batch_num % self.log_interval == 0 or batch_num == 1: # (time.time() - start) > 16: + eval_start = time.time() + generator.eval() + self.test_on_full_image(generator, batch_num, data_root, config_yaml_name) + self.flush_scalar_log(batch_num, time.time() - start) + self.model_logger.save(generator, save_num, True) + #self.model_logger.save(discriminator, save_num, False) + save_num += 1 + print(f"Eval of batch: {batch_num} took {(time.time() - eval_start)}", flush=True) + + #if batch_num > 5000: + # sys.exit(0) + + self.model_logger.save(generator, 99999) + + # Accumulates the losses + def add_log(self, log): + for k, v in log.items(): + if k in self.training_log: + self.training_log[k] += v + else: + self.training_log[k] = v + + # Divide the losses by log_interval and print'em + def flush_scalar_log(self, batch_num, took): + for key in self.training_log.keys(): + self.scalar_logger.scalar_summary(key, self.training_log[key] / self.log_interval, batch_num) + + log = "[%d]" % batch_num + for key in sorted(self.training_log.keys()): + log += " [%s] % 7.4f" % (key, self.training_log[key] / self.log_interval) + + log += ". Took {}".format(took) + print(log, flush=True) + self.training_log = {} + + # Test the intermediate model on data from _gen folder + def test_on_full_image(self, generator, batch_num, data_root, config_yaml_name): + config_yaml_name = config_yaml_name.replace("reference", "").replace(".yaml", "") + + data_root = data_root.replace("_train", "_gen") + if self.dataset is None: + self.dataset = DatasetFullImages(data_root + "/" + self.data_for_dataloader['dir_pre'].split("/")[-1], + "ignore", # data_root + "/" + "ebsynth", + "ignore", # data_root + "/" + "mask", + self.device, + dir_x1=data_root + "/" + self.data_for_dataloader['dir_x1'].split("/")[-1] if self.data_for_dataloader['dir_x1'] is not None else None, + dir_x2=data_root + "/" + self.data_for_dataloader['dir_x2'].split("/")[-1] if self.data_for_dataloader['dir_x2'] is not None else None, + dir_x3=data_root + "/" + self.data_for_dataloader['dir_x3'].split("/")[-1] if self.data_for_dataloader['dir_x3'] is not None else None, + dir_x4=data_root + "/" + self.data_for_dataloader['dir_x4'].split("/")[-1] if self.data_for_dataloader['dir_x4'] is not None else None, + dir_x5=data_root + "/" + self.data_for_dataloader['dir_x5'].split("/")[-1] if self.data_for_dataloader['dir_x5'] is not None else None, + dir_x6=data_root + "/" + self.data_for_dataloader['dir_x6'].split("/")[-1] if self.data_for_dataloader['dir_x6'] is not None else None, + dir_x7=data_root + "/" + self.data_for_dataloader['dir_x7'].split("/")[-1] if self.data_for_dataloader['dir_x7'] is not None else None, + dir_x8=data_root + "/" + self.data_for_dataloader['dir_x8'].split("/")[-1] if self.data_for_dataloader['dir_x8'] is not None else None, + dir_x9=data_root + "/" + self.data_for_dataloader['dir_x9'].split("/")[-1] if self.data_for_dataloader['dir_x9'] is not None else None) + self.imloader = torch.utils.data.DataLoader(self.dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 + + with torch.no_grad(): + log = "### \n" + log = log + "[%d]" % batch_num + " " + generator_loss_on_ebsynth = 0 + for i, batch in enumerate(self.imloader): + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + g_image_loss, g_perc_loss, g_adv_loss, e_cls_loss, e_smiling_loss, gan_output =\ + 0, 0, 0, 0, 0, generator(batch['pre']) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + if True or batch['file_name'][0] != "111.png": # do not accumulate loss in train frame + generator_loss_on_ebsynth = generator_loss_on_ebsynth + generator_loss + + if True or batch['file_name'][0] in ["111.png", "101.png", "106.png", "116.png", "121.png"]: + #log = log + batch['file_name'][0] + #log = log + ": %7.4f" % generator_loss + ", " + + image_space = to_image_space(gan_output.cpu().data.numpy()) + + gt_test_ganoutput_path = data_root + "/" + "res_" + config_yaml_name + if not os.path.exists(gt_test_ganoutput_path): + os.mkdir(gt_test_ganoutput_path) + gt_test_ganoutput_path_batch_num = gt_test_ganoutput_path + "/" + str("%07d" % batch_num) + if not os.path.exists(gt_test_ganoutput_path_batch_num): + os.mkdir(gt_test_ganoutput_path_batch_num) + for k in range(0, len(image_space)): + im = image_space[k].transpose(1, 2, 0) + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path_batch_num, batch['file_name'][k])) + if i == 0: + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path, str("%07d" % batch_num) + ".png")) + + log = log + " totalLossOnEbsynth: %7.4f" % (generator_loss_on_ebsynth/(len(self.imloader))) + print(log, flush=True) + + + def apply_mask(self, x, batch, mask_key): + if mask_key in batch: + mask = Variable(batch[mask_key].expand(x.size()), requires_grad=False) + return x * (mask / 2 + 0.5) + return x + + def ones_like(self, x): + return torch.ones_like(x).to(self.device) + + def zeros_like(self, x): + return torch.zeros_like(x).to(self.device) + + @staticmethod + def to_image_space(x): + return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8)