Overview | Setup | Reproducing key results | Contributing | Citation
This repository is the official implementation of A Unified Framework for U-Net Design and Analysis.
This repository contains four self-contained sub-repositories (in the folders diff_mnist
, diff_cifar
, pdearena
and wmh
). Their corresponding original code bases are:
diff_mnist
: Generative modelling with diffusion models on MNIST. Original code base: https://github.com/JTT94/torch_ddpmdiff_cifar
: Generative modelling with diffusion models on CIFAR10. Original code base: https://github.com/w86763777/pytorch-ddpmpdearena
: PDE modelling on Navier-Stokes and Shallow water. Original code base: https://github.com/microsoft/pdearenawmh
: Image segmentation on White Matter Hyperintensity (WMH) Segmentation Challenge. Original code base: https://github.com/hongweilibran/wmh_ibbmTum
We refer to Appendix E in the paper for more details on the existing code and other assets we used and built on.
Accompanying each sub-repository, we provide a pip requirements file named requirements_<sub-repostiory-name>.txt
. #
It contains the dependencies that we used in our working setup for each repository.
We recommend setting up a separate virtual environment for each of the four code bases.
Once you set up the virtual environment, to install the requirements, run:
pip install -r requirements_<sub-repository-name>.txt
For example, to set up the diffusion models on MNIST code base, run
pip install -r requirements_diff_mnist.txt
after setting up your virtual environment. We refer to the original code bases for further instructions on installation.
MNIST
This dataset will be downloaded automatically upon running the training script the first time.CIFAR
Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz) from https://www.cs.toronto.edu/~kriz/cifar.html. Unzip to receive the folder `cifar-10-batches-py`, and move it to `diff_cifar/data`.Navier Stokes dataset (PDEArena)
The data is [available on HuggingFace](https://huggingface.co/pdearena). You require the Navier Stokes - 2D Standard dataset. To download it with SSH:# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone git@hf.co:datasets/pdearena/NavierStokes-2D
# if you want to clone without large files – just their pointers
# prepend your git clone with the following env var:
GIT_LFS_SKIP_SMUDGE=1
Once downloaded, you need to refer to the path you downloaded the dataset to when running the training command (See below), using the --data.data_dir=<data-dir>
flag.
For instance, if your dataset has been downloaded to data/NavierStokes2D_smoke
, use the command line argument --data.data_dir=data/NavierStokes2D_smoke
.
We refer to https://microsoft.github.io/pdearena/datadownload/ for further details and the command for using HTTPS to download the data as an alternative.
Shallow water dataset (PDEArena)
The data is [available on HuggingFace](https://huggingface.co/pdearena). You require the Shallow water - 2D dataset. To download it with SSH:# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone git@hf.co:datasets/pdearena/ShallowWater-2D
# if you want to clone without large files – just their pointers
# prepend your git clone with the following env var:
GIT_LFS_SKIP_SMUDGE=1
Once downloaded, you need to refer to the path you downloaded the dataset to when running the training command (See below), using the --data.data_dir=<data-dir>
flag.
For instance, if your dataset has been downloaded to data/ShallowWater2D
, use the command line argument --data.data_dir=data/ShallowWater2D
.
We refer to https://microsoft.github.io/pdearena/datadownload/ for further details and the command for using HTTPS to download the data as an alternative.
WMH dataset
Go to [https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/AECRSD](https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/AECRSD) -> Access Dataset -> Download ZIP. Accept the dataset terms when prompted, the download should start shortly thereafter. Unzip the downloaded folder, and place its contents in `wmh/data`. Now, run the preprocessing script `preprocessing.py`, which will save a preprocessed dataset in `data_preprocessed/`.To log our experiments, we use weights&biases (wandb). To use the same logger, you are required to provide your wandb credentials as detailed below:
diff_mnist
Provide your user ID, team name and project name in `setup/wandb.yml`.diff_cifar
Provide your user ID, team name and project name in `wandb.yml`.pdearena
In `config.yml`, provide your project name under the key `trainer.logger.init_args.project`, and your team name under the key `trainer.logger.init_args.entity`.wmh
Provide your user ID, team name and project name in `wandb.yml`.In the following, we provide instructions to reproduce key experimental results in our main paper.
We refer to the hyperparams.py
files in the respective repostories for an overview on the hyperparams that can be set,
and also the instructions in the original code bases on how to run the code.
To reproduce Table 1 and the FNO vs. U-Net comparison, one requires to run each of the below runs three times with different random seeds, using either the command argument --seed
or --seed_everything
, respectively.
Navier-Stokes
Using repo pdearena
.
- Residual U-Net:
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/navierstokes2d.yaml --data.data_dir=<data-dir> --trainer.devices=1 --trainer.max_epochs=50 --data.batch_size=8 --data.time_gap=0 --data.time_history=4 --data.time_future=1 --model.name=Unetbase-64_G --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --lr_scheduler=LinearWarmupCosineAnnealingLR --lr_scheduler.warmup_epochs=5 --lr_scheduler.max_epochs=50 --lr_scheduler.eta_min=1e-7 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --model.dwt_encoder False --model.multi_res_loss False --model.freeze_lower_res False --model.up_fct interpolate_nearest --model.n_extra_resnet_layers 0 --seed_everything 1
- Multi-ResNet, no params. added in dec.
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/navierstokes2d.yaml --data.data_dir=<data-dir> --trainer.devices=1 --trainer.max_epochs=50 --data.batch_size=8 --data.time_gap=0 --data.time_history=4 --data.time_future=1 --model.name=Unetbase-64_G --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --lr_scheduler=LinearWarmupCosineAnnealingLR --lr_scheduler.warmup_epochs=5 --lr_scheduler.max_epochs=50 --lr_scheduler.eta_min=1e-7 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --model.dwt_encoder True --model.multi_res_loss False --model.freeze_lower_res False --model.up_fct interpolate_nearest --model.n_extra_resnet_layers 0 --seed_everything 1
- Multi-ResNet, saved params. added in dec.
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/navierstokes2d.yaml --data.data_dir=<data-dir> --trainer.devices=1 --trainer.max_epochs=50 --data.batch_size=8 --data.time_gap=0 --data.time_history=4 --data.time_future=1 --model.name=Unetbase-64_G --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --lr_scheduler=LinearWarmupCosineAnnealingLR --lr_scheduler.warmup_epochs=5 --lr_scheduler.max_epochs=50 --lr_scheduler.eta_min=1e-7 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --model.dwt_encoder True --model.multi_res_loss False --model.freeze_lower_res False --model.up_fct interpolate_nearest --model.n_extra_resnet_layers 3 --seed_everything 1
Shallow water
Using repo pdearena
.
- Residual U-Net:
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/shallowwater2d_2day.yaml --data.data_dir=<data-dir> --trainer.devices=1 --model.name=Unetbase-64_G --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --model.dwt_encoder False --model.multi_res_loss False --model.freeze_lower_res False --model.up_fct interpolate_nearest --model.n_extra_resnet_layers 0 --seed_everything 1
- Multi-ResNet, no params. added in dec.
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/shallowwater2d_2day.yaml --data.data_dir=<data-dir> --trainer.devices=1 --model.name=Unetbase-64_G --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --model.dwt_encoder True --model.multi_res_loss False --model.freeze_lower_res False --model.up_fct interpolate_nearest --model.n_extra_resnet_layers 0 --seed_everything 1
- Multi-ResNet, saved params. added in dec.
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/shallowwater2d_2day.yaml --data.data_dir=<data-dir> --trainer.devices=1 --model.name=Unetbase-64_G --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --model.dwt_encoder True --model.multi_res_loss False --model.freeze_lower_res False --model.up_fct interpolate_nearest --model.n_extra_resnet_layers 3 --seed_everything 1
WMH
Using repo wmh
.
- Residual U-Net:
CUDA_VISIBLE_DEVICES=0 python train_pt.py --device cuda --batch_size 32 --train_loss_every_iters 50 --train_hist_every_iters 500 --train_prec_recall_curve_every_iters 500 --hidden_channels 16 --data_augmentation none --early_stop_patience 3 --early_stop_min_improvement 0.02 --val_every_epochs 3 --num_epochs_list 25 --dwt_encoder False --n_extra_resnet_layers 0 --seed 1
- Multi-ResNet, no params. added in dec.
CUDA_VISIBLE_DEVICES=0 python train_pt.py --device cuda --batch_size 32 --train_loss_every_iters 50 --train_hist_every_iters 500 --train_prec_recall_curve_every_iters 500 --hidden_channels 16 --data_augmentation none --early_stop_patience 3 --early_stop_min_improvement 0.02 --val_every_epochs 3 --num_epochs_list 25 --dwt_encoder True --n_extra_resnet_layers 0 --seed 1
- Multi-ResNet, saved params. added in dec.
CUDA_VISIBLE_DEVICES=0 python train_pt.py --device cuda --batch_size 32 --train_loss_every_iters 50 --train_hist_every_iters 500 --train_prec_recall_curve_every_iters 500 --hidden_channels 16 --data_augmentation none --early_stop_patience 3 --early_stop_min_improvement 0.02 --val_every_epochs 3 --num_epochs_list 25 --dwt_encoder True --n_extra_resnet_layers 3 --seed 1
CIFAR
Using repo diff_cifar
.
CUDA_VISIBLE_DEVICES=0 python main.py --device cuda --sample_step 10000 --save_step 100000 --eval_step 200000 --NUM_ITERATIONS_LIST 50000 50000 50000 1500003 --DWT_ENCODER False --FREEZE_LOWER_RES False --MULTI_RES_LOSS False
MNIST-Triangular
Using repo diff_mnist
.
CUDA_VISIBLE_DEVICES=0 python main.py --DEVICE cuda --BETA_MIN 0.1 --BETA_MAX 20 --N 30 --EPS 1e-3 --T 1.0 --NUM_ITERATIONS_LIST 10000 --DWT_ENCODER False --MULTI_RES_LOSS False --AVG_POOL_DOWN True --DATASET mnist_triangular --RESOLUTION 64 --to_square_preprocess True
FNO, Navier-Stokes
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/navierstokes2d.yaml --data.data_dir=<data-dir> --trainer.max_epochs=30 --trainer.devices=1 --data.batch_size=8 --data.time_gap=0 --data.time_history=4 --data.time_future=1 --model.name=FNO-128-8m --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --lr_scheduler=LinearWarmupCosineAnnealingLR --lr_scheduler.warmup_epochs=5 --lr_scheduler.max_epochs=30 --lr_scheduler.eta_min=1e-7 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --seed_everything 1
FNO, Shallow water
CUDA_VISIBLE_DEVICES=0 python scripts/train.py -c configs/shallowwater2d_2day.yaml --data.data_dir=<data-dir> --trainer.devices=1 --model.name=FNO-128-8m --model.lr=2e-4 --optimizer=AdamW --optimizer.lr=2e-4 --optimizer.weight_decay=1e-5 --trainer.plugins DisabledSLURMEnvironment --trainer.accelerator gpu --seed_everything 1 --trainer.max_epochs=15 --lr_scheduler.max_epochs=15
U-Net results in this table: see 'The role of the encoder in a U-Net (Section 5.1, Table 1)', Navier-Stokes and Shallow water.
We welcome extensions of this repository! To ask a question, or report a bug, please leave an issue.
Any code which is not part of the original code bases is provided under MIT License. For all other code, we refer to the respective licenses in the original code bases (all also MIT licensed), which are part of the sub-repositories in this code base.
If you find this code repository or the accompanying paper useful, please cite our work as:
@article{williamsfalck2023unified,
title={A Unified Framework for U-Net Design and Analysis},
author={Williams, Christopher and Falck, Fabian and Deligiannidis, George and Holmes, Chris and Doucet, Arnaud and Syed, Saifuddin},
journal={arXiv preprint arXiv:2305.19638},
year={2023}
}