Skip to content

Commit

Permalink
refactor: training config updated
Browse files Browse the repository at this point in the history
  • Loading branch information
bagxi committed Feb 4, 2022
1 parent 02615d3 commit ae74542
Showing 1 changed file with 89 additions and 84 deletions.
173 changes: 89 additions & 84 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ model:
_key_value: true

&generator_model generator:
_target_: esrgan.model.EncoderDecoderNet
_target_: esrgan.models.EncoderDecoderNet
encoder:
_target_: esrgan.model.module.ESREncoder
_target_: esrgan.models.ESREncoder
in_channels: &num_channels 3
out_channels: &latent_channels 64
num_basic_blocks: 16
Expand All @@ -20,27 +20,27 @@ model:
inplace: true
residual_scaling: 0.2
decoder:
_target_: esrgan.model.module.SRResNetDecoder
_target_: esrgan.models.ESRNetDecoder
in_channels: *latent_channels
out_channels: *num_channels
scale_factor: *upscale
activation: *activation

&discriminator_model discriminator:
_target_: esrgan.model.VGGConv
_target_: esrgan.models.VGGConv
encoder:
_target_: esrgan.model.module.StridedConvEncoder
_target_: esrgan.models.StridedConvEncoder
pool:
_target_: catalyst.contrib.layers.AdaptiveAvgPool2d
output_size: [7,7]
head:
_target_: esrgan.model.module.LinearHead
_target_: esrgan.models.LinearHead
in_channels: 25088 # 512 * (7x7)
out_channels: 1
latent_channels: [1024]

args:
logdir: ./logs/cata21.06/14-esrgan_x4_192ps
logdir: logs

runner:
_target_: esrgan.runner.GANConfigRunner
Expand All @@ -49,59 +49,71 @@ runner:

stages:
stage1_supervised:
num_epochs: 40
num_epochs: 10000

loaders: &loaders
train: &train_loader
_target_: torch.utils.data.DataLoader
dataset: &train_dataset
_target_: esrgan.dataset.DIV2KDataset
root: data
train: true
target_type: bicubic_X4
patch_size: [*patch_size,*patch_size]
transform:
_target_: albumentations.Compose
transforms:
- &spatial_transforms
_target_: albumentations.Compose
transforms:
- _target_: albumentations.HorizontalFlip
p: 0.5
additional_targets:
real_image: image
- &hard_transforms
_target_: albumentations.Compose
transforms:
- _target_: albumentations.CoarseDropout
max_holes: 2
max_height: 2
max_width: 2
- _target_: albumentations.ImageCompression
quality_lower: 65
p: 0.25
- &post_transforms
dataset:
_target_: torch.utils.data.ConcatDataset
datasets:
- &div2k_dataset
_target_: esrgan.datasets.DIV2KDataset
root: data
train: true
target_type: bicubic_X4
patch_size: [*patch_size,*patch_size]
transform:
_target_: albumentations.Compose
transforms:
- _target_: albumentations.Normalize
mean: 0
std: 1
- _target_: albumentations.ToTensorV2
additional_targets:
real_image: image
low_resolution_image_key: image
high_resolution_image_key: real_image
download: true
batch_size: 48
- &spatial_transforms
_target_: albumentations.Compose
transforms:
_target_: albumentations.OneOf
transforms:
- _target_: albumentations.Flip
p: 0.75 # p = 1/4 (vflip) + 1/4 (hflip) + 1/4 (flip)
- _target_: albumentations.Transpose
p: 0.25 # p = 1/4
p: 0.5
additional_targets:
real_image: image
- &hard_transforms
_target_: albumentations.Compose
transforms:
- _target_: albumentations.CoarseDropout
max_holes: 8
max_height: 2
max_width: 2
- _target_: albumentations.ImageCompression
quality_lower: 65
p: 0.25
- &post_transforms
_target_: albumentations.Compose
transforms:
- _target_: albumentations.Normalize
mean: 0
std: 1
- _target_: albumentations.ToTensorV2
additional_targets:
real_image: image
low_resolution_image_key: image
high_resolution_image_key: real_image
download: true

- &flickr2k_dataset
<< : [*div2k_dataset] # Flickr2K with the same params as in `DIV2KDataset`
_target_: esrgan.datasets.Flickr2K
batch_size: 16
shuffle: true
num_workers: 8
pin_memory: true
drop_last: true

valid:
<< : [*train_loader]
dataset:
<< : [*train_dataset]
dataset: # redefine dataset to use only DIV2K
<< : [*div2k_dataset]
train: false
transform: *post_transforms
batch_size: 1
Expand All @@ -115,22 +127,20 @@ stages:
optimizer:
_key_value: true

generator_optimizer:
_target_: torch.optim.AdamW
lr_linear_scaling:
lr: 0.0002
base_batch_size: &base_batch_size 16
generator:
_target_: torch.optim.Adam
lr: 0.0002
weight_decay: 0.0
_model: *generator_model

scheduler:
_key_value: true

generator_scheduler:
_target_: torch.optim.lr_scheduler.MultiStepLR
milestones: [8,20,28]
generator:
_target_: torch.optim.lr_scheduler.StepLR
step_size: 500
gamma: 0.5
_optimizer: generator_optimizer
_optimizer: generator

callbacks: &callbacks
psnr_metric:
Expand Down Expand Up @@ -168,46 +178,45 @@ stages:
_target_: catalyst.callbacks.OptimizerCallback
metric_key: loss_content
model_key: *generator_model
optimizer_key: generator_optimizer
optimizer_key: generator
grad_clip_fn: &grad_clip_fn
_mode_: partial
_target_: torch.nn.utils.clip_grad_value_
clip_value: 5.0

scheduler_generator:
_target_: catalyst.callbacks.SchedulerCallback
scheduler_key: generator_scheduler
scheduler_key: generator
loader_key: valid
metric_key: loss_content

stage2_gan:
num_epochs: 16
num_epochs: 8000

loaders:
<< : [*loaders]
train:
<< : [*train_loader]
dataset:
<< : [*train_dataset]
<< : [*div2k_dataset]
transform:
_target_: albumentations.Compose
transforms:
- *spatial_transforms
- *post_transforms
batch_size: 32
batch_size: 16

criterion:
<< : [*criterions]

perceptual_loss:
_target_: esrgan.criterions.PerceptualLoss
_target_: esrgan.nn.PerceptualLoss
layers:
conv5_4: 1.0

# TODO: fix GAN losses
adversarial_generator_loss:
# `esrgan.criterions.RelativisticAdversarialLoss` | `esrgan.criterions.AdversarialLoss`
_target_: &adversarial_criterion esrgan.criterions.RelativisticAdversarialLoss
# `esrgan.nn.RelativisticAdversarialLoss` | `esrgan.nn.AdversarialLoss`
_target_: &adversarial_criterion esrgan.nn.RelativisticAdversarialLoss
mode: generator
adversarial_discriminator_loss:
_target_: *adversarial_criterion
Expand All @@ -216,36 +225,32 @@ stages:
optimizer:
_key_value: true

generator_optimizer:
generator:
_target_: torch.optim.AdamW
lr_linear_scaling:
lr: 0.00003
base_batch_size: *base_batch_size
lr: 0.0001
weight_decay: 0.0
_model: *generator_model

discriminator_optimizer:
discriminator:
_target_: torch.optim.AdamW
lr_linear_scaling:
lr: 0.0001
base_batch_size: *base_batch_size
lr: 0.0001
weight_decay: 0.0
_model: *discriminator_model

scheduler:
_key_value: true

generator_scheduler:
generator:
_target_: torch.optim.lr_scheduler.MultiStepLR
milestones: [16,24,32]
milestones: &scheduler_milestones [1000,2000,4000,6000]
gamma: 0.5
_optimizer: generator_optimizer
_optimizer: generator

discriminator_scheduler:
discriminator:
_target_: torch.optim.lr_scheduler.MultiStepLR
milestones: [8,16,24,32]
milestones: *scheduler_milestones
gamma: 0.5
_optimizer: discriminator_optimizer
_optimizer: discriminator

callbacks:
# re-use `psnr_metric`, `ssim_metric`, and `loss_content` callbacks
Expand All @@ -269,7 +274,7 @@ stages:
metrics:
loss_content: 0.01
loss_perceptual: 1.0
loss_adversarial: 0.05
loss_adversarial: 0.005
mode: weighted_sum

loss_discriminator:
Expand All @@ -283,22 +288,22 @@ stages:
_target_: catalyst.callbacks.OptimizerCallback
metric_key: *generator_loss
model_key: *generator_model
optimizer_key: generator_optimizer
optimizer_key: generator
grad_clip_fn: *grad_clip_fn
optimizer_discriminator:
_target_: catalyst.callbacks.OptimizerCallback
metric_key: *discriminator_loss
model_key: *discriminator_model
optimizer_key: discriminator_optimizer
optimizer_key: discriminator
grad_clip_fn: *grad_clip_fn

scheduler_generator:
_target_: catalyst.callbacks.SchedulerCallback
scheduler_key: generator_scheduler
scheduler_key: generator
loader_key: valid
metric_key: *generator_loss
scheduler_discriminator:
_target_: catalyst.callbacks.SchedulerCallback
scheduler_key: discriminator_scheduler
scheduler_key: discriminator
loader_key: valid
metric_key: *discriminator_loss

0 comments on commit ae74542

Please sign in to comment.