Skip to content

Commit

Permalink
[IP2P] Make text encoder truly optional in InstructPi2Pix (#6995)
Browse files Browse the repository at this point in the history
* make text encoder component truly optional.

* more fixes

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
sayakpaul and yiyixuxu authored Feb 18, 2024
1 parent 07349c2 commit 31de879
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -553,13 +553,15 @@ def _encode_prompt(
else:
attention_mask = None

prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]

prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
else:
prompt_embeds_dtype = self.unet.dtype

prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
Expand Down Expand Up @@ -615,7 +617,7 @@ def _encode_prompt(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
Expand Down

0 comments on commit 31de879

Please sign in to comment.