Skip to content

Commit

Permalink
Update Spandrel to 0.1.0 (#2354)
Browse files Browse the repository at this point in the history
* Update Spandrel to 0.1.0

* update dep num
  • Loading branch information
joeyballentine authored Nov 30, 2023
1 parent d19bb23 commit 7dc9e66
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 66 deletions.
49 changes: 26 additions & 23 deletions backend/src/nodes/properties/inputs/pytorch_inputs.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
from __future__ import annotations

from typing import Union

try:
import torch
from spandrel import (
FaceSRModelDescriptor,
InpaintModelDescriptor,
ImageModelDescriptor,
MaskedImageModelDescriptor,
ModelDescriptor,
RestorationModelDescriptor,
SRModelDescriptor,
)
except Exception:
torch = None
ModelDescriptor = object
SRModelDescriptor = object
FaceSRModelDescriptor = object
InpaintModelDescriptor = object
RestorationModelDescriptor = object
ImageModelDescriptor = object
MaskedImageModelDescriptor = object

import navi
from api import BaseInput
Expand All @@ -38,7 +31,7 @@ def __init__(
def enforce(self, value: object):
if torch is not None:
assert isinstance(
value, ModelDescriptor
value, (ImageModelDescriptor, MaskedImageModelDescriptor)
), "Expected a supported PyTorch model."
return value

Expand All @@ -57,12 +50,16 @@ def __init__(
),
)
if torch is not None:
self.associated_type = Union[SRModelDescriptor, RestorationModelDescriptor]
self.associated_type = ImageModelDescriptor

def enforce(self, value: object):
def enforce(self, value: ModelDescriptor):
if torch is not None:
assert isinstance(
value, (RestorationModelDescriptor, SRModelDescriptor)
value, ImageModelDescriptor
), "Expected a supported single image PyTorch model."
assert value.purpose in (
"SR",
"Restoration",
), "Expected a Super-Resolution or Restoration model."
return value

Expand All @@ -79,13 +76,16 @@ def __init__(
),
)
if torch is not None:
self.associated_type = FaceSRModelDescriptor
self.associated_type = ImageModelDescriptor

def enforce(self, value: object):
def enforce(self, value: ModelDescriptor):
if torch is not None:
assert isinstance(
value, FaceSRModelDescriptor
), "Expected a Face-specific Super-Resolution model."
value, ImageModelDescriptor
), "Expected a supported single image PyTorch model."
assert value.purpose in (
"FaceSR"
), "Expected a Face Super-Resolution model."
return value


Expand All @@ -101,13 +101,16 @@ def __init__(
),
)
if torch is not None:
self.associated_type = InpaintModelDescriptor
self.associated_type = MaskedImageModelDescriptor

def enforce(self, value: object):
def enforce(self, value: ModelDescriptor):
if torch is not None:
assert isinstance(
value, InpaintModelDescriptor
), "Expected an inpainting-specific model."
value, MaskedImageModelDescriptor
), "Expected a supported masked-image PyTorch model."
assert value.purpose in (
"Inpaint"
), "Expected a Face Super-Resolution model."
return value


Expand Down
19 changes: 1 addition & 18 deletions backend/src/nodes/properties/outputs/pytorch_outputs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from __future__ import annotations

from spandrel import (
FaceSRModelDescriptor,
InpaintModelDescriptor,
ModelDescriptor,
RestorationModelDescriptor,
SRModelDescriptor,
)

import navi
Expand All @@ -14,19 +10,6 @@
from ...utils.format import format_channel_numbers


def get_sub_type(model_descriptor: ModelDescriptor) -> str:
if isinstance(model_descriptor, SRModelDescriptor):
return "SR"
elif isinstance(model_descriptor, InpaintModelDescriptor):
return "Inpainting"
elif isinstance(model_descriptor, RestorationModelDescriptor):
return "Restoration"
elif isinstance(model_descriptor, FaceSRModelDescriptor): # type: ignore <- it wants me to just put this in an else
return "FaceSR"
else:
return "Unknown"


class ModelOutput(BaseOutput):
def __init__(
self,
Expand All @@ -53,7 +36,7 @@ def get_broadcast_type(self, value: ModelDescriptor):
"inputChannels": value.input_channels,
"outputChannels": value.output_channels,
"arch": navi.literal(value.architecture),
"subType": navi.literal(get_sub_type(value)),
"subType": navi.literal(value.purpose),
"size": navi.literal("x".join(value.tags)),
},
)
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_pytorch():
Dependency(
display_name="Spandrel",
pypi_name="spandrel",
version="0.0.4",
version="0.1.0",
size_estimate=180.7 * KB,
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from spandrel import InpaintModelDescriptor
from spandrel import MaskedImageModelDescriptor

import navi
from nodes.impl.image_utils import as_3d
Expand Down Expand Up @@ -51,7 +51,7 @@ def pad_img_to_modulo(
def inpaint(
img: np.ndarray,
mask: np.ndarray,
model: InpaintModelDescriptor,
model: MaskedImageModelDescriptor,
options: PyTorchSettings,
):
with torch.no_grad():
Expand Down Expand Up @@ -152,7 +152,7 @@ def inpaint(
def inpaint_node(
img: np.ndarray,
mask: np.ndarray,
model: InpaintModelDescriptor,
model: MaskedImageModelDescriptor,
) -> np.ndarray:
"""Inpaint an image"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch
from sanic.log import logger
from spandrel import RestorationModelDescriptor, SRModelDescriptor
from spandrel import ImageModelDescriptor

from nodes.groups import Condition, if_group
from nodes.impl.pytorch.auto_split import pytorch_auto_split
Expand All @@ -29,7 +29,7 @@

def upscale(
img: np.ndarray,
model: SRModelDescriptor | RestorationModelDescriptor,
model: ImageModelDescriptor,
tile_size: TileSize,
options: PyTorchSettings,
):
Expand Down Expand Up @@ -130,7 +130,7 @@ def estimate():
)
def upscale_image_node(
img: np.ndarray,
model: SRModelDescriptor | RestorationModelDescriptor,
model: ImageModelDescriptor,
tile_size: TileSize,
separate_alpha: bool,
) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from appdirs import user_data_dir
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from sanic.log import logger
from spandrel import FaceSRModelDescriptor
from spandrel import ImageModelDescriptor
from torchvision.transforms.functional import normalize as tv_normalize

from nodes.groups import Condition, if_group
Expand Down Expand Up @@ -37,7 +37,7 @@ def upscale(
img: np.ndarray,
background_img: np.ndarray | None,
face_helper: FaceRestoreHelper,
face_model: FaceSRModelDescriptor,
face_model: ImageModelDescriptor,
weight: float,
exec_options: PyTorchSettings,
device: torch.device,
Expand Down Expand Up @@ -149,7 +149,7 @@ def upscale(
)
def upscale_face_node(
img: np.ndarray,
face_model: FaceSRModelDescriptor,
face_model: ImageModelDescriptor,
background_img: np.ndarray | None,
scale: int,
weight: float,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from spandrel import RestorationModelDescriptor, SRModelDescriptor
from spandrel import ImageModelDescriptor
from spandrel.architectures.DAT import DAT
from spandrel.architectures.HAT import HAT
from spandrel.architectures.OmniSR import OmniSR
Expand Down Expand Up @@ -45,7 +45,7 @@
],
)
def convert_to_ncnn_node(
model: SRModelDescriptor | RestorationModelDescriptor, is_fp16: int
model: ImageModelDescriptor, is_fp16: int
) -> tuple[NcnnModelWrapper, str]:
if onnx_convert_to_ncnn_node is None:
raise ModuleNotFoundError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from io import BytesIO

import torch
from spandrel import RestorationModelDescriptor, SRModelDescriptor
from spandrel import ImageModelDescriptor
from spandrel.architectures.SCUNet import SCUNet

from nodes.impl.onnx.model import OnnxGeneric
Expand Down Expand Up @@ -33,7 +33,7 @@
],
)
def convert_to_onnx_node(
model: SRModelDescriptor | RestorationModelDescriptor, is_fp16: int
model: ImageModelDescriptor, is_fp16: int
) -> tuple[OnnxGeneric, str]:
assert not isinstance(
model.model, SCUNet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ def check_can_interp(model_a: dict, model_b: dict):
return False
interp_50 = perform_interp(model_a, model_b, 50)
model_descriptor = ModelLoader(torch.device("cpu")).load_from_state_dict(interp_50)
size = (
model_descriptor.size_requirements.minimum
if model_descriptor.size_requirements.minimum is not None
else 3
)
size = max(model_descriptor.size_requirements.minimum, 3)
assert isinstance(size, int), "min_size_restriction must be an int"
fake_img = np.ones((size, size, model_descriptor.input_channels), dtype=np.float32)
del interp_50
Expand Down
7 changes: 1 addition & 6 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
"exclude": [
"**/__pycache__"
],
"ignore": [
"backend/src/nodes/impl/pytorch/architecture"
],

"ignore": [],
"typeCheckingMode": "basic",
"useLibraryCodeForTypes": false,

"strictListInference": true,
"strictDictionaryInference": true,
"strictSetInference": true,

"reportDuplicateImport": "warning",
"reportImportCycles": "error",
"reportIncompatibleVariableOverride": "error",
Expand Down

0 comments on commit 7dc9e66

Please sign in to comment.