Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch FX][SD3] Add Dynamic Shapes #2656

Draft
wants to merge 12 commits into
base: latest
Choose a base branch
from

Conversation

anzr299
Copy link
Contributor

@anzr299 anzr299 commented Jan 15, 2025

Add Dynamic shapes to export of transformer model in SD3 Torch FX notebook.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@eaidova
Copy link
Collaborator

eaidova commented Jan 15, 2025

@anzr299 looks like something went wrong

Cell In[11], line 48, in collect_calibration_data(pipe, calibration_dataset_size, num_inference_steps)
     41 dataset = datasets.load_dataset(
     42     "google-research-datasets/conceptual_captions",
     43     split="train",
     44     trust_remote_code=True,
     45 ).shuffle(seed=42)
     47 transformer_config = dict(pipe.transformer.config)
---> 48 del transformer_config["model"]
     49 wrapped_unet = UNetWrapper(pipe.transformer.model, transformer_config)
     50 pipe.transformer = wrapped_unet

KeyError: 'model'

could you please fix?

@eaidova
Copy link
Collaborator

eaidova commented Jan 21, 2025

@anzr299, I made rude fix in this commit: 894d859#diff-ced57002da68fe81391edf1eab6f32e3320f21f35656a2211871d180924b1f27

if "model" in config:
   del config["model"]

but not sure that it s correct fix, lack of expected key in config may hide some other issues (e.g. applying patching several times or changes in config structure).I think it will be better if you look on it

@eaidova
Copy link
Collaborator

eaidova commented Jan 21, 2025

Now, next error happens:

RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

@anzr299
Copy link
Contributor Author

anzr299 commented Jan 23, 2025

Sorry for being unresposnive, I am looking into this issue now.

@anzr299
Copy link
Contributor Author

anzr299 commented Jan 23, 2025

@anzr299, I made rude fix in this commit: 894d859#diff-ced57002da68fe81391edf1eab6f32e3320f21f35656a2211871d180924b1f27

if "model" in config:
   del config["model"]

but not sure that it s correct fix, lack of expected key in config may hide some other issues (e.g. applying patching several times or changes in config structure).I think it will be better if you look on it

It actually slipped in from an older version I had locally. I have removed it now.

Now, next error happens:

This looks like it is caused due to dynamic shapes in torch export

@anzr299 anzr299 marked this pull request as draft January 24, 2025 06:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants