diff --git a/backend/src/nodes/properties/inputs/pytorch_inputs.py b/backend/src/nodes/properties/inputs/pytorch_inputs.py index 9f9a3dce7..2860df834 100644 --- a/backend/src/nodes/properties/inputs/pytorch_inputs.py +++ b/backend/src/nodes/properties/inputs/pytorch_inputs.py @@ -1,23 +1,16 @@ from __future__ import annotations -from typing import Union - try: import torch from spandrel import ( - FaceSRModelDescriptor, - InpaintModelDescriptor, + ImageModelDescriptor, + MaskedImageModelDescriptor, ModelDescriptor, - RestorationModelDescriptor, - SRModelDescriptor, ) except Exception: torch = None - ModelDescriptor = object - SRModelDescriptor = object - FaceSRModelDescriptor = object - InpaintModelDescriptor = object - RestorationModelDescriptor = object + ImageModelDescriptor = object + MaskedImageModelDescriptor = object import navi from api import BaseInput @@ -38,7 +31,7 @@ def __init__( def enforce(self, value: object): if torch is not None: assert isinstance( - value, ModelDescriptor + value, (ImageModelDescriptor, MaskedImageModelDescriptor) ), "Expected a supported PyTorch model." return value @@ -57,12 +50,16 @@ def __init__( ), ) if torch is not None: - self.associated_type = Union[SRModelDescriptor, RestorationModelDescriptor] + self.associated_type = ImageModelDescriptor - def enforce(self, value: object): + def enforce(self, value: ModelDescriptor): if torch is not None: assert isinstance( - value, (RestorationModelDescriptor, SRModelDescriptor) + value, ImageModelDescriptor + ), "Expected a supported single image PyTorch model." + assert value.purpose in ( + "SR", + "Restoration", ), "Expected a Super-Resolution or Restoration model." return value @@ -79,13 +76,16 @@ def __init__( ), ) if torch is not None: - self.associated_type = FaceSRModelDescriptor + self.associated_type = ImageModelDescriptor - def enforce(self, value: object): + def enforce(self, value: ModelDescriptor): if torch is not None: assert isinstance( - value, FaceSRModelDescriptor - ), "Expected a Face-specific Super-Resolution model." + value, ImageModelDescriptor + ), "Expected a supported single image PyTorch model." + assert value.purpose in ( + "FaceSR" + ), "Expected a Face Super-Resolution model." return value @@ -101,13 +101,16 @@ def __init__( ), ) if torch is not None: - self.associated_type = InpaintModelDescriptor + self.associated_type = MaskedImageModelDescriptor - def enforce(self, value: object): + def enforce(self, value: ModelDescriptor): if torch is not None: assert isinstance( - value, InpaintModelDescriptor - ), "Expected an inpainting-specific model." + value, MaskedImageModelDescriptor + ), "Expected a supported masked-image PyTorch model." + assert value.purpose in ( + "Inpaint" + ), "Expected a Face Super-Resolution model." return value diff --git a/backend/src/nodes/properties/outputs/pytorch_outputs.py b/backend/src/nodes/properties/outputs/pytorch_outputs.py index 4c170aee2..74da70bc0 100644 --- a/backend/src/nodes/properties/outputs/pytorch_outputs.py +++ b/backend/src/nodes/properties/outputs/pytorch_outputs.py @@ -1,11 +1,7 @@ from __future__ import annotations from spandrel import ( - FaceSRModelDescriptor, - InpaintModelDescriptor, ModelDescriptor, - RestorationModelDescriptor, - SRModelDescriptor, ) import navi @@ -14,19 +10,6 @@ from ...utils.format import format_channel_numbers -def get_sub_type(model_descriptor: ModelDescriptor) -> str: - if isinstance(model_descriptor, SRModelDescriptor): - return "SR" - elif isinstance(model_descriptor, InpaintModelDescriptor): - return "Inpainting" - elif isinstance(model_descriptor, RestorationModelDescriptor): - return "Restoration" - elif isinstance(model_descriptor, FaceSRModelDescriptor): # type: ignore <- it wants me to just put this in an else - return "FaceSR" - else: - return "Unknown" - - class ModelOutput(BaseOutput): def __init__( self, @@ -53,7 +36,7 @@ def get_broadcast_type(self, value: ModelDescriptor): "inputChannels": value.input_channels, "outputChannels": value.output_channels, "arch": navi.literal(value.architecture), - "subType": navi.literal(get_sub_type(value)), + "subType": navi.literal(value.purpose), "size": navi.literal("x".join(value.tags)), }, ) diff --git a/backend/src/packages/chaiNNer_pytorch/__init__.py b/backend/src/packages/chaiNNer_pytorch/__init__.py index 6351c9d7f..9a5c4e0c1 100644 --- a/backend/src/packages/chaiNNer_pytorch/__init__.py +++ b/backend/src/packages/chaiNNer_pytorch/__init__.py @@ -95,7 +95,7 @@ def get_pytorch(): Dependency( display_name="Spandrel", pypi_name="spandrel", - version="0.0.4", + version="0.1.0", size_estimate=180.7 * KB, ), ], diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py index 311404634..029f76523 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py @@ -4,7 +4,7 @@ import numpy as np import torch -from spandrel import InpaintModelDescriptor +from spandrel import MaskedImageModelDescriptor import navi from nodes.impl.image_utils import as_3d @@ -51,7 +51,7 @@ def pad_img_to_modulo( def inpaint( img: np.ndarray, mask: np.ndarray, - model: InpaintModelDescriptor, + model: MaskedImageModelDescriptor, options: PyTorchSettings, ): with torch.no_grad(): @@ -152,7 +152,7 @@ def inpaint( def inpaint_node( img: np.ndarray, mask: np.ndarray, - model: InpaintModelDescriptor, + model: MaskedImageModelDescriptor, ) -> np.ndarray: """Inpaint an image""" diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py index ebb210384..b6cbaa82e 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py @@ -3,7 +3,7 @@ import numpy as np import torch from sanic.log import logger -from spandrel import RestorationModelDescriptor, SRModelDescriptor +from spandrel import ImageModelDescriptor from nodes.groups import Condition, if_group from nodes.impl.pytorch.auto_split import pytorch_auto_split @@ -29,7 +29,7 @@ def upscale( img: np.ndarray, - model: SRModelDescriptor | RestorationModelDescriptor, + model: ImageModelDescriptor, tile_size: TileSize, options: PyTorchSettings, ): @@ -130,7 +130,7 @@ def estimate(): ) def upscale_image_node( img: np.ndarray, - model: SRModelDescriptor | RestorationModelDescriptor, + model: ImageModelDescriptor, tile_size: TileSize, separate_alpha: bool, ) -> np.ndarray: diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py b/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py index fb79cfd59..38cad9e3a 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py @@ -8,7 +8,7 @@ from appdirs import user_data_dir from facexlib.utils.face_restoration_helper import FaceRestoreHelper from sanic.log import logger -from spandrel import FaceSRModelDescriptor +from spandrel import ImageModelDescriptor from torchvision.transforms.functional import normalize as tv_normalize from nodes.groups import Condition, if_group @@ -37,7 +37,7 @@ def upscale( img: np.ndarray, background_img: np.ndarray | None, face_helper: FaceRestoreHelper, - face_model: FaceSRModelDescriptor, + face_model: ImageModelDescriptor, weight: float, exec_options: PyTorchSettings, device: torch.device, @@ -149,7 +149,7 @@ def upscale( ) def upscale_face_node( img: np.ndarray, - face_model: FaceSRModelDescriptor, + face_model: ImageModelDescriptor, background_img: np.ndarray | None, scale: int, weight: float, diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py index d158a6f59..bbcee54dd 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py @@ -1,6 +1,6 @@ from __future__ import annotations -from spandrel import RestorationModelDescriptor, SRModelDescriptor +from spandrel import ImageModelDescriptor from spandrel.architectures.DAT import DAT from spandrel.architectures.HAT import HAT from spandrel.architectures.OmniSR import OmniSR @@ -45,7 +45,7 @@ ], ) def convert_to_ncnn_node( - model: SRModelDescriptor | RestorationModelDescriptor, is_fp16: int + model: ImageModelDescriptor, is_fp16: int ) -> tuple[NcnnModelWrapper, str]: if onnx_convert_to_ncnn_node is None: raise ModuleNotFoundError( diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py index fe731bef9..ad9469795 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py @@ -3,7 +3,7 @@ from io import BytesIO import torch -from spandrel import RestorationModelDescriptor, SRModelDescriptor +from spandrel import ImageModelDescriptor from spandrel.architectures.SCUNet import SCUNet from nodes.impl.onnx.model import OnnxGeneric @@ -33,7 +33,7 @@ ], ) def convert_to_onnx_node( - model: SRModelDescriptor | RestorationModelDescriptor, is_fp16: int + model: ImageModelDescriptor, is_fp16: int ) -> tuple[OnnxGeneric, str]: assert not isinstance( model.model, SCUNet diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py index d3fee4390..2d88ea2e7 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py @@ -38,11 +38,7 @@ def check_can_interp(model_a: dict, model_b: dict): return False interp_50 = perform_interp(model_a, model_b, 50) model_descriptor = ModelLoader(torch.device("cpu")).load_from_state_dict(interp_50) - size = ( - model_descriptor.size_requirements.minimum - if model_descriptor.size_requirements.minimum is not None - else 3 - ) + size = max(model_descriptor.size_requirements.minimum, 3) assert isinstance(size, int), "min_size_restriction must be an int" fake_img = np.ones((size, size, model_descriptor.input_channels), dtype=np.float32) del interp_50 diff --git a/pyrightconfig.json b/pyrightconfig.json index f9bb28ddd..64bc99b07 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -5,17 +5,12 @@ "exclude": [ "**/__pycache__" ], - "ignore": [ - "backend/src/nodes/impl/pytorch/architecture" - ], - + "ignore": [], "typeCheckingMode": "basic", "useLibraryCodeForTypes": false, - "strictListInference": true, "strictDictionaryInference": true, "strictSetInference": true, - "reportDuplicateImport": "warning", "reportImportCycles": "error", "reportIncompatibleVariableOverride": "error",