Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
katjaschwarz committed Jan 22, 2021
2 parents e53ff20 + 743c1f7 commit b6d7217
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 12 deletions.
1 change: 1 addition & 0 deletions configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ training:
outdir: ./results
model_file: model.pt
monitoring: tensorboard
use_amp: False # Use automated mixed precision
nworkers: 6
batch_size: 8
chunk: 32768 # 1024*32
Expand Down
4 changes: 1 addition & 3 deletions data/download_carla.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
mkdir -p ./carla
cd ./carla
wget https://s3.eu-central-1.amazonaws.com/avg-projects/graf/data/carla.zip
unzip carla.zip
cd ..
cd ..
4 changes: 2 additions & 2 deletions graf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def build_models(config, disc=True):

config_nerf = Namespace(**config['nerf'])
# Update config for NERF
config_nerf.chunk = config['training']['chunk']
config_nerf.chunk = min(config['training']['chunk'], 1024*config['training']['batch_size']) # let batch size for training with patches limit the maximal memory
config_nerf.netchunk = config['training']['netchunk']
config_nerf.white_bkgd = config['data']['white_bkgd']
config_nerf.feat_dim = config['z_dist']['dim']
Expand All @@ -143,7 +143,7 @@ def build_models(config, disc=True):
ray_sampler=ray_sampler,
render_kwargs_train=render_kwargs_train, render_kwargs_test=render_kwargs_test,
parameters=params, named_parameters=named_parameters,
chunk=config['training']['chunk'],
chunk=config_nerf.chunk,
range_u=(float(config['data']['umin']), float(config['data']['umax'])),
range_v=(float(config['data']['vmin']), float(config['data']['vmax'])),
orthographic=config['data']['orthographic'],
Expand Down
33 changes: 33 additions & 0 deletions graf/gan_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,45 @@
import os
from tqdm import tqdm

from submodules.GAN_stability.gan_training.train import toggle_grad, Trainer as TrainerBase
from submodules.GAN_stability.gan_training.eval import Evaluator as EvaluatorBase
from submodules.GAN_stability.gan_training.metrics import FIDEvaluator, KIDEvaluator

from .utils import save_video, color_depth_map


class Trainer(TrainerBase):
def __init__(self, *args, use_amp=False, **kwargs):
super(Trainer, self).__init__(*args, **kwargs)
self.use_amp = use_amp
if self.use_amp:
self.scaler = torch.cuda.amp.GradScaler()

def generator_trainstep(self, y, z):
if not self.use_amp:
return super(Trainer, self).generator_trainstep(y, z)
assert (y.size(0) == z.size(0))
toggle_grad(self.generator, True)
toggle_grad(self.discriminator, False)
self.generator.train()
self.discriminator.train()
self.g_optimizer.zero_grad()

with torch.cuda.amp.autocast():
x_fake = self.generator(z, y)
d_fake = self.discriminator(x_fake, y)
gloss = self.compute_loss(d_fake, 1)
self.scaler.scale(gloss).backward()

self.scaler.step(self.g_optimizer)
self.scaler.update()

return gloss.item()

def discriminator_trainstep(self, x_real, y, z):
return super(Trainer, self).discriminator_trainstep(x_real, y, z) # spectral norm raises error for when using amp


class Evaluator(EvaluatorBase):
def __init__(self, eval_fid_kid, *args, **kwargs):
super(Evaluator, self).__init__(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions submodules/nerf_pytorch/run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def config_parser():
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer')
parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network')
parser.add_argument("--N_samples", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)')
parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
Expand Down Expand Up @@ -545,7 +545,7 @@ def train():
return

# Prepare raybatch tensor if batching random rays
N_rand = args.N_samples
N_rand = args.N_rand
use_batching = not args.no_batching
if use_batching:
# For random ray batching
Expand Down
6 changes: 4 additions & 2 deletions submodules/nerf_pytorch/run_nerf_helpers_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from functools import partial

# TODO: remove this dependency
from torchsearchsorted import searchsorted
Expand All @@ -12,6 +13,7 @@
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
relu = partial(F.relu, inplace=True) # saves a lot of memory


# Positional encoding (section 5.1)
Expand Down Expand Up @@ -101,7 +103,7 @@ def forward(self, x):
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
h = relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)

Expand All @@ -112,7 +114,7 @@ def forward(self, x):

for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
h = relu(h)

rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
Expand Down
5 changes: 4 additions & 1 deletion submodules/nerf_pytorch/run_nerf_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
# from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from functools import partial

import matplotlib.pyplot as plt

Expand All @@ -18,6 +19,8 @@
np.random.seed(0)
DEBUG = False

relu = partial(F.relu, inplace=True) # saves a lot of memory


def batchify(fn, chunk):
if chunk is None:
Expand Down Expand Up @@ -232,7 +235,7 @@ def create_nerf(args):
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
""" A helper function for `render_rays`.
"""
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
raw2alpha = lambda raw, dists, act_fn=relu: 1.-torch.exp(-act_fn(raw)*dists)

dists = z_vals[...,1:] - z_vals[...,:-1]
dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
import sys
sys.path.append('submodules') # needed to make imports work in GAN_stability

from graf.gan_training import Evaluator as Evaluator
from graf.gan_training import Trainer, Evaluator
from graf.config import get_data, build_models, save_config, update_config, build_lr_scheduler
from graf.utils import count_trainable_parameters, get_nsamples
from graf.transforms import ImgToPatch

from GAN_stability.gan_training import utils
from GAN_stability.gan_training.train import Trainer
from GAN_stability.gan_training.train import update_average
from GAN_stability.gan_training.logger import Logger
from GAN_stability.gan_training.checkpoints import CheckpointIO
Expand Down Expand Up @@ -208,6 +207,7 @@
# Trainer
trainer = Trainer(
generator, discriminator, g_optimizer, d_optimizer,
use_amp=config['training']['use_amp'],
gan_type=config['training']['gan_type'],
reg_type=config['training']['reg_type'],
reg_param=config['training']['reg_param']
Expand Down

0 comments on commit b6d7217

Please sign in to comment.