Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VAE select #2867

Merged
merged 6 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1 +1,33 @@
#!/bin/bashORIGINALDIR=/content/app# Use predefined DATADIR if it is defined[[ x"${DATADIR}" == "x" ]] && DATADIR=/content/data# Make persistent dir from original dirfunction mklink () { mkdir -p $DATADIR/$1 ln -s $DATADIR/$1 $ORIGINALDIR}# Copy old files from import dirfunction import () { (test -d /import/$1 && cd /import/$1 && cp -Rpn . $DATADIR/$1/)}cd $ORIGINALDIR# modelsmklink models# Copy original files(cd $ORIGINALDIR/models.org && cp -Rpn . $ORIGINALDIR/models/)# Import old filesimport models# outputsmklink outputs# Import old filesimport outputs# Start applicationpython launch.py $*
#!/bin/bash

ORIGINALDIR=/content/app
# Use predefined DATADIR if it is defined
[[ x"${DATADIR}" == "x" ]] && DATADIR=/content/data

# Make persistent dir from original dir
function mklink () {
mkdir -p $DATADIR/$1
ln -s $DATADIR/$1 $ORIGINALDIR
}

# Copy old files from import dir
function import () {
(test -d /import/$1 && cd /import/$1 && cp -Rpn . $DATADIR/$1/)
}

cd $ORIGINALDIR

# models
mklink models
# Copy original files
(cd $ORIGINALDIR/models.org && cp -Rpn . $ORIGINALDIR/models/)
# Import old files
import models

# outputs
mklink outputs
# Import old files
import outputs

# Start application
python launch.py $*
2 changes: 2 additions & 0 deletions language/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@
"sgm_uniform": "sgm_uniform",
"simple": "simple",
"ddim_uniform": "ddim_uniform",
"VAE": "VAE",
"Default (model)": "Default (model)",
"Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step",
"Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.",
"Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step",
Expand Down
13 changes: 9 additions & 4 deletions ldm_patched/modules/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,13 @@ class EmptyClass:

return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)

def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None):
sd = ldm_patched.modules.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
clip = None
clipvision = None
vae = None
vae_filename = None
model = None
model_patcher = None
clip_target = None
Expand Down Expand Up @@ -462,8 +463,12 @@ class WeightsLoader(torch.nn.Module):
model.load_model_weights(sd, "model.diffusion_model.")

if output_vae:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
if vae_filename_param is None:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
else:
vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param)
vae_filename = vae_filename_param
vae = VAE(sd=vae_sd)

if output_clip:
Expand All @@ -485,7 +490,7 @@ class WeightsLoader(torch.nn.Module):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)

return (model_patcher, clip, vae, clipvision)
return model_patcher, clip, vae, vae_filename, clipvision


def load_unet_state_dict(sd): #load unet in diffusers format
Expand Down
6 changes: 4 additions & 2 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def handler(async_task):
adaptive_cfg = args.pop()
sampler_name = args.pop()
scheduler_name = args.pop()
vae_name = args.pop()
overwrite_step = args.pop()
overwrite_switch = args.pop()
overwrite_width = args.pop()
Expand Down Expand Up @@ -428,7 +429,7 @@ def handler(async_task):
progressbar(async_task, 3, 'Loading models ...')
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner)
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)

progressbar(async_task, 3, 'Processing prompts ...')
tasks = []
Expand Down Expand Up @@ -869,6 +870,7 @@ def callback(step, x0, x, total_steps, y):

d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name))
d.append(('VAE', 'vae', vae_name))
d.append(('Seed', 'seed', str(task['task_seed'])))

if freeu_enabled:
Expand All @@ -883,7 +885,7 @@ def callback(step, x0, x, total_steps, y):
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras)
steps, base_model_name, refiner_model_name, loras, vae_name)
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format))
Expand Down
14 changes: 13 additions & 1 deletion modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
path_vae = get_dir_or_set_default('path_vae', '../models/vae/')
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
Expand Down Expand Up @@ -346,6 +347,11 @@ def init_temp_path(path: str | None, default_path: str) -> str:
default_value='karras',
validator=lambda x: x in modules.flags.scheduler_list
)
default_vae = get_config_item_or_set_default(
key='default_vae',
default_value=modules.flags.default_vae,
validator=lambda x: isinstance(x, str)
)
default_styles = get_config_item_or_set_default(
key='default_styles',
default_value=[
Expand Down Expand Up @@ -535,6 +541,7 @@ def add_ratio(x):

model_filenames = []
lora_filenames = []
vae_filenames = []
wildcard_filenames = []

sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
Expand All @@ -546,15 +553,20 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None:
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
files = []

if not isinstance(folder_paths, list):
folder_paths = [folder_paths]
for folder in folder_paths:
files += get_files_from_folder(folder, extensions, name_filter)

return files


def update_files():
global model_filenames, lora_filenames, wildcard_filenames, available_presets
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets()
return
Expand Down
10 changes: 6 additions & 4 deletions modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@


class StableDiffusionModel:
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None):
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None):
self.unet = unet
self.vae = vae
self.clip = clip
self.clip_vision = clip_vision
self.filename = filename
self.vae_filename = vae_filename
self.unet_with_lora = unet
self.clip_with_lora = clip
self.visited_loras = ''
Expand Down Expand Up @@ -142,9 +143,10 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per

@torch.no_grad()
@torch.inference_mode()
def load_model(ckpt_filename):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
def load_model(ckpt_filename, vae_filename=None):
unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings,
vae_filename_param=vae_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename)


@torch.no_grad()
Expand Down
22 changes: 14 additions & 8 deletions modules/default_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import modules.patch
import modules.config
import modules.flags
import ldm_patched.modules.model_management
import ldm_patched.modules.latent_formats
import modules.inpaint_worker
Expand Down Expand Up @@ -58,17 +59,21 @@ def assert_model_integrity():

@torch.no_grad()
@torch.inference_mode()
def refresh_base_model(name):
def refresh_base_model(name, vae_name=None):
global model_base

filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)

if model_base.filename == filename:
vae_filename = None
if vae_name is not None and vae_name != modules.flags.default_vae:
vae_filename = get_file_from_folder_list(vae_name, modules.config.path_vae)

if model_base.filename == filename and model_base.vae_filename == vae_filename:
return

model_base = core.StableDiffusionModel()
model_base = core.load_model(filename)
model_base = core.load_model(filename, vae_filename)
print(f'Base model loaded: {model_base.filename}')
print(f'VAE loaded: {model_base.vae_filename}')
return


Expand Down Expand Up @@ -216,7 +221,7 @@ def prepare_text_encoder(async_call=True):
@torch.no_grad()
@torch.inference_mode()
def refresh_everything(refiner_model_name, base_model_name, loras,
base_model_additional_loras=None, use_synthetic_refiner=False):
base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None):
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion

final_unet = None
Expand All @@ -227,11 +232,11 @@ def refresh_everything(refiner_model_name, base_model_name, loras,

if use_synthetic_refiner and refiner_model_name == 'None':
print('Synthetic Refiner Activated')
refresh_base_model(base_model_name)
refresh_base_model(base_model_name, vae_name)
synthesize_refiner_model()
else:
refresh_refiner_model(refiner_model_name)
refresh_base_model(base_model_name)
refresh_base_model(base_model_name, vae_name)

refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
assert_model_integrity()
Expand All @@ -254,7 +259,8 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
refresh_everything(
refiner_model_name=modules.config.default_refiner_model_name,
base_model_name=modules.config.default_base_model_name,
loras=get_enabled_loras(modules.config.default_loras)
loras=get_enabled_loras(modules.config.default_loras),
vae_name=modules.config.default_vae,
)


Expand Down
2 changes: 2 additions & 0 deletions modules/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
sampler_list = SAMPLER_NAMES
scheduler_list = SCHEDULER_NAMES

default_vae = 'Default (model)'

refiner_swap_method = 'joint'

cn_ip = "ImagePrompt"
Expand Down
31 changes: 24 additions & 7 deletions modules/meta_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
get_str('sampler', 'Sampler', loaded_parameter_dict, results)
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
get_str('vae', 'VAE', loaded_parameter_dict, results)
get_seed('seed', 'Seed', loaded_parameter_dict, results)

if is_generating:
Expand Down Expand Up @@ -253,6 +254,7 @@ def __init__(self):
self.refiner_model_name: str = ''
self.refiner_model_hash: str = ''
self.loras: list = []
self.vae_name: str = ''

@abstractmethod
def get_scheme(self) -> MetadataScheme:
Expand All @@ -267,7 +269,7 @@ def parse_string(self, metadata: dict) -> str:
raise NotImplementedError

def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
refiner_model_name, loras):
refiner_model_name, loras, vae_name):
self.raw_prompt = raw_prompt
self.full_prompt = full_prompt
self.raw_negative_prompt = raw_negative_prompt
Expand All @@ -289,6 +291,7 @@ def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_p
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
lora_hash = get_sha256(lora_path)
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem

@staticmethod
def remove_special_loras(lora_filenames):
Expand All @@ -310,6 +313,7 @@ def get_scheme(self) -> MetadataScheme:
'steps': 'Steps',
'sampler': 'Sampler',
'scheduler': 'Scheduler',
'vae': 'VAE',
'guidance_scale': 'CFG scale',
'seed': 'Seed',
'resolution': 'Size',
Expand Down Expand Up @@ -397,13 +401,12 @@ def parse_json(self, metadata: str) -> dict:
data['sampler'] = k
break

for key in ['base_model', 'refiner_model']:
for key in ['base_model', 'refiner_model', 'vae']:
if key in data:
for filename in modules.config.model_filenames:
path = Path(filename)
if data[key] == path.stem:
data[key] = filename
break
if key == 'vae':
self.add_extension_to_filename(data, modules.config.vae_filenames, 'vae')
else:
self.add_extension_to_filename(data, modules.config.model_filenames, key)

lora_data = ''
if 'lora_weights' in data and data['lora_weights'] != '':
Expand Down Expand Up @@ -433,6 +436,7 @@ def parse_string(self, metadata: dict) -> str:

sampler = data['sampler']
scheduler = data['scheduler']

if sampler in SAMPLERS and SAMPLERS[sampler] != '':
sampler = SAMPLERS[sampler]
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
Expand All @@ -451,6 +455,7 @@ def parse_string(self, metadata: dict) -> str:

self.fooocus_to_a1111['performance']: data['performance'],
self.fooocus_to_a1111['scheduler']: scheduler,
self.fooocus_to_a1111['vae']: Path(data['vae']).stem,
# workaround for multiline prompts
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
Expand Down Expand Up @@ -491,6 +496,14 @@ def parse_string(self, metadata: dict) -> str:
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()

@staticmethod
def add_extension_to_filename(data, filenames, key):
for filename in filenames:
path = Path(filename)
if data[key] == path.stem:
data[key] = filename
break


class FooocusMetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
Expand All @@ -499,6 +512,7 @@ def get_scheme(self) -> MetadataScheme:
def parse_json(self, metadata: dict) -> dict:
model_filenames = modules.config.model_filenames.copy()
lora_filenames = modules.config.lora_filenames.copy()
vae_filenames = modules.config.vae_filenames.copy()
self.remove_special_loras(lora_filenames)
for key, value in metadata.items():
if value in ['', 'None']:
Expand All @@ -507,6 +521,8 @@ def parse_json(self, metadata: dict) -> dict:
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
else:
continue

Expand All @@ -533,6 +549,7 @@ def parse_string(self, metadata: list) -> str:
res['refiner_model'] = self.refiner_model_name
res['refiner_model_hash'] = self.refiner_model_hash

res['vae'] = self.vae_name
res['loras'] = self.loras

if modules.config.metadata_created_by != '':
Expand Down
3 changes: 3 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ def is_json(data: str) -> bool:


def get_file_from_folder_list(name, folders):
if not isinstance(folders, list):
folders = [folders]

for folder in folders:
filename = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
if os.path.isfile(filename):
Expand Down
Loading