Skip to content

Commit

Permalink
[Onnx] support half-precision and fix bugs for onnx pipelines (huggin…
Browse files Browse the repository at this point in the history
…gface#932)

* [Onnx] support half-precision and fix bugs for onnx pipelines

* Update convert_stable_diffusion_checkpoint_to_onnx.py

* style

* fix has_nsfw_concept

* Update convert_stable_diffusion_checkpoint_to_onnx.py

* fix style
  • Loading branch information
SkyTNT authored Oct 25, 2022
1 parent 5a47af1 commit fc89d4a
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 19 deletions.
35 changes: 30 additions & 5 deletions pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def __call__(
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand All @@ -81,6 +83,9 @@ def __call__(
f" {type(callback_steps)}."
)

if generator is None:
generator = np.random

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand All @@ -98,6 +103,7 @@ def __call__(
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down Expand Up @@ -133,16 +139,18 @@ def __call__(
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])

# get the initial random noise unless the user supplied it
latents_shape = (batch_size, 4, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
if latents is None:
latents = np.random.randn(*latents_shape).astype(np.float32)
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

Expand Down Expand Up @@ -185,13 +193,30 @@ def __call__(
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __call__(
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
Expand Down Expand Up @@ -159,6 +160,8 @@ def __call__(
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand Down Expand Up @@ -197,6 +200,9 @@ def __call__(
f" {type(callback_steps)}."
)

if generator is None:
generator = np.random

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -239,7 +245,7 @@ def __call__(
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
else:
Expand All @@ -257,13 +263,15 @@ def __call__(
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])

latents_dtype = text_embeddings.dtype
init_image = init_image.astype(latents_dtype)
# encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=init_image)[0]
init_latents = 0.18215 * init_latents
Expand Down Expand Up @@ -297,7 +305,7 @@ def __call__(
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)

# add noise to latents using the timesteps
noise = np.random.randn(*init_latents.shape).astype(np.float32)
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
)
Expand Down Expand Up @@ -341,14 +349,28 @@ def __call__(
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@


def prepare_mask_and_masked_image(image, mask, latents_shape):
image = np.array(image.convert("RGB"))
image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
image = image[None].transpose(0, 3, 1, 2)
image = image.astype(np.float32) / 127.5 - 1.0

image_mask = np.array(mask.convert("L"))
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
masked_image = image * (image_mask < 127.5)

mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
Expand Down Expand Up @@ -138,6 +138,7 @@ def __call__(
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand Down Expand Up @@ -180,6 +181,8 @@ def __call__(
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
Expand Down Expand Up @@ -222,6 +225,9 @@ def __call__(
f" {type(callback_steps)}."
)

if generator is None:
generator = np.random

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -261,7 +267,7 @@ def __call__(
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
Expand All @@ -283,7 +289,7 @@ def __call__(
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
Expand All @@ -294,7 +300,7 @@ def __call__(
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = np.random.randn(*latents_shape).astype(latents_dtype)
latents = generator.randn(*latents_shape).astype(latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
Expand All @@ -307,6 +313,10 @@ def __call__(
masked_image_latents = self.vae_encoder(sample=masked_image)[0]
masked_image_latents = 0.18215 * masked_image_latents

# duplicate mask and masked_image_latents for each generation per prompt
mask = mask.repeat(batch_size * num_images_per_prompt, 0)
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0)

mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
Expand Down Expand Up @@ -367,14 +377,28 @@ def __call__(
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None

Expand Down

0 comments on commit fc89d4a

Please sign in to comment.