Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PixArt 256px inference #6789

Merged
merged 18 commits into from
Mar 3, 2024
Merged

Fix PixArt 256px inference #6789

merged 18 commits into from
Mar 3, 2024

Conversation

lawrence-cj
Copy link
Contributor

@lawrence-cj lawrence-cj commented Jan 31, 2024

This PR

  1. Removed the interpolation_scale>=1 checking. Instead, we change it into config file(config.json). Besides, we add 256bin for 256px generation.
  2. Change the definition of T5 max token length into pipeline config file(model_index.json). Talked about here
  3. fix bug in convert weight file

Fixes #6783 too

@lawrence-cj
Copy link
Contributor Author

from diffusers import PixArtAlphaPipeline, Transformer2DModel
import torch

transformer = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-XL-2-256x256", subfolder="transformer", torch_dtype=torch.float16)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", transformer=transformer, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

The test code is here for reference.

"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
# set multi_scale_train=True if using PixArtMS structure during training else set it to False
parser.add_argument("--multi_scale_train", default=True, type=str, required=True, help="If use Multi-Scale PixArtMS structure during training.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is str and we're defaulting to a bool. This needs to be fixed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we have any other 1024x1024 checkpoints that are affected by this? If not, do we really need this flag?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. Left a couple of comments.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 8, 2024

can we fix the tests here too?

@lawrence-cj
Copy link
Contributor Author

Is there any test failing here?

@@ -97,6 +97,7 @@ def __init__(
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
interpolation_scale: float = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not leveraging this anywhere no? Let's remove it?

@@ -228,6 +264,7 @@ def __init__(
vae: AutoencoderKL,
transformer: Transformer2DModel,
scheduler: DPMSolverMultistepScheduler,
model_token_max_length: int = 120,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu why did we decide to make this as a config variable instead of a pipeline call arg?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul @lawrence-cj
oh I'm not sure
Is this something we need to change each generation? e.g., if we need to adjust this value based on the prompt, I think it makes sense to add it to pipeline call arg; Otherwise, we can add it as a config, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maximum sequence length bit should be something a user wants to experiment with. I don't think it needs to be a configuration variable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice to me! Thank you ❤️

It'd be very nice to also include an example of using the 256x256 checkpoint in the PixArt-Alpha doc: https://huggingface.co/docs/diffusers/main/en/api/pipelines/pixart. WDYT?

@sayakpaul
Copy link
Member

Let's fix the quality tests :)

@lawrence-cj
Copy link
Contributor Author

lawrence-cj commented Feb 13, 2024

Looks very nice to me! Thank you ❤️

It'd be very nice to also include an example of using the 256x256 checkpoint in the PixArt-Alpha doc: https://huggingface.co/docs/diffusers/main/en/api/pipelines/pixart. WDYT?

import torch
from diffusers import PixArtAlphaPipeline

# You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" or "PixArt-alpha/PixArt-XL-2-256x256" too.
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
# Enable memory optimizations.
pipe.enable_model_cpu_offload()

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]

How about this one? The usage of 256px is totally the same as 512px or 1024px for simplicity and efficiency.

@sayakpaul
Copy link
Member

Yeah that should work. But we still need to fix the tests here.

@lawrence-cj
Copy link
Contributor Author

Sure. Most of the failures are cauesed by the model_token_max_length. Maybe you guys should decide how to arrange it first and I may help to commit a new version.

@sayakpaul
Copy link
Member

@lawrence-cj let's keep max_sequence_length as a pipeline call argument and default it at 120. WDYT? @yiyixuxu do we have your go here?

@@ -688,6 +725,7 @@ def __call__(
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
model_token_max_length: int = 120,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_token_max_length: int = 120,
max_sequence_length: int = 120,

Let's use this variable name throughout?

Can we also add this argument to the call docstrings?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very nice. Just one comment and then I think we can merge it!

@lawrence-cj
Copy link
Contributor Author

Done on my side. Thanks so much. @sayakpaul @yiyixuxu

@sayakpaul
Copy link
Member

Everything looks good. We just need to add max_sequence_length to the pipeline docstrings here:

use_resolution_binning (`bool` defaults to `True`):

@sayakpaul sayakpaul merged commit f55873b into huggingface:main Mar 3, 2024
15 checks passed
@sayakpaul
Copy link
Member

Thanks a lot @lawrence-cj for your contributions here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PixArt-XL-2-256x256 generations are messed up
4 participants