🌐 README in Korean: KR 한국어 버전
This repository contains a simplified implementation of a latent diffusion model. The code and contents will be updated continuously.
Dataset | Generation Process of Latents | Generated Data |
---|---|---|
Swiss-roll | ||
CIFAR-10 | ||
CelebA |
The following example demonstrates how to use the code in this repository.
import torch
import os
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from helper.data_generator import DataGenerator
from helper.painter import Painter
from helper.trainer import Trainer
from helper.loader import Loader
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.uncond_u_net import UnconditionalUnetwork
from diffusion_model.sampler.ddim import DDIM
# Path to the configuration file
CONFIG_PATH = './configs/cifar10_config.yaml'
# Instantiate helper classes
painter = Painter()
loader = Loader()
data_generator = DataGenerator()
# Load CIFAR-10 dataset
data_loader = data_generator.cifar10(batch_size=128)
# Train the Variational Autoencoder (VAE)
vae = VariationalAutoEncoder(CONFIG_PATH) # Initialize the VAE model
trainer = Trainer(vae, vae.loss) # Create a trainer for the VAE
trainer.train(dl=data_loader, epochs=1000, file_name='vae', no_label=True) # Train the VAE
# Train the Latent Diffusion Model (LDM)
sampler = DDIM(CONFIG_PATH) # Initialize the DDIM sampler
network = UnconditionalUnetwork(CONFIG_PATH) # Initialize the U-Net network
ldm = LatentDiffusionModel(network, sampler, vae, image_shape=(3, 32, 32)) # Initialize the LDM
trainer = Trainer(ldm, ldm.loss) # Create a trainer for the LDM
trainer.train(dl=data_loader, epochs=1000, file_name='ldm', no_label=True)
# Train the LDM; set 'no_label=False' if the dataset includes labels
# Generate samples using the trained diffusion model
ldm = LatentDiffusionModel(network, sampler, vae, image_shape=(3, 32, 32)) # Re-initialize the LDM
loader.model_load('./diffusion_model/check_points/ldm_epoch1000', ldm, ema=True) # Load the trained model
sample = ldm(n_samples=4) # Generate 4 sample images
painter.show_images(sample) # Display the generated images