Skip to content

Commit

Permalink
Merge branch 'stable-cascade' of https://github.com/kohya-ss/sd-scripts
Browse files Browse the repository at this point in the history
… into stable-cascade
  • Loading branch information
bmaltais committed Feb 24, 2024
2 parents 9564145 + 13f49d1 commit 4214874
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 9 deletions.
4 changes: 2 additions & 2 deletions gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import torch

from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory, get_preferred_device

init_ipex()

Expand Down Expand Up @@ -338,7 +338,7 @@ def __init__(
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0

# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
Expand Down
132 changes: 131 additions & 1 deletion library/stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
from types import SimpleNamespace
from typing import List, Optional
from einops import rearrange
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -148,7 +149,7 @@ def encode(self, x):
The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
"""
# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")

# x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing.
x = (x + 1) / 2
x = EFFNET_PREPROCESS(x)
Expand All @@ -172,6 +173,7 @@ def encode(self, x):
from torch.nn import Linear


r"""
class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0):
super().__init__()
Expand All @@ -185,6 +187,119 @@ def forward(self, x, kv, self_attn=False):
x = self.attn(x, kv, kv, need_weights=False)[0]
x = x.permute(0, 2, 1).view(*orig_shape)
return x
"""


class Attention(nn.Module):
def __init__(self, c, nhead, dropout=0.0):
# dropout is for attn_output_weights, so we may not need it. however, if we use sdpa, we enable it.
# xformers and normal attn are not affected by dropout
super().__init__()

self.to_q = Linear(c, c, bias=True)
self.to_k = Linear(c, c, bias=True)
self.to_v = Linear(c, c, bias=True)
self.to_out = Linear(c, c, bias=True)
self.nhead = nhead
self.dropout = dropout
self.scale = (c // nhead) ** -0.5

# default is to use sdpa
self.use_memory_efficient_attention_xformers = False
self.use_sdpa = True

def set_use_xformers_or_sdpa(self, xformers, sdpa):
# print(f"Attention: set_use_xformers_or_sdpa: xformers={xformers}, sdpa={sdpa}")
self.use_memory_efficient_attention_xformers = xformers
self.use_sdpa = sdpa

def forward(self, q_in, k_in, v_in):
q_in = self.to_q(q_in)
k_in = self.to_k(k_in)
v_in = self.to_v(v_in)

if self.use_memory_efficient_attention_xformers:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self.forward_memory_efficient_xformers(q, k, v)
del q, k, v
out = rearrange(out, "b n h d -> b n (h d)", h=self.nhead)
elif self.use_sdpa:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self.forward_sdpa(q, k, v)
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=self.nhead)
else:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self._attention(q, k, v)
del q, k, v
out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead)

return self.to_out(out)

def _attention(self, query, key, value):
# if self.upcast_attention:
# query = query.float()
# key = key.float()

attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)

# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)

# compute attention output
hidden_states = torch.bmm(attention_probs, value)

return hidden_states

def forward_memory_efficient_xformers(self, q, k, v):
import xformers.ops

q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
del q, k, v

return out

def forward_sdpa(self, q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=False)
return out


class Attention2D(nn.Module):
r"""
to_q/k/v を個別に重みをもつように変更
modified to have separate weights for to_q/k/v
"""

def __init__(self, c, nhead, dropout=0.0):
super().__init__()
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
self.attn = Attention(c, nhead, dropout=dropout) # , bias=True, batch_first=True)

def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x

def set_use_xformers_or_sdpa(self, xformers, sdpa):
self.attn.set_use_xformers_or_sdpa(xformers, sdpa)


class LayerNorm2d(nn.LayerNorm):
Expand Down Expand Up @@ -262,6 +377,9 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
def set_gradient_checkpointing(self, value):
self.gradient_checkpointing = value

def set_use_xformers_or_sdpa(self, xformers, sdpa):
self.attention.set_use_xformers_or_sdpa(xformers, sdpa)

def forward_body(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
Expand Down Expand Up @@ -657,6 +775,12 @@ def _init_weights(self, m):
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def set_use_xformers_or_sdpa(self, xformers, sdpa):
for block in self.down_blocks + self.up_blocks:
for layer in block:
if hasattr(layer, "set_use_xformers_or_sdpa"):
layer.set_use_xformers_or_sdpa(xformers, sdpa)

def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
Expand Down Expand Up @@ -920,6 +1044,12 @@ def set_gradient_checkpointing(self, value):
if hasattr(layer, "set_gradient_checkpointing"):
layer.set_gradient_checkpointing(value)

def set_use_xformers_or_sdpa(self, xformers, sdpa):
for block in self.down_blocks + self.up_blocks:
for layer in block:
if hasattr(layer, "set_use_xformers_or_sdpa"):
layer.set_use_xformers_or_sdpa(xformers, sdpa)

def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
Expand Down
57 changes: 57 additions & 0 deletions library/stable_cascade_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.
generator_c = sc.StageC()
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(stage_c_checkpoint_path)

stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)

logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
logger.info(info)
Expand All @@ -115,6 +118,9 @@ def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.
generator_b = sc.StageB()
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(stage_b_checkpoint_path)

stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)

logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
logger.info(info)
Expand Down Expand Up @@ -189,6 +195,55 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
return previewer


def convert_state_dict_mha_to_normal_attn(state_dict):
# convert nn.MultiheadAttention to to_q/k/v and to_out
print("convert_state_dict_mha_to_normal_attn")
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "in_proj_bias" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
elif "in_proj_weight" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
elif "out_proj.bias" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "to_out.bias")] = value
elif "out_proj.weight" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "to_out.weight")] = value
return state_dict


def convert_state_dict_normal_attn_to_mha(state_dict):
# convert to_q/k/v and to_out to nn.MultiheadAttention
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "to_q.bias" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
elif "to_q.weight" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
elif "to_out.bias" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.bias", "out_proj.bias")] = v
elif "to_out.weight" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.weight", "out_proj.weight")] = v
return state_dict


def get_sai_model_spec(args, lora=False):
timestamp = time.time()

Expand Down Expand Up @@ -230,6 +285,8 @@ def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadat
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}

state_dict = convert_state_dict_normal_attn_to_mha(state_dict)

save_file(state_dict, ckpt_file, metadata=sai_metadata)

# save text model
Expand Down
2 changes: 1 addition & 1 deletion networks/check_lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(file):

for key, value in values:
value = value.to(torch.float32)
logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
8 changes: 8 additions & 0 deletions stable_cascade_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,15 @@ def main(args):

generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
generator_c.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)

generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
generator_b.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)

# CLIP encoders
tokenizer = sc_utils.load_tokenizer(args)
Expand Down Expand Up @@ -332,6 +338,8 @@ def main(args):
sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--xformers", action="store_true")
parser.add_argument("--sdpa", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
parser.add_argument(
Expand Down
10 changes: 5 additions & 5 deletions stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,9 @@ def train(self, args):
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]

# # モデルに xformers とか memory efficient attention を組み込む
# モデルに xformers とか memory efficient attention を組み込む
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
# vae.set_use_memory_efficient_attention_xformers(args.xformers)
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)

# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
Expand Down Expand Up @@ -730,8 +729,8 @@ def train(self, args):
metadata["ss_network_args"] = json.dumps(net_kwargs)

# model name and hash
if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path
if args.stage_c_checkpoint_path is not None:
sd_model_name = args.stage_c_checkpoint_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
Expand Down Expand Up @@ -992,6 +991,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_tokenizer_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
Expand Down
4 changes: 4 additions & 0 deletions stable_cascade_train_stage_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def train(args):
else:
previewer = None

# モデルに xformers とか memory efficient attention を組み込む
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)

# 学習を準備する
if cache_latents:
effnet.to(accelerator.device, dtype=effnet_dtype)
Expand Down Expand Up @@ -531,6 +534,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_tokenizer_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_sdxl_training_arguments(parser) # cache text encoder outputs
Expand Down

0 comments on commit 4214874

Please sign in to comment.