Skip to content

Commit

Permalink
xhinker parser implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Aug 25, 2024
1 parent ab9a4d3 commit 517ee93
Show file tree
Hide file tree
Showing 3 changed files with 1,464 additions and 3 deletions.
34 changes: 32 additions & 2 deletions modules/prompt_parser_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
from transformers import PreTrainedTokenizer
from modules import shared, prompt_parser, devices, sd_models
from modules.prompt_parser_xhinker import get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl_2p, \
get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1


print("testing")
debug_enabled = os.environ.get('SD_PROMPT_DEBUG', None)
debug = shared.log.trace if os.environ.get('SD_PROMPT_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: PROMPT')
Expand Down Expand Up @@ -183,7 +185,10 @@ def encode_prompts(pipe, p, prompts: list, negative_prompts: list, steps: int, c
for i in range(max(len(positive_schedule), len(negative_schedule))):
positive_prompt = positive_schedule[i % len(positive_schedule)]
negative_prompt = negative_schedule[i % len(negative_schedule)]
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, positive_prompt, negative_prompt, clip_skip)
if shared.opts.prompt_attention == "xhinker parser":
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, clip_skip)
else:
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, positive_prompt, negative_prompt, clip_skip)
if prompt_embed is not None:
p.prompt_embeds.append(torch.cat([prompt_embed] * len(prompts), dim=0))
if negative_embed is not None:
Expand Down Expand Up @@ -442,3 +447,28 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
).to(device)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, t5_negative_prompt_embed], dim=-2)
return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds


def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
print("using xhinker parser")
device = devices.device
SD3 = hasattr(pipe, 'text_encoder_3')
prompt, prompt_2, prompt_3 = split_prompts(prompt, SD3)
neg_prompt, neg_prompt_2, neg_prompt_3 = split_prompts(neg_prompt, SD3)
try:
prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer)
neg_prompt = pipe.maybe_convert_prompt(neg_prompt, pipe.tokenizer)
prompt_2 = pipe.maybe_convert_prompt(prompt_2, pipe.tokenizer_2)
neg_prompt_2 = pipe.maybe_convert_prompt(neg_prompt_2, pipe.tokenizer_2)
except:
pass
prompt_embed = positive_pooled = negative_embed = negative_pooled = None
if SD3:
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sd3(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, use_t5_encoder=bool(pipe.text_encoder_3))
elif 'Flux' in pipe.__class__.__name__:
prompt_embed, positive_pooled = get_weighted_text_embeddings_flux1(pipe=pipe, prompt=prompt, prompt_2=prompt_2)
elif 'XL' in pipe.__class__.__name__:
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sdxl_2p(pipe=pipe, prompt=prompt, prompt_2=prompt_2, neg_prompt=neg_prompt, neg_prompt_2=neg_prompt_2)
else:
prompt_embed, negative_embed = get_weighted_text_embeddings_sd15(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, clip_skip=clip_skip)
return prompt_embed, positive_pooled, negative_embed, negative_pooled
Loading

0 comments on commit 517ee93

Please sign in to comment.