Skip to content

Commit

Permalink
add labelmap even if empty nifti output
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Aug 23, 2024
1 parent 4c507d9 commit 3186eb1
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
if type(resample) is float:
resample = [resample, resample, resample]

if v1_order and task_name == "total":
label_map = class_map["total_v1"]
else:
label_map = class_map[task_name]

# Keep only voxel values corresponding to the roi_subset
if roi_subset is not None:
label_map = {k: v for k, v in label_map.items() if v in roi_subset}

# for debugging
# tmp_dir = file_in.parent / ("nnunet_tmp_" + ''.join(random.Random().choices(string.ascii_uppercase + string.digits, k=8)))
# (tmp_dir).mkdir(exist_ok=True)
Expand Down Expand Up @@ -384,6 +393,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
if not quiet:
print("INFO: Crop is empty. Returning empty segmentation.")
img_out = nib.Nifti1Image(np.zeros(img_in.shape, dtype=np.uint8), img_in.affine)
img_out = add_label_map_to_nifti(img_out, label_map)
nib.save(img_out, file_out)
if nora_tag != "None":
subprocess.call(f"/opt/nora/src/node/nora -p {nora_tag} --add {file_out} --addtag atlas", shell=True)
Expand Down Expand Up @@ -603,13 +613,9 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
# Reorder labels if needed
if v1_order and task_name == "total":
img_data = reorder_multilabel_like_v1(img_data, class_map["total"], class_map["total_v1"])
label_map = class_map["total_v1"]
else:
label_map = class_map[task_name]

# Keep only voxel values corresponding to the roi_subset
if roi_subset is not None:
label_map = {k: v for k, v in label_map.items() if v in roi_subset}
img_data *= np.isin(img_data, list(label_map.keys()))

# Prepare output nifti
Expand Down

0 comments on commit 3186eb1

Please sign in to comment.