Skip to content

Commit

Permalink
optimize memory requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
katjaschwarz committed Dec 10, 2020
1 parent cf73337 commit 19663a3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
4 changes: 2 additions & 2 deletions graf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def build_models(config, disc=True):

config_nerf = Namespace(**config['nerf'])
# Update config for NERF
config_nerf.chunk = config['training']['chunk']
config_nerf.chunk = min(config['training']['chunk'], 1024*config['training']['batch_size']) # let batch size for training with patches limit the maximal memory
config_nerf.netchunk = config['training']['netchunk']
config_nerf.white_bkgd = config['data']['white_bkgd']
config_nerf.feat_dim = config['z_dist']['dim']
Expand All @@ -143,7 +143,7 @@ def build_models(config, disc=True):
ray_sampler=ray_sampler,
render_kwargs_train=render_kwargs_train, render_kwargs_test=render_kwargs_test,
parameters=params, named_parameters=named_parameters,
chunk=config['training']['chunk'],
chunk=config_nerf.chunk,
range_u=(float(config['data']['umin']), float(config['data']['umax'])),
range_v=(float(config['data']['vmin']), float(config['data']['vmax'])),
orthographic=config['data']['orthographic'],
Expand Down
6 changes: 4 additions & 2 deletions submodules/nerf_pytorch/run_nerf_helpers_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from functools import partial

# TODO: remove this dependency
from torchsearchsorted import searchsorted
Expand All @@ -12,6 +13,7 @@
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
relu = partial(F.relu, inplace=True) # saves a lot of memory


# Positional encoding (section 5.1)
Expand Down Expand Up @@ -101,7 +103,7 @@ def forward(self, x):
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
h = relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)

Expand All @@ -112,7 +114,7 @@ def forward(self, x):

for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
h = relu(h)

rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
Expand Down
5 changes: 4 additions & 1 deletion submodules/nerf_pytorch/run_nerf_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
# from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from functools import partial

import matplotlib.pyplot as plt

Expand All @@ -18,6 +19,8 @@
np.random.seed(0)
DEBUG = False

relu = partial(F.relu, inplace=True) # saves a lot of memory


def batchify(fn, chunk):
if chunk is None:
Expand Down Expand Up @@ -232,7 +235,7 @@ def create_nerf(args):
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
""" A helper function for `render_rays`.
"""
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
raw2alpha = lambda raw, dists, act_fn=relu: 1.-torch.exp(-act_fn(raw)*dists)

dists = z_vals[...,1:] - z_vals[...,:-1]
dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
Expand Down

0 comments on commit 19663a3

Please sign in to comment.