Skip to content

Commit

Permalink
move image preprocessing inside pipeline and allow non 512x512 mask
Browse files Browse the repository at this point in the history
  • Loading branch information
nagolinc committed Aug 25, 2022
1 parent 721d55d commit b6140f0
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions examples/inference/inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def preprocess(image):

def preprocess_mask(mask):
mask=mask.convert("L")
mask = mask.resize((64,64), resample=PIL.Image.LANCZOS)
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w//8, h//8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask,(4,1,1))
mask = mask[None].transpose(0, 1, 2, 3)#what does this step do?
mask = 1 - mask #repaint white, keep black
mask = torch.from_numpy(mask)
return mask

Expand Down Expand Up @@ -87,13 +90,20 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

#preprocess image
init_image = preprocess(init_image).to(self.device)

# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample()
init_latents = self.vae.encode(init_image).sample()
init_latents = 0.18215 * init_latents
init_latents_orig = init_latents

# preprocess mask
mask = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask] * batch_size)

#check sizes
assert mask.shape == init_latents.shape

# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)
Expand Down

0 comments on commit b6140f0

Please sign in to comment.