Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolai256 authored Oct 29, 2022
1 parent eaec966 commit 2594acf
Showing 1 changed file with 97 additions and 8 deletions.
105 changes: 97 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch.optim as optim
import yaml
from logger import Logger, ModelLogger
from trainers import Trainer
from trainers import *
import sys
import numpy as np

#from trainers import *

def build_model(model_type, args, device):
model = getattr(m, model_type)(**args)
Expand Down Expand Up @@ -42,10 +42,96 @@ def worker_init_fn(worker_id):
parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True)
parser.add_argument('--resume', '-rs', type=str, help='resume', required=False)
parser.add_argument('--projectname', type=str, help='name of the project_', required=True)

parser.add_argument('--logpath', type=str, help='path where your training happens',default = 'logs')

parser.add_argument('--perception_loss_weight', help='reconstruction_weight', default=6.0)
parser.add_argument('--reconstruction_weight', help='reconstruction_weight', default=4.)
parser.add_argument('--adversarial_weight', help='adversarial_weight', default=0.5)
parser.add_argument('--append_smoothers', type=str, help='movement smoothing',default = 'True', choices=['True', 'False'])
parser.add_argument('--filters_layers', type=str, help='test',default = '326412812812864', choices=['326412812812864', '323232323232','326464646464' ])
parser.add_argument('--patch_size', type=str, help='test',default = '16', choices=[ '16','32','64','128'])
parser.add_argument('--use_normalization', type=str, help='test',default = 'False', choices=['True', 'False'])
parser.add_argument('--use_image_loss', type=str, help='test',default = 'True', choices=['True', 'False'])
parser.add_argument('--tanh', type=str, help='test',default = 'True', choices=['True', 'False'])
parser.add_argument('--use_bias', type=str, help='test',default = 'True', choices=['True', 'False'])
args = parser.parse_args()

perception_loss_weight = args.perception_loss_weight
adversarial_weight = args.adversarial_weight
reconstruction_weight = args.reconstruction_weight
perception_loss_weight2 = perception_loss_weight.replace( "'", "",2)
adversarial_weight2 = adversarial_weight.replace( "'", "",2)
reconstruction_weight2 = reconstruction_weight.replace( "'", "",2)

#yaml config changer
import yaml

fname = args.config

import pathlib
from ruamel.yaml import YAML
import ruamel
#import ruamel.yaml
def set_state(state):
#yaml = ruamel.yaml.YAML()
yaml = YAML()
mf = pathlib.Path(fname)
yaml.default_flow_style = True
doc = yaml.load(mf)#, Loader=ruamel.yaml.RoundTripLoader)#, ruamel.yaml.preserve_quotes=False)
if args.append_smoothers == 'False':
doc['generator']['args']['append_smoothers'] = False
elif args.append_smoothers == 'True':
doc['generator']['args']['append_smoothers'] = True
if args.filters_layers == '326412812812864':
doc['generator']['args']['filters']= [32, 64, 128, 128, 128, 64]
elif args.filters_layers == '323232323232':
doc['generator']['args']['filters']= [32, 32, 32, 32, 32, 32]
elif args.filters_layers == '323232323232':
doc['generator']['args']['filters']= [32, 32, 32, 32, 32, 32]
elif args.filters_layers == '326464646464' :
doc['generator']['args']['filters']= [32, 64, 64, 64, 64, 64]
if args.patch_size == '128':
doc['training_dataset']['patch_size']= 128
elif args.patch_size == '16':
doc['training_dataset']['patch_size']= 16
elif args.patch_size == '32':
doc['training_dataset']['patch_size']= 32
elif args.patch_size == '64':
doc['training_dataset']['patch_size']= 64
if args.use_normalization == 'False':
doc['perception_loss']['perception_model']['args']['use_normalization'] = False
elif args.use_normalization == 'True':
doc['perception_loss']['perception_model']['args']['use_normalization'] = True
if args.use_image_loss == 'False':
doc['trainer']['use_image_loss'] = False
elif args.use_image_loss == 'True':
doc['trainer']['use_image_loss'] = True
if args.tanh == 'False':
doc['generator']['args']['tanh'] = False
elif args.tanh == 'True':
doc['generator']['args']['tanh'] = True
if args.use_bias == 'False':
doc['generator']['args']['use_bias'] = False
elif args.use_bias == 'True':
doc['generator']['args']['use_bias'] = True
doc['trainer']['adversarial_weight'] = [adversarial_weight]
doc['trainer']['reconstruction_weight'] = [reconstruction_weight]
#doc['trainer']['adversarial_weight'] = 0.5
#doc['trainer']['reconstruction_weight'] = 40.
yaml.dump(doc, mf)
set_state(2)
####################################################

"""import sys
from ruamel.yaml import YAML
import pathlib
fname = args.config
mf = pathlib.Path(fname)
yaml1 = YAML()
data = yaml1.load(yaml_str)
yaml1.dump(data, mf)"""

if args.logpath:
data_path = args.logpath
Expand All @@ -57,8 +143,10 @@ def worker_init_fn(worker_id):

with open(args.config, 'r') as f:
job_description = yaml.load(f, Loader=yaml.FullLoader)



config = job_description['job']

scalar_logger, model_logger = build_loggers(args_log_folder)

model_logger.copy_file(args.config)
Expand Down Expand Up @@ -120,15 +208,16 @@ def worker_init_fn(worker_id):
else:
perception_loss_model = discriminator

perception_loss_weight = config['perception_loss']['weight']

perception_loss_weight = int(float(perception_loss_weight2))
adversarial_weight3 = int(float(adversarial_weight2))
reconstruction_weight3 = int(float(reconstruction_weight2))
trainer = Trainer(
train_loader=train_loader,
data_for_dataloader=d, # data for later dataloader creation, if needed
opt_generator=opt_generator, opt_discriminator=opt_discriminator,
adversarial_criterion=adversarial_criterion, reconstruction_criterion=reconstruction_criterion,
reconstruction_weight=config['trainer']['reconstruction_weight'],
adversarial_weight=config['trainer']['adversarial_weight'],
reconstruction_weight=adversarial_weight3,#config['trainer']['reconstruction_weight'],#args.reconstruction_weight,#config['trainer']['reconstruction_weight'],
adversarial_weight=reconstruction_weight3,#config['trainer']['adversarial_weight'],#args.adversarial_weight,#config['trainer']['adversarial_weight'],
log_interval=args.log_interval,
model_logger=model_logger, scalar_logger=scalar_logger,
perception_loss_model=perception_loss_model,
Expand Down

0 comments on commit 2594acf

Please sign in to comment.