#!/usr/bin/env python3 """Applies a text prompt to an existing image by finding a latent that would produce it with the unconditioned DDIM ODE, then integrating the text-conditional DDIM ODE starting from that latent.""" import argparse from functools import partial from pathlib import Path from PIL import Image import torch from torch import nn from torch.nn import functional as F from torchvision import transforms from torchvision.transforms import functional as TF from tqdm import trange from CLIP import clip from diffusion import get_model, get_models, sampling, utils MODULE_DIR = Path(__file__).resolve().parent def parse_prompt(prompt, default_weight=3.): if prompt.startswith('http://') or prompt.startswith('https://'): vals = prompt.rsplit(':', 2) vals = [vals[0] + ':' + vals[1], *vals[2:]] else: vals = prompt.rsplit(':', 1) vals = vals + ['', default_weight][len(vals):] return vals[0], float(vals[1]) def resize_and_center_crop(image, size): fac = max(size[0] / image.size[0], size[1] / image.size[1]) image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) return TF.center_crop(image, size[::-1]) def main(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('init', type=str, help='the init image') p.add_argument('prompts', type=str, default=[], nargs='*', help='the text prompts to use') p.add_argument('--images', type=str, default=[], nargs='*', metavar='IMAGE', help='the image prompts') p.add_argument('--checkpoint', type=str, help='the checkpoint to use') p.add_argument('--device', type=str, help='the device to use') p.add_argument('--max-timestep', '-mt', type=float, default=1., help='the maximum timestep') p.add_argument('--method', type=str, default='plms', choices=['ddim', 'prk', 'plms', 'pie', 'plms2', 'iplms'], help='the sampling method to use') p.add_argument('--model', type=str, default='cc12m_1_cfg', choices=['cc12m_1_cfg'], help='the model to use') p.add_argument('--output', '-o', type=str, default='out.png', help='the output filename') p.add_argument('--size', type=int, nargs=2, help='the output image size') p.add_argument('--steps', type=int, default=50, help='the number of timesteps') args = p.parse_args() if args.device: device = torch.device(args.device) else: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('Using device:', device) model = get_model(args.model)() _, side_y, side_x = model.shape if args.size: side_x, side_y = args.size checkpoint = args.checkpoint if not checkpoint: checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth' model.load_state_dict(torch.load(checkpoint, map_location='cpu')) if device.type == 'cuda': model = model.half() model = model.to(device).eval().requires_grad_(False) clip_model_name = model.clip_model if hasattr(model, 'clip_model') else 'ViT-B/16' clip_model = clip.load(clip_model_name, jit=False, device=device)[0] clip_model.eval().requires_grad_(False) normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) init = Image.open(utils.fetch(args.init)).convert('RGB') init = resize_and_center_crop(init, (side_x, side_y)) init = utils.from_pil_image(init).to(device)[None] zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device) target_embeds, weights = [zero_embed], [] for prompt in args.prompts: txt, weight = parse_prompt(prompt) target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float()) weights.append(weight) for prompt in args.images: path, weight = parse_prompt(prompt) img = Image.open(utils.fetch(path)).convert('RGB') clip_size = clip_model.visual.input_resolution img = resize_and_center_crop(img, (clip_size, clip_size)) batch = TF.to_tensor(img)[None].to(device) embed = F.normalize(clip_model.encode_image(normalize(batch)).float(), dim=-1) target_embeds.append(embed) weights.append(weight) weights = torch.tensor([1 - sum(weights), *weights], device=device) def cfg_model_fn(x, t): n = x.shape[0] n_conds = len(target_embeds) x_in = x.repeat([n_conds, 1, 1, 1]) t_in = t.repeat([n_conds]) clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) v = vs.mul(weights[:, None, None, None, None]).sum(0) return v def run(): t = torch.linspace(0, 1, args.steps + 1, device=device) steps = utils.get_spliced_ddpm_cosine_schedule(t) steps = steps[steps <= args.max_timestep] if args.method == 'ddim': x = sampling.reverse_sample(model, init, steps, {'clip_embed': zero_embed}) out = sampling.sample(cfg_model_fn, x, steps.flip(0)[:-1], 0, {}) if args.method == 'prk': x = sampling.prk_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True) out = sampling.prk_sample(cfg_model_fn, x, steps.flip(0)[:-1], {}) if args.method == 'plms': x = sampling.plms_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True) out = sampling.plms_sample(cfg_model_fn, x, steps.flip(0)[:-1], {}) if args.method == 'pie': x = sampling.pie_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True) out = sampling.pie_sample(cfg_model_fn, x, steps.flip(0)[:-1], {}) if args.method == 'plms2': x = sampling.plms2_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True) out = sampling.plms2_sample(cfg_model_fn, x, steps.flip(0)[:-1], {}) if args.method == 'iplms': x = sampling.iplms_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True) out = sampling.iplms_sample(cfg_model_fn, x, steps.flip(0)[:-1], {}) utils.to_pil_image(out[0]).save(args.output) try: run() except KeyboardInterrupt: pass if __name__ == '__main__': main()