Skip to content

Commit

Permalink
MAISI Quality check (Project-MONAI#1789)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#1791 .

### Description
Add MAISI Quality check algorithm

Add suggested spacing and output_size to generate results with better
quality.

### Checks
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Avoid including large-size files in the PR.
- [ ] Clean up long text outputs from code cells in the notebook.
- [ ] For security purposes, please check the contents and remove any
sensitive info such as user names and private key.
- [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use
relative paths for tutorial repo files (3) put figure and graphs in the
`./figure` folder
- [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>`

---------

Signed-off-by: Can-Zhao <volcanofly@gmail.com>
Signed-off-by: Can Zhao <69829124+Can-Zhao@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Can-Zhao and pre-commit-ci[bot] authored Aug 21, 2024
1 parent 7f397e5 commit 2ec36f2
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 51 deletions.
11 changes: 7 additions & 4 deletions generation/maisi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,30 @@ MAISI is based on the following papers:
Network definition is stored in [./configs/config_maisi.json](./configs/config_maisi.json). Training and inference should use the same [./configs/config_maisi.json](./configs/config_maisi.json).

### 2. Model Inference
#### Inference parameters:
The information for the inference input, like body region and anatomy to generate, is stored in [./configs/config_infer.json](./configs/config_infer.json). Please feel free to play with it. Here are the details of the parameters.

- `"num_output_samples"`: int, the number of output image/mask pairs it will generate.
- `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5x1.5x2.0 mm.
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512x512x256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers.
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512x512x256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768x768x512 mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least 384mm for other body regions like abdomen. There is no such restriction for z-axis.
- `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. The output will contain paired image and segmentation mask for the controllable anatomy.
- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
- `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
- `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
- `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.


Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.

#### Execute Inference:
To run the inference script, please run:
```bash
export MONAI_DATA_DIRECTORY=<dir_you_will_download_data>
python -m scripts.inference -c ./configs/config_maisi.json -i ./configs/config_infer.json -e ./configs/environment.json --random-seed 0
```

Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.

#### Quality Check:
We have implemented a quality check function for the generated CT images. The main idea behind this function is to ensure that the Hounsfield units (HU) intensity for each organ in the CT images remains within a defined range. For each training image used in the Diffusion network, we computed the median value for a few major organs. Then we summarize the statistics of these median values and save it to [./configs/image_median_statistics.json](./configs/image_median_statistics.json). During inference, for each generated image, we compute the median HU values for the major organs and check whether they fall within the normal range.

### 3. Model Training
Training data preparation can be found in [./data/README.md](./data/README.md)

Expand Down
72 changes: 72 additions & 0 deletions generation/maisi/configs/image_median_statistics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
{
"liver": {
"min_median": -14.0,
"max_median": 1000.0,
"percentile_0_5": 9.530000000000001,
"percentile_99_5": 162.0,
"sigma_6_low": -21.596463547885904,
"sigma_6_high": 156.27881534763367,
"sigma_12_low": -110.53410299564568,
"sigma_12_high": 245.21645479539342
},
"spleen": {
"min_median": -69.0,
"max_median": 1000.0,
"percentile_0_5": 16.925000000000004,
"percentile_99_5": 184.07500000000073,
"sigma_6_low": -43.133891656525165,
"sigma_6_high": 177.40494997185993,
"sigma_12_low": -153.4033124707177,
"sigma_12_high": 287.6743707860525
},
"pancreas": {
"min_median": -124.0,
"max_median": 1000.0,
"percentile_0_5": -29.0,
"percentile_99_5": 145.92000000000007,
"sigma_6_low": -56.59382515620725,
"sigma_6_high": 149.50627399318438,
"sigma_12_low": -159.64387473090306,
"sigma_12_high": 252.5563235678802
},
"kidney": {
"min_median": -165.5,
"max_median": 819.0,
"percentile_0_5": -40.0,
"percentile_99_5": 254.61999999999898,
"sigma_6_low": -130.56375604853028,
"sigma_6_high": 267.28163511081016,
"sigma_12_low": -329.4864516282005,
"sigma_12_high": 466.20433069048045
},
"lung": {
"min_median": -1000.0,
"max_median": 65.0,
"percentile_0_5": -937.0,
"percentile_99_5": -366.9500000000007,
"sigma_6_low": -1088.5583843889117,
"sigma_6_high": -551.8503346949108,
"sigma_12_low": -1356.912409235912,
"sigma_12_high": -283.4963098479103
},
"bone": {
"min_median": 77.5,
"max_median": 1000.0,
"percentile_0_5": 136.45499999999998,
"percentile_99_5": 551.6350000000002,
"sigma_6_low": 71.39901958080469,
"sigma_6_high": 471.9957615639765,
"sigma_12_low": -128.8993514107812,
"sigma_12_high": 672.2941325555623
},
"brain": {
"min_median": -1000.0,
"max_median": 238.0,
"percentile_0_5": -951.0,
"percentile_99_5": 126.25,
"sigma_6_low": -304.8208236135867,
"sigma_6_high": 369.5118535139189,
"sigma_12_low": -641.9871621773394,
"sigma_12_high": 706.6781920776717
}
}
2 changes: 1 addition & 1 deletion generation/maisi/scripts/diff_model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from .diff_model_setting import initialize_distributed, load_config, setup_logging
from .sample import ReconModel
from .utils import define_instance, load_autoencoder_ckpt
from .utils import define_instance


def set_random_seed(seed: int) -> int:
Expand Down
2 changes: 1 addition & 1 deletion generation/maisi/scripts/infer_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.utils import RankFilter

from .sample import ldm_conditional_sample_one_image
from .utils import define_instance, load_autoencoder_ckpt, prepare_maisi_controlnet_json_dataloader, setup_ddp
from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp


@torch.inference_mode()
Expand Down
147 changes: 147 additions & 0 deletions generation/maisi/scripts/quality_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nibabel as nib
import numpy as np


def get_masked_data(label_data, image_data, labels):
"""
Extracts and returns the image data corresponding to specified labels within a 3D volume.
This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array.
The function handles cases with both a large and small number of labels, optimizing performance accordingly.
Args:
label_data (np.ndarray): A NumPy array containing label data, representing different anatomical
regions or classes in a 3D medical image.
image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions
will be extracted.
labels (list of int): A list of integers representing the label values to be used for masking.
Returns:
np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified
labels in `label_data`. If no labels are provided, an empty array is returned.
Raises:
ValueError: If `image_data` and `label_data` do not have the same shape.
Example:
label_int_dict = {"liver": [1], "kidney": [5, 14]}
masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"])
"""

# Check if the shapes of image_data and label_data match
if image_data.shape != label_data.shape:
raise ValueError(
f"Shape mismatch: image_data has shape {image_data.shape}, "
f"but label_data has shape {label_data.shape}. They must be the same."
)

if not labels:
return np.array([]) # Return an empty array if no labels are provided

labels = list(set(labels)) # remove duplicate items

# Optimize performance based on the number of labels
num_label_acceleration_thresh = 3
if len(labels) >= num_label_acceleration_thresh:
# if many labels, np.isin is faster
mask = np.isin(label_data, labels)
else:
# Use logical OR to combine masks if the number of labels is small
mask = np.zeros_like(label_data, dtype=bool)
for label in labels:
mask = np.logical_or(mask, label_data == label)

# Retrieve the masked data
masked_data = image_data[mask.astype(bool)]

return masked_data


def is_outlier(statistics, image_data, label_data, label_int_dict):
"""
Perform a quality check on the generated image by comparing its statistics with precomputed thresholds.
Args:
statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges.
image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array.
label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest.
label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists.
e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]}
Returns:
dict: A dictionary with labels as keys, each containing the quality check result,
including whether it's an outlier, the median value, and the thresholds used.
If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`.
Example:
# Example input data
statistics = {
"liver": {
"sigma_6_low": -21.596463547885904,
"sigma_6_high": 156.27881534763367
},
"kidney": {
"sigma_6_low": -15.0,
"sigma_6_high": 120.0
}
}
label_int_dict = {
"liver": [1],
"kidney": [5, 14]
}
image_data = np.random.rand(100, 100, 100) # Replace with actual image data
label_data = np.zeros((100, 100, 100)) # Replace with actual label data
label_data[40:60, 40:60, 40:60] = 1 # Example region for liver
label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney
result = is_outlier(statistics, image_data, label_data, label_int_dict)
"""
outlier_results = {}

for label_name, stats in statistics.items():
# Get the thresholds from the statistics
low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs
high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs

# Retrieve the corresponding label integers
labels = label_int_dict.get(label_name, [])
masked_data = get_masked_data(label_data, image_data, labels)
masked_data = masked_data[~np.isnan(masked_data)]

if len(masked_data) == 0 or masked_data.size == 0:
outlier_results[label_name] = {
"is_outlier": False,
"median_value": None,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}
continue

# Compute the median of the masked region
median_value = np.nanmedian(masked_data)

if np.isnan(median_value):
median_value = None
is_outlier = False
else:
# Determine if the median value is an outlier
is_outlier = median_value < low_thresh or median_value > high_thresh

outlier_results[label_name] = {
"is_outlier": is_outlier,
"median_value": median_value,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}

return outlier_results
Loading

0 comments on commit 2ec36f2

Please sign in to comment.