Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Faster training #84

Merged
merged 4 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:23.07-py3
FROM nvcr.io/nvidia/pytorch:23.10-py3

RUN apt update -y && apt install -y \
git tmux
Expand All @@ -20,6 +20,10 @@ RUN pip install --upgrade pip && \
pip install . && \
pip install pre-commit

# Install xformers
# RUN export TORCH_CUDA_ARCH_LIST="9.0+PTX" MAX_JOBS=1 && \
# pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.22.post7#egg=xformers

# Language settings
ENV LANG C.UTF-8
ENV LANGUAGE en_US
Expand Down
16 changes: 16 additions & 0 deletions configs/stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ $ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
$ mim train diffengine configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip.py
```

## Training Speed

Environment:

- A6000 Single GPU
- nvcr.io/nvidia/pytorch:23.10-py3

Settings:

- 1epoch training.

| Model | total time |
| :-------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s |
| stable_diffusion_xl_pokemon_blip_fast | 12 m 10 s |

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_base_ = [
"../_base_/models/stable_diffusion_xl.py",
"../_base_/datasets/pokemon_blip_xl.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

train_dataloader = dict(batch_size=1)

optim_wrapper = dict(
dtype="float16",
accumulative_counts=4)

env_cfg = dict(
cudnn_benchmark=True,
)

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["yoda pokemon"] * 4,
height=1024,
width=1024),
dict(type="SDCheckpointHook"),
dict(type="FastNormHook"),
]
4 changes: 4 additions & 0 deletions diffengine/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .compile_hook import CompileHook
from .controlnet_save_hook import ControlNetSaveHook
from .fast_norm_hook import FastNormHook
from .ip_adapter_save_hook import IPAdapterSaveHook
from .lora_save_hook import LoRASaveHook
from .sd_checkpoint_hook import SDCheckpointHook
Expand All @@ -14,4 +16,6 @@
"ControlNetSaveHook",
"IPAdapterSaveHook",
"T2IAdapterSaveHook",
"CompileHook",
"FastNormHook",
]
45 changes: 45 additions & 0 deletions diffengine/engine/hooks/compile_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS


@HOOKS.register_module()
class CompileHook(Hook):
"""Compile Hook.

Args:
----
backend (str): The backend to use for compilation.
Defaults to "inductor".
"""

priority = "VERY_LOW"

def __init__(self, backend: str = "inductor") -> None:
super().__init__()
self.backend = backend

def before_train(self, runner) -> None:
"""Compile the model.

Args:
----
runner (Runner): The runner of the training process.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
model.unet = torch.compile(model.unet, backend=self.backend)
if hasattr(model, "text_encoder"):
model.text_encoder = torch.compile(
model.text_encoder, backend=self.backend)
if hasattr(model, "text_encoder_one"):
model.text_encoder_one = torch.compile(
model.text_encoder_one, backend=self.backend)
if hasattr(model, "text_encoder_two"):
model.text_encoder_two = torch.compile(
model.text_encoder_two, backend=self.backend)
if hasattr(model, "vae"):
model.vae = torch.compile(
model.vae, backend=self.backend)
107 changes: 107 additions & 0 deletions diffengine/engine/hooks/fast_norm_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS
from torch import nn
from torch.nn import functional as F # noqa

try:
import apex
except ImportError:
apex = None


def _fast_gn_forward(self, x) -> torch.Tensor:
"""Faster group normalization forward.

Copied from
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/
fast_norm.py
"""
if torch.is_autocast_enabled():
dt = torch.get_autocast_gpu_dtype()
x = x.to(dt)
weight = self.weight.to(dt)
bias = self.bias.to(dt) if self.bias is not None else None
else:
weight = self.weight
bias = self.bias

with torch.cuda.amp.autocast(enabled=False):
return F.group_norm(x, self.num_groups, weight, bias, self.eps)


@HOOKS.register_module()
class FastNormHook(Hook):
"""Fast Normalization Hook.

Replace the normalization layer with a faster one.

Args:
----
fuse_text_encoder (bool, optional): Whether to fuse the text encoder.
Defaults to False.
"""

priority = "VERY_LOW"

def __init__(self, *, fuse_text_encoder: bool = False) -> None:
super().__init__()
if apex is None:
msg = "Please install apex to use FastNormHook."
raise ImportError(
msg)
self.fuse_text_encoder = fuse_text_encoder

def _replace_ln(self, module: nn.Module, name: str, device: str) -> None:
"""Replace the layer normalization with a fused one."""
from apex.normalization import FusedLayerNorm
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if isinstance(target_attr, torch.nn.LayerNorm):
print_log(f"replaced LN: {name}")
normalized_shape = target_attr.normalized_shape
eps = target_attr.eps
elementwise_affine = target_attr.elementwise_affine
# Create a new fused layer normalization with the same arguments
fused_ln = FusedLayerNorm(normalized_shape, eps, elementwise_affine)
fused_ln.load_state_dict(target_attr.state_dict())
fused_ln.to(device)
setattr(module, attr_str, fused_ln)

for name, immediate_child_module in module.named_children():
self._replace_ln(immediate_child_module, name, device)

def _replace_gn_forward(self, module: nn.Module, name: str) -> None:
"""Replace the group normalization forward with a faster one."""
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if isinstance(target_attr, torch.nn.GroupNorm):
print_log(f"replaced GN: {name}")
target_attr.forward = _fast_gn_forward.__get__(
target_attr, torch.nn.GroupNorm)

for name, immediate_child_module in module.named_children():
self._replace_gn_forward(immediate_child_module, name)

def before_train(self, runner) -> None:
"""Replace the normalization layer with a faster one.

Args:
----
runner (Runner): The runner of the training process.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
self._replace_ln(model.unet, "model", model.device)
self._replace_gn_forward(model.unet, "unet")

if self.fuse_text_encoder:
if hasattr(model, "text_encoder"):
self._replace_ln(model.text_encoder, "model", model.device)
if hasattr(model, "text_encoder_one"):
self._replace_ln(model.text_encoder_one, "model", model.device)
if hasattr(model, "text_encoder_two"):
self._replace_ln(model.text_encoder_two, "model", model.device)
4 changes: 2 additions & 2 deletions diffengine/engine/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .builder import TRANSFORMER_OPTIMIZERS
from .builder import APEX_OPTIMIZERS

__all__ = ["TRANSFORMER_OPTIMIZERS"]
__all__ = ["APEX_OPTIMIZERS"]
22 changes: 14 additions & 8 deletions diffengine/engine/optimizers/builder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from transformers import Adafactor

from diffengine.registry import OPTIMIZERS

try:
import apex
except ImportError:
apex = None

def register_transformer_optimizers() -> list:
def register_apex_optimizers() -> list:
"""Register transformer optimizers."""
transformer_optimizers = []
OPTIMIZERS.register_module(name="Adafactor")(Adafactor)
transformer_optimizers.append("Adafactor")
return transformer_optimizers
apex_optimizers = []
if apex is not None:
from apex.optimizers import FusedAdam, FusedSGD
OPTIMIZERS.register_module(name="FusedAdam")(FusedAdam)
apex_optimizers.append("FusedAdam")
OPTIMIZERS.register_module(name="FusedSGD")(FusedSGD)
apex_optimizers.append("FusedSGD")
return apex_optimizers


TRANSFORMER_OPTIMIZERS = register_transformer_optimizers()
APEX_OPTIMIZERS = register_apex_optimizers()
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ dependencies = [
"torch>=2.0.1",
"torchvision>=0.15.2",
"openmim>=0.3.9",
"datasets==2.14.5",
"datasets==2.14.6",
"diffusers==0.21.4",
"mmengine>=0.8.5",
"mmengine>=0.9.0",
"sentencepiece>=0.1.99",
"tqdm",
"transformers==4.33.3",
"transformers==4.34.1",
"ujson"
]
license = {file = "LICENSE"}
Expand Down
74 changes: 74 additions & 0 deletions tests/test_engine/test_hooks/test_compile_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import copy

from diffusers import AutoencoderKL, UNet2DConditionModel
from mmengine.registry import MODELS
from mmengine.testing import RunnerTestCase
from transformers import CLIPTextModel, CLIPTextModelWithProjection

from diffengine.engine.hooks import CompileHook
from diffengine.models.editors import (
SDDataPreprocessor,
SDXLDataPreprocessor,
StableDiffusion,
StableDiffusionXL,
)
from diffengine.models.losses import L2Loss


class TestCompileHook(RunnerTestCase):

def setUp(self) -> None:
MODELS.register_module(name="StableDiffusion", module=StableDiffusion)
MODELS.register_module(
name="StableDiffusionXL", module=StableDiffusionXL)
MODELS.register_module(
name="SDDataPreprocessor", module=SDDataPreprocessor)
MODELS.register_module(
name="SDXLDataPreprocessor", module=SDXLDataPreprocessor)
MODELS.register_module(name="L2Loss", module=L2Loss)
return super().setUp()

def tearDown(self) -> None:
MODELS.module_dict.pop("StableDiffusion")
MODELS.module_dict.pop("StableDiffusionXL")
MODELS.module_dict.pop("SDDataPreprocessor")
MODELS.module_dict.pop("SDXLDataPreprocessor")
MODELS.module_dict.pop("L2Loss")
return super().tearDown()

def test_init(self) -> None:
CompileHook()

def test_before_train(self) -> None:
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.model.type = "StableDiffusion"
cfg.model.model = "diffusers/tiny-stable-diffusion-torch"
runner = self.build_runner(cfg)
hook = CompileHook()
assert isinstance(runner.model.unet, UNet2DConditionModel)
assert isinstance(runner.model.vae, AutoencoderKL)
assert isinstance(runner.model.text_encoder, CLIPTextModel)
# compile
hook.before_train(runner)
assert not isinstance(runner.model.unet, UNet2DConditionModel)
assert not isinstance(runner.model.vae, AutoencoderKL)
assert not isinstance(runner.model.text_encoder, CLIPTextModel)

# Test StableDiffusionXL
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.model.type = "StableDiffusionXL"
cfg.model.model = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
runner = self.build_runner(cfg)
hook = CompileHook()
assert isinstance(runner.model.unet, UNet2DConditionModel)
assert isinstance(runner.model.vae, AutoencoderKL)
assert isinstance(runner.model.text_encoder_one, CLIPTextModel)
assert isinstance(
runner.model.text_encoder_two, CLIPTextModelWithProjection)
# compile
hook.before_train(runner)
assert not isinstance(runner.model.unet, UNet2DConditionModel)
assert not isinstance(runner.model.vae, AutoencoderKL)
assert not isinstance(runner.model.text_encoder_one, CLIPTextModel)
assert not isinstance(
runner.model.text_encoder_two, CLIPTextModelWithProjection)
Loading