From 55020b1f9203d69a2ceb98724194288b80a603c8 Mon Sep 17 00:00:00 2001 From: Logan zoellner Date: Thu, 25 Aug 2022 09:40:17 -0400 Subject: [PATCH] move image preprocessing inside pipeline and allow non 512x512 mask --- examples/inference/inpainting.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/inference/inpainting.py b/examples/inference/inpainting.py index c7c397289514c..de75c5636b356 100644 --- a/examples/inference/inpainting.py +++ b/examples/inference/inpainting.py @@ -22,10 +22,13 @@ def preprocess(image): def preprocess_mask(mask): mask=mask.convert("L") - mask = mask.resize((64,64), resample=PIL.Image.LANCZOS) + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w//8, h//8), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask,(4,1,1)) mask = mask[None].transpose(0, 1, 2, 3)#what does this step do? + mask = 1 - mask #repaint white, keep black mask = torch.from_numpy(mask) return mask @@ -87,13 +90,20 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + #preprocess image + init_image = preprocess(init_image).to(self.device) + # encode the init image into latents and scale the latents - init_latents = self.vae.encode(init_image.to(self.device)).sample() + init_latents = self.vae.encode(init_image).sample() init_latents = 0.18215 * init_latents init_latents_orig = init_latents # preprocess mask mask = preprocess_mask(mask_image).to(self.device) + mask = torch.cat([mask] * batch_size) + + #check sizes + assert mask.shape == init_latents.shape # prepare init_latents noise to latents init_latents = torch.cat([init_latents] * batch_size)