Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
katjaschwarz committed Nov 30, 2020
1 parent f053288 commit 7ffd877
Show file tree
Hide file tree
Showing 119 changed files with 18,445 additions and 2 deletions.
168 changes: 166 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,168 @@
# GRAF
Official code release for "GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis"

Coming soon!
<div style="text-align: center">
<img src="animations/carla_256.gif" width="512"/><br>
</div>

This repository contains official code for the paper
[GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis](https://avg.is.tuebingen.mpg.de/publications/schwarz2020neurips).

You can find detailed usage instructions for training your own models and using pre-trained models below.


If you find our code or paper useful, please consider citing

@inproceedings{Schwarz2020NEURIPS,
title = {GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis},
author = {Schwarz, Katja and Liao, Yiyi and Niemeyer, Michael and Geiger, Andreas},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2020}
}

## Installation
First you have to make sure that you have all dependencies in place.
The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).

You can create an anaconda environment called `graf` using
```
conda env create -f environment.yaml
conda activate graf
```

Next, for nerf-pytorch install torchsearchsorted. Note that this requires `torch>=1.4.0` and `CUDA >= v10.1`.
You can install torchsearchsorted via
```
cd submodules/nerf_pytorch
pip install -r requirements.txt
cd torchsearchsorted
pip install .
cd ../../../
```

## Demo

You can now test our code via:
```
python eval.py configs/carla.yaml --pretrained --rotation_elevation
```
This script should create a folder `results/carla_128_from_pretrained/eval/` where you can find generated videos varying camera pose for the Cars dataset.

## Datasets

If you only want to generate images using our pretrained models you do not need to download the datasets.
The datasets are only needed if you want to train a model from scratch.

### Cars

To download the Cars dataset from the paper simply run
```
cd data
./download_carla.sh
cd ..
```
This creates a folder `data/carla/` and downloads the images as a zip file.
Next extract the images to `data/carla/`.

### Faces

Download [celebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
Then replace `data/celebA` in `configs/celebA.yaml` with `*PATH/TO/CELEBA*/Img/img_align_celebA`.

Download [celebA_hq](https://github.com/tkarras/progressive_growing_of_gans).
Then replace `data/celebA_hq` in `configs/celebAHQ.yaml` with `*PATH/TO/CELEBA_HQ*`.

### Cats
Download the [CatDataset](https://www.kaggle.com/crawford/cat-dataset).
Run
```
cd data
python preprocess_cats.py PATH/TO/CATS/DATASET
cd ..
```
to preprocess the data and save it to `data/cats`.
If successful this script should print: `Preprocessed 9407 images.`

### Birds
Download [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and the corresponding [Segmentation Masks](https://drive.google.com/file/d/1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP/view).
Run
```
cd data
python preprocess_cub.py PATH/TO/CUB-200-2011 PATH/TO/SEGMENTATION/MASKS
cd ..
```
to preprocess the data and save it to `data/cub`.
If successful this script should print: `Preprocessed 8444 images.`

## Usage

When you have installed all dependencies, you are ready to run our pre-trained models for 3D-aware image synthesis.

### Generate images using a pretrained model

To evaluate a pretrained model, run
```
python eval.py CONFIG.yaml --pretrained --fid_kid --rotation_elevation --shape_appearance
```
where you replace CONFIG.yaml with one of the config files in `./configs`.

This script should create a folder `results/EXPNAME/eval` with FID and KID scores in `fid_kid.csv`, videos for rotation and elevation in the respective folders and an interpolation for shape and appearance, `shape_appearance.png`.

Note that some pretrained models are available for different image sizes which you can choose by setting `data:imsize` in the config file to one of the following values:
```
configs/carla.yaml:
data:imsize 64 or 128 or 256 or 512
configs/celebA.yaml:
data:imsize 64 or 128
configs/celebAHQ.yaml:
data:imsize 256 or 512
```

### Train a model from scratch

To train a 3D-aware generative model from scratch run
```
python train.py CONFIG.yaml
```
where you replace `CONFIG.yaml` with your config file.
The easiest way is to use one of the existing config files in the `./configs` directory
which correspond to the experiments presented in the paper.
Note that this will train the model from scratch and will not resume training for a pretrained model.

You can monitor on <http://localhost:6006> the training process using [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard):
```
cd OUTPUT_DIR
tensorboard --logdir ./monitoring --port 6006
```
where you replace `OUTPUT_DIR` with the respective output directory.

For available training options, please take a look at `configs/default.yaml`.

### Evaluation of a new model

For evaluation of the models run
```
python eval.py CONFIG.yaml --fid_kid --rotation_elevation --shape_appearance
```
where you replace `CONFIG.yaml` with your config file.

## Multi-View Consistenty Check

You can evaluate the multi-view consistency of the generated images by running a Multi-View-Stereo (MVS) algorithm on the generated images. This evaluation uses [COLMAP](https://colmap.github.io/) and make sure that you have COLMAP installed to run
```
python eval.py CONFIG.yaml --reconstruction
```
where you replace `CONFIG.yaml` with your config file. You can also evaluate our pretrained models via:
```
python eval.py configs/carla.yaml --pretrained --reconstruction
```
This script should create a folder `results/EXPNAME/eval/reconstruction/` where you can find generated multi-view images in `images/` and the corresponding 3D reconstructions in `models/`.

## Further Information

### GAN training

This repository uses Lars Mescheder's awesome framework for [GAN training](https://github.com/LMescheder/GAN_stability).

### NeRF

We base our code for the Generator on this great [Pytorch reimplementation](https://github.com/yenchenlin/nerf-pytorch) of Neural Radiance Fields.
Binary file added animations/carla_256.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions configs/carla.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
expname: carla_128
data:
imsize: 128
datadir: data/carla
type: carla
radius: 10.
near: 7.5
far: 12.5
fov: 30.0
15 changes: 15 additions & 0 deletions configs/cats.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname: cats_64
data:
datadir: data/cats
type: cats
imsize: 64
white_bkgd: False
radius: 10
near: 7.5
far: 12.5
fov: 10
umin: 0
umax: 0.19444444444444445 #70 deg
vmin: 0.32898992833716556 # 70 deg
vmax: 0.45642212862617093 # 85 deg

15 changes: 15 additions & 0 deletions configs/celebA.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname: celebA_64
data:
datadir: /PATH/TO/CELEBA/Img/img_align_celebA
type: celebA
imsize: 64
white_bkgd: False
radius: 9.5,10.5
near: 7.5
far: 12.5
fov: 10.
umin: 0
umax: 0.25
vmin: 0.32898992833716556 # 70 deg
vmax: 0.45642212862617093 # 85 deg

19 changes: 19 additions & 0 deletions configs/celebAHQ.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
expname: celebAHQ_256
data:
datadir: /PATH/TO/CELEBA_HQ
type: celebA_hq
imsize: 256
white_bkgd: False
radius: 9.5,10.5
near: 7.5
far: 12.5
fov: 10
umin: 0
umax: 0.25
vmin: 0.32898992833716556 # 70 deg
vmax: 0.45642212862617093 # 85 deg
ray_sampler:
min_scale: 0.125
scale_anneal: 0.0019
training:
fid_every: 10000
15 changes: 15 additions & 0 deletions configs/cub.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname: cub_64
data:
imsize: 64
datadir: data/cub
type: cub
radius: 9,11
near: 7.5
far: 12.5
fov: 30
vmin: 0.24999999999999994 # 60 deg
vmax: 0.5435778713738291 # 95 deg
discriminator:
hflip: True
nerf:
use_viewdirs: False
13 changes: 13 additions & 0 deletions configs/debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
expname: debug
data:
imsize: 64
datadir: data/cub
type: cub
radius: 10.
near: 7.5
far: 12.5
fov: 30.0
training:
batch_size: 2
nworkers: 0
fid_every: -1
70 changes: 70 additions & 0 deletions configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
expname: default
data:
datadir: data/carla
type: carla
imsize: 64
white_bkgd: True
near: 1.
far: 6.
radius: 3.4 # set according to near and far plane
fov: 90.
orthographic: False
umin: 0. # 0 deg, convert to degree via 360. * u
umax: 1. # 360 deg, convert to degree via 360. * u
vmin: 0. # 0 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi
vmax: 0.45642212862617093 # 85 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi
nerf:
i_embed: 0
use_viewdirs: True
multires: 10
multires_views: 4
N_samples: 64
N_importance: 0
netdepth: 8
netwidth: 256
netdepth_fine: 8
netwidth_fine: 256
perturb: 1.
raw_noise_std: 1.
decrease_noise: True
z_dist:
type: gauss
dim: 256
dim_appearance: 128 # This dimension is subtracted from "dim"
ray_sampler:
min_scale: 0.25
max_scale: 1.
scale_anneal: 0.0025 # no effect if scale_anneal<0, else the minimum scale decreases exponentially until converge to min_scale
N_samples: 1024 # 32*32, patchsize
discriminator:
ndf: 64
hflip: False # Randomly flip discriminator input horizontally
training:
outdir: ./results
model_file: model.pt
monitoring: tensorboard
nworkers: 6
batch_size: 8
chunk: 32768 # 1024*32
netchunk: 65536 # 1024*64
lr_g: 0.0005
lr_d: 0.0001
lr_anneal: 0.5
lr_anneal_every: 50000,100000,200000
equalize_lr: False
gan_type: standard
reg_type: real
reg_param: 10.
optimizer: rmsprop
n_test_samples_with_same_shape_code: 4
take_model_average: true
model_average_beta: 0.999
model_average_reinit: false
restart_every: -1
save_best: fid
fid_every: 5000 # Valid for FID and KID
print_every: 10
sample_every: 500
save_every: 900
backup_every: 50000
video_every: 10000
15 changes: 15 additions & 0 deletions configs/pretrained_models.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
carla:
64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_64.pt
128: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_128.pt
256: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_256.pt
512: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_512.pt
celebA:
64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_64.pt
128: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_128.pt
celebA_hq:
256: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_hq_256.pt
512: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_hq_512.pt
cats:
64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/cats/cats_64.pt
cub:
64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/birds/cub_64.pt
Loading

0 comments on commit 7ffd877

Please sign in to comment.