Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch FX][SD3] Add Dynamic Shapes #2656

Draft
wants to merge 12 commits into
base: latest
Choose a base branch
from
31 changes: 21 additions & 10 deletions notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,17 @@
"unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n",
"unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n",
"\n",
"# Feature map height and width are dynamic\n",
"fm_height = torch.export.Dim(\"fm_height\", min=16, max=256)\n",
"fm_width = torch.export.Dim(\"fm_width\", min=16, max=256)\n",
"dim = torch.export.Dim(\"dim\", min=1, max=16)\n",
"fm_height = 16 * dim\n",
"fm_width = 16 * dim\n",
"\n",
"dynamic_shapes = {\"sample\": {2: fm_height, 3: fm_width}}\n",
"# iterate through the unet kwargs and set only hidden state kwarg to dynamic\n",
"dynamic_shapes_transformer = {key: (None if key != \"hidden_states\" else {2: fm_height, 3: fm_width}) for key in unet_kwargs.keys()}\n",
"\n",
"with torch.no_grad():\n",
" with disable_patching():\n",
" text_encoder = torch.export.export_for_training(\n",
Expand All @@ -308,10 +319,12 @@
" args=(text_encoder_input,),\n",
" kwargs=(text_encoder_kwargs),\n",
" ).module()\n",
" pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,)).module()\n",
" pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,)).module()\n",
" pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,), dynamic_shapes=dynamic_shapes).module()\n",
" pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,), dynamic_shapes=dynamic_shapes).module()\n",
" vae = pipe.vae\n",
" transformer = torch.export.export_for_training(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs)).module()\n",
" transformer = torch.export.export_for_training(\n",
" pipe.transformer.eval(), args=(), kwargs=(unet_kwargs), dynamic_shapes=dynamic_shapes_transformer\n",
" ).module()\n",
"models_dict = {}\n",
"models_dict[\"transformer\"] = transformer\n",
"models_dict[\"vae\"] = vae\n",
Expand Down Expand Up @@ -450,8 +463,6 @@
" ).shuffle(seed=42)\n",
"\n",
" transformer_config = dict(pipe.transformer.config)\n",
" if \"model\" in transformer_config:\n",
" del transformer_config[\"model\"]\n",
" wrapped_unet = UNetWrapper(pipe.transformer.model, transformer_config)\n",
" pipe.transformer = wrapped_unet\n",
" # Run inference for data collection\n",
Expand Down Expand Up @@ -517,10 +528,10 @@
"if to_quantize:\n",
" with disable_patching():\n",
" with torch.no_grad():\n",
" nncf.compress_weights(text_encoder)\n",
" nncf.compress_weights(text_encoder_2)\n",
" nncf.compress_weights(vae_encoder)\n",
" nncf.compress_weights(vae_decoder)\n",
" text_encoder = nncf.compress_weights(text_encoder)\n",
" text_encoder_2 = nncf.compress_weights(text_encoder_2)\n",
" vae_encoder = nncf.compress_weights(vae_encoder)\n",
" vae_decoder = nncf.compress_weights(vae_decoder)\n",
" quantized_transformer = nncf.quantize(\n",
" model=original_transformer,\n",
" calibration_dataset=nncf.Dataset(unet_calibration_data),\n",
Expand Down Expand Up @@ -766,7 +777,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand Down
Loading