Skip to content

Commit

Permalink
add roi_subset_robust and fastest argument
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Nov 23, 2023
1 parent 69771d4 commit dd1d716
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
* Bugfix: add flush to DummyFile
* Require python >= 3.9 in setup.py
* properly add `vertebrae_body` model
* add `--roi_subset_robust` argument
* add `--fastest` argument


## Release 2.0.5
Expand Down
12 changes: 10 additions & 2 deletions totalsegmentator/bin/TotalSegmentator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def main():
parser.add_argument("-ns", "--nr_thr_saving", type=int, help="Nr of threads for saving segmentations",
default=6)

parser.add_argument("-f", "--fast", action="store_true", help="Run faster lower resolution model",
parser.add_argument("-f", "--fast", action="store_true", help="Run faster lower resolution model (3mm)",
default=False)

parser.add_argument("-ff", "--fastest", action="store_true", help="Run even faster lower resolution model (6mm)",
default=False)

parser.add_argument("-t", "--nora_tag", type=str,
Expand All @@ -57,6 +60,11 @@ def main():
parser.add_argument("-rs", "--roi_subset", type=str, nargs="+",
help="Define a subset of classes to save (space separated list of class names). If running 1.5mm model, will only run the appropriate models for these rois.")

# Will use 3mm model instead of 6mm model to crop to the rois specified in this argument.
# 3mm is slower but more accurate.
parser.add_argument("-rsr", "--roi_subset_robust", type=str, nargs="+",
help="Like roi_subset but uses a slower but more robust model to find the rois.")

parser.add_argument("-s", "--statistics", action="store_true",
help="Calc volume (in mm3) and mean intensity. Results will be in statistics.json",
default=False)
Expand Down Expand Up @@ -125,7 +133,7 @@ def main():
args.statistics, args.radiomics, args.crop_path, args.body_seg,
args.force_split, args.output_type, args.quiet, args.verbose, args.test, args.skip_saving,
args.device, args.license_number, not args.stats_include_incomplete,
args.no_derived_masks, args.v1_order)
args.no_derived_masks, args.v1_order, args.fastest, args.roi_subset_robust)


if __name__ == "__main__":
Expand Down
22 changes: 19 additions & 3 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
force_split=False, output_type="nifti", quiet=False, verbose=False, test=0,
skip_saving=False, device="gpu", license_number=None,
statistics_exclude_masks_at_border=True, no_derived_masks=False,
v1_order=False):
v1_order=False, fastest=False, roi_subset_robust=None):
"""
Run TotalSegmentator from within python.
Expand Down Expand Up @@ -86,6 +86,12 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
# trainer = "nnUNetTrainerNoMirroring"
crop = None
if not quiet: print("Using 'fast' option: resampling to lower resolution (3mm)")
elif fastest:
task_id = 298
resample = 6.0
trainer = "nnUNetTrainer_4000epochs_NoMirroring"
crop = None
if not quiet: print("Using 'fastest' option: resampling to lower resolution (6mm)")
else:
task_id = [291, 292, 293, 294, 295]
resample = 1.5
Expand Down Expand Up @@ -246,6 +252,10 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
else:
download_pretrained_weights(task_id)

if roi_subset_robust is not None:
roi_subset = roi_subset_robust
robust_rs = True

if roi_subset is not None and type(roi_subset) is not list:
raise ValueError("roi_subset must be a list of strings")
if roi_subset is not None and task != "total":
Expand All @@ -258,8 +268,14 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
download_pretrained_weights(298)
st = time.time()
if not quiet: print("Generating rough body segmentation...")
organ_seg, _ = nnUNet_predict_image(input, None, 298, model="3d_fullres", folds=[0],
trainer="nnUNetTrainer_4000epochs_NoMirroring", tta=False, multilabel_image=True, resample=6.0,
if robust_rs:
crop_model_task = 297
crop_spacing = 3.0
else:
crop_model_task = 298
crop_spacing = 6.0
organ_seg, _ = nnUNet_predict_image(input, None, crop_model_task, model="3d_fullres", folds=[0],
trainer="nnUNetTrainer_4000epochs_NoMirroring", tta=False, multilabel_image=True, resample=crop_spacing,
crop=None, crop_path=None, task_name="total", nora_tag="None", preview=False,
save_binary=False, nr_threads_resampling=nr_thr_resamp, nr_threads_saving=1,
crop_addon=crop_addon, output_type=output_type, statistics=False,
Expand Down

0 comments on commit dd1d716

Please sign in to comment.