From 08f6fe36d24b2a804cf713eefcf5c0e1091a59c1 Mon Sep 17 00:00:00 2001 From: wasserth Date: Fri, 20 Oct 2023 11:09:27 +0200 Subject: [PATCH] change nnunet back to version 2.1 --- setup.py | 2 +- totalsegmentator/libs.py | 11 ++-- totalsegmentator/nnunet.py | 91 ++++++++++++++++++---------------- totalsegmentator/python_api.py | 5 +- 4 files changed, 56 insertions(+), 53 deletions(-) diff --git a/setup.py b/setup.py index fcf177784..f0aeb7882 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ 'p_tqdm', 'xvfbwrapper', 'fury', - 'nnunetv2==2.2', + 'nnunetv2==2.1', f'requests{requests_version}', 'rt_utils', 'dicom2nifti' diff --git a/totalsegmentator/libs.py b/totalsegmentator/libs.py index 74d7ffe56..bbf99f29d 100644 --- a/totalsegmentator/libs.py +++ b/totalsegmentator/libs.py @@ -159,8 +159,8 @@ def download_pretrained_weights(task_id): "nnUNet/3d_fullres/Task417_heart_mixed_317subj", "nnUNet/3d_fullres/Task278_TotalSegmentator_part6_bones_1259subj", "nnUNet/3d_fullres/Task435_Heart_vessels_118subj", - "Dataset297_TotalSegmentator_total_3mm_1559subj", # for >= v2.0.4 - # "Dataset297_TotalSegmentator_total_3mm_1559subj_v204", # for >= v2.0.5 + # "Dataset297_TotalSegmentator_total_3mm_1559subj", # for >= v2.0.4 + "Dataset297_TotalSegmentator_total_3mm_1559subj_v204", # for >= v2.0.5 # "Dataset298_TotalSegmentator_total_6mm_1559subj", # for >= v2.0.5 ] @@ -193,11 +193,12 @@ def download_pretrained_weights(task_id): # WEIGHTS_URL = url + "/static/totalseg_v2/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip" WEIGHTS_URL = url + "/v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip" elif task_id == 297: - weights_path = config_dir / "Dataset297_TotalSegmentator_total_3mm_1559subj_v204" + weights_path = config_dir / "Dataset297_TotalSegmentator_total_3mm_1559subj" + # weights_path = config_dir / "Dataset297_TotalSegmentator_total_3mm_1559subj_v204" # WEIGHTS_URL = "https://zenodo.org/record/6802052/files/Task256_TotalSegmentator_3mm_1139subj.zip?download=1" # WEIGHTS_URL = url + "/static/totalseg_v2/Dataset297_TotalSegmentator_total_3mm_1559subj.zip" - # WEIGHTS_URL = url + "/v2.0.0-weights/Dataset297_TotalSegmentator_total_3mm_1559subj.zip" # v200 - WEIGHTS_URL = url + "/v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip" + WEIGHTS_URL = url + "/v2.0.0-weights/Dataset297_TotalSegmentator_total_3mm_1559subj.zip" # v200 + # WEIGHTS_URL = url + "/v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip" # WEIGHTS_URL = url + "/v2.0.5-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v205.zip" elif task_id == 298: weights_path = config_dir / "Dataset298_TotalSegmentator_total_6mm_1559subj" diff --git a/totalsegmentator/nnunet.py b/totalsegmentator/nnunet.py index cc97d0130..d7c57dc6c 100644 --- a/totalsegmentator/nnunet.py +++ b/totalsegmentator/nnunet.py @@ -18,10 +18,12 @@ from totalsegmentator.libs import nostdout -# todo important: change -# with nostdout(): -# from nnunetv2.inference.predict_from_raw_data import predict_from_raw_data -from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +# nnUNet 2.1 +with nostdout(): + from nnunetv2.inference.predict_from_raw_data import predict_from_raw_data +# nnUNet 2.2 +# from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + from nnunetv2.utilities.file_path_utilities import get_output_folder from totalsegmentator.map_to_binary import class_map, class_map_5_parts, map_taskid_to_partname @@ -172,46 +174,47 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None, part_id = 0 allow_tqdm = not quiet - # predict_from_raw_data(dir_in, - # dir_out, - # model_folder, - # folds, - # step_size, - # use_gaussian=True, - # use_mirroring=not disable_tta, - # perform_everything_on_gpu=True, - # verbose=verbose, - # save_probabilities=save_probabilities, - # overwrite=not continue_prediction, - # checkpoint_name=chk, - # num_processes_preprocessing=npp, - # num_processes_segmentation_export=nps, - # folder_with_segs_from_prev_stage=prev_stage_predictions, - # num_parts=num_parts, - # part_id=part_id, - # device=device) - - - predictor = nnUNetPredictor( - tile_step_size=step_size, - use_gaussian=True, - use_mirroring=not disable_tta, - perform_everything_on_gpu=True, - device=device, - verbose=verbose, - verbose_preprocessing=verbose, - allow_tqdm=allow_tqdm - ) - predictor.initialize_from_trained_model_folder( - model_folder, - use_folds=folds, - checkpoint_name=chk, - ) - predictor.predict_from_files(dir_in, dir_out, - save_probabilities=save_probabilities, overwrite=not continue_prediction, - num_processes_preprocessing=npp, num_processes_segmentation_export=nps, - folder_with_segs_from_prev_stage=prev_stage_predictions, - num_parts=num_parts, part_id=part_id) + # nnUNet 2.1 + predict_from_raw_data(dir_in, + dir_out, + model_folder, + folds, + step_size, + use_gaussian=True, + use_mirroring=not disable_tta, + perform_everything_on_gpu=True, + verbose=verbose, + save_probabilities=save_probabilities, + overwrite=not continue_prediction, + checkpoint_name=chk, + num_processes_preprocessing=npp, + num_processes_segmentation_export=nps, + folder_with_segs_from_prev_stage=prev_stage_predictions, + num_parts=num_parts, + part_id=part_id, + device=device) + + # nnUNet 2.2 + # predictor = nnUNetPredictor( + # tile_step_size=step_size, + # use_gaussian=True, + # use_mirroring=not disable_tta, + # perform_everything_on_gpu=True, + # device=device, + # verbose=verbose, + # verbose_preprocessing=verbose, + # allow_tqdm=allow_tqdm + # ) + # predictor.initialize_from_trained_model_folder( + # model_folder, + # use_folds=folds, + # checkpoint_name=chk, + # ) + # predictor.predict_from_files(dir_in, dir_out, + # save_probabilities=save_probabilities, overwrite=not continue_prediction, + # num_processes_preprocessing=npp, num_processes_segmentation_export=nps, + # folder_with_segs_from_prev_stage=prev_stage_predictions, + # num_parts=num_parts, part_id=part_id) def save_segmentation_nifti(class_map_item, tmp_dir=None, file_out=None, nora_tag=None, header=None, task_name=None, quiet=None): diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index bca90efc9..faf9f7528 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -82,9 +82,8 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, if fast: task_id = 297 resample = 3.0 - # todo: change back to 4000 epochs in next release - # trainer = "nnUNetTrainer_4000epochs_NoMirroring" - trainer = "nnUNetTrainerNoMirroring" + trainer = "nnUNetTrainer_4000epochs_NoMirroring" + # trainer = "nnUNetTrainerNoMirroring" crop = None if not quiet: print("Using 'fast' option: resampling to lower resolution (3mm)") else: