Skip to content

Commit

Permalink
Merge branch 'feature/optimize-performance-lora-filtering-in-metadata'
Browse files Browse the repository at this point in the history
# Conflicts:
#	modules/flags.py
#	modules/util.py
#	webui.py
  • Loading branch information
mashb1t committed May 30, 2024
2 parents f5863d8 + 83ef32a commit 5768330
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 48 deletions.
4 changes: 3 additions & 1 deletion modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,10 @@ def handler(async_task):

progressbar(async_task, 2, 'Loading models ...')

loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
lora_filenames = modules.util.remove_performance_lora(modules.config.lora_filenames, performance_selection)
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number, lora_filenames=lora_filenames)
loras += performance_loras

pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
Expand Down
32 changes: 8 additions & 24 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,25 +578,9 @@ def add_ratio(x):

model_filenames = []
lora_filenames = []
lora_filenames_no_special = []
vae_filenames = []
wildcard_filenames = []

sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors'
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]


def remove_special_loras(lora_filenames):
global loras_metadata_remove

loras_no_special = lora_filenames.copy()
for lora_to_remove in loras_metadata_remove:
if lora_to_remove in loras_no_special:
loras_no_special.remove(lora_to_remove)
return loras_no_special


def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None:
Expand All @@ -612,10 +596,9 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):


def update_files():
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
lora_filenames_no_special = remove_special_loras(lora_filenames)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets()
Expand Down Expand Up @@ -664,26 +647,27 @@ def downloading_sdxl_lcm_lora():
load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
model_dir=paths_loras[0],
file_name=sdxl_lcm_lora
file_name=modules.flags.PerformanceLoRA.EXTREME_SPEED.value
)
return sdxl_lcm_lora
return modules.flags.PerformanceLoRA.EXTREME_SPEED.value


def downloading_sdxl_lightning_lora():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_lightning_4step_lora.safetensors',
model_dir=paths_loras[0],
file_name=sdxl_lightning_lora
file_name=modules.flags.PerformanceLoRA.LIGHTNING.value
)
return sdxl_lightning_lora
return modules.flags.PerformanceLoRA.LIGHTNING.value


def downloading_sdxl_hyper_sd_lora():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_hyper_sd_4step_lora.safetensors',
model_dir=paths_loras[0],
file_name=sdxl_hyper_sd_lora
file_name=modules.flags.PerformanceLoRA.HYPER_SD.value
)
return sdxl_hyper_sd_lora
return modules.flags.PerformanceLoRA.HYPER_SD.value


def downloading_controlnet_canny():
Expand Down
23 changes: 20 additions & 3 deletions modules/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@

KSAMPLER_NAMES = list(KSAMPLER.keys())

SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo",
"align_your_steps", "tcd"]
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())

sampler_list = SAMPLER_NAMES
Expand Down Expand Up @@ -95,6 +96,7 @@
'1664*576', '1728*576'
]


class MetadataScheme(Enum):
FOOOCUS = 'fooocus'
A1111 = 'a1111'
Expand All @@ -119,6 +121,14 @@ def list(cls) -> list:
return list(map(lambda c: c.value, cls))


class PerformanceLoRA(Enum):
QUALITY = None
SPEED = None
EXTREME_SPEED = 'sdxl_lcm_lora.safetensors'
LIGHTNING = 'sdxl_lightning_4step_lora.safetensors'
HYPER_SD = 'sdxl_hyper_sd_4step_lora.safetensors'


class Steps(IntEnum):
QUALITY = 60
SPEED = 30
Expand Down Expand Up @@ -150,17 +160,24 @@ def list(cls) -> list:
def values(cls) -> list:
return list(map(lambda c: c.value, cls))

@classmethod
def by_steps(cls, steps: int | str):
return cls[Steps(int(steps)).name]

@classmethod
def has_restricted_features(cls, x) -> bool:
if isinstance(x, Performance):
x = x.value
return x in [cls.EXTREME_SPEED.value, cls.LIGHTNING.value, cls.HYPER_SD.value]

def steps(self) -> int | None:
return Steps[self.name].value if Steps[self.name] else None
return Steps[self.name].value if self.name in Steps.__members__ else None

def steps_uov(self) -> int | None:
return StepsUOV[self.name].value if Steps[self.name] else None
return StepsUOV[self.name].value if self.name in StepsUOV.__members__ else None

def lora_filename(self) -> str | None:
return PerformanceLoRA[self.name].value if self.name in PerformanceLoRA.__members__ else None


performance_selections = []
Expand Down
43 changes: 28 additions & 15 deletions modules/meta_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_str('prompt', 'Prompt', loaded_parameter_dict, results)
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
get_list('styles', 'Styles', loaded_parameter_dict, results)
get_str('performance', 'Performance', loaded_parameter_dict, results)
performance = get_str('performance', 'Performance', loaded_parameter_dict, results)
get_steps('steps', 'Steps', loaded_parameter_dict, results)
get_number('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
Expand All @@ -59,19 +59,27 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):

get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)

# prevent performance LoRAs to be added twice, by performance and by lora
performance_filename = None
if performance is not None and performance in Performance.list():
performance = Performance(performance)
performance_filename = performance.lora_filename()

for i in range(modules.config.default_max_lora_number):
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results, performance_filename)

return results


def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None) -> str | None:
try:
h = source_dict.get(key, source_dict.get(fallback, default))
assert isinstance(h, str)
results.append(h)
return h
except:
results.append(gr.update())
return None


def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
Expand Down Expand Up @@ -181,7 +189,7 @@ def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list,
results.append(gr.update())


def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, performance_filename: str | None):
try:
split_data = source_dict.get(key, source_dict.get(fallback)).split(' : ')
enabled = True
Expand All @@ -193,6 +201,9 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
name = split_data[1]
weight = split_data[2]

if name == performance_filename:
raise Exception

weight = float(weight)
results.append(enabled)
results.append(name)
Expand Down Expand Up @@ -247,7 +258,7 @@ def __init__(self):
self.full_prompt: str = ''
self.raw_negative_prompt: str = ''
self.full_negative_prompt: str = ''
self.steps: int = 30
self.steps: int = Steps.SPEED.value
self.base_model_name: str = ''
self.base_model_hash: str = ''
self.refiner_model_name: str = ''
Expand All @@ -260,11 +271,11 @@ def get_scheme(self) -> MetadataScheme:
raise NotImplementedError

@abstractmethod
def parse_json(self, metadata: dict | str) -> dict:
def to_json(self, metadata: dict | str) -> dict:
raise NotImplementedError

@abstractmethod
def parse_string(self, metadata: dict) -> str:
def to_string(self, metadata: dict) -> str:
raise NotImplementedError

def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
Expand Down Expand Up @@ -327,7 +338,7 @@ def get_scheme(self) -> MetadataScheme:
'version': 'Version'
}

def parse_json(self, metadata: str) -> dict:
def to_json(self, metadata: str) -> dict:
metadata_prompt = ''
metadata_negative_prompt = ''

Expand Down Expand Up @@ -381,9 +392,9 @@ def parse_json(self, metadata: str) -> dict:
data['styles'] = str(found_styles)

# try to load performance based on steps, fallback for direct A1111 imports
if 'steps' in data and 'performance' not in data:
if 'steps' in data and 'performance' in data is None:
try:
data['performance'] = Performance[Steps(int(data['steps'])).name].value
data['performance'] = Performance.by_steps(data['steps']).value
except ValueError | KeyError:
pass

Expand Down Expand Up @@ -413,15 +424,15 @@ def parse_json(self, metadata: str) -> dict:
lora_split = lora.split(': ')
lora_name = lora_split[0]
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
for filename in modules.config.lora_filenames_no_special:
for filename in modules.config.lora_filenames:
path = Path(filename)
if lora_name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
break

return data

def parse_string(self, metadata: dict) -> str:
def to_string(self, metadata: dict) -> str:
data = {k: v for _, k, v in metadata}

width, height = eval(data['resolution'])
Expand Down Expand Up @@ -501,22 +512,22 @@ class FooocusMetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
return MetadataScheme.FOOOCUS

def parse_json(self, metadata: dict) -> dict:
def to_json(self, metadata: dict) -> dict:
for key, value in metadata.items():
if value in ['', 'None']:
continue
if key in ['base_model', 'refiner_model']:
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
else:
continue

return metadata

def parse_string(self, metadata: list) -> str:
def to_string(self, metadata: list) -> str:
for li, (label, key, value) in enumerate(metadata):
# remove model folder paths from metadata
if key.startswith('lora_combined_'):
Expand Down Expand Up @@ -556,6 +567,8 @@ def replace_value_with_filename(key, value, filenames):
elif value == path.stem:
return filename

return None


def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
match metadata_scheme:
Expand Down
2 changes: 1 addition & 1 deletion modules/private_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def log(img, metadata, metadata_parser: MetadataParser | None = None, output_for
date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format)
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)

parsed_parameters = metadata_parser.parse_string(metadata.copy()) if metadata_parser is not None else ''
parsed_parameters = metadata_parser.to_string(metadata.copy()) if metadata_parser is not None else ''
image = Image.fromarray(img)

if output_format == OutputFormat.PNG.value:
Expand Down
26 changes: 24 additions & 2 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import modules.config
import modules.sdxl_styles
from modules.flags import Performance

LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)

Expand Down Expand Up @@ -387,10 +388,15 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:


def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True,
lora_filenames=None) -> tuple[List[Tuple[AnyStr, float]], str]:
if lora_filenames is None:
lora_filenames = []

found_loras = []
prompt_without_loras = ''
cleaned_prompt = ''

for token in prompt.split(','):
matches = LORAS_PROMPT_PATTERN.findall(token)

Expand All @@ -400,7 +406,7 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
for match in matches:
lora_name = match[1] + '.safetensors'
if not skip_file_check:
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
lora_name = get_filname_by_stem(match[1], lora_filenames)
if lora_name is not None:
found_loras.append((lora_name, float(match[2])))
token = token.replace(match[0], '')
Expand Down Expand Up @@ -430,6 +436,22 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
return updated_loras[:loras_limit], cleaned_prompt


def remove_performance_lora(filenames: list, performance: Performance | None):
loras_without_performance = filenames.copy()

if performance is None:
return loras_without_performance

performance_lora = performance.lora_filename()

for filename in filenames:
path = Path(filename)
if performance_lora == path.name:
loras_without_performance.remove(filename)

return loras_without_performance


def cleanup_prompt(prompt):
prompt = re.sub(' +', ' ', prompt)
prompt = re.sub(',+', ',', prompt)
Expand Down
Loading

0 comments on commit 5768330

Please sign in to comment.