-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* start refactoring -not yet functional * first phase of refactor done - not sure weighted prompts working * Second phase of refactoring. Everything mostly working. * The refactoring has moved all the hard-core inference work into ldm.dream.generator.*, where there are submodules for txt2img and img2img. inpaint will go in there as well. * Some additional refactoring will be done soon, but relatively minor work. * fix -save_orig flag to actually work * add @neonsecret attention.py memory optimization * remove unneeded imports * move token logging into conditioning.py * add placeholder version of inpaint; porting in progress * fix crash in img2img * inpainting working; not tested on variations * fix crashes in img2img * ported attention.py memory optimization #117 from basujindal branch * added @torch_no_grad() decorators to img2img, txt2img, inpaint closures * Final commit prior to PR against development * fixup crash when generating intermediate images in web UI * rename ldm.simplet2i to ldm.generate * add backward-compatibility simplet2i shell with deprecation warning * add back in mps exception, addresses @Vargol comment in #354 * replaced Conditioning class with exported functions * fix wrong type of with_variations attribute during intialization * changed "image_iterator()" to "get_make_image()" * raise NotImplementedError for calling get_make_image() in parent class * Update ldm/generate.py better error message Co-authored-by: Kevin Gibbons <bakkot@gmail.com> * minor stylistic fixes and assertion checks from code review * moved get_noise() method into img2img class * break get_noise() into two methods, one for txt2img and the other for img2img * inpainting works on non-square images now * make get_noise() an abstract method in base class * much improved inpainting Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
- Loading branch information
Showing
16 changed files
with
1,261 additions
and
990 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
''' | ||
This module handles the generation of the conditioning tensors, including management of | ||
weighted subprompts. | ||
Useful function exports: | ||
get_uc_and_c() get the conditioned and unconditioned latent | ||
split_weighted_subpromopts() split subprompts, normalize and weight them | ||
log_tokenization() print out colour-coded tokens and warn if truncated | ||
''' | ||
import re | ||
import torch | ||
|
||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): | ||
uc = model.get_learned_conditioning(['']) | ||
|
||
# get weighted sub-prompts | ||
weighted_subprompts = split_weighted_subprompts( | ||
prompt, skip_normalize | ||
) | ||
|
||
if len(weighted_subprompts) > 1: | ||
# i dont know if this is correct.. but it works | ||
c = torch.zeros_like(uc) | ||
# normalize each "sub prompt" and add it | ||
for subprompt, weight in weighted_subprompts: | ||
log_tokenization(subprompt, model, log_tokens) | ||
c = torch.add( | ||
c, | ||
model.get_learned_conditioning([subprompt]), | ||
alpha=weight, | ||
) | ||
else: # just standard 1 prompt | ||
log_tokenization(prompt, model, log_tokens) | ||
c = model.get_learned_conditioning([prompt]) | ||
return (uc, c) | ||
|
||
def split_weighted_subprompts(text, skip_normalize=False)->list: | ||
""" | ||
grabs all text up to the first occurrence of ':' | ||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight | ||
if ':' has no value defined, defaults to 1.0 | ||
repeats until no text remaining | ||
""" | ||
prompt_parser = re.compile(""" | ||
(?P<prompt> # capture group for 'prompt' | ||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' | ||
) # end 'prompt' | ||
(?: # non-capture group | ||
:+ # match one or more ':' characters | ||
(?P<weight> # capture group for 'weight' | ||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number | ||
)? # end weight capture group, make optional | ||
\s* # strip spaces after weight | ||
| # OR | ||
$ # else, if no ':' then match end of line | ||
) # end non-capture group | ||
""", re.VERBOSE) | ||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( | ||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] | ||
if skip_normalize: | ||
return parsed_prompts | ||
weight_sum = sum(map(lambda x: x[1], parsed_prompts)) | ||
if weight_sum == 0: | ||
print( | ||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") | ||
equal_weight = 1 / len(parsed_prompts) | ||
return [(x[0], equal_weight) for x in parsed_prompts] | ||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts] | ||
|
||
# shows how the prompt is tokenized | ||
# usually tokens have '</w>' to indicate end-of-word, | ||
# but for readability it has been replaced with ' ' | ||
def log_tokenization(text, model, log=False): | ||
if not log: | ||
return | ||
tokens = model.cond_stage_model.tokenizer._tokenize(text) | ||
tokenized = "" | ||
discarded = "" | ||
usedTokens = 0 | ||
totalTokens = len(tokens) | ||
for i in range(0, totalTokens): | ||
token = tokens[i].replace('</w>', ' ') | ||
# alternate color | ||
s = (usedTokens % 6) + 1 | ||
if i < model.cond_stage_model.max_length: | ||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}" | ||
usedTokens += 1 | ||
else: # over max token length | ||
discarded = discarded + f"\x1b[0;3{s};40m{token}" | ||
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m") | ||
if discarded != "": | ||
print( | ||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
''' | ||
Initialization file for the ldm.dream.generator package | ||
''' | ||
from .base import Generator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
''' | ||
Base class for ldm.dream.generator.* | ||
including img2img, txt2img, and inpaint | ||
''' | ||
import torch | ||
import numpy as np | ||
import random | ||
from tqdm import tqdm, trange | ||
from PIL import Image | ||
from einops import rearrange, repeat | ||
from pytorch_lightning import seed_everything | ||
from ldm.dream.devices import choose_autocast_device | ||
|
||
downsampling = 8 | ||
|
||
class Generator(): | ||
def __init__(self,model): | ||
self.model = model | ||
self.seed = None | ||
self.latent_channels = model.channels | ||
self.downsampling_factor = downsampling # BUG: should come from model or config | ||
self.variation_amount = 0 | ||
self.with_variations = [] | ||
|
||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py | ||
def get_make_image(self,prompt,**kwargs): | ||
""" | ||
Returns a function returning an image derived from the prompt and the initial image | ||
Return value depends on the seed at the time you call it | ||
""" | ||
raise NotImplementedError("image_iterator() must be implemented in a descendent class") | ||
|
||
def set_variation(self, seed, variation_amount, with_variations): | ||
self.seed = seed | ||
self.variation_amount = variation_amount | ||
self.with_variations = with_variations | ||
|
||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None, | ||
image_callback=None, step_callback=None, | ||
**kwargs): | ||
device_type,scope = choose_autocast_device(self.model.device) | ||
make_image = self.get_make_image( | ||
prompt, | ||
init_image = init_image, | ||
width = width, | ||
height = height, | ||
step_callback = step_callback, | ||
**kwargs | ||
) | ||
|
||
results = [] | ||
seed = seed if seed else self.new_seed() | ||
seed, initial_noise = self.generate_initial_noise(seed, width, height) | ||
with scope(device_type), self.model.ema_scope(): | ||
for n in trange(iterations, desc='Generating'): | ||
x_T = None | ||
if self.variation_amount > 0: | ||
seed_everything(seed) | ||
target_noise = self.get_noise(width,height) | ||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise) | ||
elif initial_noise is not None: | ||
# i.e. we specified particular variations | ||
x_T = initial_noise | ||
else: | ||
seed_everything(seed) | ||
if self.model.device.type == 'mps': | ||
x_T = self.get_noise(width,height) | ||
|
||
# make_image will do the equivalent of get_noise itself | ||
image = make_image(x_T) | ||
results.append([image, seed]) | ||
if image_callback is not None: | ||
image_callback(image, seed) | ||
seed = self.new_seed() | ||
return results | ||
|
||
def sample_to_image(self,samples): | ||
""" | ||
Returns a function returning an image derived from the prompt and the initial image | ||
Return value depends on the seed at the time you call it | ||
""" | ||
x_samples = self.model.decode_first_stage(samples) | ||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) | ||
if len(x_samples) != 1: | ||
raise Exception( | ||
f'>> expected to get a single image, but got {len(x_samples)}') | ||
x_sample = 255.0 * rearrange( | ||
x_samples[0].cpu().numpy(), 'c h w -> h w c' | ||
) | ||
return Image.fromarray(x_sample.astype(np.uint8)) | ||
|
||
def generate_initial_noise(self, seed, width, height): | ||
initial_noise = None | ||
if self.variation_amount > 0 or len(self.with_variations) > 0: | ||
# use fixed initial noise plus random noise per iteration | ||
seed_everything(seed) | ||
initial_noise = self.get_noise(width,height) | ||
for v_seed, v_weight in self.with_variations: | ||
seed = v_seed | ||
seed_everything(seed) | ||
next_noise = self.get_noise(width,height) | ||
initial_noise = self.slerp(v_weight, initial_noise, next_noise) | ||
if self.variation_amount > 0: | ||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations | ||
seed = random.randrange(0,np.iinfo(np.uint32).max) | ||
return (seed, initial_noise) | ||
else: | ||
return (seed, None) | ||
|
||
# returns a tensor filled with random numbers from a normal distribution | ||
def get_noise(self,width,height): | ||
""" | ||
Returns a tensor filled with random numbers, either form a normal distribution | ||
(txt2img) or from the latent image (img2img, inpaint) | ||
""" | ||
raise NotImplementedError("get_noise() must be implemented in a descendent class") | ||
|
||
def new_seed(self): | ||
self.seed = random.randrange(0, np.iinfo(np.uint32).max) | ||
return self.seed | ||
|
||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): | ||
''' | ||
Spherical linear interpolation | ||
Args: | ||
t (float/np.ndarray): Float value between 0.0 and 1.0 | ||
v0 (np.ndarray): Starting vector | ||
v1 (np.ndarray): Final vector | ||
DOT_THRESHOLD (float): Threshold for considering the two vectors as | ||
colineal. Not recommended to alter this. | ||
Returns: | ||
v2 (np.ndarray): Interpolation vector between v0 and v1 | ||
''' | ||
inputs_are_torch = False | ||
if not isinstance(v0, np.ndarray): | ||
inputs_are_torch = True | ||
v0 = v0.detach().cpu().numpy() | ||
if not isinstance(v1, np.ndarray): | ||
inputs_are_torch = True | ||
v1 = v1.detach().cpu().numpy() | ||
|
||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) | ||
if np.abs(dot) > DOT_THRESHOLD: | ||
v2 = (1 - t) * v0 + t * v1 | ||
else: | ||
theta_0 = np.arccos(dot) | ||
sin_theta_0 = np.sin(theta_0) | ||
theta_t = theta_0 * t | ||
sin_theta_t = np.sin(theta_t) | ||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | ||
s1 = sin_theta_t / sin_theta_0 | ||
v2 = s0 * v0 + s1 * v1 | ||
|
||
if inputs_are_torch: | ||
v2 = torch.from_numpy(v2).to(self.model.device) | ||
|
||
return v2 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
''' | ||
ldm.dream.generator.txt2img descends from ldm.dream.generator | ||
''' | ||
|
||
import torch | ||
import numpy as np | ||
from ldm.dream.devices import choose_autocast_device | ||
from ldm.dream.generator.base import Generator | ||
from ldm.models.diffusion.ddim import DDIMSampler | ||
|
||
class Img2Img(Generator): | ||
def __init__(self,model): | ||
super().__init__(model) | ||
self.init_latent = None # by get_noise() | ||
|
||
@torch.no_grad() | ||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, | ||
conditioning,init_image,strength,step_callback=None,**kwargs): | ||
""" | ||
Returns a function returning an image derived from the prompt and the initial image | ||
Return value depends on the seed at the time you call it. | ||
""" | ||
|
||
# PLMS sampler not supported yet, so ignore previous sampler | ||
if not isinstance(sampler,DDIMSampler): | ||
print( | ||
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler" | ||
) | ||
sampler = DDIMSampler(self.model, device=self.model.device) | ||
|
||
sampler.make_schedule( | ||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False | ||
) | ||
|
||
device_type,scope = choose_autocast_device(self.model.device) | ||
with scope(device_type): | ||
self.init_latent = self.model.get_first_stage_encoding( | ||
self.model.encode_first_stage(init_image) | ||
) # move to latent space | ||
|
||
t_enc = int(strength * steps) | ||
uc, c = conditioning | ||
|
||
@torch.no_grad() | ||
def make_image(x_T): | ||
# encode (scaled latent) | ||
z_enc = sampler.stochastic_encode( | ||
self.init_latent, | ||
torch.tensor([t_enc]).to(self.model.device), | ||
noise=x_T | ||
) | ||
# decode it | ||
samples = sampler.decode( | ||
z_enc, | ||
c, | ||
t_enc, | ||
img_callback = step_callback, | ||
unconditional_guidance_scale=cfg_scale, | ||
unconditional_conditioning=uc, | ||
) | ||
return self.sample_to_image(samples) | ||
|
||
return make_image | ||
|
||
def get_noise(self,width,height): | ||
device = self.model.device | ||
init_latent = self.init_latent | ||
assert init_latent is not None,'call to get_noise() when init_latent not set' | ||
if device.type == 'mps': | ||
return torch.randn_like(init_latent, device='cpu').to(device) | ||
else: | ||
return torch.randn_like(init_latent, device=device) |
Oops, something went wrong.