Skip to content

Commit

Permalink
Merge branch 'seed-fuzz' of github.com:bakkot/stable-diffusion into b…
Browse files Browse the repository at this point in the history
…akkot-seed-fuzz
  • Loading branch information
lstein committed Sep 2, 2022
2 parents 92d1ed7 + 4fe2657 commit 2d65b03
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 47 deletions.
5 changes: 5 additions & 0 deletions ldm/dream/pngwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def normalize_prompt(self):
switches.append(f'-G{opt.gfpgan_strength}')
if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if opt.variation_amount > 0:
switches.append(f'-v {opt.variation_amount}')
if opt.with_variations:
formatted_variations = ';'.join(f'{seed},{weight}' for seed, weight in opt.with_variations)
switches.append(f'-V {formatted_variations}')
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)
4 changes: 2 additions & 2 deletions ldm/models/diffusion/ksampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def route_callback(k_callback_values):
img_callback(k_callback_values['x'], k_callback_values['i'])

sigmas = self.model.get_sigmas(S)
if x_T:
x = x_T
if x_T is not None:
x = x_T * sigmas[0]
else:
x = (
torch.randn([batch_size, *shape], device=self.device)
Expand Down
155 changes: 120 additions & 35 deletions ldm/simplet2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def prompt2image(
upscale = None,
sampler_name = None,
log_tokenization= False,
with_variations = None,
variation_amount = 0.0,
**args,
): # eat up additional cruft
"""
Expand All @@ -244,6 +246,8 @@ def prompt2image(
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
step_callback // a function or method that will be called each step
image_callback // a function or method that will be called each time an image is generated
with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
To use the step callback, define a function that receives two arguments:
- Image GPU data
Expand All @@ -270,6 +274,7 @@ def process_image(image,seed):
iterations = iterations or self.iterations
strength = strength or self.strength
self.log_tokenization = log_tokenization
with_variations = [] if with_variations is None else with_variations

model = (
self.load_model()
Expand All @@ -278,6 +283,18 @@ def process_image(image,seed):
assert (
0.0 <= strength <= 1.0
), 'can only work with strength in [0.0, 1.0]'
assert (
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'

if len(with_variations) > 0:
assert seed is not None,\
'seed must be specified when using with_variations'
if variation_amount == 0.0:
assert iterations == 1,\
'when using --with_variations, multiple iterations are only possible when using --variation_amount'
assert all(0 <= weight <= 1 for _, weight in with_variations),\
f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'

width, height, _ = self._resolution_check(width, height, log=True)

Expand All @@ -301,24 +318,25 @@ def process_image(image,seed):
try:
if init_img:
assert os.path.exists(init_img), f'{init_img}: File not found'
images_iterator = self._img2img(
init_image = self._load_img(init_img, width, height, fit).to(self.device)
with scope(device.type):
init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space

make_image = self._img2img(
prompt,
precision_scope=scope,
steps=steps,
cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,
width=width,
height=height,
fit=fit,
init_latent=init_latent,
strength=strength,
callback=step_callback,
)
else:
images_iterator = self._txt2img(
make_image = self._txt2img(
prompt,
precision_scope=scope,
steps=steps,
cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
Expand All @@ -328,11 +346,45 @@ def process_image(image,seed):
callback=step_callback,
)

def get_noise():
if init_img:
return torch.randn_like(init_latent, device=self.device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)

initial_noise = None
if variation_amount > 0 or len(with_variations) > 0:
# use fixed initial noise plus random noise per iteration
seed_everything(seed)
initial_noise = get_noise()
for v_seed, v_weight in with_variations:
seed = v_seed
seed_everything(seed)
next_noise = get_noise()
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
seed = random.randrange(0,np.iinfo(np.uint32).max)

device_type = choose_autocast_device(self.device)
with scope(device_type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
seed_everything(seed)
image = next(images_iterator)
x_T = None
if variation_amount > 0:
seed_everything(seed)
target_noise = get_noise()
x_T = self.slerp(variation_amount, initial_noise, target_noise)
elif initial_noise is not None:
# i.e. we specified particular variations
x_T = initial_noise
else:
seed_everything(seed)
# make_image will do the equivalent of get_noise itself
image = make_image(x_T)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed)
Expand Down Expand Up @@ -406,7 +458,6 @@ def process_image(image,seed):
def _txt2img(
self,
prompt,
precision_scope,
steps,
cfg_scale,
ddim_eta,
Expand All @@ -416,12 +467,13 @@ def _txt2img(
callback,
):
"""
An infinite iterator of images from the prompt.
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
"""

sampler = self.sampler

while True:
def make_image(x_T):
uc, c = self._get_uc_and_c(prompt, skip_normalize)
shape = [
self.latent_channels,
Expand All @@ -431,6 +483,7 @@ def _txt2img(
samples, _ = sampler.sample(
batch_size=1,
S=steps,
x_T=x_T,
conditioning=c,
shape=shape,
verbose=False,
Expand All @@ -439,26 +492,24 @@ def _txt2img(
eta=ddim_eta,
img_callback=callback
)
yield self._sample_to_image(samples)
return self._sample_to_image(samples)
return make_image

@torch.no_grad()
def _img2img(
self,
prompt,
precision_scope,
steps,
cfg_scale,
ddim_eta,
skip_normalize,
init_img,
width,
height,
fit,
init_latent,
strength,
callback, # Currently not implemented for img2img
):
"""
An infinite iterator of images from the prompt and the initial image
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
"""

# PLMS sampler not supported yet, so ignore previous sampler
Expand All @@ -470,24 +521,20 @@ def _img2img(
else:
sampler = self.sampler

init_image = self._load_img(init_img, width, height,fit).to(self.device)
with precision_scope(self.device.type):
init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space

sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)

t_enc = int(strength * steps)

while True:
def make_image(x_T):
uc, c = self._get_uc_and_c(prompt, skip_normalize)

# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(self.device)
init_latent,
torch.tensor([t_enc]).to(self.device),
noise=x_T
)
# decode it
samples = sampler.decode(
Expand All @@ -498,7 +545,8 @@ def _img2img(
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
yield self._sample_to_image(samples)
return self._sample_to_image(samples)
return make_image

# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, skip_normalize):
Expand All @@ -513,8 +561,7 @@ def _get_uc_and_c(self, prompt, skip_normalize):
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for i in range(0, len(weighted_subprompts)):
subprompt, weight = weighted_subprompts[i]
for subprompt, weight in weighted_subprompts:
self._log_tokenization(subprompt)
c = torch.add(
c,
Expand Down Expand Up @@ -619,7 +666,7 @@ def _load_img(self, path, width, height, fit=False):
print(
f'>> loaded input image of size {image.width}x{image.height} from {path}'
)

# The logic here is:
# 1. If "fit" is true, then the image will be fit into the bounding box defined
# by width and height. It will do this in a way that preserves the init image's
Expand All @@ -644,7 +691,7 @@ def _squeeze_image(self,image):
if resize_needed:
return InitImageResizer(image).resize(x,y)
return image


def _fit_image(self,image,max_dimensions):
w,h = max_dimensions
Expand Down Expand Up @@ -677,10 +724,10 @@ def _split_weighted_subprompts(text, skip_normalize=False):
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
Expand Down Expand Up @@ -741,3 +788,41 @@ def _resolution_check(self, width, height, log=False):
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")

return width, height, resize_needed


def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
'''
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
'''
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()

dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1

if inputs_are_torch:
v2 = torch.from_numpy(v2).to(self.device)

return v2
Loading

0 comments on commit 2d65b03

Please sign in to comment.