-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathnn_upscale.py
70 lines (57 loc) · 1.99 KB
/
nn_upscale.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from .latent_resizer import LatentResizer
from comfy import model_management
import os
class NNLatentUpscale:
"""
Upscales SDXL latent using neural network
"""
def __init__(self):
self.local_dir = os.path.dirname(os.path.realpath(__file__))
self.scale_factor = 0.13025
self.dtype = torch.float32
if model_management.should_use_fp16():
self.dtype = torch.float16
self.weight_path = {
"SDXL": os.path.join(self.local_dir, "sdxl_resizer.pt"),
"SD 1.x": os.path.join(self.local_dir, "sd15_resizer.pt"),
}
self.version = "none"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latent": ("LATENT",),
"version": (["SDXL", "SD 1.x"],),
"upscale": (
"FLOAT",
{
"default": 1.5,
"min": 1.0,
"max": 2.0,
"step": 0.01,
"display": "number",
},
),
},
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
CATEGORY = "latent"
def upscale(self, latent, version, upscale):
device = model_management.get_torch_device()
samples = latent["samples"].to(device=device, dtype=self.dtype)
if version != self.version:
self.model = LatentResizer.load_model(
self.weight_path[version], device, self.dtype
)
self.version = version
self.model.to(device=device)
latent_out = (
self.model(self.scale_factor * samples, scale=upscale) / self.scale_factor
)
if self.dtype != torch.float32:
latent_out = latent_out.to(dtype=torch.float32)
latent_out = latent_out.to(device="cpu")
self.model.to(device=model_management.vae_offload_device())
return ({"samples": latent_out},)