Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maisi readme #1743

Merged
merged 63 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
c97509f
add readme
Can-Zhao Jun 27, 2024
544ee06
add readme
Can-Zhao Jun 27, 2024
033e614
add readme
Can-Zhao Jun 27, 2024
fce7a02
Merge branch 'main' into maisi_readme
Can-Zhao Jun 27, 2024
d315e0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
1af6f91
correct typo
Can-Zhao Jun 27, 2024
0f2a0e1
Merge branch 'maisi_readme' of https://github.com/Can-Zhao/tutorials …
Can-Zhao Jun 27, 2024
8959b8b
add mri training data number
Can-Zhao Jun 28, 2024
cdc6ae7
add more details for inference
Can-Zhao Jun 28, 2024
aeb301e
Merge branch 'main' into maisi_readme
guopengf Jul 3, 2024
31ffdf5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
58f637f
test commit
guopengf Jul 3, 2024
9648aef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
03f905a
Merge branch 'main' into maisi_readme
Can-Zhao Jul 4, 2024
87305e6
Merge branch 'main' into maisi_readme
Can-Zhao Jul 8, 2024
3fea7e3
Merge branch 'main' into maisi_readme
guopengf Jul 9, 2024
9e0c8df
update controlnet readme
guopengf Jul 9, 2024
3f18e85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2024
bd030b7
Merge branch 'main' into maisi_readme
mingxin-zheng Jul 10, 2024
6aee4ff
Merge branch 'main' into maisi_readme
Can-Zhao Jul 11, 2024
3d94f1a
Merge branch 'Project-MONAI:main' into maisi_readme
Can-Zhao Jul 12, 2024
7ebfba7
Merge branch 'Project-MONAI:main' into maisi_readme
Can-Zhao Jul 13, 2024
7ea4767
Update readme for highlight, infer, and vae
Can-Zhao Jul 13, 2024
746cfa8
Update readme for highlight, infer, and vae
Can-Zhao Jul 13, 2024
914022c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 13, 2024
b767492
Update readme for highlight, infer, and vae
Can-Zhao Jul 13, 2024
73f2edc
Merge branch 'maisi_readme' of https://github.com/Can-Zhao/tutorials …
Can-Zhao Jul 13, 2024
54bf9d8
update controlnet readme
guopengf Jul 13, 2024
9aadd0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 13, 2024
cba35c0
update controlnet readme
guopengf Jul 13, 2024
279f8ea
update vae readme, update vae botebook
Can-Zhao Jul 15, 2024
fc6e869
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
1b02466
Merge branch 'main' into maisi_readme
mingxin-zheng Jul 15, 2024
d1bdd4c
resolve conflict
Can-Zhao Jul 17, 2024
383e0dd
Merge branch 'maisi_readme' of https://github.com/Can-Zhao/tutorials …
Can-Zhao Jul 17, 2024
3c41b60
add description about ARM64
Can-Zhao Jul 17, 2024
8d56497
add description about ARM64
Can-Zhao Jul 17, 2024
880945d
add description about VAE data
Can-Zhao Jul 19, 2024
1cd46e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
aa7b524
Merge branch 'main' into maisi_readme
guopengf Jul 23, 2024
23090c2
update readme
guopengf Jul 23, 2024
ae3875b
Merge branch 'main' into maisi_readme
guopengf Jul 23, 2024
8d3e4f6
add detail info on vae data
Can-Zhao Jul 23, 2024
4deda70
add detail info on vae data
Can-Zhao Jul 23, 2024
5de4aa9
typo
Can-Zhao Jul 23, 2024
a1aaba5
typo
Can-Zhao Jul 23, 2024
994008b
update controlnet part
guopengf Jul 26, 2024
028fe7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2024
37a48a5
Update generative/maisi/data/README.md
guopengf Jul 26, 2024
09c741e
Merge branch 'main' into maisi_readme
KumoLiu Jul 29, 2024
43c8542
update
dongyang0122 Jul 29, 2024
861ab9f
update
dongyang0122 Jul 29, 2024
452210e
update
dongyang0122 Jul 29, 2024
4cdc0e1
update
dongyang0122 Jul 29, 2024
4a36743
update
dongyang0122 Jul 29, 2024
4f89f54
update
dongyang0122 Jul 29, 2024
c0a4355
update
dongyang0122 Jul 29, 2024
ce271d5
update
dongyang0122 Jul 29, 2024
6a36781
update
dongyang0122 Jul 29, 2024
a655980
update
guopengf Jul 31, 2024
bb3d46a
update
guopengf Jul 31, 2024
97738a7
update license
guopengf Jul 31, 2024
7274da7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions generative/maisi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Medical AI for Synthetic Imaging (MAISI)
This example shows the use cases of training and validating NVIDIA MAISI (Medical AI for Synthetic Imaging), a 3D Latent Diffusion Model that can generate large CT images with paired segmentation masks, variable volume size and voxel size, as well as controllable organ/tumor size.

## MAISI Model Highlight
- A Foundation VAE model for latent feature compression that works for both CT and MRI with flexible volume size and voxel size
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved
- A Foundation Diffusion model that can generate large CT volumes up to 512x512x768 size, with flexible volume size and voxel size
- A ControlNet to generate image/mask pairs that can improve downstream tasks, with controllable organ/tumor size

## Example Results and Evaluation
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved

## MAISI Model Workflow
The training and inference workflows of MAISI are depicted in the figure below. It begins by training an autoencoder in pixel space to encode images into latent features. Following that, it trains a diffusion model in the latent space to denoise the noisy latent features. During inference, it first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Finally, it decodes the denoised latent features into images using the trained autoencoder.
<p align="center">
<img src="./figures/maisi_train.jpg" alt="MAISI training scheme">
<br>
<em>Figure 1: MAISI training scheme</em>
</p>

<p align="center">
<img src="./figures/maisi_infer.jpg" alt="MAISI inference scheme")
<br>
<em>Figure 2: MAISI inference scheme</em>
</p>
MAISI is based on the following papers:

[**Latent Diffusion:** Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf)

[**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf)

### 1. Installation
Please refer to the [Installation of MONAI Generative Model](../README.md).

Note: MAISI depends on [xFormers](https://github.com/facebookresearch/xformers) library, which unfortunately does not yet support ARM64.
We will update after xFormers supports ARM64.
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved

### 2. Model inference and example outputs
Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.

### 3. Training example
Training data preparation can be found in [./data/README.md](./data/README.md)

#### [3.1 3D Autoencoder Training](./train_autoencoder.py)
Please refer to [maisi_train_vae_tutorial.ipynb](maisi_train_vae_tutorial.ipynb) for the tutorial for MAISI VAE model training.

#### [3.2 3D Latent Diffusion Training](./train_diffusion.py)
The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the `"batch_size"` and `"patch_size"` parameters in the `"diffusion_train"` to match your GPU. Note that the `"patch_size"` needs to be divisible by 16.

To train with single 32G GPU, please run:
```bash
python train_diffusion.py -c ./config/config_maisi.json -e ./config/environment.json -g 1
```

The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command:
```bash
export NUM_GPUS_PER_NODE=8
torchrun \
--nproc_per_node=${NUM_GPUS_PER_NODE} \
--nnodes=1 \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_maisi.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
guopengf marked this conversation as resolved.
Show resolved Hide resolved
```
<p align="center">
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
&nbsp; &nbsp; &nbsp; &nbsp;
<img src="./figs/val_diffusion.png" alt="latent diffusion validation curve" width="45%" >
</p>

Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved
#### [3.3 3D ControNet Training](./scripts/train_controlnet.py)
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved

We provide a [training config](./configs/config_maisi_controlnet_train.json) executing finetuning for pretrained ControlNet with with a new class (i.e., Kidney Tumor). When finetuning with other new class names, please update the `weighted_loss_label` in training config and [label_dict.json](./configs/label_dict.json) accordingly. There are 8 dummy labels as placeholders in default `label_dict.json` that can be used for finetuning. Preprocessed dataset for controNet training and more deatils anout data preparation can be found in the [README](./data/README.md).
guopengf marked this conversation as resolved.
Show resolved Hide resolved

#### Training configuration
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved
The training was performed with the following:
- GPU: at least 60GB GPU memory for 512 x 512 x 512 volume
- Actual Model Input (the size of image embedding in latent space): 128 x 128 x 128
- AMP: True

#### Execute training:
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved
To train with a single GPU, please run:
```bash
python ./scripts/train_controlnet.py -c ./config/config_maisi.json -t ./config/config_maisi_controlnet_train.json -e ./config/environment_maisi_controlnet_train.json -g 1
```

The training script also enables multi-GPU training. For instance, if you are using eight GPUs, you can run the training script with the following command:
```bash
export NUM_GPUS_PER_NODE=8
torchrun \
--nproc_per_node=${NUM_GPUS_PER_NODE} \
--nnodes=1 \
--master_addr=localhost --master_port=1234 \
./scripts/train_controlnet.py -c ./config/config_maisi.json -t ./config/config_maisi_controlnet_train.json -e ./config/environment_maisi_controlnet_train.json -g ${NUM_GPUS_PER_NODE}
```
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
### 4. Questions and bugs
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved

- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).
107 changes: 107 additions & 0 deletions generative/maisi/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Medical AI for Synthetic Imaging (MAISI) Data Preparation

Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved

### 1 VAE training Data
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved
For the released Foundation autoencoder model weights in MAISI, we used 37243 CT training data and 1963 CT validation data from chest, abdomen, head and neck region; and 17887 MRI training data and 940 MRI validation data from brain, skul-stripped brain, chest, and below-abdomen region. The training data come from [TCIA Covid 19 Chest CT](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19#70227107b92475d33ae7421a9b9c426f5bb7d5b3), [TCIA Colon Abdomen CT](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=3539213), [MSD03 Liver Abdomen CT](http://medicaldecathlon.com/), [LIDC chest CT](https://www.cancerimagingarchive.net/collection/lidc-idri/), [TCIA Stony Brook Covid Chest CT](https://www.cancerimagingarchive.net/collection/covid-19-ny-sbu/), [NLST Chest CT](https://www.cancerimagingarchive.net/collection/nlst/), [TCIA Upenn GBM Brain MR](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70225642), [Aomic Brain MR](https://openneuro.org/datasets/ds003097/versions/1.2.1), [QTIM Brain MR](https://openneuro.org/datasets/ds004169/versions/1.0.7), [TCIA Acrin Chest MR](https://www.cancerimagingarchive.net/collection/acrin-contralateral-breast-mr/), [TCIA Prostate MR Below-Abdomen MR](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=68550661#68550661a2c52df5969d435eae49b9669bea21a6).
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

In total, we included:
| Index | Dataset Name | Number of Training Data | Number of Validation Data |
|-------|------------------------------------------------|-------------------------|---------------------------|
| 1 | Covid 19 Chest CT | 722 | 49 |
| 2 | TCIA Colon Abdomen CT | 1522 | 77 |
| 3 | MSD03 Liver Abdomen CT | 104 | 0 |
| 4 | LIDC chest CT | 450 | 24 |
| 5 | TCIA Stony Brook Covid Chest CT | 2644 | 139 |
| 6 | NLST Chest CT | 31801 | 1674 |
| 7 | TCIA Upenn GBM Brain MR (skull-stripped) | 2550 | 134 |
| 8 | Aomic Brain MR | 2630 | 138 |
| 9 | QTIM Brain MR | 1275 | 67 |
| 10 | Acrin Chest MR | 6599 | 347 |
| 11 | TCIA Prostate MR Below-Abdomen MR | 928 | 49 |
| 12 | Aomic Brain MR, skull-stripped | 2630 | 138 |
| 13 | QTIM Brain MR, skull-stripped | 1275 | 67 |
| | Total CT | 37243 | 1963 |
| | Total MRI | 17887 | 940 |


### 2 Diffusion model training Data
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved


### 3 ControlNet model training Data

#### 3.1 Example preprocessed dataset

We provide the preprocessed subset of [C4KC-KiTS](https://www.cancerimagingarchive.net/collection/c4kc-kits/) dataset used in the finetuning config `environment_maisi_controlnet_train.json`. The dataset and coresponding JSON data list can be downloaded from [this link](https://drive.google.com/drive/folders/1iMStdYxcl26dEXgJEXOjkWvx-I2fYZ2u?usp=sharing) and should be saved in `maisi/dataset/` folder.
guopengf marked this conversation as resolved.
Show resolved Hide resolved

The structure of example folder in the preprocessed dataset is:
```
|-*arterial*.nii.gz # original image
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved
|-*arterial_emb*.nii.gz # encoded image embedding
KiTS-000* --|-mask*.nii.gz # original labels
|-mask_pseudo_label*.nii.gz # pseudo labels
|-mask_combined_label*.nii.gz # combined mask of original and pseudo labels

```
An example combined mask of original and pseudo labels is shown below:
guopengf marked this conversation as resolved.
Show resolved Hide resolved
![example_combined_mask](../figures/example_combined_mask.png)

Please note that the label of Kidney Tumor is mapped to index `129` in this preprocessed dataset. The encoded image embedding is generated by provided `Autoencoder` in `./models/autoencoder_epoch273.pt` during preprocessing to save memeory usage for training. The pseudo labels are generated by [VISTA 3D](https://github.com/Project-MONAI/VISTA). In addition, the dimension of each volume and corresponding pseudo label is resampled to the closest multiple of 128 (e.g., 128, 256, 384, 512, ...).

The training workflow requires one JSON file to specify the image embedding and segmentation pairs. The example file is located in the `maisi/dataset/C4KC-KiTS_subset.json`.

The JSON file has the following structure:
```python
{
"training": [
{
"image": "*/*arterial_emb*.nii.gz", # relative path to the image embedding file
"label": "*/mask_combined_label*.nii.gz", # relative path to the combined label file
"dim": [512, 512, 512], # the dimension of image
"spacing": [1.0, 1.0, 1.0], # the spacing of image
"top_region_index": [0, 1, 0, 0], # the top region index of the image
"bottom_region_index": [0, 0, 0, 1], # the bottom region index of the image
"fold": 0 # fold index for cross validation, fold 0 is used for training
},

...
]
}
```

#### 3.2 Controlnet full training datasets
The ControlNet training dataset used in MAISI contains 6330 CT volumes (5058 and 1272 volumes are used for training and validation, respectively) across 20 datasets and covers different body regions and diseases.

The table below summarizes the number of volumes for each dataset.

|Index| Dataset name|Number of volumes|
|:-----|:-----|:-----|
1 | AbdomenCT-1K | 789
2 | AeroPath | 15
3 | AMOS22 | 240
4 | Bone-Lesion | 237
5 | BTCV | 48
6 | CT-ORG | 94
7 | CTPelvic1K-CLINIC | 94
8 | LIDC | 422
9 | MSD Task03 | 105
10 | MSD Task06 | 50
11 | MSD Task07 | 225
12 | MSD Task08 | 235
13 | MSD Task09 | 33
14 | MSD Task10 | 101
15 | Multi-organ-Abdominal-CT | 64
16 | Pancreas-CT | 51
17 | StonyBrook-CT | 1258
18 | TCIA_Colon | 1436
19 | TotalSegmentatorV2 | 654
20| VerSe | 179

### 4. Questions and bugs

- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).

### Reference
[1] [Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/maisi/figures/maisi_infer.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/maisi/figures/maisi_train.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions generative/maisi/maisi_train_vae_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,10 @@
" print(f\"Epoch {epoch} train_vae_loss {loss_weighted_sum(train_epoch_losses)}: {train_epoch_losses}.\")\n",
" for loss_name, loss_value in train_epoch_losses.items():\n",
" tensorboard_writer.add_scalar(f\"train_{loss_name}_epoch\", loss_value, epoch)\n",
" torch.save(autoencoder.state_dict(), trained_g_path)\n",
" torch.save(discriminator.state_dict(), trained_d_path)\n",
" print(\"Save trained autoencoder to\", trained_g_path)\n",
" print(\"Save trained discriminator to\", trained_d_path)\n",
"\n",
" # Validation\n",
" if epoch % val_interval == 0:\n",
Expand All @@ -891,12 +895,8 @@
" for key in val_epoch_losses:\n",
" val_epoch_losses[key] /= len(dataloader_val)\n",
"\n",
" torch.save(autoencoder.state_dict(), trained_g_path)\n",
" torch.save(discriminator.state_dict(), trained_d_path)\n",
" val_loss_g = loss_weighted_sum(val_epoch_losses)\n",
" print(f\"Epoch {epoch} val_vae_loss {val_loss_g}: {val_epoch_losses}.\")\n",
" print(\"Save trained autoencoder to\", trained_g_path)\n",
" print(\"Save trained discriminator to\", trained_d_path)\n",
"\n",
" if val_loss_g < best_val_recon_epoch_loss:\n",
" best_val_recon_epoch_loss = val_loss_g\n",
Expand Down
Loading