PyTorch implementation of 3D U-Net based on:
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger
- Linux
- NVIDIA GPU
- CUDA CuDNN
- pytorch (0.4.1+)
- torchvision (0.2.1+)
- tensorboardx (1.4+)
- h5py
- pytest
Setup a new conda environment with the required dependencies via:
conda create -n 3dunet pytorch torchvision tensorboardx h5py pytest -c conda-forge
Activate newly created conda environment via:
source activate 3dunet
usage: train.py [-h] --checkpoint-dir CHECKPOINT_DIR --in-channels IN_CHANNELS
--out-channels OUT_CHANNELS [--interpolate]
[--layer-order LAYER_ORDER] [--loss LOSS] [--epochs EPOCHS]
[--iters ITERS] [--patience PATIENCE]
[--learning-rate LEARNING_RATE] [--weight-decay WEIGHT_DECAY]
[--validate-after-iters VALIDATE_AFTER_ITERS]
[--log-after-iters LOG_AFTER_ITERS] [--resume RESUME]
UNet3D training
optional arguments:
-h, --help show this help message and exit
--checkpoint-dir CHECKPOINT_DIR
checkpoint directory
--in-channels IN_CHANNELS
number of input channels
--out-channels OUT_CHANNELS
number of output channels
--interpolate use F.interpolate instead of ConvTranspose3d
--layer-order LAYER_ORDER
Conv layer ordering, e.g. 'brc' ->
BatchNorm3d+ReLU+Conv3D
--loss LOSS Which loss function to use. Possible values: [bce, ce,
dice]. Where bce - BinaryCrossEntropy (binary
classification only), ce - CrossEntropy (multi-class
classification), dice - DiceLoss (binary
classification only)
--epochs EPOCHS max number of epochs (default: 500)
--iters ITERS max number of iterations (default: 1e5)
--patience PATIENCE number of validation steps with no improvement after
which the training will be stopped (default: 20)
--learning-rate LEARNING_RATE
initial learning rate (default: 0.0002)
--weight-decay WEIGHT_DECAY
weight decay (default: 0.0001)
--validate-after-iters VALIDATE_AFTER_ITERS
how many iterations between validations (default: 100)
--log-after-iters LOG_AFTER_ITERS
how many iterations between tensorboard logging
(default: 100)
--resume RESUME path to latest checkpoint (default: none); if provided
the training will be resumed from that checkpoint
E.g. fit to randomly generated 3D volume and random segmentation mask (see train.py):
python train.py --checkpoint-dir ~/3dunet --in-channels 1 --out-channels 2 --layer-order brc --validate-after-iters 10 --log-after-iters 10 --epoch 50 --learning-rate 0.0001 --weight-decay 0.0005 --interpolate
In order to resume training from the last checkpoint:
python train.py --resume ~/3dunet/last_checkpoint.pytorch --in-channels 1 --out-channels 2 --layer-order brc --validate-after-iters 10 --log-after-iters 10 --epoch 50 --learning-rate 0.0001 --weight-decay 0.0005 --interpolate
In order to train on your own data just replace the _get_loaders
implementation in train.py by returning your own 'train' and 'valid' loaders.
Monitor progress with Tensorboard tensorboard --logdir ~/3dunet/logs/ --port 8666
(you need tensorboard
installed in your conda env).
usage: predict.py [-h] --model-path MODEL_PATH --in-channels IN_CHANNELS
--out-channels OUT_CHANNELS [--interpolate] [--layer-order]
3D U-Net predictions
optional arguments:
-h, --help show this help message and exit
--model-path MODEL_PATH
path to the model
--in-channels IN_CHANNELS
number of input channels
--out-channels OUT_CHANNELS
number of output channels
--interpolate use F.interpolate instead of ConvTranspose3d
--layer-order LAYER_ORDER
Conv layer ordering, e.g. 'brc' ->
BatchNorm3d+ReLU+Conv3D
Test on randomly generated 3D volume (just for demonstration purposes). See predict.py for more info.
python predict.py --model-path ~/3dunet/best_checkpoint.pytorch --in-channels 1 --out-channels 2 --interpolate --layer-order brc
Prediction masks will be saved to ~/3dunet/probabilities.h5
.
Replace the _get_dataset
implementation in predict.py to test the trained model on you own data.