From 76276f02e9f0b17c9931de88e0aa5c3351950e9b Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:25:43 -0400 Subject: [PATCH] update maisi contorlnet train config file (#1762) Fixes # . ### Description Refactor the controlnet train configuration to reuse the network definition in config_maisi.json, which can reduce redundancy. ### Checks - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t ` Signed-off-by: Pengfei Guo --- .../config_maisi_controlnet_train.json | 68 ------------------- generative/maisi/scripts/train_controlnet.py | 11 ++- 2 files changed, 10 insertions(+), 69 deletions(-) diff --git a/generative/maisi/configs/config_maisi_controlnet_train.json b/generative/maisi/configs/config_maisi_controlnet_train.json index 69fb5855e6..50adf9a478 100644 --- a/generative/maisi/configs/config_maisi_controlnet_train.json +++ b/generative/maisi/configs/config_maisi_controlnet_train.json @@ -1,72 +1,4 @@ { - "random_seed": null, - "spatial_dims": 3, - "image_channels": 1, - "latent_channels": 4, - "diffusion_unet_def": { - "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", - "spatial_dims": "@spatial_dims", - "in_channels": "@latent_channels", - "out_channels": "@latent_channels", - "num_channels": [ - 64, - 128, - 256, - 512 - ], - "attention_levels": [ - false, - false, - true, - true - ], - "num_head_channels": [ - 0, - 0, - 32, - 32 - ], - "num_res_blocks": 2, - "use_flash_attention": true, - "include_top_region_index_input": true, - "include_bottom_region_index_input": true, - "include_spacing_input": true - }, - "controlnet_def": { - "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", - "spatial_dims": "@spatial_dims", - "in_channels": "@latent_channels", - "num_channels": [ - 64, - 128, - 256, - 512 - ], - "attention_levels": [ - false, - false, - true, - true - ], - "num_head_channels": [ - 0, - 0, - 32, - 32 - ], - "num_res_blocks": 2, - "use_flash_attention": true, - "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [8, 32, 64] - }, - "noise_scheduler": { - "_target_": "generative.networks.schedulers.DDPMScheduler", - "num_train_timesteps": 1000, - "beta_start": 0.0015, - "beta_end": 0.0195, - "schedule": "scaled_linear_beta", - "clip_sample": false - }, "controlnet_train": { "batch_size": 1, "cache_rate": 0.0, diff --git a/generative/maisi/scripts/train_controlnet.py b/generative/maisi/scripts/train_controlnet.py index 40f4ea3a1f..f059fe205e 100644 --- a/generative/maisi/scripts/train_controlnet.py +++ b/generative/maisi/scripts/train_controlnet.py @@ -40,8 +40,14 @@ def main(): parser.add_argument( "-c", "--config-file", + default="./configs/config_maisi.json", + help="config json file that stores network hyper-parameters", + ) + parser.add_argument( + "-t", + "--training-config", default="./configs/config_maisi_controlnet_train.json", - help="config json file that stores hyper-parameters", + help="config json file that stores training hyper-parameters", ) parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") args = parser.parse_args() @@ -66,11 +72,14 @@ def main(): env_dict = json.load(open(args.environment_file, "r")) config_dict = json.load(open(args.config_file, "r")) + training_config_dict = json.load(open(args.training_config, "r")) for k, v in env_dict.items(): setattr(args, k, v) for k, v in config_dict.items(): setattr(args, k, v) + for k, v in training_config_dict.items(): + setattr(args, k, v) # initialize tensorboard writer if rank == 0: