Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a config option to use embeddings from the huggingface stable diffusion concept library. #1197

Merged
merged 1 commit into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/webui/webui_streamlit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ general:
default_model: "Stable Diffusion v1.4"
default_model_config: "configs/stable-diffusion/v1-inference.yaml"
default_model_path: "models/ldm/stable-diffusion-v1/model.ckpt"
fp:
name: ''
use_sd_concepts_library: True
sd_concepts_library_folder: "models/custom/sd-concepts-library"
GFPGAN_dir: "./src/gfpgan"
RealESRGAN_dir: "./src/realesrgan"
RealESRGAN_model: "RealESRGAN_x4plus"
Expand Down
13 changes: 7 additions & 6 deletions scripts/ModelManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ def layout():
#search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")

csvString = f"""
,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media
,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt
,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt
,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media
,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt
,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt
,Stable Diffusion Concept Library , ./models/customsd-concepts-library , https://github.com/sd-webui/sd-concepts-library
"""
colms = st.columns((1, 3, 5, 5))
columns = ["№",'Model Name','Save Location','Download Link']
Expand Down
4 changes: 1 addition & 3 deletions scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback
realesrgan_model_name=RealESRGAN_model,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images,
init_img=init_img,
Expand Down Expand Up @@ -329,7 +328,6 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=RealESRGAN_model,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images,
init_img=init_img,
Expand Down Expand Up @@ -569,7 +567,7 @@ def layout():
sampler_name=st.session_state["sampler_name"], n_iter=batch_count,
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,
height=height, fp=st.session_state['defaults'].general.fp, variant_amount=variant_amount,
height=height, variant_amount=variant_amount,
ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=st.session_state["RealESRGAN_model"],
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images, save_grid=save_grid,
Expand Down
127 changes: 119 additions & 8 deletions scripts/sd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@


# streamlit imports


from streamlit import StopException
#other imports

import warnings
import json

import base64
import os, sys, re, random, datetime, time, math
import os, sys, re, random, datetime, time, math, glob
from PIL import Image, ImageFont, ImageDraw, ImageFilter
from PIL.PngImagePlugin import PngInfo
from scipy import integrate
Expand Down Expand Up @@ -65,6 +64,16 @@
opt_C = 4
opt_f = 8

if not "defaults" in st.session_state:
st.session_state["defaults"] = {}

st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")

if (os.path.exists("configs/webui/userconfig_streamlit.yaml")):
user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults)


# should and will be moved to a settings menu in the UI at some point
grid_format = [s.lower() for s in st.session_state["defaults"].general.grid_format.split(':')]
grid_lossless = False
Expand Down Expand Up @@ -102,7 +111,7 @@
if save_quality < 0: # e.g. webp:-100 for lossless mode
save_lossless = True
save_quality = abs(save_quality)

# this should force GFPGAN and RealESRGAN onto the selected gpu as well
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu)
Expand Down Expand Up @@ -788,7 +797,7 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
@retry(tries=5)
def generation_callback(img, i=0):
if "update_preview_frequency" not in st.session_state:
return
raise StopException

try:
if i == 0:
Expand Down Expand Up @@ -928,6 +937,82 @@ def load_embeddings(fp):
if fp is not None and hasattr(st.session_state["model"], "embedding_manager"):
st.session_state["model"].embedding_manager.load(fp['name'])

def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")

# separate token and the embeds
if learned_embeds_path.endswith('.pt'):
print(loaded_learned_embeds['string_to_token'])
trained_token = list(loaded_learned_embeds['string_to_token'].keys())[0]
embeds = list(loaded_learned_embeds['string_to_param'].values())[0]

elif learned_embeds_path.endswith('.bin'):
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]

embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)

# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)

# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))

# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token

def concepts_library():

html_gallery = '''
<div class="flex gr-gap gr-form-gap row gap-4 w-full flex-wrap" id="main_row">
'''
for model in models:
html_gallery = html_gallery+f'''
<div class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200 gr-panel">
<div class="output-markdown gr-prose" style="max-width: 100%;">
<h3>
<a href="https://huggingface.co/{model["id"]}" target="_blank">
<code>{html.escape(model["token"])}</code>
</a>
</h3>
</div>
<div id="gallery" class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200">
<div class="wrap svelte-17ttdjv opacity-0"></div>
<div class="absolute left-0 top-0 py-1 px-2 rounded-br-lg shadow-sm text-xs text-gray-500 flex items-center pointer-events-none bg-white z-20 border-b border-r border-gray-100 dark:bg-gray-900">
<span class="mr-2 h-[12px] w-[12px] opacity-80">
<svg xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect>
<circle cx="8.5" cy="8.5" r="1.5"></circle>
<polyline points="21 15 16 10 5 21"></polyline>
</svg>
</span> {model["concept_type"]}
</div>
<div class="overflow-y-auto h-full p-2" style="position: relative;">
<div class="grid gap-2 grid-cols-2 sm:grid-cols-2 md:grid-cols-2 lg:grid-cols-2 xl:grid-cols-2 2xl:grid-cols-2 svelte-1g9btlg pt-6">
'''
for image in model["images"]:
html_gallery = html_gallery + f'''
<button class="gallery-item svelte-1g9btlg">
<img alt="" loading="lazy" class="h-full w-full overflow-hidden object-contain" src="file/{image}">
</button>
'''
html_gallery = html_gallery+'''
</div>
<iframe style="display: block; position: absolute; top: 0; left: 0; width: 100%; height: 100%; overflow: hidden; border: 0; opacity: 0; pointer-events: none; z-index: -1;" aria-hidden="true" tabindex="-1" src="about:blank"></iframe>
</div>
</div>
</div>
'''
html_gallery = html_gallery+'''
</div>
'''

def image_grid(imgs, batch_size, force_n_rows=None, captions=None):
#print (len(imgs))
if force_n_rows is not None:
Expand Down Expand Up @@ -1242,7 +1327,7 @@ def classToArrays( items, seed, n_iter ):
def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, noise_mode=0, find_noise_steps=1, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
Expand All @@ -1257,10 +1342,36 @@ def process_images(

mem_mon = MemUsageMonitor('MemMon')
mem_mon.start()

if st.session_state.defaults.general.use_sd_concepts_library:

if hasattr(st.session_state["model"], "embedding_manager"):
load_embeddings(fp)
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)

if prompt_tokens:
# compviz
tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer

# diffusers
#tokenizer = pipe.tokenizer
#text_encoder = pipe.text_encoder

ext = ('pt', 'bin')

if len(prompt_tokens) > 1:
for token_name in prompt_tokens:
embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name)
if os.path.exists(embedding_path):
for files in os.listdir(embedding_path):
if files.endswith(ext):
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
else:
embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0])
if os.path.exists(embedding_path):
for files in os.listdir(embedding_path):
if files.endswith(ext):
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")

os.makedirs(outpath, exist_ok=True)

sample_path = os.path.join(outpath, "samples")
Expand Down
Loading