-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PixArt 256px inference #6789
Conversation
from diffusers import PixArtAlphaPipeline, Transformer2DModel
import torch
transformer = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-XL-2-256x256", subfolder="transformer", torch_dtype=torch.float16)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", transformer=transformer, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png") The test code is here for reference. |
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." | ||
) | ||
# set multi_scale_train=True if using PixArtMS structure during training else set it to False | ||
parser.add_argument("--multi_scale_train", default=True, type=str, required=True, help="If use Multi-Scale PixArtMS structure during training.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type is str
and we're defaulting to a bool
. This needs to be fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do we have any other 1024x1024 checkpoints that are affected by this? If not, do we really need this flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes. Left a couple of comments.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
can we fix the tests here too? |
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Is there any test failing here? |
@@ -97,6 +97,7 @@ def __init__( | |||
norm_eps: float = 1e-5, | |||
attention_type: str = "default", | |||
caption_channels: int = None, | |||
interpolation_scale: float = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not leveraging this anywhere no? Let's remove it?
@@ -228,6 +264,7 @@ def __init__( | |||
vae: AutoencoderKL, | |||
transformer: Transformer2DModel, | |||
scheduler: DPMSolverMultistepScheduler, | |||
model_token_max_length: int = 120, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu why did we decide to make this as a config variable instead of a pipeline call arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul @lawrence-cj
oh I'm not sure
Is this something we need to change each generation? e.g., if we need to adjust this value based on the prompt, I think it makes sense to add it to pipeline call arg; Otherwise, we can add it as a config, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The maximum sequence length bit should be something a user wants to experiment with. I don't think it needs to be a configuration variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice to me! Thank you ❤️
It'd be very nice to also include an example of using the 256x256 checkpoint in the PixArt-Alpha doc: https://huggingface.co/docs/diffusers/main/en/api/pipelines/pixart. WDYT?
Let's fix the quality tests :) |
import torch
from diffusers import PixArtAlphaPipeline
# You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" or "PixArt-alpha/PixArt-XL-2-256x256" too.
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
# Enable memory optimizations.
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0] How about this one? The usage of 256px is totally the same as 512px or 1024px for simplicity and efficiency. |
Yeah that should work. But we still need to fix the tests here. |
Sure. Most of the failures are cauesed by the |
@lawrence-cj let's keep |
@@ -688,6 +725,7 @@ def __call__( | |||
callback_steps: int = 1, | |||
clean_caption: bool = True, | |||
use_resolution_binning: bool = True, | |||
model_token_max_length: int = 120, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_token_max_length: int = 120, | |
max_sequence_length: int = 120, |
Let's use this variable name throughout?
Can we also add this argument to the call docstrings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very nice. Just one comment and then I think we can merge it!
Done on my side. Thanks so much. @sayakpaul @yiyixuxu |
Everything looks good. We just need to add
|
Thanks a lot @lawrence-cj for your contributions here! |
This PR
interpolation_scale
>=1 checking. Instead, we change it into config file(config.json). Besides, we add 256bin for 256px generation.Fixes #6783 too