-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: move model helper function in pipeline to a mixin class #6571
Conversation
I am not comfortable changing a legacy structure here. So, I will let @patrickvonplaten comment. |
@@ -2104,3 +2106,123 @@ def set_attention_slice(self, slice_size: Optional[int]): | |||
|
|||
for module in modules: | |||
module.set_attention_slice(slice_size) | |||
|
|||
|
|||
class EfficiencyMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we maybe rename this:
class EfficiencyMixin: | |
class StableDiffusionMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But seems like it doesn't just affect the Stable Diffusion family, though.
If a similar mixin class were to be designed what would be the process? It might be better to have a base EfficiencyMixin
class and then use it as a subclass to write more pipeline-specific classes such as StableDiffusionEfficiencyMixin
.
I don't like the StableDiffusionMixin
name -- it's uninformative and confusing in light of DiffusionPipeline
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is actually a very nice change! Could we maybe rename the mixin to StableDiffusionMixin
as all functions that are moved into the mixin are only applicable to stable-diffusion-like models.
I don't know if |
cpu_offload(self.safety_checker.vision_model, device) | ||
|
||
@property | ||
def _execution_device(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would leave this as is in the Pipeline for now and not add to the Mixin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I apologize some of those changes should be spilt into a single commit(like: remove unnecessary overload ...), but as they have been inherited from DiffusionPipeline those function enable_sequential_cpu_offload
, _execution_device
, enable_attention_slicing
could be removed without changing any behaviour?
Acceptable name, and also let's decide if add this mixin for BTW, as those function are only related with VAE and UNet models, this mixin class is not limited to Stable Diffusion pipeline as long as it contains VAE and UNet. Further split it into VAEMixin and UNetMixin seems to be too tedious, though it will make them available to more pipeline like pixart-alpha.
That's why I'm also thinking about some alternative design, but for now none of them seems to be perfect. For example a more direct(but not that friendly) way is to deprecate those functions and encourage users call them directly through models lol. BTW we could just add them to DiffusionPipeline if all pipelines have UNets or VAEs, but unfortunately not. That's why I propose this Mixin class. Again, looking forward to a better design! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice!
I have a major design question but apart from that, things look good!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should only add this Mixin to pipelines that support ALL four methods (freeu, vae_slicing, vae_tiling, fuse_qkv_projection), no? Not sure if it is the case here. I think maybe some of the pipelines only support some of the methods e.g. AudioLDM
We should add a TexterMixin that runs fast tests for all pipelines inheriting from this mixin for all four methods; It is also the easiest way to find out if they all support these methods . similar to what's done here #6862
gentle pin @ultranity |
Thanks!! First I'm still open for a better Mixin name :)
Technically, any pipeline with UNet2DConditionModel Unet component will support freeu and fuse_qkv_projection, and any pipeline with AutoencoderKL VAE component will support vae_slicing, vae_tiling, fuse_qkv_projection. Which is basically the reason why this PR exists.
I will check if this TesterMixin would help |
fb94f42
to
fc71e97
Compare
Rebased this PR on current main branch (sorry for #6871) and added a initial |
@ultranity |
b3c3de0 fixed a bug in fuse_projections to:
Though we might should make this fix in a seperate PR but it is causing test for text_to_video and i2vgen_xl fail so I fixed it here @sayakpaul |
@yiyixuxu @sayakpaul any updates? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left a few nits,
thanks a ton for working on this!
can merge once these final comments are addressed and @sayakpaul confirm the change we made to Attention
class is ok
@@ -116,6 +116,8 @@ def __init__( | |||
super().__init__() | |||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |||
self.query_dim = query_dim | |||
self.use_bias = bias | |||
self.is_cross_attention = cross_attention_dim is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @sayakpaul
the change here looks good to me - can you take a look and confirm if it's ok?
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype) | ||
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safe to always make it to False because attention layers don't use bias.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then why Attention block need a bias param when init?
And it actually make test fails in some case where attention_bias==True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean fused projections make some of the tests fail when attention bias is true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the results will no longer be same when bias is enabled in original Attention but fused projections do not have bias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned by @yiyixuxu as well if we decide to not add projection fusion to the other models, then we don't have this case anymore. Am I right?
@@ -503,6 +504,44 @@ def disable_freeu(self): | |||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |||
setattr(upsample_block, k, None) | |||
|
|||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections | |||
def fuse_qkv_projections(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't have a use case for projection fusion yet for the 3D UNet, yet. So, it's better to not bloat the codebase here.
@@ -474,6 +474,44 @@ def disable_freeu(self): | |||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |||
setattr(upsample_block, k, None) | |||
|
|||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same let's not add here since we don't have the use case yet.
inputs["return_dict"] = False | ||
output_2 = pipe(**inputs) | ||
|
||
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 3e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here when I was testing I2V Gen Pipeline, I found that this pipeline result is not that reproducible, i.e. even with same inputs both from inputs = self.get_dummy_inputs(device)
and disable any changes like pipe.enable_vae_slicing()
by comment it out, the result still failed to pass the check. But I have not figure out why. @yiyixuxu Any idea?
@@ -701,6 +702,44 @@ def disable_freeu(self) -> None: | |||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |||
setattr(upsample_block, k, None) | |||
|
|||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same.
@@ -697,27 +699,32 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor | |||
|
|||
@torch.no_grad() | |||
def fuse_projections(self, fuse=True): | |||
is_cross_attention = self.cross_attention_dim != self.query_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we keep it as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need to left it with ambiguity when we could actually check if an Attention Block is_cross_attention
during init?
It's not True
that all Cross Attention have to meet the self.cross_attention_dim != self.query_dim
constraint.
For example in I2V Gen case, not some self attention bolck have self.cross_attention_dim == self.query_dim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is actually a good change regardless if we want to support I2V Gen or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments.
My only concern is that we don't have to support projection fusion for the all the UNets right now as we don't if that's gonna improve the end performance with |
can we confirm the classes we should add this mixin to vae: AutoencoderKL, AutoencoderTiny |
I also think we don't need to support these we are not current supporting |
I think we can, but we need to make sure that we are not allowing any unsupported method. |
self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny))) | ||
self.assertTrue( | ||
hasattr(pipe, "unet") | ||
and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul we decide which Unet we support here
and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel)) | |
and isinstance(pipe.unet, UNet2DConditionModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds neat!
@ultranity vae: AutoencoderKL, AutoencoderTiny sorry about the back-and-forth 🥺 |
As we have passed all test for UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, I hope we can keep them for now, for simplicity, or we need to disable fuse_qkv_projections on those unet pipeline by checking We may reduce bloat code by introducing some interface or plugin mechanics in a future PR. |
@ultranity |
What does this PR do?
seperate following func to a mixin class
those function are only related with VAE and UNet models, not pipeline impl, so let us move them to a mixin class instead copy it everywhere(just like what have been done to
xformers_memory_efficient_attention
orattention_slice
)alternative design:
_pipeline_helper_functions
)Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten and @sayakpaul