From 3186eb10261be4c5658beb4d632ec99f3e081c5f Mon Sep 17 00:00:00 2001 From: wasserth Date: Fri, 23 Aug 2024 14:58:50 +0200 Subject: [PATCH] add labelmap even if empty nifti output --- totalsegmentator/nnunet.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/totalsegmentator/nnunet.py b/totalsegmentator/nnunet.py index 2c2b7894d..40ee2d92a 100644 --- a/totalsegmentator/nnunet.py +++ b/totalsegmentator/nnunet.py @@ -327,6 +327,15 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_ if type(resample) is float: resample = [resample, resample, resample] + if v1_order and task_name == "total": + label_map = class_map["total_v1"] + else: + label_map = class_map[task_name] + + # Keep only voxel values corresponding to the roi_subset + if roi_subset is not None: + label_map = {k: v for k, v in label_map.items() if v in roi_subset} + # for debugging # tmp_dir = file_in.parent / ("nnunet_tmp_" + ''.join(random.Random().choices(string.ascii_uppercase + string.digits, k=8))) # (tmp_dir).mkdir(exist_ok=True) @@ -384,6 +393,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_ if not quiet: print("INFO: Crop is empty. Returning empty segmentation.") img_out = nib.Nifti1Image(np.zeros(img_in.shape, dtype=np.uint8), img_in.affine) + img_out = add_label_map_to_nifti(img_out, label_map) nib.save(img_out, file_out) if nora_tag != "None": subprocess.call(f"/opt/nora/src/node/nora -p {nora_tag} --add {file_out} --addtag atlas", shell=True) @@ -603,13 +613,9 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_ # Reorder labels if needed if v1_order and task_name == "total": img_data = reorder_multilabel_like_v1(img_data, class_map["total"], class_map["total_v1"]) - label_map = class_map["total_v1"] - else: - label_map = class_map[task_name] # Keep only voxel values corresponding to the roi_subset if roi_subset is not None: - label_map = {k: v for k, v in label_map.items() if v in roi_subset} img_data *= np.isin(img_data, list(label_map.keys())) # Prepare output nifti