diff --git a/README.md b/README.md index 8716b44..24af561 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,132 @@ # Semantic Segmentation with Noisy Boundary Annotations -Implemented boundary detection based on "Devil is in the Edges: Learning Semantic Boundaries from Noisy Annotations" (see [link](http://openaccess.thecvf.com/content_CVPR_2019/papers/Acuna_Devil_Is_in_the_Edges_Learning_Semantic_Boundaries_From_Noisy_CVPR_2019_paper.pdf)) +Implemented boundary detection based on "Devil is in the Edges: Learning Semantic Boundaries from Noisy Annotations" (see [link](http://openaccess.thecvf.com/content_CVPR_2019/papers/Acuna_Devil_Is_in_the_Edges_Learning_Semantic_Boundaries_From_Noisy_CVPR_2019_paper.pdf)), generalized it to 3-dimension cases. + +### Implementation + +- [x] Basic 2D/3D CASENet network with weighted multilabel BCE loss + +- [x] 2D/3D Geodesic active contour inference +- [x] Iterative update between network training and level-set refinement +- [x] 3D UNET (obsolete code) +- [ ] NMS loss and direction losstr + +### Configuration example + +1. Setup configuration for 2D CASENet with 2D level set: training ([link](./resources/train_config_case2D.yaml)), testing ([link](./resources/test_config_case2D.yaml)) +2. Setup configuration for 3D CASENet with 3D level set: trianing ([link](./resources/train_config_case3D.yaml)) +3. obsolete configuration code: UNet3D traning([link](./resources/train_config_unet3D.yaml)), testing ([link](file:///Users/sherryshen/Dropbox/SenseTime/edgeDL/resources/test_config_unet3D.yaml)) + +### Usage + +**Clone this repo** + +```bash +git clone http://gitlab.bj.sensetime.com/shenrui/edgeDL.git +cd edgeDL +``` + +**Install dependencies** + +Require Python 3.6+ and Pytorch 1.0+. Please install dependencies by + +```bash +conda env create -f environment.yml +``` + +**Preprocessing** + +Resample the data into same resolution. This code requires Free Surfer mri_convert (see [link](https://surfer.nmr.mgh.harvard.edu/fswiki/mri_convert), Free Surfer installation [guide](https://surfer.nmr.mgh.harvard.edu/fswiki/DownloadAndInstall)). + +```bash +./utils/resample.sh +``` + +Generate file lists for traning, validation and testing sets. + +```bash +python data2txt.py +``` + +**Traning** + +Setup configuration file and run + +```bash +python train_casenet.py --config PATH_TO_CONFIG_FILE +``` + +**Testing** + +Setup configuration file and run + +```bash +python predict_casenet.py --config PATH_TO_CONFIG_FILE +``` + +### Loss function + +1. Weighted multilabel BCE loss + + $\mathcal{L}_{BCE}(\theta) = - \sum_k\sum_m\{\beta y_k^m\log f_k(m|x,\theta) + (1-\beta) (1-y_k^m)\log(1 - f_k(m|x,\theta))\}$ + + where + + $\beta$ : non-edge pixels/voxels ratio, $\beta = \frac{|Y^-|}{|Y|}$ + + $k$ : class + + $m$ : pixel/voxel + +2. NMS loss (edge thinning layers, to be implemented) + + $\mathcal{L}_{NSM}(\theta) = -\sum_k\sum_p \log h_k(p|x,\theta)$ + + where + + $h_k(p|x,\theta) = \frac{\exp(f_k(p|x,\theta)/\tau)}{\sum_{t=-L}^L \exp(f_k(p_t|x,\theta)/\tau)}$ for normalization + + $x(p_t) = x(p) + t · \cos \vec{d_p} $ , $y(p_t) = y(p) + t · \sin \vec{d_p}$ + + $p$ : gt boundary pixel/voxel + + $\vec{d_p}$ : normal direction at $p$ computed from gt boundary map + + $t \in \{-L, -L+1, ... L\}$ + + ##### Notes for implementation + + - Normal direction: use a fixed convolutoonal layer to estimate second derivatives, and then use trigonometry function to compute normal direction from the gt boundary map + - code reference: edgesNMS([link](https://github.com/pdollar/edges/blob/master/private/edgesNmsMex.cpp)) + +3. Direction Loss (to be implemented) + + $\mathcal{L}_{Dir}(\theta) = \sum_k\sum_p ||\cos ^{-1} <\vec{d_p}, \vec{e_p}(\theta)>||$ + + where + + $\vec{e_p}(\theta)$ : normal direction at p computed from prediction map + +### Level Set + +1. Level set evolution + + $\frac{\partial \phi}{\partial t} = g_k(\kappa + c)|\nabla\phi| + \nabla g_k · \nabla \phi$ + + solved by morphological approach (see [link](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.307.3979&rep=rep1&type=pdf)) + +2. Energy (edge) map for level set alignment + + $g_k = \frac{1}{\sqrt{1+\alpha f_k}}+\frac{\lambda}{\sqrt{1+\alpha \sigma(y_k)}}$ + + where + + $f_k$ : probability map predicted by neural network + + $\sigma(y_k)$ : (previous) ground truth annotation smoothed by gaussian filter with $\sigma$ ### Github reference: + 1. STEAL ([link](https://github.com/nv-tlabs/STEAL)) -2. edges ([link](https://github.com/pdollar/edges)) \ No newline at end of file +2. edges ([link](https://github.com/pdollar/edges)) +3. Morphsnakes ([link](https://github.com/pmneila/morphsnakes)) \ No newline at end of file diff --git a/models/casenet2d/metrics.py b/models/casenet2d/metrics.py index 95a3158..d91726b 100755 --- a/models/casenet2d/metrics.py +++ b/models/casenet2d/metrics.py @@ -197,7 +197,7 @@ def __call__(self, input, target): if target.dim() < input.dim(): target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) weight_sum = target.sum(dim=1).sum(dim=1).sum(dim=1) - edge_weight = weight_sum / (target.size()[2] * target.size()[3]) + edge_weight = weight_sum.float() / (target.size()[2] * target.size()[3]) edge_weight = edge_weight.unsqueeze(1).unsqueeze(2).unsqueeze(3) non_edge_weight = 1 - edge_weight diff --git a/models/casenet2d/model.py b/models/casenet2d/model.py index ce0d6f7..fee109b 100755 --- a/models/casenet2d/model.py +++ b/models/casenet2d/model.py @@ -1,3 +1,6 @@ +# modified by Rui Shen, Aug 2019 +# ------------------------------------------------------------------ + # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without diff --git a/models/casenet3d/model.py b/models/casenet3d/model.py index a8e2357..e94d43b 100755 --- a/models/casenet3d/model.py +++ b/models/casenet3d/model.py @@ -1,3 +1,6 @@ +# modified by Rui Shen, Aug 2019 +# ------------------------------------------------------------------ + # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without diff --git a/predict_casenet2d.py b/predict_casenet.py similarity index 100% rename from predict_casenet2d.py rename to predict_casenet.py diff --git a/resources/test_config_backup.yaml b/resources/test_config_backup.yaml deleted file mode 100755 index f4c466f..0000000 --- a/resources/test_config_backup.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# path to the checkpoint file containing the model -model_path: /home/SENSETIME/shenrui/Dropbox/SenseTime/edgeDL/saved_weights/pelvis/casenet2d/best_checkpoint.pytorch -# path to the folder of the predictions -save_path: /home/SENSETIME/shenrui/data/pelvis_predict -prediction_channel: null -# model configuration -model: - # model class - name: ResNet - # number of input channels to the model - in_channels: 1 - # number of output classes - out_channels: 5 - # set layers - layers: [3, 4, 23, 3] - # apply element-wise nn.Sigmoid after the final 1x1 convolution - final_sigmoid: True -# evaluation metric configuration -eval_metric: - name: PrecisionStats - # a target label that is ignored during metric evaluation - ignore_index: null - # number of points in PR curve - nthresh: 19 -# specify the test datasets -loaders: - # test patch size given to the network (adapt to fit in your GPU mem) - test_patch: [null, null] - # test stride between patches (make sure the the patches overlap in order to get smoother prediction maps) - test_stride: [null, null] - # clip value within the range - clip_val: [-1000, 2000] - # how many subprocesses to use for data loading - num_workers: 0 - # paths to the datasets - test_path: - - '/home/SENSETIME/shenrui/data/pelvis_resampled/dataset_train_temp.txt' - transformer: - test: - raw: - - name: ClipNormalize - - name: ToTensor - expand_dims: true - label: - - name: SegToEdge - - name: ToTensor - expand_dims: false - dtype: 'long' \ No newline at end of file diff --git a/resources/test_config_case.yaml b/resources/test_config_case2D.yaml similarity index 89% rename from resources/test_config_case.yaml rename to resources/test_config_case2D.yaml index eebff48..a17a18a 100755 --- a/resources/test_config_case.yaml +++ b/resources/test_config_case2D.yaml @@ -5,7 +5,7 @@ save_path: /mnt/lustre/shenrui/data/pelvis_predict_BCE20000 prediction_channel: null # model configuration model: - # model class + # model class, CASENet is based on ResNet name: ResNet # number of input channels to the model in_channels: 1 @@ -13,7 +13,7 @@ model: out_channels: 5 # set layers layers: [3, 4, 23, 3] - # apply element-wise nn.Sigmoid after the final 1x1 convolution + # whether to apply the sigmoid function, set false for training, true for testing final_sigmoid: True # evaluation metric configuration eval_metric: @@ -43,6 +43,8 @@ loaders: expand_dims: true label: - name: SegToEdge + out_channels: 5 + radius: 1 - name: ToTensor expand_dims: false dtype: 'long' diff --git a/resources/test_config_unet.yaml b/resources/test_config_unet3D.yaml similarity index 100% rename from resources/test_config_unet.yaml rename to resources/test_config_unet3D.yaml diff --git a/resources/train_config_3Dbackup.yaml b/resources/train_config_3Dbackup.yaml deleted file mode 100755 index f3c2bd9..0000000 --- a/resources/train_config_3Dbackup.yaml +++ /dev/null @@ -1,137 +0,0 @@ -# use a fixed random seed to guarantee that when you run the code twice you will get the same outcome -manual_seed: null -# Network input sample dimension -dim: 3 -# model configuration -model: - # model class - name: ResNet - # number of input channels to the model - in_channels: 1 - # number of output classes - out_channels: 5 - # set layers - layers: [3, 4, 6, 3] - # apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax - final_sigmoid: False -# trainer configuration -trainer: - # path to the checkpoint directory - checkpoint_dir: /home/SENSETIME/shenrui/Dropbox/SenseTime/edgeDL/checkpoints/pelvis/casenet3d - # path to latest checkpoint; if provided the training will be resumed from that checkpoint - resume: null - # how many iterations between validations - validate_after_iters: 20 - # how many iterations between tensorboard logging - log_after_iters: 3 - # how many iterations evaluated in validations - validate_iters: null - # how many iterations before start level set alignment - align_start_iters: 20000 - # how many iterations between level set alignment - align_after_iters: 10 - # max number of epochs - epochs: 50 - # max number of iterations - iters: null - # model with lower eval score is considered better - eval_score_higher_is_better: False -# optimizer configuration -optimizer: - # initial learning rate - learning_rate: 0.0001 - # weight decay - weight_decay: 0.001 -# loss function configuration -loss: - # loss function to be used during training - name: STEALEdgeLoss - # A manual rescaling weight given to each class. - loss_weight: null - # a target value that is ignored and does not contribute to the input gradient - ignore_index: null -# evaluation metric configuration -eval_metric: - name: STEALEdgeLoss - # a target label that is ignored during metric evaluation - ignore_index: null -lr_scheduler: - name: MultiStepLR - milestones: [10, 30, 60] - gamma: 0.2 -# configuration for level set alignment -level_set: - dim: 3 - lambda_: 0.1 - alpha: 1 - sigma: 1 - smoothing: 1 - render_radius: 1 - is_gt_semantic: True - method: 'MLS' - balloon: 0 - threshold: 0.95 - step_ckpts: 50 - dz: 16 - prefix: '/home/SENSETIME/shenrui/data/3D_pelvis_predict_BCE' - n_workers: 0 -# data loaders configuration -loaders: - # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) - train_patch: [4, 350, 350] - # train stride between patches - train_stride: [2, 200, 200] - # validation patch (can be bigger than train patch since there is no backprop) - val_patch: [4, 350, 350] - # validation stride (validation patches doesn't need to overlap) - val_stride: [4, 350, 350] - # clip value within the range - clip_val: [-1000, 2000] - # paths to the training datasets - train_path: - - '/home/SENSETIME/shenrui/data/pelvis_resampled/dataset_train_temp.txt' - # paths to the validation datasets - val_path: - - '/home/SENSETIME/shenrui/data/pelvis_resampled/dataset_val_temp.txt' - # how many subprocesses to use for data loading - num_workers: 0 - # batch size in training process - batch_size: 1 - # data transformations/augmentations - transformer: - train: - raw: - - name: ClipNormalize - - name: RandomRotate - axes: [[2, 1]] - angle_spectrum: 30 - mode: reflect - - name: ElasticDeformation - spline_order: 3 - - name: ToTensor - expand_dims: true - label: - - name: RandomRotate - axes: [[2, 1]] - angle_spectrum: 30 - mode: reflect - - name: ElasticDeformation - spline_order: 0 - - name: SegToEdge - out_channels: 5 - radius: 1 - - name: ToTensor - expand_dims: false - dtype: 'long' - test: - raw: - - name: ClipNormalize - - name: ToTensor - expand_dims: true - label: - - name: SegToEdge - out_channels: 5 - radius: 1 - - name: ToTensor - expand_dims: false - dtype: 'long' \ No newline at end of file diff --git a/resources/train_config_backup.yaml b/resources/train_config_backup.yaml deleted file mode 100755 index c89e5f8..0000000 --- a/resources/train_config_backup.yaml +++ /dev/null @@ -1,137 +0,0 @@ -# use a fixed random seed to guarantee that when you run the code twice you will get the same outcome -manual_seed: null -# Network input sample dimension -dim: 2 -# model configuration -model: - # model class - name: ResNet - # number of input channels to the model - in_channels: 1 - # number of output classes - out_channels: 5 - # set layers - layers: [3, 4, 23, 3] - # apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax - final_sigmoid: False -# trainer configuration -trainer: - # path to the checkpoint directory - checkpoint_dir: /home/SENSETIME/shenrui/Dropbox/SenseTime/edgeDL/checkpoints/pelvis/casenet2d - # path to latest checkpoint; if provided the training will be resumed from that checkpoint - resume: /home/SENSETIME/shenrui/Dropbox/SenseTime/edgeDL/checkpoints/pelvis/casenet2d/last_checkpoint_iter20000.pytorch - # how many iterations between validations - validate_after_iters: 20 - # how many iterations between tensorboard logging - log_after_iters: 3 - # how many iterations evaluated in validations - validate_iters: null - # how many iterations before start level set alignment - align_start_iters: 20000 - # how many iterations between level set alignment - align_after_iters: 10 - # max number of epochs - epochs: 100 - # max number of iterations - iters: null - # model with lower eval score is considered better - eval_score_higher_is_better: False -# optimizer configuration -optimizer: - # initial learning rate - learning_rate: 0.0001 - # weight decay - weight_decay: 0.001 -# loss function configuration -loss: - # loss function to be used during training - name: STEALEdgeLoss - # A manual rescaling weight given to each class. - loss_weight: null - # a target value that is ignored and does not contribute to the input gradient - ignore_index: null -# evaluation metric configuration -eval_metric: - name: STEALEdgeLoss - # a target label that is ignored during metric evaluation - ignore_index: null -lr_scheduler: - name: MultiStepLR - milestones: [10, 30, 60] - gamma: 0.2 -# configuration for level set alignment -level_set: - dim: 2 - lambda_: 0.1 - alpha: 1 - sigma: 1 - smoothing: 1 - render_radius: 1 - is_gt_semantic: True - method: 'MLS' - balloon: 0 - threshold: 0.95 - step_ckpts: 50 - dz: 16 - prefix: '/home/SENSETIME/shenrui/data/pelvis_predict_BCE' - n_workers: 0 -# data loaders configuration -loaders: - # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) - train_patch: [350, 350] - # train stride between patches - train_stride: [100, 100] - # validation patch (can be bigger than train patch since there is no backprop) - val_patch: [350, 350] - # validation stride (validation patches doesn't need to overlap) - val_stride: [350, 350] - # clip value within the range - clip_val: [-1000, 2000] - # paths to the training datasets - train_path: - - '/home/SENSETIME/shenrui/data/pelvis_resampled/dataset_train_temp.txt' - # paths to the validation datasets - val_path: - - '/home/SENSETIME/shenrui/data/pelvis_resampled/dataset_val_temp.txt' - # how many subprocesses to use for data loading - num_workers: 0 - # batch size in training process - batch_size: 1 - # data transformations/augmentations - transformer: - train: - raw: - - name: ClipNormalize - - name: RandomRotate - axes: [[0, 1]] - angle_spectrum: 30 - mode: reflect - - name: ElasticDeformation - spline_order: 3 - - name: ToTensor - expand_dims: true - label: - - name: RandomRotate - axes: [[0, 1]] - angle_spectrum: 30 - mode: reflect - - name: ElasticDeformation - spline_order: 0 - - name: SegToEdge - out_channels: 5 - radius: 1 - - name: ToTensor - expand_dims: false - dtype: 'long' - test: - raw: - - name: ClipNormalize - - name: ToTensor - expand_dims: true - label: - - name: SegToEdge - out_channels: 5 - radius: 1 - - name: ToTensor - expand_dims: false - dtype: 'long' \ No newline at end of file diff --git a/resources/train_config_case.yaml b/resources/train_config_case2d.yaml similarity index 69% rename from resources/train_config_case.yaml rename to resources/train_config_case2d.yaml index 9e9ec0c..c1924a1 100755 --- a/resources/train_config_case.yaml +++ b/resources/train_config_case2d.yaml @@ -1,8 +1,10 @@ -# use a fixed random seed to guarantee that when you run the code twice you will get the same outcome +# Set a fixed random seed for reproducibility, or null manual_seed: null +# Network dimension, determine whether to use 2D network or 3D network +dim: 2 # model configuration model: - # model class + # model class, CASENet is based on ResNet name: ResNet # number of input channels to the model in_channels: 1 @@ -10,19 +12,19 @@ model: out_channels: 5 # set layers layers: [3, 4, 23, 3] - # apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax + # whether to apply the sigmoid function, set false for training, true for testing final_sigmoid: False # trainer configuration trainer: # path to the checkpoint directory checkpoint_dir: /mnt/lustre/shenrui/project/edgeDL/checkpoints/pelvis/casenet2d - # path to latest checkpoint; if provided the training will be resumed from that checkpoint + # path to latest checkpoint or null; if provided the training will be resumed resume: /mnt/lustre/shenrui/project/edgeDL/checkpoints/pelvis/casenet2d/last_checkpoint_iter20000.pytorch # how many iterations between validations validate_after_iters: 5000 # how many iterations between tensorboard logging log_after_iters: 50 - # how many iterations evaluated in validations + # how many iterations evaluated in validations, set null for evaluating whole val set validate_iters: null # how many iterations before start level set alignment align_start_iters: 20000 @@ -30,7 +32,7 @@ trainer: align_after_iters: 5000 # max number of epochs epochs: 100 - # max number of iterations + # max number of iterations, set null for finishing all epochs iters: null # model with lower eval score is considered better eval_score_higher_is_better: False @@ -59,31 +61,47 @@ lr_scheduler: gamma: 0.2 # configuration for level set alignment level_set: + # Apply 2D or 3D level set dim: 2 + # lambda value in level set equation (see README.md) lambda_: 0.1 + # alpha value in level set equation (see README.md) alpha: 1 + # sigma value applied to gt map sigma: 1 + # smoothing iteration in each step smoothing: 1 + # radius value in generating boundary (SegToEdge) render_radius: 1 + # set true for semantic gt is_gt_semantic: True + # method used in level set ('MLS' for morphological level set) method: 'MLS' + # balloon value in level set balloon: 0 + # threshold value in level set threshold: 0.95 + # level set iterations step_ckpts: 50 - dz: 16 + # input shape for level set, dz x Y x X, set dz as 1 for 2D level set + dz: 1 + # batch size used in evaluation + batch_size: 16 + # path prefix for saving alignment results, full folder path is prefix + num of iter prefix: '/mnt/lustre/shenrui/data/pelvis_predict_BCE' - n_workers: 8 + # number of subprocesses used for level set calculation + n_workers: 16 # data loaders configuration loaders: - # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) + # train patch size given to the network, can be 2D [W, H] or 3D [D, W, H] train_patch: [350, 350] - # train stride between patches + # train stride between patches, can be 2D or 3D train_stride: [100, 100] # validation patch (can be bigger than train patch since there is no backprop) val_patch: [350, 350] # validation stride (validation patches doesn't need to overlap) val_stride: [350, 350] - # clip value within the range + # clip image value within the range clip_val: [-1000, 2000] # paths to the training datasets train_path: @@ -116,6 +134,8 @@ loaders: - name: ElasticDeformation spline_order: 0 - name: SegToEdge + out_channels: 5 + radius: 1 - name: ToTensor expand_dims: false dtype: 'long' @@ -126,6 +146,8 @@ loaders: expand_dims: true label: - name: SegToEdge + out_channels: 5 + radius: 1 - name: ToTensor expand_dims: false dtype: 'long' \ No newline at end of file diff --git a/resources/train_config_case3D.yaml b/resources/train_config_case3D.yaml index 57dd2d5..26c056c 100755 --- a/resources/train_config_case3D.yaml +++ b/resources/train_config_case3D.yaml @@ -1,10 +1,10 @@ -# use a fixed random seed to guarantee that when you run the code twice you will get the same outcome +# Set a fixed random seed for reproducibility, or null manual_seed: null -# Network input sample dimension +# Network dimension, determine whether to use 2D network or 3D network dim: 3 # model configuration model: - # model class + # model class, CASENet is based on ResNet name: ResNet # number of input channels to the model in_channels: 1 @@ -12,19 +12,19 @@ model: out_channels: 5 # set layers layers: [3, 4, 6, 3] - # apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax + # whether to apply the sigmoid function, set false for training, true for testing final_sigmoid: False # trainer configuration trainer: # path to the checkpoint directory checkpoint_dir: /mnt/lustre/shenrui/project/edgeDL/checkpoints/pelvis/casenet3d - # path to latest checkpoint; if provided the training will be resumed from that checkpoint + # path to latest checkpoint or null; if provided the training will be resumed resume: null # how many iterations between validations validate_after_iters: 500 # how many iterations between tensorboard logging log_after_iters: 1 - # how many iterations evaluated in validations + # how many iterations evaluated in validations, set null for evaluating whole val set validate_iters: null # how many iterations before start level set alignment align_start_iters: 5000 @@ -32,7 +32,7 @@ trainer: align_after_iters: 2000 # max number of epochs epochs: 100 - # max number of iterations + # max number of iterations, set null for finishing all epochs iters: null # model with lower eval score is considered better eval_score_higher_is_better: False @@ -46,7 +46,7 @@ optimizer: loss: # loss function to be used during training name: STEALEdgeLoss - # A manual rescaling weight given to each class. + # A manual rescaling weight given to each class loss_weight: null # a target value that is ignored and does not contribute to the input gradient ignore_index: null @@ -61,31 +61,47 @@ lr_scheduler: gamma: 0.2 # configuration for level set alignment level_set: + # Apply 2D or 3D level set dim: 3 + # lambda value in level set equation (see README.md) lambda_: 0.1 + # alpha value in level set equation (see README.md) alpha: 1 + # sigma value applied to gt map sigma: 1 + # smoothing iteration in each step smoothing: 1 + # radius value in generating boundary (SegToEdge) render_radius: 1 + # set true for semantic gt is_gt_semantic: True + # method used in level set ('MLS' for morphological level set) method: 'MLS' + # balloon value in level set balloon: 0 + # threshold value in level set threshold: 0.95 + # level set iterations step_ckpts: 50 - dz: 32 + # input shape for level set, dz x Y x X, set dz as 1 for 2D level set + dz: 8 + # batch size used in evaluation + batch_size: 8 + # path prefix for saving alignment results, full folder path is prefix + num of iter prefix: '/mnt/lustre/shenrui/data/3D_pelvis_predict_BCE' + # number of subprocesses used for level set calculation n_workers: 16 # data loaders configuration loaders: - # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) + # train patch size given to the network, can be 2D [W, H] or 3D [D, W, H] train_patch: [8, 350, 350] - # train stride between patches + # train stride between patches, can be 2D or 3D train_stride: [4, 200, 200] # validation patch (can be bigger than train patch since there is no backprop) val_patch: [8, 350, 350] # validation stride (validation patches doesn't need to overlap) val_stride: [8, 350, 350] - # clip value within the range + # clip image value within the range clip_val: [-1000, 2000] # paths to the training datasets train_path: diff --git a/resources/train_config_unet.yaml b/resources/train_config_unet3D.yaml similarity index 100% rename from resources/train_config_unet.yaml rename to resources/train_config_unet3D.yaml diff --git a/tests/temp.py b/tests/temp.py deleted file mode 100644 index 0f328aa..0000000 --- a/tests/temp.py +++ /dev/null @@ -1,2 +0,0 @@ -import h5py -f = h5py.File('random_label3D.h5', 'r') \ No newline at end of file diff --git a/train_casenet.py b/train_casenet.py index fe060d3..f215616 100755 --- a/train_casenet.py +++ b/train_casenet.py @@ -94,7 +94,7 @@ def main(): logger = get_logger('CASENetTrainer') parser = argparse.ArgumentParser(description='CASENet training') - parser.add_argument('--config', type=str, help='Path to the YAML config file', default='/home/SENSETIME/shenrui/Dropbox/SenseTime/edgeDL/resources/train_config_3Dbackup.yaml') + parser.add_argument('--config', type=str, help='Path to the YAML config file', default='/home/SENSETIME/shenrui/Dropbox/SenseTime/edgeDL/resources/train_config_backup.yaml') args = parser.parse_args() # Load and log experiment configuration diff --git a/utils/.resample.sh.swp b/utils/.resample.sh.swp new file mode 100644 index 0000000..62ee869 Binary files /dev/null and b/utils/.resample.sh.swp differ diff --git a/utils/contours/ContourBox.py b/utils/contours/ContourBox.py index ab11867..ee7781c 100755 --- a/utils/contours/ContourBox.py +++ b/utils/contours/ContourBox.py @@ -1,3 +1,6 @@ +# modified by Rui Shen, Aug 2019 +# ------------------------------------------------------------------ + # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without diff --git a/utils/contours/ContourBox_MLS.py b/utils/contours/ContourBox_MLS.py index 7b64d2b..106a5a6 100755 --- a/utils/contours/ContourBox_MLS.py +++ b/utils/contours/ContourBox_MLS.py @@ -1,3 +1,6 @@ +# modified by Rui Shen, Aug 2019 +# ------------------------------------------------------------------ + # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -50,7 +53,7 @@ def _fill_inside(self, bdry, method='fill_holes'): raise ValueError('_fill_inside wrong method:%s' % method) def _eval_singleK(self, gt_K, pK_Image, step_ckpts, lambda_, alpha, sigma, smoothing, - render_radius, is_gt_semantic, **kwargs): + render_radius, is_gt_semantic, dim, **kwargs): # This way of checking mostly cares about speed. as I am assuming the whole GT is very sparse. @@ -59,6 +62,10 @@ def _eval_singleK(self, gt_K, pK_Image, step_ckpts, lambda_, alpha, sigma, smoot if all_zeros: return gt_K + if dim == 2: + gt_K = gt_K[0] + pK_Image = pK_Image[0] + if is_gt_semantic is False: # filling inside_ to represent the curve # this may have problem with boundaries that are not closed(corners of the image dimension) @@ -66,7 +73,7 @@ def _eval_singleK(self, gt_K, pK_Image, step_ckpts, lambda_, alpha, sigma, smoot init_ls = self._fill_inside(gt_K) else: init_ls = gt_K - gt_K = seg2edges(gt_K, radius=render_radius) + gt_K = seg2edges(gt_K, radius=render_radius, dim=dim) h = self._compute_h(gt_K, pK_Image, lambda_, alpha, sigma) @@ -82,6 +89,9 @@ def _eval_singleK(self, gt_K, pK_Image, step_ckpts, lambda_, alpha, sigma, smoot evolution = morphological_geodesic_active_contour(h, step_ckpts, init_ls, smoothing=smoothing, balloon=balloon, threshold=threshold) + if dim == 2: + evolution = evolution[np.newaxis, :, :] + return evolution def process_batch_fn(self, args): @@ -120,7 +130,7 @@ def _multi_cpu_call(self, gt, pk): assert gt.shape == pk.shape global shared_mem_data shared_mem_data = (gt, pk) - N, K, H, W = pk.shape + N, K, D, H, W = pk.shape # pool = mp.Pool(min(N, self.n_workers)) @@ -140,21 +150,20 @@ def _multi_cpu_call_2(self, gt, pk): """ assert gt.shape == pk.shape - N, K, H, W = pk.shape + N, K, D, H, W = pk.shape - gt = np.reshape(gt, [N * K, 1, H, W]) - pk = np.reshape(pk, [N * K, 1, H, W]) + gt = np.reshape(gt, [N * K, 1, D, H, W]) + pk = np.reshape(pk, [N * K, 1, D, H, W]) global shared_mem_data shared_mem_data = (gt, pk) - # pool = mp.Pool(min(N * K, self.n_workers)) output_ = pool.map(self.process_batch_hack_multicpu, [(i, 1, id(gt), id(pk)) for i in range(N * K)]) - output_ = np.reshape(output_, [N, K, H, W]) + output_ = np.reshape(output_, [N, K, D, H, W]) pool.close() pool.join() # we need to reorder the array to make it compatible with the rest of the api. @@ -174,7 +183,8 @@ def __call__(self, gt_dict, pk): pk = pk.cpu().numpy() assert gt.shape == pk.shape - N, K, H, W = pk.shape + + N, K, D, H, W = pk.shape if N * K > 1 and self.n_workers > 1: output_ = self._multi_cpu_call_2(gt, pk) @@ -184,4 +194,4 @@ def __call__(self, gt_dict, pk): gt_hat = self.process_batch_fn((i, K, gt, pk)) output_.append(gt_hat) - return np.stack(output_, axis=0) # NxKxLStepsxHxW + return np.stack(output_, axis=0) # NxKxDxHxW diff --git a/utils/contours/cutils.py b/utils/contours/cutils.py index df8560b..6aa28fe 100755 --- a/utils/contours/cutils.py +++ b/utils/contours/cutils.py @@ -1,4 +1,5 @@ -# Modified by Rui Shen, 08/08/2019 +# modified by Rui Shen, Aug 2019 +# ------------------------------------------------------------------ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # @@ -54,7 +55,7 @@ def fn_post_process_callback(evol, pxwise): return fn_post_process_callback -def seg2edges(image, radius): +def seg2edges(image, radius, dim): """ :param image: semantic map should be HxW with values 0,1,label_ignores :param radius: radius size @@ -65,13 +66,20 @@ def seg2edges(image, radius): return image # we need to pad the borders, to solve problems with dt around the boundaries of the image. - image_pad = np.pad(image, ((1, 1), (1, 1)), mode='constant', constant_values=0) + if dim == 2: + image_pad = np.pad(image, ((1, 1), (1, 1)), mode='constant', constant_values=0) + elif dim == 3: + image_pad = np.pad(image, ((1, 1), (1, 1), (1, 1)), mode='constant', constant_values=0) + dist1 = distance_transform_edt(image_pad) dist2 = distance_transform_edt(1.0 - image_pad) dist = dist1 + dist2 # removing padding, it shouldnt affect result other than if the image is seg to the boundary. - dist = dist[1:-1, 1:-1] + if dim == 2: + dist = dist[1:-1, 1:-1] + elif dim == 3: + dist = dist[1:-1, 1:-1, 1:-1] assert dist.shape == image.shape dist[dist > radius] = 0 diff --git a/utils/contours/morph_snakes.py b/utils/contours/morph_snakes.py index 54daad0..c3d9dbf 100755 --- a/utils/contours/morph_snakes.py +++ b/utils/contours/morph_snakes.py @@ -1,3 +1,6 @@ +# modified by Rui Shen, Aug 2019 +# ------------------------------------------------------------------ + # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without diff --git a/utils/trainer.py b/utils/trainer.py index aa8989c..0459fe7 100755 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -14,17 +14,17 @@ import nibabel as nib class NNTrainer: - """3D UNet trainer. + """Network trainer. Args: - model (Unet3D): UNet 3D model to be trained + model: network model to be trained optimizer (nn.optim.Optimizer): optimizer used for training lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler WARN: bear in mind that lr_scheduler.step() is invoked after every validation step (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 the learning rate will be adjusted after every 30 * validate_after_iters iterations. loss_criterion (callable): loss function - eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score) + eval_criterion (callable): used to compute training/validation metricc saving the best checkpoint is based on the result of this function on the validation set device (torch.device): device to train on loaders (dict): 'train' and 'val' loaders @@ -39,6 +39,9 @@ class NNTrainer: best_eval_score (float): best validation score so far (higher better) num_iterations (int): useful when loading the model from the checkpoint num_epoch (int): useful when loading the model from the checkpoint + align_start_iters (int): number of iterations before alignment start + align_after_iters (int): number of iterations between two alignment steps + level_set_config (dict): configure files for level set alignment """ def __init__(self, model, optimizer, lr_scheduler, loss_criterion, @@ -228,7 +231,7 @@ def train(self, train_loader): self._log_params() #self._log_images(input, target, output) - if (self.num_iterations >= self.align_start_iters) and (self.num_iterations % self.align_after_iters == 0): + if (self.num_iterations >= self.align_start_iters) and ((self.num_iterations - self.align_start_iters) % self.align_after_iters == 0): self.loaders['train'] = self.align(self.loaders['train']) if self.max_num_iterations < self.num_iterations: @@ -294,20 +297,38 @@ def align(self, loader): for i in tqdm(range(len(datasets))): iz, iy, ix = datasets[i].raw.shape affine = datasets[i].affine - dz = self.level_set_config.get('dz', 16) - for js in range(0, iz, dz): - je = min(iz, js+dz) + if dim == 2: + dz = 1 + elif dim == 3: + dz = self.level_set_config.get('dz', 1) + batch_size = self.level_set_config.get('batch_size', 1) + js = 0 + while js < iz: + je = min(iz, js+dz*batch_size) + if (je - js) % dz != 0: + je = iz - (je - js) % dz + batch_size = 1 + dz = iz - je idx = slice(js, je) + js = je + + data_sliced = (datasets[i].raw[idx]).astype(np.float32) + label_sliced = np.reshape((datasets[i].label[idx]), (-1, dz, iy, ix)).astype(np.long) + if dim == 2: - input = torch.from_numpy(((datasets[i].raw[idx]).astype(np.float32))[:,np.newaxis,:,:]).to(self.device) + input = torch.from_numpy(data_sliced[:,np.newaxis,:,:]).to(self.device) pred = torch.sigmoid(self.model(input)) + pred = pred.unsqueeze(2) + elif dim == 3: - input = torch.from_numpy(((datasets[i].raw[idx]).astype(np.float32))[np.newaxis,np.newaxis,:,:,:]).to(self.device) - pred = torch.sigmoid(self.model(input)).squeeze(0).permute(1,0,2,3) - gt = self._expand_as_one_hot(torch.from_numpy((datasets[i].label[idx]).astype(np.long)).to(self.device), pred.shape[1]) + data_sliced = np.reshape(data_sliced, (-1, dz, iy, ix)).astype(np.float32) + input = torch.from_numpy(data_sliced[:,np.newaxis,:,:,:]).to(self.device) + pred = torch.sigmoid(self.model(input)).squeeze(0) + + gt = self._expand_as_one_hot(torch.from_numpy(label_sliced).to(self.device), pred.shape[1]) output = cbox({'seg': gt, 'bdry': None}, pred) output = np.multiply(np.sum(output, axis=1) > 0, np.argmax(output, axis=1) + 1) - loader.dataset.datasets[i].label[idx] = output + loader.dataset.datasets[i].label[idx] = np.reshape(output, (-1, iy, ix)) if self.level_set_config['prefix'] is not None: output_file = self._get_output_file(datasets[i], folderpath=folderpath, suffix='_refine') nib.save(nib.Nifti1Image((np.transpose(loader.dataset.datasets[i].label).astype(np.int16)), affine), output_file) @@ -461,7 +482,7 @@ def _expand_as_one_hot(input, C): shape.insert(1, C+1) shape = tuple(shape) - # expand the input tensor to Nx1x(D)xHxW + # expand the input tensor to NxCx(D)xHxW src = input.unsqueeze(1) if input.dim() == 3: return torch.zeros(shape).to(input.device).scatter_(1, src, 1)[:, 1:, :, :]