Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204

Merged
merged 39 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
500051f
fix PyTorch classes and start deprecsation cycles.
sayakpaul Mar 4, 2024
8476239
Merge branch 'main' into remove-linear_cls
sayakpaul Mar 4, 2024
9061ebb
remove args crafting for accommodating scale.
sayakpaul Mar 4, 2024
76b7cbd
remove scale check in feedforward.
sayakpaul Mar 4, 2024
93b5106
assert against nn.Linear and not CompatibleLinear.
sayakpaul Mar 4, 2024
d1f3d3c
Merge branch 'main' into remove-linear_cls
sayakpaul Mar 5, 2024
8095317
remove conv_cls and lineaR_cls.
sayakpaul Mar 5, 2024
7e96549
remove scale
sayakpaul Mar 5, 2024
59181e4
👋 scale.
sayakpaul Mar 5, 2024
1413847
fix: unet2dcondition
sayakpaul Mar 5, 2024
0556738
fix attention.py
sayakpaul Mar 5, 2024
f56f9f3
fix: attention.py again
sayakpaul Mar 5, 2024
9997189
fix: unet_2d_blocks.
sayakpaul Mar 5, 2024
4f66db0
fix-copies.
sayakpaul Mar 5, 2024
fd348d1
more fixes.
sayakpaul Mar 5, 2024
10c4232
fix: resnet.py
sayakpaul Mar 5, 2024
6fe19d9
more fixes
sayakpaul Mar 5, 2024
c256b77
fix i2vgenxl unet.
sayakpaul Mar 5, 2024
45030fb
depcrecate scale gently.
sayakpaul Mar 5, 2024
6b5212b
fix-copies
sayakpaul Mar 5, 2024
b487395
Merge branch 'main' into remove-linear_cls
sayakpaul Mar 7, 2024
ee645c7
Apply suggestions from code review
sayakpaul Mar 8, 2024
bfdfc20
quality
sayakpaul Mar 8, 2024
d4fa31d
throw warning when scale is passed to the the BasicTransformerBlock c…
sayakpaul Mar 8, 2024
15ef1e7
remove scale from signature.
sayakpaul Mar 8, 2024
d0375fa
cross_attention_kwargs, very nice catch by Yiyi
sayakpaul Mar 8, 2024
99e557c
Merge branch 'main' into remove-linear_cls
sayakpaul Mar 8, 2024
8f76caa
fix: logger.warn
sayakpaul Mar 8, 2024
5d2dd72
Merge branch 'main' into remove-linear_cls
sayakpaul Mar 8, 2024
5956b54
Merge branch 'main' into remove-linear_cls
sayakpaul Mar 11, 2024
402ec90
make deprecation message clearer.
sayakpaul Mar 11, 2024
de01273
address final comments.
sayakpaul Mar 11, 2024
f56254e
maintain same depcrecation message and also add it to activations.
sayakpaul Mar 12, 2024
0ffc8e5
address yiyi
sayakpaul Mar 12, 2024
69bbe93
fix copies
sayakpaul Mar 12, 2024
5cd5a28
Merge branch 'main' into remove-linear_cls
sayakpaul Mar 13, 2024
416b412
Apply suggestions from code review
sayakpaul Mar 13, 2024
83181b4
more depcrecation
sayakpaul Mar 13, 2024
938c9b0
fix-copies
sayakpaul Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import torch.nn.functional as F
from torch import nn

from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleLinear


ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
Expand Down Expand Up @@ -87,19 +84,16 @@ class GEGLU(nn.Module):

def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states, scale: float = 1.0):
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)


Expand Down
53 changes: 20 additions & 33 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,29 @@
import torch.nn.functional as F
from torch import nn

from ..utils import USE_PEFT_BACKEND
from ..utils import logging
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm


def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
logger = logging.get_logger(__name__)


def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)

num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)

ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output


Expand Down Expand Up @@ -299,6 +291,10 @@ def forward(
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated.")
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
Expand Down Expand Up @@ -326,10 +322,7 @@ def forward(
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)

# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

# 2. Prepare GLIGEN inputs
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)

Expand All @@ -348,7 +341,7 @@ def forward(
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)

# 2.5 GLIGEN Control
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])

Expand Down Expand Up @@ -394,11 +387,9 @@ def forward(

if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
ff_output = self.ff(norm_hidden_states)

if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
Expand Down Expand Up @@ -643,7 +634,7 @@ def __init__(
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
linear_cls = nn.Linear

if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
Expand All @@ -665,11 +656,7 @@ def __init__(
if final_dropout:
self.net.append(nn.Dropout(dropout))

def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
for module in self.net:
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
hidden_states = module(hidden_states)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
return hidden_states
Loading
Loading