Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved output types for the Resolutions node #2937

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions backend/src/nodes/properties/inputs/generic_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,11 @@ def __init__(
for variant in enum:
value = variant.value
assert isinstance(value, (int, str))
assert (
re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", variant.name) is not None
), f"Expected the name of {enum.__name__}.{variant.name} to be snake case."

name = split_snake_case(variant.name)
variant_type = f"{type_name}::{join_pascal_case(name)}"
option_label = option_labels.get(variant, join_space_case(name))
variant_type = EnumInput.get_variant_type(variant, type_name)
option_label = option_labels.get(
variant, join_space_case(split_snake_case(variant.name))
)
condition = conditions.get(variant)
if condition is not None:
condition = condition.to_json()
Expand Down Expand Up @@ -301,6 +299,22 @@ def __init__(

self.associated_type = enum

@staticmethod
def get_variant_type(variant: Enum, type_name: str | None = None) -> str:
"""
Returns the full type name of a variant of an enum.
"""

enum = variant.__class__
if type_name is None:
type_name = enum.__name__

assert (
re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", variant.name) is not None
), f"Expected the name of {enum.__name__}.{variant.name} to be snake case."

return f"{type_name}::{join_pascal_case(split_snake_case(variant.name))}"

def enforce(self, value: object) -> E:
value = super().enforce(value)
return self.enum(value)
Expand Down
51 changes: 32 additions & 19 deletions backend/src/packages/chaiNNer_standard/utility/value/resolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from enum import Enum

import navi
from nodes.groups import if_enum_group
from nodes.properties.inputs import EnumInput, NumberInput
from nodes.properties.outputs import NumberOutput
Expand Down Expand Up @@ -96,29 +97,41 @@ class ResList(Enum):
ResList.SQ8192: "Square 8192x8192",
ResList.CUSTOM: "Custom Resolution",
},
),
).with_id(0),
if_enum_group(0, ResList.CUSTOM)(
NumberInput(
"Width",
min=1,
max=None,
default=1920,
unit="px",
has_handle=False,
),
NumberInput(
"Height",
min=1,
max=None,
default=1080,
unit="px",
has_handle=False,
),
NumberInput("Width", min=1, default=1920, unit="px"),
NumberInput("Height", min=1, default=1080, unit="px"),
),
],
outputs=[
NumberOutput("Width", output_type="int(1..)"),
NumberOutput("Height", output_type="int(1..)"),
NumberOutput(
"Width",
output_type=navi.match(
"Input0",
(EnumInput.get_variant_type(ResList.CUSTOM), None, "Input1"),
default=navi.match(
"Input0",
*(
(EnumInput.get_variant_type(v), None, w)
for v, (w, _) in RESOLUTIONS.items()
),
),
),
),
NumberOutput(
"Height",
output_type=navi.match(
"Input0",
(EnumInput.get_variant_type(ResList.CUSTOM), None, "Input2"),
default=navi.match(
"Input0",
*(
(EnumInput.get_variant_type(v), None, h)
for v, (_, h) in RESOLUTIONS.items()
),
),
),
),
],
)
def resolutions_node(
Expand Down
Loading