Skip to content

Commit

Permalink
update maisi contorlnet train config file (Project-MONAI#1762)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Refactor the controlnet train configuration to reuse the network
definition in config_maisi.json, which can reduce redundancy.

### Checks
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Avoid including large-size files in the PR.
- [ ] Clean up long text outputs from code cells in the notebook.
- [ ] For security purposes, please check the contents and remove any
sensitive info such as user names and private key.
- [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use
relative paths for tutorial repo files (3) put figure and graphs in the
`./figure` folder
- [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>`

Signed-off-by: Pengfei Guo <pengfeig@nvidia.com>
  • Loading branch information
guopengf authored Jul 23, 2024
1 parent 5160f8a commit 76276f0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 69 deletions.
68 changes: 0 additions & 68 deletions generative/maisi/configs/config_maisi_controlnet_train.json
Original file line number Diff line number Diff line change
@@ -1,72 +1,4 @@
{
"random_seed": null,
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 4,
"diffusion_unet_def": {
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"include_top_region_index_input": true,
"include_bottom_region_index_input": true,
"include_spacing_input": true
},
"controlnet_def": {
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"conditioning_embedding_in_channels": 8,
"conditioning_embedding_num_channels": [8, 32, 64]
},
"noise_scheduler": {
"_target_": "generative.networks.schedulers.DDPMScheduler",
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195,
"schedule": "scaled_linear_beta",
"clip_sample": false
},
"controlnet_train": {
"batch_size": 1,
"cache_rate": 0.0,
Expand Down
11 changes: 10 additions & 1 deletion generative/maisi/scripts/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ def main():
parser.add_argument(
"-c",
"--config-file",
default="./configs/config_maisi.json",
help="config json file that stores network hyper-parameters",
)
parser.add_argument(
"-t",
"--training-config",
default="./configs/config_maisi_controlnet_train.json",
help="config json file that stores hyper-parameters",
help="config json file that stores training hyper-parameters",
)
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
args = parser.parse_args()
Expand All @@ -66,11 +72,14 @@ def main():

env_dict = json.load(open(args.environment_file, "r"))
config_dict = json.load(open(args.config_file, "r"))
training_config_dict = json.load(open(args.training_config, "r"))

for k, v in env_dict.items():
setattr(args, k, v)
for k, v in config_dict.items():
setattr(args, k, v)
for k, v in training_config_dict.items():
setattr(args, k, v)

# initialize tensorboard writer
if rank == 0:
Expand Down

0 comments on commit 76276f0

Please sign in to comment.