Skip to content

Commit

Permalink
feat: only use valid inline loras, add subfolder support (#2968)
Browse files Browse the repository at this point in the history
  • Loading branch information
mashb1t authored May 20, 2024
1 parent ac14d9d commit 7537612
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 37 deletions.
14 changes: 13 additions & 1 deletion modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def add_ratio(x):

model_filenames = []
lora_filenames = []
lora_filenames_no_special = []
vae_filenames = []
wildcard_filenames = []

Expand All @@ -556,6 +557,16 @@ def add_ratio(x):
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]


def remove_special_loras(lora_filenames):
global loras_metadata_remove

loras_no_special = lora_filenames.copy()
for lora_to_remove in loras_metadata_remove:
if lora_to_remove in loras_no_special:
loras_no_special.remove(lora_to_remove)
return loras_no_special


def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None:
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
Expand All @@ -570,9 +581,10 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):


def update_files():
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
lora_filenames_no_special = remove_special_loras(lora_filenames)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets()
Expand Down
21 changes: 4 additions & 17 deletions modules/meta_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
def get_sha256(filepath):
global hash_cache
if filepath not in hash_cache:
# is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors'
hash_cache[filepath] = sha256(filepath)

return hash_cache[filepath]
Expand Down Expand Up @@ -293,12 +292,6 @@ def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_p
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem

@staticmethod
def remove_special_loras(lora_filenames):
for lora_to_remove in modules.config.loras_metadata_remove:
if lora_to_remove in lora_filenames:
lora_filenames.remove(lora_to_remove)


class A1111MetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
Expand Down Expand Up @@ -415,13 +408,11 @@ def parse_json(self, metadata: str) -> dict:
lora_data = data['lora_hashes']

if lora_data != '':
lora_filenames = modules.config.lora_filenames.copy()
self.remove_special_loras(lora_filenames)
for li, lora in enumerate(lora_data.split(', ')):
lora_split = lora.split(': ')
lora_name = lora_split[0]
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
for filename in lora_filenames:
for filename in modules.config.lora_filenames_no_special:
path = Path(filename)
if lora_name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
Expand Down Expand Up @@ -510,19 +501,15 @@ def get_scheme(self) -> MetadataScheme:
return MetadataScheme.FOOOCUS

def parse_json(self, metadata: dict) -> dict:
model_filenames = modules.config.model_filenames.copy()
lora_filenames = modules.config.lora_filenames.copy()
vae_filenames = modules.config.vae_filenames.copy()
self.remove_special_loras(lora_filenames)
for key, value in metadata.items():
if value in ['', 'None']:
continue
if key in ['base_model', 'refiner_model']:
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
else:
continue

Expand Down
41 changes: 29 additions & 12 deletions modules/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import numpy as np
import datetime
import random
Expand Down Expand Up @@ -360,6 +362,14 @@ def is_json(data: str) -> bool:
return True


def get_filname_by_stem(lora_name, filenames: List[str]) -> str | None:
for filename in filenames:
path = Path(filename)
if lora_name == path.stem:
return filename
return None


def get_file_from_folder_list(name, folders):
if not isinstance(folders, list):
folders = [folders]
Expand All @@ -377,28 +387,35 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:


def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
found_loras = []
prompt_without_loras = ""
for token in prompt.split(" "):
prompt_without_loras = ''
cleaned_prompt = ''
for token in prompt.split(','):
matches = LORAS_PROMPT_PATTERN.findall(token)

if matches:
for match in matches:
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
prompt_without_loras += token.replace(match[0], '')
else:
prompt_without_loras += token
prompt_without_loras += ' '
if len(matches) == 0:
prompt_without_loras += token + ', '
continue
for match in matches:
lora_name = match[1] + '.safetensors'
if not skip_file_check:
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
if lora_name is not None:
found_loras.append((lora_name, float(match[2])))
token = token.replace(match[0], '')
prompt_without_loras += token + ', '

if prompt_without_loras != '':
cleaned_prompt = prompt_without_loras[:-2]

cleaned_prompt = prompt_without_loras[:-1]
if prompt_cleanup:
cleaned_prompt = cleanup_prompt(prompt_without_loras)

new_loras = []
lora_names = [lora[0] for lora in loras]
for found_lora in found_loras:
if deduplicate_loras and found_lora[0] in lora_names:
if deduplicate_loras and (found_lora[0] in lora_names or found_lora in new_loras):
continue
new_loras.append(found_lora)

Expand Down
32 changes: 25 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ class TestUtils(unittest.TestCase):
def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True),
"output": (
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt, very cool, cool'
Expand All @@ -25,6 +25,7 @@ def test_can_parse_tokens_with_lora(self):
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
[("hey-lora.safetensors", 0.4)],
5,
True
),
"output": (
[
Expand All @@ -37,18 +38,35 @@ def test_can_parse_tokens_with_lora(self):
'some prompt, very cool'
)
},
# test correct matching even if there is no space separating loras in the same token
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True),
"output": (
[
('hey-lora.safetensors', 0.4),
('you-lora.safetensors', 0.2)
],
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
'some prompt, very cool'
),
},
# test deduplication, also selected loras are never overridden with loras in prompt
{
"input": (
"some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>",
[('you-lora.safetensors', 0.3)],
3,
True
),
"output": (
[
('you-lora.safetensors', 0.3),
('hey-lora.safetensors', 0.4)
],
'some prompt, very cool'
),
},
{
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True),
"output": (
[],
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
Expand All @@ -57,7 +75,7 @@ def test_can_parse_tokens_with_lora(self):
]

for test in test_cases:
prompt, loras, loras_limit = test["input"]
prompt, loras, loras_limit, skip_file_check = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
self.assertEqual(expected, actual)

0 comments on commit 7537612

Please sign in to comment.