Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: only use valid inline loras, add subfolder support #2968

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def add_ratio(x):

model_filenames = []
lora_filenames = []
lora_filenames_no_special = []
vae_filenames = []
wildcard_filenames = []

Expand All @@ -556,6 +557,16 @@ def add_ratio(x):
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]


def remove_special_loras(lora_filenames):
global loras_metadata_remove

loras_no_special = lora_filenames.copy()
for lora_to_remove in loras_metadata_remove:
if lora_to_remove in loras_no_special:
loras_no_special.remove(lora_to_remove)
return loras_no_special


def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None:
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
Expand All @@ -570,9 +581,10 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):


def update_files():
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
lora_filenames_no_special = remove_special_loras(lora_filenames)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets()
Expand Down
21 changes: 4 additions & 17 deletions modules/meta_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
def get_sha256(filepath):
global hash_cache
if filepath not in hash_cache:
# is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors'
hash_cache[filepath] = sha256(filepath)

return hash_cache[filepath]
Expand Down Expand Up @@ -293,12 +292,6 @@ def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_p
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem

@staticmethod
def remove_special_loras(lora_filenames):
for lora_to_remove in modules.config.loras_metadata_remove:
if lora_to_remove in lora_filenames:
lora_filenames.remove(lora_to_remove)


class A1111MetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
Expand Down Expand Up @@ -415,13 +408,11 @@ def parse_json(self, metadata: str) -> dict:
lora_data = data['lora_hashes']

if lora_data != '':
lora_filenames = modules.config.lora_filenames.copy()
self.remove_special_loras(lora_filenames)
for li, lora in enumerate(lora_data.split(', ')):
lora_split = lora.split(': ')
lora_name = lora_split[0]
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
for filename in lora_filenames:
for filename in modules.config.lora_filenames_no_special:
path = Path(filename)
if lora_name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
Expand Down Expand Up @@ -510,19 +501,15 @@ def get_scheme(self) -> MetadataScheme:
return MetadataScheme.FOOOCUS

def parse_json(self, metadata: dict) -> dict:
model_filenames = modules.config.model_filenames.copy()
lora_filenames = modules.config.lora_filenames.copy()
vae_filenames = modules.config.vae_filenames.copy()
self.remove_special_loras(lora_filenames)
for key, value in metadata.items():
if value in ['', 'None']:
continue
if key in ['base_model', 'refiner_model']:
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
else:
continue

Expand Down
41 changes: 29 additions & 12 deletions modules/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import numpy as np
import datetime
import random
Expand Down Expand Up @@ -360,6 +362,14 @@ def is_json(data: str) -> bool:
return True


def get_filname_by_stem(lora_name, filenames: List[str]) -> str | None:
for filename in filenames:
path = Path(filename)
if lora_name == path.stem:
return filename
return None


def get_file_from_folder_list(name, folders):
if not isinstance(folders, list):
folders = [folders]
Expand All @@ -377,28 +387,35 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:


def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
found_loras = []
prompt_without_loras = ""
for token in prompt.split(" "):
prompt_without_loras = ''
cleaned_prompt = ''
for token in prompt.split(','):
matches = LORAS_PROMPT_PATTERN.findall(token)

if matches:
for match in matches:
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
prompt_without_loras += token.replace(match[0], '')
else:
prompt_without_loras += token
prompt_without_loras += ' '
if len(matches) == 0:
prompt_without_loras += token + ', '
continue
for match in matches:
lora_name = match[1] + '.safetensors'
if not skip_file_check:
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
if lora_name is not None:
found_loras.append((lora_name, float(match[2])))
token = token.replace(match[0], '')
prompt_without_loras += token + ', '

if prompt_without_loras != '':
cleaned_prompt = prompt_without_loras[:-2]

cleaned_prompt = prompt_without_loras[:-1]
if prompt_cleanup:
cleaned_prompt = cleanup_prompt(prompt_without_loras)

new_loras = []
lora_names = [lora[0] for lora in loras]
for found_lora in found_loras:
if deduplicate_loras and found_lora[0] in lora_names:
if deduplicate_loras and (found_lora[0] in lora_names or found_lora in new_loras):
continue
new_loras.append(found_lora)

Expand Down
32 changes: 25 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ class TestUtils(unittest.TestCase):
def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True),
"output": (
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt, very cool, cool'
Expand All @@ -25,6 +25,7 @@ def test_can_parse_tokens_with_lora(self):
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
[("hey-lora.safetensors", 0.4)],
5,
True
),
"output": (
[
Expand All @@ -37,18 +38,35 @@ def test_can_parse_tokens_with_lora(self):
'some prompt, very cool'
)
},
# test correct matching even if there is no space separating loras in the same token
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True),
"output": (
[
('hey-lora.safetensors', 0.4),
('you-lora.safetensors', 0.2)
],
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
'some prompt, very cool'
),
},
# test deduplication, also selected loras are never overridden with loras in prompt
{
"input": (
"some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>",
[('you-lora.safetensors', 0.3)],
3,
True
),
"output": (
[
('you-lora.safetensors', 0.3),
('hey-lora.safetensors', 0.4)
],
'some prompt, very cool'
),
},
{
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True),
"output": (
[],
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
Expand All @@ -57,7 +75,7 @@ def test_can_parse_tokens_with_lora(self):
]

for test in test_cases:
prompt, loras, loras_limit = test["input"]
prompt, loras, loras_limit, skip_file_check = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
self.assertEqual(expected, actual)