Skip to content

Commit

Permalink
Merge pull request #13568 from AUTOMATIC1111/lora_emb_bundle
Browse files Browse the repository at this point in the history
Add lora-embedding bundle system
  • Loading branch information
AUTOMATIC1111 authored Oct 14, 2023
2 parents 19f5795 + a8cbe50 commit 4be7b62
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 34 deletions.
33 changes: 33 additions & 0 deletions extensions-builtin/Lora/lora_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sys
import copy
import logging


class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[0;36m", # CYAN
"INFO": "\033[0;32m", # GREEN
"WARNING": "\033[0;33m", # YELLOW
"ERROR": "\033[0;31m", # RED
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
"RESET": "\033[0m", # RESET COLOR
}

def format(self, record):
colored_record = copy.copy(record)
levelname = colored_record.levelname
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
return super().format(colored_record)


logger = logging.getLogger("lora")
logger.propagate = False


if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
)
logger.addHandler(handler)
1 change: 1 addition & 0 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk):
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
self.bundle_embeddings = {}
self.mtime = None

self.mentioned_name = None
Expand Down
41 changes: 41 additions & 0 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from typing import Union

from modules import shared, devices, sd_models, errors, scripts, sd_hijack
import modules.textual_inversion.textual_inversion as textual_inversion

from lora_logger import logger

module_types = [
network_lora.ModuleTypeLora(),
Expand Down Expand Up @@ -151,9 +154,19 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping

matched_networks = {}
bundle_embeddings = {}

for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {})
if vec_name.split('.')[0] == 'string_to_param':
_, k2 = vec_name.split('.', 1)
emb_dict['string_to_param'] = {k2: weight}
else:
emb_dict[vec_name] = weight
bundle_embeddings[emb_name] = emb_dict

key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
Expand Down Expand Up @@ -197,6 +210,14 @@ def load_network(name, network_on_disk):

net.modules[key] = net_module

embeddings = {}
for emb_name, data in bundle_embeddings.items():
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
embedding.loaded = None
embeddings[emb_name] = embedding

net.bundle_embeddings = embeddings

if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")

Expand All @@ -212,11 +233,15 @@ def purge_networks_from_memory():


def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}

for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
for emb_name, embedding in net.bundle_embeddings.items():
if embedding.loaded:
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)

loaded_networks.clear()

Expand Down Expand Up @@ -259,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)

for emb_name, embedding in net.bundle_embeddings.items():
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
logger.warning(
f'Skip bundle embedding: "{emb_name}"'
' as it was already loaded from embeddings folder'
)
continue

embedding.loaded = False
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
embedding.loaded = True
emb_db.register_embedding(embedding, shared.sd_model)
else:
emb_db.skipped_embeddings[name] = embedding

if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))

Expand Down Expand Up @@ -567,6 +607,7 @@ def infotext_pasted(infotext, params):
available_networks = {}
available_network_aliases = {}
loaded_networks = []
loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
Expand Down
74 changes: 40 additions & 34 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,40 +181,7 @@ def load_from_file(self, path, filename):
else:
return


# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vectors
embedding.shape = shape
embedding.filename = path
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)

if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
Expand Down Expand Up @@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
return fn


def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
if 'string_to_param' in data: # textual inversion embeddings
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vectors
embedding.shape = shape

if filepath:
embedding.filename = filepath
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')

return embedding


def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
Expand Down

0 comments on commit 4be7b62

Please sign in to comment.