Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xformers attention #1851

Merged
merged 21 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@


def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu

if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
if cmd_opts.opt_split_attention:
C43H66N12O12S2 marked this conversation as resolved.
Show resolved Hide resolved
ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you naughty boy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have 0 recollection of this

ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
elif not cmd_opts.disable_opt_xformers_attention:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init
ldm.modules.attention.CrossAttention.attention_op = None
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward


Expand Down
39 changes: 38 additions & 1 deletion modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math
import torch
from torch import einsum

import xformers.ops
import functorch
xformers._is_functorch_available=True
from ldm.util import default
from einops import rearrange

Expand Down Expand Up @@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None):

return self.to_out(r2)

def _maybe_init(self, x):
"""
Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
: B, Head, Length
"""
if self.attention_op is not None:
return
_, M, K = x.shape
try:
self.attention_op = xformers.ops.AttentionOpDispatch(
dtype=x.dtype,
device=x.device,
k=K,
attn_bias_type=type(None),
has_dropout=False,
kv_len=M,
q_len=M,
).op
except NotImplementedError as err:
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")

def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
C43H66N12O12S2 marked this conversation as resolved.
Show resolved Hide resolved
del q_in, k_in, v_in
self._maybe_init(q)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
C43H66N12O12S2 marked this conversation as resolved.
Show resolved Hide resolved

out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ resize-right
torchdiffeq
kornia
lark
functorch
#xformers?