


This is the official implementation of ViewFusion: Learning Composable Diffusion Models for Novel View Synthesis.
@misc{spiegl2024viewfusion,
title={ViewFusion: Learning Composable Diffusion Models for Novel View Synthesis},
author={Bernard Spiegl and Andrea Perin and Stéphane Deny and Alexander Ilin},
year={2024},
eprint={2402.02906},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
You can install and activate the conda environment by simply running:
conda env create -f environment.yml
conda activate view-fusion
For ARM-based macOS (not tested extensively) run:
conda env create -f environment_osx.yml
conda activate view-fusion
Version of the NMR ShapeNet dataset we use is hosted by (Niemeyer et al.). Downloadable here.
Please note that our current setup is optimized for use in a cluster computing environment and requires sharding.
To ensure correct placement (in data/nmr/
), you can download the dataset using fetch_dataset.sh
.
Afterwards, to shard the dataset, run
python data/dataset_prep.py
The default sharding will split the dataset into four shards. In order to enable parallelization, the number of shards has to be divisible by the number of GPUs you use.
Configurations for various experiments can be found in configs/
.
To launch training on a single GPU run:
python main.py -c configs/small-v100.yaml -g -t --wandb
For a distributed training setup run:
torchrun --nnodes=$NUM_NODES --nproc_per_node=$NUM_GPUS main.py -c configs/small-v100-4.yaml -g -t --wandb
where $NUM_NODES
and $NUM_GPUS
can, for instance, be replaced by 1 and 4, respectively. This would correspond to a single-node setup with four V100 GPUs. Please note that all of the experiments were run on a single-node setup, multi-node environments have not been tested thoroughly.
(In case you are using Slurm, some example scripts are available in slurm/
.)
Inference mode supports a variety of visualization options that can be executed by applying their corresponding flags:
-gif
produces animated generation around the axis along with the weights as shown in Figure 2.-ar
produces animated autoregressive generation as shown in Figure 3.-ex
performs extrapolation beyond six input views that are given at training time.
Pretrained model weights are available here via HuggingFace. For running the model using provided inference script fetch the weights by running fetch_checkpoint.sh
.
Inference can be performed on a saved checkpoint by running:
python main.py -g -i -s ./logs/pretrained --wandb -gif -ar
which produces GIFs as shown in Figure 2 and 3. The outputs are saved to Weights & Biases.
The setup draws random samples from validation visualisation dataloader.
In case you want to implement separate data pipelines or training procedures, all the architecture details are available in model/
.
At training time, the model receives:
y_0
which is the target (ground truth) of shape(B C H W)
,y_cond
which contains all the input views and is of shape(B N C H W)
where N denotes the total number of views (24 in our case),view_count
of shape(B,)
which contains the number of views used as conditioning for each sample in the batch,angle
also of shape(B,)
indicating the target angle for each sample.
At inference time, y_0
is omitted, with everything else remaining the same as training.
See paper for full implementation details.
NB Training configurations require significant amount of VRAM.
The model referenced in the paper was trained using configs/multi-view-composable-variable-small-v100-4.yaml
configuration for 710k steps (approx. 6.5 days) on 4x V100 GPUs, each with 32GB VRAM.
view-fusion
├── configs # various experiment configurations
├── data # everything data preparation and loading related
│ ├── __init__.py
│ ├── dataset_prep.py # script to shard the dataset
│ └── nmr_dataset.py # sample processing, dataloaders, nodesplitters
├── logs # default loging directory
│ └── pretrained # default pretrained model directory
│ └── config.yaml # pretrained model configuration
├── model # everything model related
│ ├── unet.py # unet architecture (used for denoising)
│ └── view_fusion.py # DDPM and composable weighting logic
├── slurm # some slurm script examples
├── utils # various utilities
│ ├── __init__.py
│ ├── checkpoint.py # checkpointing logic
│ ├── compute_metrics.py # computes metrics on a directory containing all generated test samples
│ ├── dist.py # distributed training helpers
│ ├── metrics.py # SSIM and PSNR functions
│ └── schedulers.py # learning rate scheduler
├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── environment_osx.yml
├── experiment.py # full experiment logic, including training, validation and inference
├── fetch_dataset.sh # script for downloading dataset
├── fetch_pretrained.sh # script for downloading pretrained model weights
└── main.py # main