From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: [PATCH] Be more clear about Spandrel model nomenclature --- extensions-builtin/SwinIR/scripts/swinir_model.py | 6 +++--- modules/gfpgan_model.py | 10 ++++++---- modules/modelloader.py | 2 +- modules/realesrgan_model.py | 4 ++-- modules/upscaler_utils.py | 2 +- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index aae159af56e..95c7ec648ef 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -71,7 +71,7 @@ def load_model(self, path, scale=4): else: filename = path - model = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), dtype=devices.dtype, @@ -79,10 +79,10 @@ def load_model(self, path, scale=4): ) if getattr(opts, 'SWIN_torch_compile', False): try: - model = torch.compile(model) + model_descriptor.model.compile() except Exception: logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) - return model + return model_descriptor def _get_device(self): return devices.get_device_for('swinir') diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 48f8ad5e294..445b040925e 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -3,6 +3,8 @@ import logging import os +import torch + from modules import ( devices, errors, @@ -25,7 +27,7 @@ def name(self): def get_device(self): return devices.device_gfpgan - def load_net(self) -> None: + def load_net(self) -> torch.Module: for model_path in modelloader.load_models( model_path=self.model_path, model_url=model_url, @@ -34,13 +36,13 @@ def load_net(self) -> None: ext_filter=['.pth'], ): if 'GFPGAN' in os.path.basename(model_path): - net = modelloader.load_spandrel_model( + model = modelloader.load_spandrel_model( model_path, device=self.get_device(), expected_architecture='GFPGAN', ).model - net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 - return net + model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return model raise ValueError("No GFPGAN model found") def restore(self, np_image): diff --git a/modules/modelloader.py b/modules/modelloader.py index 8bcee08c155..a7194137571 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -143,7 +143,7 @@ def load_spandrel_model( *, device: str | torch.device | None, half: bool = False, - dtype: str | None = None, + dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 65f2e880668..4d35b695c3b 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -36,14 +36,14 @@ def do_upscale(self, img, path): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - mod = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( - mod, + model_descriptor, img, tile_size=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index dde5d7ad43a..174c9bc3713 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__)