diff --git a/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 9fd0a635a8d4..a2cb99be421f 100644 --- a/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -55,7 +55,9 @@ def __call__( num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -81,6 +83,9 @@ def __call__( f" {type(callback_steps)}." ) + if generator is None: + generator = np.random + # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -98,6 +103,7 @@ def __call__( ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -133,6 +139,7 @@ def __call__( return_tensors="np", ) uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -140,9 +147,10 @@ def __call__( text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it - latents_shape = (batch_size, 4, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) if latents is None: - latents = np.random.randn(*latents_shape).astype(np.float32) + latents = generator.randn(*latents_shape).astype(latents_dtype) elif latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") @@ -185,13 +193,30 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=latents)[0] + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index ce3f3fbacbc7..b2105458b574 100644 --- a/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -121,6 +121,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -159,6 +160,8 @@ def __call__( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -197,6 +200,9 @@ def __call__( f" {type(callback_steps)}." ) + if generator is None: + generator = np.random + # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -239,7 +245,7 @@ def __call__( f" {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] + uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError("The length of `negative_prompt` should be equal to batch_size.") else: @@ -257,13 +263,15 @@ def __call__( uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0] # duplicate unconditional embeddings for each generation per prompt - uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0) + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + latents_dtype = text_embeddings.dtype + init_image = init_image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=init_image)[0] init_latents = 0.18215 * init_latents @@ -297,7 +305,7 @@ def __call__( timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) # add noise to latents using the timesteps - noise = np.random.randn(*init_latents.shape).astype(np.float32) + noise = generator.randn(*init_latents.shape).astype(latents_dtype) init_latents = self.scheduler.add_noise( torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) @@ -341,14 +349,28 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=latents)[0] + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) else: has_nsfw_concept = None diff --git a/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index b45d968f66e3..37374aa4b1bc 100644 --- a/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -23,11 +23,11 @@ def prepare_mask_and_masked_image(image, mask, latents_shape): - image = np.array(image.convert("RGB")) + image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) image = image[None].transpose(0, 3, 1, 2) image = image.astype(np.float32) / 127.5 - 1.0 - image_mask = np.array(mask.convert("L")) + image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) masked_image = image * (image_mask < 127.5) mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST) @@ -138,6 +138,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -180,6 +181,8 @@ def __call__( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -222,6 +225,9 @@ def __call__( f" {type(callback_steps)}." ) + if generator is None: + generator = np.random + # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -261,7 +267,7 @@ def __call__( f" {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] + uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -283,7 +289,7 @@ def __call__( uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0] # duplicate unconditional embeddings for each generation per prompt - uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0) + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -294,7 +300,7 @@ def __call__( latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: - latents = np.random.randn(*latents_shape).astype(latents_dtype) + latents = generator.randn(*latents_shape).astype(latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") @@ -307,6 +313,10 @@ def __call__( masked_image_latents = self.vae_encoder(sample=masked_image)[0] masked_image_latents = 0.18215 * masked_image_latents + # duplicate mask and masked_image_latents for each generation per prompt + mask = mask.repeat(batch_size * num_images_per_prompt, 0) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) + mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents @@ -367,14 +377,28 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=latents)[0] + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) else: has_nsfw_concept = None