Skip to content

Commit

Permalink
change nnunet back to version 2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Oct 20, 2023
1 parent 676965b commit 08f6fe3
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 53 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'p_tqdm',
'xvfbwrapper',
'fury',
'nnunetv2==2.2',
'nnunetv2==2.1',
f'requests{requests_version}',
'rt_utils',
'dicom2nifti'
Expand Down
11 changes: 6 additions & 5 deletions totalsegmentator/libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def download_pretrained_weights(task_id):
"nnUNet/3d_fullres/Task417_heart_mixed_317subj",
"nnUNet/3d_fullres/Task278_TotalSegmentator_part6_bones_1259subj",
"nnUNet/3d_fullres/Task435_Heart_vessels_118subj",
"Dataset297_TotalSegmentator_total_3mm_1559subj", # for >= v2.0.4
# "Dataset297_TotalSegmentator_total_3mm_1559subj_v204", # for >= v2.0.5
# "Dataset297_TotalSegmentator_total_3mm_1559subj", # for >= v2.0.4
"Dataset297_TotalSegmentator_total_3mm_1559subj_v204", # for >= v2.0.5
# "Dataset298_TotalSegmentator_total_6mm_1559subj", # for >= v2.0.5
]

Expand Down Expand Up @@ -193,11 +193,12 @@ def download_pretrained_weights(task_id):
# WEIGHTS_URL = url + "/static/totalseg_v2/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"
WEIGHTS_URL = url + "/v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"
elif task_id == 297:
weights_path = config_dir / "Dataset297_TotalSegmentator_total_3mm_1559subj_v204"
weights_path = config_dir / "Dataset297_TotalSegmentator_total_3mm_1559subj"
# weights_path = config_dir / "Dataset297_TotalSegmentator_total_3mm_1559subj_v204"
# WEIGHTS_URL = "https://zenodo.org/record/6802052/files/Task256_TotalSegmentator_3mm_1139subj.zip?download=1"
# WEIGHTS_URL = url + "/static/totalseg_v2/Dataset297_TotalSegmentator_total_3mm_1559subj.zip"
# WEIGHTS_URL = url + "/v2.0.0-weights/Dataset297_TotalSegmentator_total_3mm_1559subj.zip" # v200
WEIGHTS_URL = url + "/v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"
WEIGHTS_URL = url + "/v2.0.0-weights/Dataset297_TotalSegmentator_total_3mm_1559subj.zip" # v200
# WEIGHTS_URL = url + "/v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"
# WEIGHTS_URL = url + "/v2.0.5-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v205.zip"
elif task_id == 298:
weights_path = config_dir / "Dataset298_TotalSegmentator_total_6mm_1559subj"
Expand Down
91 changes: 47 additions & 44 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

from totalsegmentator.libs import nostdout

# todo important: change
# with nostdout():
# from nnunetv2.inference.predict_from_raw_data import predict_from_raw_data
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
# nnUNet 2.1
with nostdout():
from nnunetv2.inference.predict_from_raw_data import predict_from_raw_data
# nnUNet 2.2
# from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

from nnunetv2.utilities.file_path_utilities import get_output_folder

from totalsegmentator.map_to_binary import class_map, class_map_5_parts, map_taskid_to_partname
Expand Down Expand Up @@ -172,46 +174,47 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
part_id = 0
allow_tqdm = not quiet

# predict_from_raw_data(dir_in,
# dir_out,
# model_folder,
# folds,
# step_size,
# use_gaussian=True,
# use_mirroring=not disable_tta,
# perform_everything_on_gpu=True,
# verbose=verbose,
# save_probabilities=save_probabilities,
# overwrite=not continue_prediction,
# checkpoint_name=chk,
# num_processes_preprocessing=npp,
# num_processes_segmentation_export=nps,
# folder_with_segs_from_prev_stage=prev_stage_predictions,
# num_parts=num_parts,
# part_id=part_id,
# device=device)


predictor = nnUNetPredictor(
tile_step_size=step_size,
use_gaussian=True,
use_mirroring=not disable_tta,
perform_everything_on_gpu=True,
device=device,
verbose=verbose,
verbose_preprocessing=verbose,
allow_tqdm=allow_tqdm
)
predictor.initialize_from_trained_model_folder(
model_folder,
use_folds=folds,
checkpoint_name=chk,
)
predictor.predict_from_files(dir_in, dir_out,
save_probabilities=save_probabilities, overwrite=not continue_prediction,
num_processes_preprocessing=npp, num_processes_segmentation_export=nps,
folder_with_segs_from_prev_stage=prev_stage_predictions,
num_parts=num_parts, part_id=part_id)
# nnUNet 2.1
predict_from_raw_data(dir_in,
dir_out,
model_folder,
folds,
step_size,
use_gaussian=True,
use_mirroring=not disable_tta,
perform_everything_on_gpu=True,
verbose=verbose,
save_probabilities=save_probabilities,
overwrite=not continue_prediction,
checkpoint_name=chk,
num_processes_preprocessing=npp,
num_processes_segmentation_export=nps,
folder_with_segs_from_prev_stage=prev_stage_predictions,
num_parts=num_parts,
part_id=part_id,
device=device)

# nnUNet 2.2
# predictor = nnUNetPredictor(
# tile_step_size=step_size,
# use_gaussian=True,
# use_mirroring=not disable_tta,
# perform_everything_on_gpu=True,
# device=device,
# verbose=verbose,
# verbose_preprocessing=verbose,
# allow_tqdm=allow_tqdm
# )
# predictor.initialize_from_trained_model_folder(
# model_folder,
# use_folds=folds,
# checkpoint_name=chk,
# )
# predictor.predict_from_files(dir_in, dir_out,
# save_probabilities=save_probabilities, overwrite=not continue_prediction,
# num_processes_preprocessing=npp, num_processes_segmentation_export=nps,
# folder_with_segs_from_prev_stage=prev_stage_predictions,
# num_parts=num_parts, part_id=part_id)


def save_segmentation_nifti(class_map_item, tmp_dir=None, file_out=None, nora_tag=None, header=None, task_name=None, quiet=None):
Expand Down
5 changes: 2 additions & 3 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,8 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
if fast:
task_id = 297
resample = 3.0
# todo: change back to 4000 epochs in next release
# trainer = "nnUNetTrainer_4000epochs_NoMirroring"
trainer = "nnUNetTrainerNoMirroring"
trainer = "nnUNetTrainer_4000epochs_NoMirroring"
# trainer = "nnUNetTrainerNoMirroring"
crop = None
if not quiet: print("Using 'fast' option: resampling to lower resolution (3mm)")
else:
Expand Down

0 comments on commit 08f6fe3

Please sign in to comment.