-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance 1/6] use_checkpoint = False #15803
Conversation
def BasicTransformerBlock_forward(self, x, context=None): | ||
return checkpoint(self._forward, x, context) | ||
return checkpoint(self._forward, x, context, flag=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkpoint
here is torch.utils.checkpoint.checkpoint
, and it does not have flag=False
. I think you confused this with ldm.modules.diffusionmodules.util.checkpoint
. The sd_hijack_checkpoint.py
already removed the checkpointing in ldm, but we might need to do it on sgm as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand, what you are looking for is actually
ldm.modules.attention.BasicTransformerBlock.forward = ldm.modules.attention.BasicTransformerBlock._forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A closer look indicates that the checkpoint here is only called when training occurs (textual_inversion & hypernetwork) so disabling checkpoint here may be undesirable.
Description
According to lllyasviel/stable-diffusion-webui-forge#716 (comment) ,
calls to
parameters
in checkpoint function is a significant overhead in A1111. However, checkpoint function is mainly used for training, disabling it does not affect inference at all.This PR disables checkpoint in A1111 in exchange for performance improvement. This reduces about 100ms/it on my local setup (4090). The duration/it before patch is ~580ms/it.
Screenshots/videos:
Checklist: