Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add controlnet training #551

Merged
merged 10 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
import gc
import math
import os
import toml
from multiprocessing import Value

from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler

import library.train_util as train_util
Expand Down Expand Up @@ -139,11 +137,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
print("Use xformers by Diffusers")
accelerator.print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True)
else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
accelerator.print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)

Expand All @@ -168,7 +166,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
training_models.append(unet)

if args.train_text_encoder:
print("enable text encoder training")
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
Expand All @@ -194,7 +192,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
params_to_optimize = params

# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)

# dataloaderを準備する
Expand All @@ -214,7 +212,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
Expand All @@ -227,7 +225,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

Expand Down Expand Up @@ -257,14 +255,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")

progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
Expand All @@ -278,7 +276,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)

for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}")
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
Expand Down
227 changes: 227 additions & 0 deletions library/attention_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import math
from typing import Any
from einops import rearrange
import torch
from diffusers.models.attention_processor import Attention


# flash attention forwards and backwards

# https://arxiv.org/abs/2205.14135

EPSILON = 1e-6


class FlashAttentionFunction(torch.autograd.function.Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
"""Algorithm 2 in the paper"""

device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full(
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
)

scale = q.shape[-1] ** -0.5

if mask is None:
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, "b n -> b 1 1 n")
mask = mask.split(q_bucket_size, dim=-1)

row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)

for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff

col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)

for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

attn_weights = (
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
)

if row_mask is not None:
attn_weights.masked_fill_(~row_mask, max_neg_value)

if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones(
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)

block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)

if row_mask is not None:
exp_weights.masked_fill_(~row_mask, 0.0)

block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
min=EPSILON
)

new_row_maxes = torch.maximum(block_row_maxes, row_maxes)

exp_values = torch.einsum(
"... i j, ... j d -> ... i d", exp_weights, vc
)

exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)

new_row_sums = (
exp_row_max_diff * row_sums
+ exp_block_row_max_diff * block_row_sums
)

oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
(exp_block_row_max_diff / new_row_sums) * exp_values
)

row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)

ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)

return o

@staticmethod
@torch.no_grad()
def backward(ctx, do):
"""Algorithm 4 in the paper"""

causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors

device = q.device

max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)

row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2),
)

for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff

col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)

for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

attn_weights = (
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
)

if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones(
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)

exp_attn_weights = torch.exp(attn_weights - mc)

if row_mask is not None:
exp_attn_weights.masked_fill_(~row_mask, 0.0)

p = exp_attn_weights / lc

dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)

D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)

dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)

dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)

return dq, dk, dv, None, None, None, None


class FlashAttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
) -> Any:
q_bucket_size = 512
k_bucket_size = 1024

h = attn.heads
q = attn.to_q(hidden_states)

encoder_hidden_states = (
encoder_hidden_states
if encoder_hidden_states is not None
else hidden_states
)
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)

if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
context_k, context_v = attn.hypernetwork.forward(
hidden_states, encoder_hidden_states
)
context_k = context_k.to(hidden_states.dtype)
context_v = context_v.to(hidden_states.dtype)
else:
context_k = encoder_hidden_states
context_v = encoder_hidden_states

k = attn.to_k(context_k)
v = attn.to_v(context_v)
del encoder_hidden_states, hidden_states

q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

out = FlashAttentionFunction.apply(
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
)

out = rearrange(out, "b h n d -> b n (h d)")

out = attn.to_out[0](out)
out = attn.to_out[1](out)
return out
Loading