diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 44471eab3cd..79c7740485b 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,11 +1,8 @@ -import sys -from typing import Any - -from fastapi.responses import HTMLResponse - # parse_args() must be called before any other imports. if it is not called first, consumers of the config # which are imported/used before parse_args() is called will get the default config values instead of the # values from the command line or config file. +import sys + from invokeai.version.invokeai_version import __version__ from .services.config import InvokeAIAppConfig @@ -22,6 +19,7 @@ import socket from inspect import signature from pathlib import Path + from typing import Any import uvicorn from fastapi import FastAPI @@ -29,7 +27,7 @@ from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.utils import get_openapi - from fastapi.responses import FileResponse + from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware @@ -58,9 +56,9 @@ from .api.sockets import SocketIO from .invocations.baseinvocation import ( BaseInvocation, + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, UIConfigBase, - _InputField, - _OutputField, ) if is_mps_available(): @@ -157,7 +155,11 @@ def custom_openapi() -> dict[str, Any]: # Add Node Editor UI helper schemas ui_config_schemas = models_json_schema( - [(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")], + [ + (UIConfigBase, "serialization"), + (InputFieldJSONSchemaExtra, "serialization"), + (OutputFieldJSONSchemaExtra, "serialization"), + ], ref_template="#/components/schemas/{model}", ) for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items(): @@ -165,7 +167,7 @@ def custom_openapi() -> dict[str, Any]: # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: - invoker_name = invoker.__name__ + invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute output_type = signature(obj=invoker.invoke).return_annotation output_type_title = output_type_titles[output_type.__name__] invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1b3e535d340..cddbd071deb 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -17,11 +17,15 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string +from invokeai.backend.util.logging import InvokeAILogger if TYPE_CHECKING: from ..services.invocation_services import InvocationServices +logger = InvokeAILogger.get_logger() + class InvalidVersionError(ValueError): pass @@ -31,7 +35,7 @@ class InvalidFieldError(TypeError): pass -class Input(str, Enum): +class Input(str, Enum, metaclass=MetaEnum): """ The type of input a field accepts. - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ @@ -45,86 +49,120 @@ class Input(str, Enum): Any = "any" -class UIType(str, Enum): +class FieldKind(str, Enum, metaclass=MetaEnum): """ - Type hints for the UI. - If a field should be provided a data type that does not exactly match the python type of the field, \ - use this to provide the type that should be used instead. See the node development docs for detail \ - on adding a new field type, which involves client-side changes. + The kind of field. + - `Input`: An input field on a node. + - `Output`: An output field on a node. + - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + allowing "metadata" for that field. + - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + but which are used to store information about the node. For example, the `id` and `type` fields are node + attributes. + + The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + startup, and when generating the OpenAPI schema for the workflow editor. """ - # region Primitives - Boolean = "boolean" - Color = "ColorField" - Conditioning = "ConditioningField" - Control = "ControlField" - Float = "float" - Image = "ImageField" - Integer = "integer" - Latents = "LatentsField" - String = "string" - # endregion + Input = "input" + Output = "output" + Internal = "internal" + NodeAttribute = "node_attribute" - # region Collection Primitives - BooleanCollection = "BooleanCollection" - ColorCollection = "ColorCollection" - ConditioningCollection = "ConditioningCollection" - ControlCollection = "ControlCollection" - FloatCollection = "FloatCollection" - ImageCollection = "ImageCollection" - IntegerCollection = "IntegerCollection" - LatentsCollection = "LatentsCollection" - StringCollection = "StringCollection" - # endregion - # region Polymorphic Primitives - BooleanPolymorphic = "BooleanPolymorphic" - ColorPolymorphic = "ColorPolymorphic" - ConditioningPolymorphic = "ConditioningPolymorphic" - ControlPolymorphic = "ControlPolymorphic" - FloatPolymorphic = "FloatPolymorphic" - ImagePolymorphic = "ImagePolymorphic" - IntegerPolymorphic = "IntegerPolymorphic" - LatentsPolymorphic = "LatentsPolymorphic" - StringPolymorphic = "StringPolymorphic" - # endregion +class UIType(str, Enum, metaclass=MetaEnum): + """ + Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + + - Model Fields + The most common node-author-facing use will be for model fields. Internally, there is no difference + between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + the field is an SDXL main model field. + + - Any Field + We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + + - Scheduler Field + Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. - # region Models - MainModel = "MainModelField" + - Internal Fields + Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + should not be used by node authors. + """ + + # region Model Field Types SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" ONNXModel = "ONNXModelField" - VaeModel = "VaeModelField" + VaeModel = "VAEModelField" LoRAModel = "LoRAModelField" ControlNetModel = "ControlNetModelField" IPAdapterModel = "IPAdapterModelField" - UNet = "UNetField" - Vae = "VaeField" - CLIP = "ClipField" # endregion - # region Iterate/Collect - Collection = "Collection" - CollectionItem = "CollectionItem" + # region Misc Field Types + Scheduler = "SchedulerField" + Any = "AnyField" + # endregion + + # region Internal Field Types + _Collection = "CollectionField" + _CollectionItem = "CollectionItemField" # endregion - # region Misc - Enum = "enum" - Scheduler = "Scheduler" - WorkflowField = "WorkflowField" - IsIntermediate = "IsIntermediate" - BoardField = "BoardField" - Any = "Any" - MetadataItem = "MetadataItem" - MetadataItemCollection = "MetadataItemCollection" - MetadataItemPolymorphic = "MetadataItemPolymorphic" - MetadataDict = "MetadataDict" + # region DEPRECATED + Boolean = "DEPRECATED_Boolean" + Color = "DEPRECATED_Color" + Conditioning = "DEPRECATED_Conditioning" + Control = "DEPRECATED_Control" + Float = "DEPRECATED_Float" + Image = "DEPRECATED_Image" + Integer = "DEPRECATED_Integer" + Latents = "DEPRECATED_Latents" + String = "DEPRECATED_String" + BooleanCollection = "DEPRECATED_BooleanCollection" + ColorCollection = "DEPRECATED_ColorCollection" + ConditioningCollection = "DEPRECATED_ConditioningCollection" + ControlCollection = "DEPRECATED_ControlCollection" + FloatCollection = "DEPRECATED_FloatCollection" + ImageCollection = "DEPRECATED_ImageCollection" + IntegerCollection = "DEPRECATED_IntegerCollection" + LatentsCollection = "DEPRECATED_LatentsCollection" + StringCollection = "DEPRECATED_StringCollection" + BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" + ColorPolymorphic = "DEPRECATED_ColorPolymorphic" + ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" + ControlPolymorphic = "DEPRECATED_ControlPolymorphic" + FloatPolymorphic = "DEPRECATED_FloatPolymorphic" + ImagePolymorphic = "DEPRECATED_ImagePolymorphic" + IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" + LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" + StringPolymorphic = "DEPRECATED_StringPolymorphic" + MainModel = "DEPRECATED_MainModel" + UNet = "DEPRECATED_UNet" + Vae = "DEPRECATED_Vae" + CLIP = "DEPRECATED_CLIP" + Collection = "DEPRECATED_Collection" + CollectionItem = "DEPRECATED_CollectionItem" + Enum = "DEPRECATED_Enum" + WorkflowField = "DEPRECATED_WorkflowField" + IsIntermediate = "DEPRECATED_IsIntermediate" + BoardField = "DEPRECATED_BoardField" + MetadataItem = "DEPRECATED_MetadataItem" + MetadataItemCollection = "DEPRECATED_MetadataItemCollection" + MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" + MetadataDict = "DEPRECATED_MetadataDict" # endregion -class UIComponent(str, Enum): +class UIComponent(str, Enum, metaclass=MetaEnum): """ - The type of UI component to use for a field, used to override the default components, which are \ + The type of UI component to use for a field, used to override the default components, which are inferred from the field type. """ @@ -133,7 +171,7 @@ class UIComponent(str, Enum): Slider = "slider" -class _InputField(BaseModel): +class InputFieldJSONSchemaExtra(BaseModel): """ *DO NOT USE* This helper class is used to tell the client about our custom field attributes via OpenAPI @@ -142,12 +180,15 @@ class _InputField(BaseModel): """ input: Input - ui_hidden: bool - ui_type: Optional[UIType] - ui_component: Optional[UIComponent] - ui_order: Optional[int] - ui_choice_labels: Optional[dict[str, str]] - item_default: Optional[Any] + orig_required: bool + field_kind: FieldKind + default: Optional[Any] = None + orig_default: Optional[Any] = None + ui_hidden: bool = False + ui_type: Optional[UIType] = None + ui_component: Optional[UIComponent] = None + ui_order: Optional[int] = None + ui_choice_labels: Optional[dict[str, str]] = None model_config = ConfigDict( validate_assignment=True, @@ -155,7 +196,7 @@ class _InputField(BaseModel): ) -class _OutputField(BaseModel): +class OutputFieldJSONSchemaExtra(BaseModel): """ *DO NOT USE* This helper class is used to tell the client about our custom field attributes via OpenAPI @@ -163,6 +204,7 @@ class _OutputField(BaseModel): purpose in the backend. """ + field_kind: FieldKind ui_hidden: bool ui_type: Optional[UIType] ui_order: Optional[int] @@ -180,6 +222,7 @@ def get_type(klass: BaseModel) -> str: def InputField( # copied from pydantic's Field + # TODO: Can we support default_factory? default: Any = _Unset, default_factory: Callable[[], Any] | None = _Unset, title: str | None = _Unset, @@ -203,12 +246,11 @@ def InputField( ui_hidden: bool = False, ui_order: Optional[int] = None, ui_choice_labels: Optional[dict[str, str]] = None, - item_default: Optional[Any] = None, ) -> Any: """ Creates an input field for an invocation. - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ that adds a few extra parameters to support graph execution and the node editor UI. :param Input input: [Input.Any] The kind of input this field requires. \ @@ -228,28 +270,59 @@ def InputField( For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ For this case, you could provide `UIComponent.Textarea`. - : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. - : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. - : param bool item_default: [None] Specifies the default item value, if this is a collection input. \ - Ignored for non-collection fields. + :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. """ - json_schema_extra_: dict[str, Any] = { - "input": input, - "ui_type": ui_type, - "ui_component": ui_component, - "ui_hidden": ui_hidden, - "ui_order": ui_order, - "item_default": item_default, - "ui_choice_labels": ui_choice_labels, - "_field_kind": "input", - } + json_schema_extra_ = InputFieldJSONSchemaExtra( + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + ui_order=ui_order, + ui_choice_labels=ui_choice_labels, + field_kind=FieldKind.Input, + orig_required=True, + ) + + """ + There is a conflict between the typing of invocation definitions and the typing of an invocation's + `invoke()` function. + + On instantiation of a node, the invocation definition is used to create the python class. At this time, + any number of fields may be optional, because they may be provided by connections. + + On calling of `invoke()`, however, those fields may be required. + + For example, consider an ResizeImageInvocation with an `image: ImageField` field. + + `image` is required during the call to `invoke()`, but when the python class is instantiated, + the field may not be present. This is fine, because that image field will be provided by a + connection from an ancestor node, which outputs an image. + + This means we want to type the `image` field as optional for the node class definition, but required + for the `invoke()` function. + + If we use `typing.Optional` in the node class definition, the field will be typed as optional in the + `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or + any static type analysis tools will complain. + + To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, + but secretly make them optional in `InputField()`. We also store the original required bool and/or default + value. When we call `invoke()`, we use this stored information to do an additional check on the class. + """ + + if default_factory is not _Unset and default_factory is not None: + default = default_factory() + del default_factory + logger.warn('"default_factory" is not supported, calling it now to set "default"') + # These are the args we may wish pass to the pydantic `Field()` function field_args = { "default": default, - "default_factory": default_factory, "title": title, "description": description, "pattern": pattern, @@ -266,70 +339,34 @@ def InputField( "max_length": max_length, } - """ - Invocation definitions have their fields typed correctly for their `invoke()` functions. - This typing is often more specific than the actual invocation definition requires, because - fields may have values provided only by connections. - - For example, consider an ResizeImageInvocation with an `image: ImageField` field. - - `image` is required during the call to `invoke()`, but when the python class is instantiated, - the field may not be present. This is fine, because that image field will be provided by a - an ancestor node that outputs the image. - - So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then - we need to handle a lot of extra logic in the `invoke()` function to check if the field has a - value or not. This is very tedious. - - Ideally, the invocation definition would be able to specify that the field is required during - invocation, but optional during instantiation. So the field would be typed as `image: ImageField`, - but when calling the `invoke()` function, we raise an error if the field is not present. - - To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do - extra validation when calling `invoke()`. - - There is some additional logic here to cleaning create the pydantic field via the wrapper. - """ - - # Filter out field args not provided + # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} - if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined): - raise ValueError("Cannot specify both default and default_factory") + # Because we are manually making fields optional, we need to store the original required bool for reference later + json_schema_extra_.orig_required = default is PydanticUndefined - # because we are manually making fields optional, we need to store the original required bool for reference later - if default is PydanticUndefined and default_factory is PydanticUndefined: - json_schema_extra_.update({"orig_required": True}) - else: - json_schema_extra_.update({"orig_required": False}) - - # make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one - if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined: + # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one + if input is Input.Any or input is Input.Connection: default_ = None if default is PydanticUndefined else default provided_args.update({"default": default_}) if default is not PydanticUndefined: - # before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value - json_schema_extra_.update({"default": default}) - json_schema_extra_.update({"orig_default": default}) - elif default is not PydanticUndefined and default_factory is PydanticUndefined: + # Before invoking, we'll check for the original default value and set it on the field if the field has no value + json_schema_extra_.default = default + json_schema_extra_.orig_default = default + elif default is not PydanticUndefined: default_ = default provided_args.update({"default": default_}) - json_schema_extra_.update({"orig_default": default_}) - elif default_factory is not PydanticUndefined: - provided_args.update({"default_factory": default_factory}) - # TODO: cannot serialize default_factory... - # json_schema_extra_.update(dict(orig_default_factory=default_factory)) + json_schema_extra_.orig_default = default_ return Field( **provided_args, - json_schema_extra=json_schema_extra_, + json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), ) def OutputField( # copied from pydantic's Field default: Any = _Unset, - default_factory: Callable[[], Any] | None = _Unset, title: str | None = _Unset, description: str | None = _Unset, pattern: str | None = _Unset, @@ -362,13 +399,12 @@ def OutputField( `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ - : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ """ return Field( default=default, - default_factory=default_factory, title=title, description=description, pattern=pattern, @@ -383,12 +419,12 @@ def OutputField( decimal_places=decimal_places, min_length=min_length, max_length=max_length, - json_schema_extra={ - "ui_type": ui_type, - "ui_hidden": ui_hidden, - "ui_order": ui_order, - "_field_kind": "output", - }, + json_schema_extra=OutputFieldJSONSchemaExtra( + ui_type=ui_type, + ui_hidden=ui_hidden, + ui_order=ui_order, + field_kind=FieldKind.Output, + ).model_dump(exclude_none=True), ) @@ -538,7 +574,7 @@ def get_output_type(cls) -> BaseInvocationOutput: return signature(cls.invoke).return_annotation @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: # Add the various UI-facing attributes to the schema. These are used to build the invocation templates. uiconfig = getattr(model_class, "UIConfig", None) if uiconfig and hasattr(uiconfig, "title"): @@ -604,15 +640,17 @@ def get_type(self) -> str: id: str = Field( default_factory=uuid_string, description="The id of this instance of an invocation. Must be unique among all instances of invocations.", - json_schema_extra={"_field_kind": "internal"}, + json_schema_extra={"field_kind": FieldKind.NodeAttribute}, ) is_intermediate: bool = Field( default=False, description="Whether or not this is an intermediate invocation.", - json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"}, + json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute}, ) use_cache: bool = Field( - default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"} + default=True, + description="Whether or not to use the cache", + json_schema_extra={"field_kind": FieldKind.NodeAttribute}, ) UIConfig: ClassVar[Type[UIConfigBase]] @@ -629,12 +667,15 @@ def get_type(self) -> str: TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation) -RESERVED_INPUT_FIELD_NAMES = { +RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = { "id", "is_intermediate", "use_cache", "type", "workflow", +} + +RESERVED_INPUT_FIELD_NAMES = { "metadata", } @@ -653,39 +694,56 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None """ Validates the fields of an invocation or invocation output: - must not override any pydantic reserved fields + - must not end with "Collection" or "Polymorphic" as these are reserved for internal use - must be created via `InputField`, `OutputField`, or be an internal field defined in this file """ for name, field in model_fields.items(): if name in RESERVED_PYDANTIC_FIELD_NAMES: raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)') - field_kind = ( - # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file - field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None - ) + if not field.annotation: + raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)') + + if not isinstance(field.json_schema_extra, dict): + raise InvalidFieldError( + f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)' + ) + + field_kind = field.json_schema_extra.get("field_kind", None) # must have a field_kind - if field_kind is None or field_kind not in {"input", "output", "internal"}: + if not isinstance(field_kind, FieldKind): raise InvalidFieldError( f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)' ) - if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES: + if field_kind is FieldKind.Input and ( + name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES + ): raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)') - if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES: + if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES: raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)') - # internal fields *must* be in the reserved list + if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES: + raise InvalidFieldError( + f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)' + ) + + # node attribute fields *must* be in the reserved list if ( - field_kind == "internal" - and name not in RESERVED_INPUT_FIELD_NAMES + field_kind is FieldKind.NodeAttribute + and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES and name not in RESERVED_OUTPUT_FIELD_NAMES ): raise InvalidFieldError( - f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)' + f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)' ) + ui_type = field.json_schema_extra.get("ui_type", None) + if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"): + logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring") + field.json_schema_extra.pop("ui_type") return None @@ -749,7 +807,7 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]: invocation_type_annotation = Literal[invocation_type] # type: ignore invocation_type_field = Field( - title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"} + title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) docstring = cls.__doc__ @@ -795,7 +853,9 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: # Add the output type to the model. output_type_annotation = Literal[output_type] # type: ignore - output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"}) + output_type_field = Field( + title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} + ) docstring = cls.__doc__ cls = create_model( @@ -827,7 +887,7 @@ class WorkflowField(RootModel): class WithWorkflow(BaseModel): workflow: Optional[WorkflowField] = Field( - default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"} + default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) @@ -845,5 +905,11 @@ class MetadataField(RootModel): class WithMetadata(BaseModel): metadata: Optional[MetadataField] = Field( - default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"} + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), ) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index f26eebe1ff5..4c7b6f94cd4 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,7 +5,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation @@ -55,7 +55,7 @@ def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: title="Random Range", tags=["range", "integer", "random", "collection"], category="collections", - version="1.0.0", + version="1.0.1", use_cache=False, ) class RandomRangeInvocation(BaseInvocation): @@ -65,10 +65,10 @@ class RandomRangeInvocation(BaseInvocation): high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") size: int = InputField(default=1, description="The number of values to generate") seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description="The seed for the RNG (omit for random)", - default_factory=get_random_seed, ) def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: diff --git a/invokeai/app/invocations/custom_nodes/init.py b/invokeai/app/invocations/custom_nodes/init.py index c6708e95a7f..a379a35fbfe 100644 --- a/invokeai/app/invocations/custom_nodes/init.py +++ b/invokeai/app/invocations/custom_nodes/init.py @@ -39,6 +39,8 @@ logger.warn(f"Could not load {init}") continue + logger.info(f"Loading node pack {spec.name}") + module = module_from_spec(spec) sys.modules[spec.name] = module spec.loader.exec_module(module) @@ -47,5 +49,5 @@ del init, module_name - -logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}") +if loaded_count > 0: + logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}") diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 9905aa1b5ec..0822a4ce2df 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch @@ -154,17 +154,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput: ) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1") class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata): """Infills transparent areas of an image with tiles of the image""" image: ImageField = InputField(description="The image to infill") tile_size: int = InputField(default=32, ge=1, description="The tile size (px)") seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description="The seed to use for tile generation (omit for random)", - default_factory=get_random_seed, ) def invoke(self, context: InvocationContext) -> ImageOutput: diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 485932e18dd..e0f582eab82 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -11,7 +11,6 @@ InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -67,7 +66,7 @@ class IPAdapterInvocation(BaseInvocation): # weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float) weight: Union[float, List[float]] = InputField( - default=1, ge=-1, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight" + default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight" ) begin_step_percent: float = InputField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 9d4afb70204..d438bcae02e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -274,7 +274,10 @@ class DenoiseLatentsInvocation(BaseInvocation): ui_order=7, ) latents: Optional[LatentsField] = InputField( - default=None, description=FieldDescriptions.latents, input=Input.Connection + default=None, + description=FieldDescriptions.latents, + input=Input.Connection, + ui_order=4, ) denoise_mask: Optional[DenoiseMaskField] = InputField( default=None, diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 8cce9bdb881..99dcc72999b 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -14,7 +14,6 @@ InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -395,7 +394,6 @@ class VaeLoaderInvocation(BaseInvocation): vae_model: VAEModelField = InputField( description=FieldDescriptions.vae_model, input=Input.Direct, - ui_type=UIType.VaeModel, title="VAE", ) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index e975b7bf22b..b1ee91e1cdf 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -6,7 +6,7 @@ from invokeai.app.invocations.latent import LatentsField from invokeai.app.shared.fields import FieldDescriptions -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( @@ -83,16 +83,16 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): title="Noise", tags=["latents", "noise"], category="latents", - version="1.0.0", + version="1.0.1", ) class NoiseInvocation(BaseInvocation): """Generates latent noise.""" seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description=FieldDescriptions.seed, - default_factory=get_random_seed, ) width: int = InputField( default=512, diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index ccfb7dcbb3e..afe8ff06d9d 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -62,12 +62,12 @@ def invoke(self, context: InvocationContext) -> BooleanOutput: title="Boolean Collection Primitive", tags=["primitives", "boolean", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class BooleanCollectionInvocation(BaseInvocation): """A collection of boolean primitive values""" - collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values") + collection: list[bool] = InputField(default=[], description="The collection of boolean values") def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -111,12 +111,12 @@ def invoke(self, context: InvocationContext) -> IntegerOutput: title="Integer Collection Primitive", tags=["primitives", "integer", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class IntegerCollectionInvocation(BaseInvocation): """A collection of integer primitive values""" - collection: list[int] = InputField(default_factory=list, description="The collection of integer values") + collection: list[int] = InputField(default=[], description="The collection of integer values") def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -158,12 +158,12 @@ def invoke(self, context: InvocationContext) -> FloatOutput: title="Float Collection Primitive", tags=["primitives", "float", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class FloatCollectionInvocation(BaseInvocation): """A collection of float primitive values""" - collection: list[float] = InputField(default_factory=list, description="The collection of float values") + collection: list[float] = InputField(default=[], description="The collection of float values") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -205,12 +205,12 @@ def invoke(self, context: InvocationContext) -> StringOutput: title="String Collection Primitive", tags=["primitives", "string", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class StringCollectionInvocation(BaseInvocation): """A collection of string primitive values""" - collection: list[str] = InputField(default_factory=list, description="The collection of string values") + collection: list[str] = InputField(default=[], description="The collection of string values") def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -467,13 +467,13 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: title="Conditioning Collection Primitive", tags=["primitives", "conditioning", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class ConditioningCollectionInvocation(BaseInvocation): """A collection of conditioning tensor primitive values""" collection: list[ConditioningField] = InputField( - default_factory=list, + default=[], description="The collection of conditioning tensors", ) diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 8ff8ca762c2..2412a000798 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -9,7 +9,6 @@ InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -59,7 +58,7 @@ class T2IAdapterInvocation(BaseInvocation): ui_order=-1, ) weight: Union[float, list[float]] = InputField( - default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight" + default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight" ) begin_step_percent: float = InputField( default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)" diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 29af1e2333c..ee86ef17c65 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -205,7 +205,7 @@ class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" item: Any = OutputField( - description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem + description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem ) @@ -215,7 +215,7 @@ class IterateInvocation(BaseInvocation): """Iterates over a list of items""" collection: list[Any] = InputField( - description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection + description="The list of items to iterate over", default=[], ui_type=UIType._Collection ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) @@ -227,7 +227,7 @@ def invoke(self, context: InvocationContext) -> IterateInvocationOutput: @invocation_output("collect_output") class CollectInvocationOutput(BaseInvocationOutput): collection: list[Any] = OutputField( - description="The collection of input items", title="Collection", ui_type=UIType.Collection + description="The collection of input items", title="Collection", ui_type=UIType._Collection ) @@ -238,12 +238,12 @@ class CollectInvocation(BaseInvocation): item: Optional[Any] = InputField( default=None, description="The item to collect (all inputs must be of the same type)", - ui_type=UIType.CollectionItem, + ui_type=UIType._CollectionItem, title="Collection Item", input=Input.Connection, ) collection: list[Any] = InputField( - description="The collection, will be provided on execution", default_factory=list, ui_hidden=True + description="The collection, will be provided on execution", default=[], ui_hidden=True ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 9951b21cd8c..faa870bd326 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -805,6 +805,8 @@ "clipField": "Clip", "clipFieldDescription": "Tokenizer and text_encoder submodels.", "collection": "Collection", + "collectionFieldType": "{{name}} Collection", + "polymorphicFieldType": "{{name}} Polymorphic", "collectionDescription": "TODO", "collectionItem": "Collection Item", "collectionItemDescription": "TODO", @@ -891,10 +893,15 @@ "mainModelField": "Model", "mainModelFieldDescription": "TODO", "maybeIncompatible": "May be Incompatible With Installed", - "mismatchedVersion": "Has Mismatched Version", + "mismatchedVersion": "Invalid node: node {{node}} of type {{type}} has mismatched version (try updating?)", "missingCanvaInitImage": "Missing canvas init image", "missingCanvaInitMaskImages": "Missing canvas init and mask images", - "missingTemplate": "Missing Template", + "missingTemplate": "Invalid node: node {{node}} of type {{type}} missing template (not installed?)", + "sourceNodeDoesNotExist": "Invalid edge: source/output node {{node}} does not exist", + "targetNodeDoesNotExist": "Invalid edge: target/input node {{node}} does not exist", + "sourceNodeFieldDoesNotExist": "Invalid edge: source/output field {{node}}.{{field}} does not exist", + "targetNodeFieldDoesNotExist": "Invalid edge: target/input field {{node}}.{{field}} does not exist", + "deletedInvalidEdge": "Deleted invalid edge {{source}} -> {{target}}", "noConnectionData": "No connection data", "noConnectionInProgress": "No connection in progress", "node": "Node", @@ -954,10 +961,17 @@ "stringDescription": "Strings are text.", "stringPolymorphic": "String Polymorphic", "stringPolymorphicDescription": "A collection of strings.", - "unableToLoadWorkflow": "Unable to Validate Workflow", + "unableToLoadWorkflow": "Unable to Load Workflow", "unableToParseEdge": "Unable to parse edge", "unableToParseNode": "Unable to parse node", + "unableToUpdateNode": "Unable to update node", "unableToValidateWorkflow": "Unable to Validate Workflow", + "unknownErrorValidatingWorkflow": "Unknown error validating workflow", + "inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})", + "outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})", + "unableToExtractSchemaNameFromRef": "unable to extract schema name from ref", + "unsupportedArrayItemType": "unsupported array item type \"{{type}}\"", + "unableToParseFieldType": "unable to parse field type", "uNetField": "UNet", "uNetFieldDescription": "UNet submodel.", "unhandledInputProperty": "Unhandled input property", @@ -971,8 +985,9 @@ "unkownInvocation": "Unknown Invocation type", "unknownOutput": "Unknown output", "updateNode": "Update Node", - "updateAllNodes": "Update All Nodes", "updateApp": "Update App", + "updateAllNodes": "Update All Nodes", + "allNodesUpdated": "All Nodes Updated", "unableToUpdateNodes_one": "Unable to update {{count}} node", "unableToUpdateNodes_other": "Unable to update {{count}} nodes", "vaeField": "Vae", @@ -981,6 +996,8 @@ "vaeModelFieldDescription": "TODO", "validateConnections": "Validate Connections and Graph", "validateConnectionsHelp": "Prevent invalid connections from being made, and invalid graphs from being invoked", + "unableToGetWorkflowVersion": "Unable to get workflow schema version", + "unrecognizedWorkflowVersion": "Unrecognized workflow schema version {{version}}", "version": "Version", "versionUnknown": " Version Unknown", "workflow": "Workflow", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 9c1727fc79d..4a41cb3db67 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -71,7 +71,7 @@ import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } f import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; import { addTabChangedListener } from './listeners/tabChanged'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; -import { addWorkflowLoadedListener } from './listeners/workflowLoaded'; +import { addWorkflowLoadRequestedListener } from './listeners/workflowLoadRequested'; import { addUpdateAllNodesRequestedListener } from './listeners/updateAllNodesRequested'; export const listenerMiddleware = createListenerMiddleware(); @@ -178,7 +178,7 @@ addBoardIdSelectedListener(); addReceivedOpenAPISchemaListener(); // Workflows -addWorkflowLoadedListener(); +addWorkflowLoadRequestedListener(); addUpdateAllNodesRequestedListener(); // DND diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index 1996ec99a52..0966a8c86b6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -12,10 +12,10 @@ import { addToast } from 'features/system/store/systemSlice'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; -import { isImageOutput } from 'services/api/guards'; import { BatchConfig, ImageDTO } from 'services/api/types'; import { socketInvocationComplete } from 'services/events/actions'; import { startAppListening } from '..'; +import { isImageOutput } from 'features/nodes/types/common'; export const addControlNetImageProcessedListener = () => { startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index bd5422841f8..f23b7284fe9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -5,19 +5,20 @@ import { controlAdapterProcessedImageChanged, selectControlAdapterAll, } from 'features/controlAdapters/store/controlAdaptersSlice'; +import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions'; import { isModalOpenChanged } from 'features/deleteImageModal/store/slice'; import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isImageFieldInputInstance } from 'features/nodes/types/field'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clamp, forEach } from 'lodash-es'; import { api } from 'services/api'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; import { startAppListening } from '..'; -import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; export const addRequestedSingleImageDeletionListener = () => { startAppListening({ @@ -121,7 +122,7 @@ export const addRequestedSingleImageDeletionListener = () => { forEach(node.data.inputs, (input) => { if ( - input.type === 'ImageField' && + isImageFieldInputInstance(input) && input.value?.image_name === imageDTO.image_name ) { dispatch( @@ -241,7 +242,7 @@ export const addRequestedMultipleImageDeletionListener = () => { forEach(node.data.inputs, (input) => { if ( - input.type === 'ImageField' && + isImageFieldInputInstance(input) && input.value?.image_name === imageDTO.image_name ) { dispatch( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 6ed0b93e996..e4175affe6c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -12,12 +12,12 @@ import { setWidth, vaeSelected, } from 'features/parameters/store/generationSlice'; -import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; import { startAppListening } from '..'; +import { zParameterModel } from 'features/parameters/types/parameterSchemas'; export const addModelSelectedListener = () => { startAppListening({ @@ -26,7 +26,7 @@ export const addModelSelectedListener = () => { const log = logger('models'); const state = getState(); - const result = zMainOrOnnxModel.safeParse(action.payload); + const result = zParameterModel.safeParse(action.payload); if (!result.success) { log.error( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 785630495bc..afb390470ba 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -11,9 +11,9 @@ import { vaeSelected, } from 'features/parameters/store/generationSlice'; import { - zMainOrOnnxModel, - zSDXLRefinerModel, - zVaeModel, + zParameterModel, + zParameterSDXLRefinerModel, + zParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; import { refinerModelChanged, @@ -67,7 +67,7 @@ export const addModelsLoadedListener = () => { return; } - const result = zMainOrOnnxModel.safeParse(models[0]); + const result = zParameterModel.safeParse(models[0]); if (!result.success) { log.error( @@ -119,7 +119,7 @@ export const addModelsLoadedListener = () => { return; } - const result = zSDXLRefinerModel.safeParse(models[0]); + const result = zParameterSDXLRefinerModel.safeParse(models[0]); if (!result.success) { log.error( @@ -170,7 +170,7 @@ export const addModelsLoadedListener = () => { return; } - const result = zVaeModel.safeParse(firstModel); + const result = zParameterVAEModel.safeParse(firstModel); if (!result.success) { log.error( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts index 5599913a189..f5b630a39d8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts @@ -15,6 +15,7 @@ export const addReceivedOpenAPISchemaListener = () => { log.debug({ schemaJSON }, 'Received OpenAPI schema'); const { nodesAllowlist, nodesDenylist } = getState().config; + const nodeTemplates = parseSchema( schemaJSON, nodesAllowlist, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index cfd69ce9bc6..bc9959b8fc9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -13,13 +13,13 @@ import { } from 'features/nodes/util/graphBuilders/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; -import { isImageOutput } from 'services/api/guards'; import { imagesAdapter } from 'services/api/util'; import { appSocketInvocationComplete, socketInvocationComplete, } from 'services/events/actions'; import { startAppListening } from '../..'; +import { isImageOutput } from 'features/nodes/types/common'; // These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them const nodeTypeDenylist = ['load_image', 'image']; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index ece6702cebb..b2383410bde 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -1,14 +1,16 @@ +import { logger } from 'app/logging/logger'; +import { updateAllNodesRequested } from 'features/nodes/store/actions'; +import { nodeReplaced } from 'features/nodes/store/nodesSlice'; import { getNeedsUpdate, updateNode, -} from 'features/nodes/hooks/useNodeVersion'; -import { updateAllNodesRequested } from 'features/nodes/store/actions'; -import { nodeReplaced } from 'features/nodes/store/nodesSlice'; -import { startAppListening } from '..'; -import { logger } from 'app/logging/logger'; +} from 'features/nodes/store/util/nodeUpdate'; +import { NodeUpdateError } from 'features/nodes/types/error'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { t } from 'i18next'; +import { startAppListening } from '..'; export const addUpdateAllNodesRequestedListener = () => { startAppListening({ @@ -20,22 +22,31 @@ export const addUpdateAllNodesRequestedListener = () => { let unableToUpdateCount = 0; - nodes.forEach((node) => { + nodes.filter(isInvocationNode).forEach((node) => { const template = templates[node.data.type]; - const needsUpdate = getNeedsUpdate(node, template); - const updatedNode = updateNode(node, template); - if (!updatedNode) { - if (needsUpdate) { + if (!template) { + unableToUpdateCount++; + return; + } + if (!getNeedsUpdate(node, template)) { + // No need to increment the count here, since we're not actually updating + return; + } + try { + const updatedNode = updateNode(node, template); + dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); + } catch (e) { + if (e instanceof NodeUpdateError) { unableToUpdateCount++; } - return; } - dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); }); if (unableToUpdateCount) { log.warn( - `Unable to update ${unableToUpdateCount} nodes. Please report this issue.` + t('nodes.unableToUpdateNodes', { + count: unableToUpdateCount, + }) ); dispatch( addToast( @@ -46,6 +57,15 @@ export const addUpdateAllNodesRequestedListener = () => { }) ) ); + } else { + dispatch( + addToast( + makeToast({ + title: t('nodes.allNodesUpdated'), + status: 'success', + }) + ) + ); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts new file mode 100644 index 00000000000..5336c639422 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -0,0 +1,105 @@ +import { logger } from 'app/logging/logger'; +import { parseify } from 'common/util/serialize'; +import { workflowLoadRequested } from 'features/nodes/store/actions'; +import { workflowLoaded } from 'features/nodes/store/nodesSlice'; +import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { WorkflowVersionError } from 'features/nodes/types/error'; +import { validateWorkflow } from 'features/nodes/util/validateWorkflow'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { t } from 'i18next'; +import { z } from 'zod'; +import { fromZodError } from 'zod-validation-error'; +import { startAppListening } from '..'; + +export const addWorkflowLoadRequestedListener = () => { + startAppListening({ + actionCreator: workflowLoadRequested, + effect: (action, { dispatch, getState }) => { + const log = logger('nodes'); + const workflow = action.payload; + const nodeTemplates = getState().nodes.nodeTemplates; + + try { + const { workflow: validatedWorkflow, warnings } = validateWorkflow( + workflow, + nodeTemplates + ); + dispatch(workflowLoaded(validatedWorkflow)); + if (!warnings.length) { + dispatch( + addToast( + makeToast({ + title: t('toast.workflowLoaded'), + status: 'success', + }) + ) + ); + } else { + dispatch( + addToast( + makeToast({ + title: t('toast.loadedWithWarnings'), + status: 'warning', + }) + ) + ); + warnings.forEach(({ message, ...rest }) => { + log.warn(rest, message); + }); + } + + dispatch(setActiveTab('nodes')); + requestAnimationFrame(() => { + $flow.get()?.fitView(); + }); + } catch (e) { + if (e instanceof WorkflowVersionError) { + // The workflow version was not recognized in the valid list of versions + log.error({ error: parseify(e) }, e.message); + dispatch( + addToast( + makeToast({ + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: e.message, + }) + ) + ); + } else if (e instanceof z.ZodError) { + // There was a problem validating the workflow itself + const { message } = fromZodError(e, { + prefix: t('nodes.workflowValidation'), + }); + log.error({ error: parseify(e) }, message); + dispatch( + addToast( + makeToast({ + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: message, + }) + ) + ); + } else { + // Some other error occurred + console.log(e); + log.error( + { error: parseify(e) }, + t('nodes.unknownErrorValidatingWorkflow') + ); + dispatch( + addToast( + makeToast({ + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: t('nodes.unknownErrorValidatingWorkflow'), + }) + ) + ); + } + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts deleted file mode 100644 index de697a70e50..00000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { workflowLoadRequested } from 'features/nodes/store/actions'; -import { workflowLoaded } from 'features/nodes/store/nodesSlice'; -import { $flow } from 'features/nodes/store/reactFlowInstance'; -import { validateWorkflow } from 'features/nodes/util/validateWorkflow'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; -import { setActiveTab } from 'features/ui/store/uiSlice'; -import { startAppListening } from '..'; -import { t } from 'i18next'; - -export const addWorkflowLoadedListener = () => { - startAppListening({ - actionCreator: workflowLoadRequested, - effect: (action, { dispatch, getState }) => { - const log = logger('nodes'); - const workflow = action.payload; - const nodeTemplates = getState().nodes.nodeTemplates; - - const { workflow: validatedWorkflow, errors } = validateWorkflow( - workflow, - nodeTemplates - ); - - dispatch(workflowLoaded(validatedWorkflow)); - - if (!errors.length) { - dispatch( - addToast( - makeToast({ - title: t('toast.workflowLoaded'), - status: 'success', - }) - ) - ); - } else { - dispatch( - addToast( - makeToast({ - title: t('toast.loadedWithWarnings'), - status: 'warning', - }) - ) - ); - errors.forEach(({ message, ...rest }) => { - log.warn(rest, message); - }); - } - - dispatch(setActiveTab('nodes')); - requestAnimationFrame(() => { - $flow.get()?.fitView(); - }); - }, - }); -}; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 309154db506..b61dfee857c 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import i18n from 'i18next'; import { forEach } from 'lodash-es'; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index 9e293f11048..cdaba3e9a19 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -6,9 +6,9 @@ import { isAnyOf, } from '@reduxjs/toolkit'; import { - ControlNetModelParam, - IPAdapterModelParam, - T2IAdapterModelParam, + ParameterControlNetModel, + ParameterIPAdapterModel, + ParameterT2IAdapterModel, } from 'features/parameters/types/parameterSchemas'; import { cloneDeep, merge, uniq } from 'lodash-es'; import { appSocketInvocationError } from 'services/events/actions'; @@ -243,9 +243,9 @@ export const controlAdaptersSlice = createSlice({ action: PayloadAction<{ id: string; model: - | ControlNetModelParam - | T2IAdapterModelParam - | IPAdapterModelParam; + | ParameterControlNetModel + | ParameterT2IAdapterModel + | ParameterIPAdapterModel; }> ) => { const { id, model } = action.payload; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts index afc6df45e4d..ea63600cdd7 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts @@ -1,8 +1,8 @@ import { EntityState } from '@reduxjs/toolkit'; import { - ControlNetModelParam, - IPAdapterModelParam, - T2IAdapterModelParam, + ParameterControlNetModel, + ParameterIPAdapterModel, + ParameterT2IAdapterModel, } from 'features/parameters/types/parameterSchemas'; import { isObject } from 'lodash-es'; import { components } from 'services/api/schema'; @@ -378,7 +378,7 @@ export type ControlNetConfig = { type: 'controlnet'; id: string; isEnabled: boolean; - model: ControlNetModelParam | null; + model: ParameterControlNetModel | null; weight: number; beginStepPct: number; endStepPct: number; @@ -395,7 +395,7 @@ export type T2IAdapterConfig = { type: 't2i_adapter'; id: string; isEnabled: boolean; - model: T2IAdapterModelParam | null; + model: ParameterT2IAdapterModel | null; weight: number; beginStepPct: number; endStepPct: number; @@ -412,7 +412,7 @@ export type IPAdapterConfig = { id: string; isEnabled: boolean; controlImage: string | null; - model: IPAdapterModelParam | null; + model: ParameterIPAdapterModel | null; weight: number; beginStepPct: number; endStepPct: number; diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts index d8e68dca21c..387d5916fa7 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts @@ -1,11 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { some } from 'lodash-es'; import { ImageUsage } from './types'; import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; +import { isImageFieldInputInstance } from 'features/nodes/types/field'; export const getImageUsage = (state: RootState, image_name: string) => { const { generation, canvas, nodes, controlAdapters } = state; @@ -19,7 +20,8 @@ export const getImageUsage = (state: RootState, image_name: string) => { return some( node.data.inputs, (input) => - input.type === 'ImageField' && input.value?.image_name === image_name + isImageFieldInputInstance(input) && + input.value?.image_name === image_name ); }); diff --git a/invokeai/frontend/web/src/features/dnd/types/index.ts b/invokeai/frontend/web/src/features/dnd/types/index.ts index f5254f8a5a8..45f325ebd1b 100644 --- a/invokeai/frontend/web/src/features/dnd/types/index.ts +++ b/invokeai/frontend/web/src/features/dnd/types/index.ts @@ -11,9 +11,9 @@ import { useDroppable as useOriginalDroppable, } from '@dnd-kit/core'; import { - InputFieldTemplate, - InputFieldValue, -} from 'features/nodes/types/types'; + FieldInputTemplate, + FieldInputInstance, +} from 'features/nodes/types/field'; import { ImageDTO } from 'services/api/types'; type BaseDropData = { @@ -93,8 +93,8 @@ export type NodeFieldDraggableData = BaseDragData & { payloadType: 'NODE_FIELD'; payload: { nodeId: string; - field: InputFieldValue; - fieldTemplate: InputFieldTemplate; + field: FieldInputInstance; + fieldTemplate: FieldInputTemplate; }; }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index ce5b178fa25..537df1bd283 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -4,14 +4,14 @@ import { LoRAMetadataItem, IPAdapterMetadataItem, T2IAdapterMetadataItem, -} from 'features/nodes/types/types'; +} from 'features/nodes/types/metadata'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useMemo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { - isValidControlNetModel, - isValidLoRAModel, - isValidT2IAdapterModel, + isParameterControlNetModel, + isParameterLoRAModel, + isParameterT2IAdapterModel, } from '../../../parameters/types/parameterSchemas'; import ImageMetadataItem from './ImageMetadataItem'; @@ -132,7 +132,7 @@ const ImageMetadataActions = (props: Props) => { const validControlNets: ControlNetMetadataItem[] = useMemo(() => { return metadata?.controlnets ? metadata.controlnets.filter((controlnet) => - isValidControlNetModel(controlnet.control_model) + isParameterControlNetModel(controlnet.control_model) ) : []; }, [metadata?.controlnets]); @@ -140,7 +140,7 @@ const ImageMetadataActions = (props: Props) => { const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { return metadata?.ipAdapters ? metadata.ipAdapters.filter((ipAdapter) => - isValidControlNetModel(ipAdapter.ip_adapter_model) + isParameterControlNetModel(ipAdapter.ip_adapter_model) ) : []; }, [metadata?.ipAdapters]); @@ -148,7 +148,7 @@ const ImageMetadataActions = (props: Props) => { const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => { return metadata?.t2iAdapters ? metadata.t2iAdapters.filter((t2iAdapter) => - isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model) + isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model) ) : []; }, [metadata?.t2iAdapters]); @@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => { return null; } - console.log(metadata); - return ( <> {metadata.created_by && ( @@ -275,7 +273,7 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.loras && metadata.loras.map((lora, index) => { - if (isValidLoRAModel(lora.lora)) { + if (isParameterLoRAModel(lora.lora)) { return ( { const { t } = useTranslation(); const fieldFilter = useAppSelector( - (state) => state.nodes.currentConnectionFieldType + (state) => state.nodes.connectionStartFieldType ); const handleFilter = useAppSelector( (state) => state.nodes.connectionStartParams?.handleType @@ -111,7 +110,7 @@ const AddNodePopover = () => { data.sort((a, b) => a.label.localeCompare(b.label)); - return { data, t }; + return { data }; }, defaultSelectorOptions ); @@ -121,7 +120,7 @@ const AddNodePopover = () => { const inputRef = useRef(null); const addNode = useCallback( - (nodeType: AnyInvocationType) => { + (nodeType: string) => { const invocation = buildInvocation(nodeType); if (!invocation) { const errorMessage = t('nodes.unknownNode', { @@ -145,7 +144,7 @@ const AddNodePopover = () => { return; } - addNode(v as AnyInvocationType); + addNode(v); }, [addNode] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx index a379be7ee28..f3d705b3476 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx @@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { FIELDS } from 'features/nodes/types/constants'; import { memo } from 'react'; import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; +import { getFieldColor } from '../edges/util/getEdgeColor'; const selector = createSelector(stateSelector, ({ nodes }) => { - const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = + const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } = nodes; - const stroke = - currentConnectionFieldType && shouldColorEdges - ? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color) - : colorTokenToCssVar('base.500'); + const stroke = shouldColorEdges + ? getFieldColor(connectionStartFieldType) + : colorTokenToCssVar('base.500'); let className = 'react-flow__custom_connection-path'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts new file mode 100644 index 00000000000..15c63b0bae8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -0,0 +1,12 @@ +import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { FIELD_COLORS } from 'features/nodes/types/constants'; +import { FieldType } from 'features/nodes/types/field'; + +export const getFieldColor = (fieldType: FieldType | null): string => { + if (!fieldType) { + return colorTokenToCssVar('base.500'); + } + const color = FIELD_COLORS[fieldType.name]; + + return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index b5dc484eaea..73d3d5dc4d7 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { FIELDS } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; +import { getFieldColor } from './getEdgeColor'; export const makeEdgeSelector = ( source: string, @@ -29,7 +29,7 @@ export const makeEdgeSelector = ( const stroke = sourceType && nodes.shouldColorEdges - ? colorTokenToCssVar(FIELDS[sourceType].color) + ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index 30e02bfd849..b1ca6ac22fb 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -1,7 +1,7 @@ import { useColorModeValue } from '@chakra-ui/react'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useNodeData } from 'features/nodes/hooks/useNodeData'; -import { isInvocationNodeData } from 'features/nodes/types/types'; +import { isInvocationNodeData } from 'features/nodes/types/invocation'; import { map } from 'lodash-es'; import { CSSProperties, memo, useMemo } from 'react'; import { Handle, Position } from 'reactflow'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx index 83867a35cb2..a439538075c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx @@ -2,8 +2,8 @@ import { Flex, Icon, Text, Tooltip } from '@chakra-ui/react'; import { compare } from 'compare-versions'; import { useNodeData } from 'features/nodes/hooks/useNodeData'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; -import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion'; -import { isInvocationNodeData } from 'features/nodes/types/types'; +import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; +import { isInvocationNodeData } from 'features/nodes/types/invocation'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { FaInfoCircle } from 'react-icons/fa'; @@ -13,7 +13,7 @@ interface Props { } const InvocationNodeInfoIcon = ({ nodeId }: Props) => { - const { needsUpdate } = useNodeVersion(nodeId); + const needsUpdate = useNodeNeedsUpdate(nodeId); return ( { const { status, progress, progressImage } = nodeExecutionState; const { t } = useTranslation(); - if (status === NodeStatus.PENDING) { + if (status === zNodeStatus.enum.PENDING) { return {t('queue.pending')}; } - if (status === NodeStatus.IN_PROGRESS) { + if (status === zNodeStatus.enum.IN_PROGRESS) { if (progressImage) { return ( @@ -108,11 +111,11 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => { return {t('nodes.executionStateInProgress')}; } - if (status === NodeStatus.COMPLETED) { + if (status === zNodeStatus.enum.COMPLETED) { return {t('nodes.executionStateCompleted')}; } - if (status === NodeStatus.FAILED) { + if (status === zNodeStatus.enum.FAILED) { return {t('nodes.executionStateError')}; } @@ -127,7 +130,7 @@ type StatusIconProps = { const StatusIcon = memo((props: StatusIconProps) => { const { progress, status } = props.nodeExecutionState; - if (status === NodeStatus.PENDING) { + if (status === zNodeStatus.enum.PENDING) { return ( { /> ); } - if (status === NodeStatus.IN_PROGRESS) { + if (status === zNodeStatus.enum.IN_PROGRESS) { return progress === null ? ( { /> ); } - if (status === NodeStatus.COMPLETED) { + if (status === zNodeStatus.enum.COMPLETED) { return ( { /> ); } - if (status === NodeStatus.FAILED) { + if (status === zNodeStatus.enum.FAILED) { return ( { ); const mayExpose = useMemo( - () => ['any', 'direct'].includes(input ?? '__UNKNOWN_INPUT__'), + () => input && ['any', 'direct'].includes(input), [input] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 31665902547..a622e5018c1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -1,18 +1,17 @@ import { Tooltip } from '@chakra-ui/react'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; import { - COLLECTION_TYPES, - FIELDS, HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES, - POLYMORPHIC_TYPES, } from 'features/nodes/types/constants'; import { - InputFieldTemplate, - OutputFieldTemplate, -} from 'features/nodes/types/types'; + FieldInputTemplate, + FieldOutputTemplate, +} from 'features/nodes/types/field'; import { CSSProperties, memo, useMemo } from 'react'; import { Handle, HandleType, Position } from 'reactflow'; +import { getFieldColor } from '../../../edges/util/getEdgeColor'; export const handleBaseStyles: CSSProperties = { position: 'absolute', @@ -32,11 +31,11 @@ export const outputHandleStyles: CSSProperties = { }; type FieldHandleProps = { - fieldTemplate: InputFieldTemplate | OutputFieldTemplate; + fieldTemplate: FieldInputTemplate | FieldOutputTemplate; handleType: HandleType; isConnectionInProgress: boolean; isConnectionStartField: boolean; - connectionError: string | null; + connectionError?: string; }; const FieldHandle = (props: FieldHandleProps) => { @@ -47,23 +46,21 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionStartField, connectionError, } = props; - const { name, type } = fieldTemplate; - const { color: typeColor, title } = FIELDS[type]; - + const { name } = fieldTemplate; + const type = fieldTemplate.type; + const fieldTypeName = useFieldTypeName(type); const styles: CSSProperties = useMemo(() => { - const isCollectionType = COLLECTION_TYPES.includes(type); - const isPolymorphicType = POLYMORPHIC_TYPES.includes(type); - const isModelType = MODEL_TYPES.includes(type); - const color = colorTokenToCssVar(typeColor); + const isModelType = MODEL_TYPES.some((t) => t === type.name); + const color = getFieldColor(type); const s: CSSProperties = { backgroundColor: - isCollectionType || isPolymorphicType - ? 'var(--invokeai-colors-base-900)' + type.isCollection || type.isPolymorphic + ? colorTokenToCssVar('base.900') : color, position: 'absolute', width: '1rem', height: '1rem', - borderWidth: isCollectionType || isPolymorphicType ? 4 : 0, + borderWidth: type.isCollection || type.isPolymorphic ? 4 : 0, borderStyle: 'solid', borderColor: color, borderRadius: isModelType ? 4 : '100%', @@ -97,18 +94,14 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionInProgress, isConnectionStartField, type, - typeColor, ]); const tooltip = useMemo(() => { - if (isConnectionInProgress && isConnectionStartField) { - return title; - } if (isConnectionInProgress && connectionError) { - return connectionError ?? title; + return connectionError; } - return title; - }, [connectionError, isConnectionInProgress, isConnectionStartField, title]); + return fieldTypeName; + }, [connectionError, fieldTypeName, isConnectionInProgress]); return ( { - const field = useFieldData(nodeId, fieldName); + const field = useFieldInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); - const isInputTemplate = isInputFieldTemplate(fieldTemplate); + const isInputTemplate = isFieldInputTemplate(fieldTemplate); + const fieldTypeName = useFieldTypeName(fieldTemplate?.type); const { t } = useTranslation(); const fieldTitle = useMemo(() => { - if (isInputFieldValue(field)) { + if (isFieldInputInstance(field)) { if (field.label && fieldTemplate?.title) { return `${field.label} (${fieldTemplate.title})`; } @@ -49,9 +49,9 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => { {fieldTemplate.description} )} - {fieldTemplate && ( + {fieldTypeName && ( - {t('parameters.type')}: {FIELDS[fieldTemplate.type].title} + {t('parameters.type')}: {fieldTypeName} )} {isInputTemplate && ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 4a489716025..dac9404c265 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -77,10 +77,10 @@ const InputField = ({ nodeId, fieldName }: Props) => { sx={{ display: 'flex', alignItems: 'center', - h: 'full', mb: 0, px: 1, gap: 2, + h: 'full', }} > { const { t } = useTranslation(); - const field = useFieldData(nodeId, fieldName); + const fieldInstance = useFieldInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); if (fieldTemplate?.fieldKind === 'output') { return ( - {t('nodes.outputFieldInInput')}: {field?.type} + {t('nodes.outputFieldInInput')}: {fieldInstance?.type.name} ); } if ( - (field?.type === 'string' && fieldTemplate?.type === 'string') || - (field?.type === 'StringPolymorphic' && - fieldTemplate?.type === 'StringPolymorphic') + isStringFieldInputInstance(fieldInstance) && + isStringFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') || - (field?.type === 'BooleanPolymorphic' && - fieldTemplate?.type === 'BooleanPolymorphic') + isBooleanFieldInputInstance(fieldInstance) && + isBooleanFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - (field?.type === 'integer' && fieldTemplate?.type === 'integer') || - (field?.type === 'float' && fieldTemplate?.type === 'float') || - (field?.type === 'FloatPolymorphic' && - fieldTemplate?.type === 'FloatPolymorphic') || - (field?.type === 'IntegerPolymorphic' && - fieldTemplate?.type === 'IntegerPolymorphic') + (isIntegerFieldInputInstance(fieldInstance) && + isIntegerFieldInputTemplate(fieldTemplate)) || + (isFloatFieldInputInstance(fieldInstance) && + isFloatFieldInputTemplate(fieldTemplate)) ) { return ( - ); } - if (field?.type === 'enum' && fieldTemplate?.type === 'enum') { + if ( + isEnumFieldInputInstance(fieldInstance) && + isEnumFieldInputTemplate(fieldTemplate) + ) { return ( - ); } if ( - (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') || - (field?.type === 'ImagePolymorphic' && - fieldTemplate?.type === 'ImagePolymorphic') + isImageFieldInputInstance(fieldInstance) && + isImageFieldInputTemplate(fieldTemplate) ) { return ( - ); } - if (field?.type === 'BoardField' && fieldTemplate?.type === 'BoardField') { + if ( + isBoardFieldInputInstance(fieldInstance) && + isBoardFieldInputTemplate(fieldTemplate) + ) { return ( - ); } if ( - field?.type === 'MainModelField' && - fieldTemplate?.type === 'MainModelField' + isMainModelFieldInputInstance(fieldInstance) && + isMainModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'SDXLRefinerModelField' && - fieldTemplate?.type === 'SDXLRefinerModelField' + isSDXLRefinerModelFieldInputInstance(fieldInstance) && + isSDXLRefinerModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'VaeModelField' && - fieldTemplate?.type === 'VaeModelField' + isVAEModelFieldInputInstance(fieldInstance) && + isVAEModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'LoRAModelField' && - fieldTemplate?.type === 'LoRAModelField' + isLoRAModelFieldInputInstance(fieldInstance) && + isLoRAModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'ControlNetModelField' && - fieldTemplate?.type === 'ControlNetModelField' + isControlNetModelFieldInputInstance(fieldInstance) && + isControlNetModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'IPAdapterModelField' && - fieldTemplate?.type === 'IPAdapterModelField' + isIPAdapterModelFieldInputInstance(fieldInstance) && + isIPAdapterModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'T2IAdapterModelField' && - fieldTemplate?.type === 'T2IAdapterModelField' + isT2IAdapterModelFieldInputInstance(fieldInstance) && + isT2IAdapterModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } - if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') { + if ( + isColorFieldInputInstance(fieldInstance) && + isColorFieldInputTemplate(fieldTemplate) + ) { return ( - ); } if ( - field?.type === 'SDXLMainModelField' && - fieldTemplate?.type === 'SDXLMainModelField' + isSDXLMainModelFieldInputInstance(fieldInstance) && + isSDXLMainModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } - if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') { + if ( + isSchedulerFieldInputInstance(fieldInstance) && + isSchedulerFieldInputTemplate(fieldTemplate) + ) { return ( - ); } - if (field && fieldTemplate) { + if (fieldInstance && fieldTemplate) { // Fallback for when there is no component for the type return null; } @@ -255,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { _dark: { color: 'error.300' }, }} > - {t('nodes.unknownFieldType')}: {field?.type} + {t('nodes.unknownFieldType')}: {fieldInstance?.type.name} ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardFieldInputComponent.tsx similarity index 82% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardFieldInputComponent.tsx index a6e8cbb0c1b..8f0f3260a61 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardFieldInputComponent.tsx @@ -3,15 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice'; import { - BoardInputFieldTemplate, - BoardInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + BoardFieldInputTemplate, + BoardFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback } from 'react'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; -const BoardInputFieldComponent = ( - props: FieldComponentProps +const BoardFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); @@ -61,4 +61,4 @@ const BoardInputFieldComponent = ( ); }; -export default memo(BoardInputFieldComponent); +export default memo(BoardFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanFieldInputComponent.tsx similarity index 65% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanFieldInputComponent.tsx index d14756dbdb3..3bac81b0f7a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanFieldInputComponent.tsx @@ -2,18 +2,16 @@ import { Switch } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; import { - BooleanInputFieldTemplate, - BooleanInputFieldValue, - BooleanPolymorphicInputFieldTemplate, - BooleanPolymorphicInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + BooleanFieldInputInstance, + BooleanFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { ChangeEvent, memo, useCallback } from 'react'; -const BooleanInputFieldComponent = ( +const BooleanFieldInputComponent = ( props: FieldComponentProps< - BooleanInputFieldValue | BooleanPolymorphicInputFieldValue, - BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate + BooleanFieldInputInstance, + BooleanFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -42,4 +40,4 @@ const BooleanInputFieldComponent = ( ); }; -export default memo(BooleanInputFieldComponent); +export default memo(BooleanFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorFieldInputComponent.tsx similarity index 70% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorFieldInputComponent.tsx index c2af279cb59..875bb062708 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorFieldInputComponent.tsx @@ -1,15 +1,15 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice'; import { - ColorInputFieldTemplate, - ColorInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + ColorFieldInputTemplate, + ColorFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback } from 'react'; import { RgbaColor, RgbaColorPicker } from 'react-colorful'; -const ColorInputFieldComponent = ( - props: FieldComponentProps +const ColorFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field } = props; @@ -37,4 +37,4 @@ const ColorInputFieldComponent = ( ); }; -export default memo(ColorInputFieldComponent); +export default memo(ColorFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 804671204de..8604e6319e7 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - ControlNetModelInputFieldTemplate, - ControlNetModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + ControlNetModelFieldInputTemplate, + ControlNetModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; -const ControlNetModelInputFieldComponent = ( +const ControlNetModelFieldInputComponent = ( props: FieldComponentProps< - ControlNetModelInputFieldValue, - ControlNetModelInputFieldTemplate + ControlNetModelFieldInputInstance, + ControlNetModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -97,4 +97,4 @@ const ControlNetModelInputFieldComponent = ( ); }; -export default memo(ControlNetModelInputFieldComponent); +export default memo(ControlNetModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumFieldInputComponent.tsx similarity index 77% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumFieldInputComponent.tsx index 277020d847e..e741afe964c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumFieldInputComponent.tsx @@ -2,14 +2,14 @@ import { Select } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import { fieldEnumModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - EnumInputFieldTemplate, - EnumInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + EnumFieldInputInstance, + EnumFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { ChangeEvent, memo, useCallback } from 'react'; -const EnumInputFieldComponent = ( - props: FieldComponentProps +const EnumFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field, fieldTemplate } = props; @@ -45,4 +45,4 @@ const EnumInputFieldComponent = ( ); }; -export default memo(EnumInputFieldComponent); +export default memo(EnumFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 637fa79f60b..0ec332cd5a7 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - IPAdapterModelInputFieldTemplate, - IPAdapterModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + IPAdapterModelFieldInputTemplate, + IPAdapterModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; -const IPAdapterModelInputFieldComponent = ( +const IPAdapterModelFieldInputComponent = ( props: FieldComponentProps< - IPAdapterModelInputFieldValue, - IPAdapterModelInputFieldTemplate + IPAdapterModelFieldInputInstance, + IPAdapterModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -97,4 +97,4 @@ const IPAdapterModelInputFieldComponent = ( ); }; -export default memo(IPAdapterModelInputFieldComponent); +export default memo(IPAdapterModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx index 94095f26128..5feb870adc0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx @@ -9,23 +9,18 @@ import { } from 'features/dnd/types'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - ImageInputFieldTemplate, - ImageInputFieldValue, - ImagePolymorphicInputFieldTemplate, - ImagePolymorphicInputFieldValue, -} from 'features/nodes/types/types'; + ImageFieldInputInstance, + ImageFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { FaUndo } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/types'; -const ImageInputFieldComponent = ( - props: FieldComponentProps< - ImageInputFieldValue | ImagePolymorphicInputFieldValue, - ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate - > +const ImageFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); @@ -102,7 +97,7 @@ const ImageInputFieldComponent = ( ); }; -export default memo(ImageInputFieldComponent); +export default memo(ImageFieldInputComponent); const UploadElement = memo(() => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx similarity index 91% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index dc79436ec6b..fa2bada6317 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -5,10 +5,10 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - LoRAModelInputFieldTemplate, - LoRAModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + LoRAModelFieldInputTemplate, + LoRAModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam'; import { forEach } from 'lodash-es'; @@ -16,10 +16,10 @@ import { memo, useCallback, useMemo } from 'react'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useTranslation } from 'react-i18next'; -const LoRAModelInputFieldComponent = ( +const LoRAModelFieldInputComponent = ( props: FieldComponentProps< - LoRAModelInputFieldValue, - LoRAModelInputFieldTemplate + LoRAModelFieldInputInstance, + LoRAModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -121,4 +121,4 @@ const LoRAModelInputFieldComponent = ( ); }; -export default memo(LoRAModelInputFieldComponent); +export default memo(LoRAModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx similarity index 93% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx index af68b4291c1..8c62548924a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx @@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - MainModelInputFieldTemplate, - MainModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + MainModelFieldInputTemplate, + MainModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -21,10 +21,10 @@ import { } from 'services/api/endpoints/models'; import { useTranslation } from 'react-i18next'; -const MainModelInputFieldComponent = ( +const MainModelFieldInputComponent = ( props: FieldComponentProps< - MainModelInputFieldValue, - MainModelInputFieldTemplate + MainModelFieldInputInstance, + MainModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -149,4 +149,4 @@ const MainModelInputFieldComponent = ( ); }; -export default memo(MainModelInputFieldComponent); +export default memo(MainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldInputComponent.tsx similarity index 71% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldInputComponent.tsx index 2b2763ca3ec..9daff53448c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldInputComponent.tsx @@ -9,28 +9,18 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { numberStringRegex } from 'common/components/IAINumberInput'; import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - FloatInputFieldTemplate, - FloatInputFieldValue, - FloatPolymorphicInputFieldTemplate, - FloatPolymorphicInputFieldValue, - IntegerInputFieldTemplate, - IntegerInputFieldValue, - IntegerPolymorphicInputFieldTemplate, - IntegerPolymorphicInputFieldValue, -} from 'features/nodes/types/types'; + FloatFieldInputInstance, + FloatFieldInputTemplate, + IntegerFieldInputInstance, + IntegerFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback, useEffect, useMemo, useState } from 'react'; -const NumberInputFieldComponent = ( +const NumberFieldInputComponent = ( props: FieldComponentProps< - | IntegerInputFieldValue - | IntegerPolymorphicInputFieldValue - | FloatInputFieldValue - | FloatPolymorphicInputFieldValue, - | IntegerInputFieldTemplate - | IntegerPolymorphicInputFieldTemplate - | FloatInputFieldTemplate - | FloatPolymorphicInputFieldTemplate + IntegerFieldInputInstance | FloatFieldInputInstance, + IntegerFieldInputTemplate | FloatFieldInputTemplate > ) => { const { nodeId, field, fieldTemplate } = props; @@ -39,7 +29,7 @@ const NumberInputFieldComponent = ( String(field.value) ); const isIntegerField = useMemo( - () => fieldTemplate.type === 'integer', + () => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type] ); @@ -86,4 +76,4 @@ const NumberInputFieldComponent = ( ); }; -export default memo(NumberInputFieldComponent); +export default memo(NumberFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx similarity index 90% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx index e6db6031b88..42e63a8cb65 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx @@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - SDXLRefinerModelInputFieldTemplate, - SDXLRefinerModelInputFieldValue, -} from 'features/nodes/types/types'; + SDXLRefinerModelFieldInputTemplate, + SDXLRefinerModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -18,10 +18,10 @@ import { useTranslation } from 'react-i18next'; import { REFINER_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; -const RefinerModelInputFieldComponent = ( +const RefinerModelFieldInputComponent = ( props: FieldComponentProps< - SDXLRefinerModelInputFieldValue, - SDXLRefinerModelInputFieldTemplate + SDXLRefinerModelFieldInputInstance, + SDXLRefinerModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -120,4 +120,4 @@ const RefinerModelInputFieldComponent = ( ); }; -export default memo(RefinerModelInputFieldComponent); +export default memo(RefinerModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx similarity index 92% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx index c6ef5c6bb4e..260f51ee8bc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx @@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - SDXLMainModelInputFieldTemplate, - SDXLMainModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + SDXLMainModelFieldInputTemplate, + SDXLMainModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -21,10 +21,10 @@ import { useGetOnnxModelsQuery, } from 'services/api/endpoints/models'; -const ModelInputFieldComponent = ( +const SDXLMainModelFieldInputComponent = ( props: FieldComponentProps< - SDXLMainModelInputFieldValue, - SDXLMainModelInputFieldTemplate + SDXLMainModelFieldInputInstance, + SDXLMainModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -147,4 +147,4 @@ const ModelInputFieldComponent = ( ); }; -export default memo(ModelInputFieldComponent); +export default memo(SDXLMainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerFieldInputComponent.tsx similarity index 72% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerFieldInputComponent.tsx index e4a3fb2a3db..a3b30f5057a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerFieldInputComponent.tsx @@ -5,14 +5,12 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice'; import { - SchedulerInputFieldTemplate, - SchedulerInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; + SchedulerFieldInputTemplate, + SchedulerFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; @@ -24,7 +22,7 @@ const selector = createSelector( const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ value: name, label: label, - group: enabledSchedulers.includes(name as SchedulerParam) + group: enabledSchedulers.includes(name as ParameterScheduler) ? 'Favorites' : undefined, })).sort((a, b) => a.label.localeCompare(b.label)); @@ -36,10 +34,10 @@ const selector = createSelector( defaultSelectorOptions ); -const SchedulerInputField = ( +const SchedulerFieldInputComponent = ( props: FieldComponentProps< - SchedulerInputFieldValue, - SchedulerInputFieldTemplate + SchedulerFieldInputInstance, + SchedulerFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -55,7 +53,7 @@ const SchedulerInputField = ( fieldSchedulerValueChanged({ nodeId, fieldName: field.name, - value: value as SchedulerParam, + value: value as ParameterScheduler, }) ); }, @@ -72,4 +70,4 @@ const SchedulerInputField = ( ); }; -export default memo(SchedulerInputField); +export default memo(SchedulerFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldInputComponent.tsx similarity index 70% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldInputComponent.tsx index 720722030be..50c8c487da5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldInputComponent.tsx @@ -3,19 +3,14 @@ import IAIInput from 'common/components/IAIInput'; import IAITextarea from 'common/components/IAITextarea'; import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice'; import { - StringInputFieldTemplate, - StringInputFieldValue, - FieldComponentProps, - StringPolymorphicInputFieldValue, - StringPolymorphicInputFieldTemplate, -} from 'features/nodes/types/types'; + StringFieldInputInstance, + StringFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { ChangeEvent, memo, useCallback } from 'react'; -const StringInputFieldComponent = ( - props: FieldComponentProps< - StringInputFieldValue | StringPolymorphicInputFieldValue, - StringInputFieldTemplate | StringPolymorphicInputFieldTemplate - > +const StringFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field, fieldTemplate } = props; const dispatch = useAppDispatch(); @@ -48,4 +43,4 @@ const StringInputFieldComponent = ( return ; }; -export default memo(StringInputFieldComponent); +export default memo(StringFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index f5ae6b747a8..03b3cba4f01 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - T2IAdapterModelInputFieldTemplate, - T2IAdapterModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + T2IAdapterModelFieldInputInstance, + T2IAdapterModelFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; -const T2IAdapterModelInputFieldComponent = ( +const T2IAdapterModelFieldInputComponent = ( props: FieldComponentProps< - T2IAdapterModelInputFieldValue, - T2IAdapterModelInputFieldTemplate + T2IAdapterModelFieldInputInstance, + T2IAdapterModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -97,4 +97,4 @@ const T2IAdapterModelInputFieldComponent = ( ); }; -export default memo(T2IAdapterModelInputFieldComponent); +export default memo(T2IAdapterModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VaeModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx similarity index 89% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VaeModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index 79ada94c3e3..93d397b2024 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VaeModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -4,20 +4,20 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - VaeModelInputFieldTemplate, - VaeModelInputFieldValue, -} from 'features/nodes/types/types'; + VAEModelFieldInputTemplate, + VAEModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; -const VaeModelInputFieldComponent = ( +const VAEModelFieldInputComponent = ( props: FieldComponentProps< - VaeModelInputFieldValue, - VaeModelInputFieldTemplate + VAEModelFieldInputInstance, + VAEModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -105,4 +105,4 @@ const VaeModelInputFieldComponent = ( ); }; -export default memo(VaeModelInputFieldComponent); +export default memo(VAEModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts new file mode 100644 index 00000000000..22c488c16f4 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts @@ -0,0 +1,13 @@ +import { + FieldInputInstance, + FieldInputTemplate, +} from 'features/nodes/types/field'; + +export type FieldComponentProps< + V extends FieldInputInstance, + T extends FieldInputTemplate, +> = { + nodeId: string; + field: V; + fieldTemplate: T; +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx index ec869f3dad8..bbbb7b7372b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx @@ -2,7 +2,7 @@ import { Box, Flex } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import IAITextarea from 'common/components/IAITextarea'; import { notesNodeValueChanged } from 'features/nodes/store/nodesSlice'; -import { NotesNodeData } from 'features/nodes/types/types'; +import { NotesNodeData } from 'features/nodes/types/invocation'; import { ChangeEvent, memo, useCallback } from 'react'; import { NodeProps } from 'reactflow'; import NodeWrapper from '../common/NodeWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx index 79de65760f9..b6ccd4ae9f7 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx @@ -14,7 +14,7 @@ import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH, } from 'features/nodes/types/constants'; -import { NodeStatus } from 'features/nodes/types/types'; +import { zNodeStatus } from 'features/nodes/types/invocation'; import { contextMenusClosed } from 'features/ui/store/uiSlice'; import { MouseEvent, @@ -40,7 +40,8 @@ const NodeWrapper = (props: NodeWrapperProps) => { createSelector( stateSelector, ({ nodes }) => - nodes.nodeExecutionStates[nodeId]?.status === NodeStatus.IN_PROGRESS + nodes.nodeExecutionStates[nodeId]?.status === + zNodeStatus.enum.IN_PROGRESS ), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx index 8454f5539ff..eb593ee1db0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx @@ -8,7 +8,7 @@ import { FaUpload } from 'react-icons/fa'; const LoadWorkflowButton = () => { const { t } = useTranslation(); const resetRef = useRef<() => void>(null); - const loadWorkflowFromFile = useLoadWorkflowFromFile(); + const loadWorkflowFromFile = useLoadWorkflowFromFile(resetRef); return ( { - return ( - - {map(FIELDS, ({ title, description, color }, key) => ( - - - {title} - - - ))} - - ); -}; - -export default memo(FieldTypeLegend); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx index db8f544c2e0..c289ea02dd1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx @@ -1,18 +1,11 @@ import { Flex } from '@chakra-ui/layout'; -import { useAppSelector } from 'app/store/storeHooks'; import { memo } from 'react'; -import FieldTypeLegend from './FieldTypeLegend'; import WorkflowEditorSettings from './WorkflowEditorSettings'; const TopRightPanel = () => { - const shouldShowFieldTypeLegend = useAppSelector( - (state) => state.nodes.shouldShowFieldTypeLegend - ); - return ( - {shouldShowFieldTypeLegend && } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index d906557dd37..ecbe538fcc9 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -10,17 +10,15 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIIconButton from 'common/components/IAIIconButton'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion'; +import { getNeedsUpdate } from 'features/nodes/store/util/nodeUpdate'; import { InvocationNodeData, InvocationTemplate, isInvocationNode, -} from 'features/nodes/types/types'; -import { memo } from 'react'; +} from 'features/nodes/types/invocation'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { FaSync } from 'react-icons/fa'; import { Node } from 'reactflow'; import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea'; import ScrollableContent from '../ScrollableContent'; @@ -63,12 +61,17 @@ const InspectorDetailsTab = () => { export default memo(InspectorDetailsTab); -const Content = (props: { +type ContentProps = { node: Node; template: InvocationTemplate; -}) => { +}; + +const Content = memo(({ node, template }: ContentProps) => { const { t } = useTranslation(); - const { needsUpdate, updateNode } = useNodeVersion(props.node.id); + const needsUpdate = useMemo( + () => getNeedsUpdate(node, template), + [node, template] + ); return ( - + {t('nodes.nodeType')} - {props.template.title} + {template.title} {t('nodes.nodeVersion')} - {props.node.data.version} + {node.data.version} - {needsUpdate && ( - } - onClick={updateNode} - /> - )} - + ); -}; +}); + +Content.displayName = 'Content'; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index f4abc621b4c..265d369f5ad 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,7 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { ImageOutput } from 'services/api/types'; import { AnyResult } from 'services/events/types'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index dda2efc1568..ccfa0f57fdb 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; +import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; -import { - POLYMORPHIC_TYPES, - TYPES_WITH_INPUT_COMPONENTS, -} from '../types/constants'; +import { isInvocationNode } from '../types/invocation'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; +import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate'; export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const selector = useMemo( @@ -28,8 +25,8 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const fields = map(nodeTemplate.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || - POLYMORPHIC_TYPES.includes(field.type)) && - TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + field.type.isPolymorphic) && + keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) ); return getSortedFilteredFieldNames(fields); }, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts index 036ce8d44e6..694261d9439 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts @@ -3,10 +3,13 @@ import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { Node, useReactFlow } from 'reactflow'; -import { AnyInvocationType } from 'services/events/types'; -import { buildNodeData } from '../store/util/buildNodeData'; +import { + buildCurrentImageNode, + buildInvocationNode, + buildNotesNode, +} from '../store/util/buildNodeData'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants'; - +import { AnyNodeData, InvocationTemplate } from '../types/invocation'; const templatesSelector = createSelector( [(state: RootState) => state.nodes], (nodes) => nodes.nodeTemplates @@ -22,7 +25,8 @@ export const useBuildNodeData = () => { const flow = useReactFlow(); return useCallback( - (type: AnyInvocationType | 'current_image' | 'notes') => { + // string here is "any invocation type" + (type: string | 'current_image' | 'notes'): Node => { let _x = window.innerWidth / 2; let _y = window.innerHeight / 2; @@ -41,9 +45,19 @@ export const useBuildNodeData = () => { y: _y, }); - const template = nodeTemplates[type]; + if (type === 'current_image') { + return buildCurrentImageNode(position); + } + + if (type === 'notes') { + return buildNotesNode(position); + } + + // TODO: Keep track of invocation types so we do not need to cast this + // We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates. + const template = nodeTemplates[type] as InvocationTemplate; - return buildNodeData(type, position, template); + return buildInvocationNode(position, template); }, [nodeTemplates, flow] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 9fb31df801d..2951167944c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; +import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -import { - POLYMORPHIC_TYPES, - TYPES_WITH_INPUT_COMPONENTS, -} from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; +import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate'; export const useConnectionInputFieldNames = (nodeId: string) => { const selector = useMemo( @@ -29,9 +26,8 @@ export const useConnectionInputFieldNames = (nodeId: string) => { // get the visible fields const fields = map(nodeTemplate.inputs).filter( (field) => - (field.input === 'connection' && - !POLYMORPHIC_TYPES.includes(field.type)) || - !TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + (field.input === 'connection' && !field.type.isPolymorphic) || + !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) ); return getSortedFilteredFieldNames(fields); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 96b2d652e92..cc3b2ce7ac9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts'; const selectIsConnectionInProgress = createSelector( stateSelector, ({ nodes }) => - nodes.currentConnectionFieldType !== null && + nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index 926c56ac1ed..82db025a557 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { compareVersions } from 'compare-versions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useDoNodeVersionsMatch = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index 83bf6b8af0c..e5264c07c49 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts b/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts index 866d8ab9708..863bc64718e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useEmbedWorkflow = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts index ba2c4e2d5ce..7cdd44e4fd0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts @@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; -export const useFieldData = (nodeId: string, fieldName: string) => { +export const useFieldInstance = (nodeId: string, fieldName: string) => { const selector = useMemo( () => createSelector( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index 159815a6a62..82f90531ddd 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts index fcf33c34273..cabef729ae8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldLabel = (nodeId: string, fieldName: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index 93d545aaead..a18a027c6b1 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; import { KIND_MAP } from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldTemplate = ( nodeId: string, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index 923c25cc186..faec9c1ff37 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; import { KIND_MAP } from '../types/constants'; export const useFieldTemplateTitle = ( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts index f4d78f89548..0775c32cb29 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; import { KIND_MAP } from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldType = ( nodeId: string, @@ -20,7 +20,8 @@ export const useFieldType = ( if (!isInvocationNode(node)) { return; } - return node?.data[KIND_MAP[kind]][fieldName]?.type; + const field = node.data[KIND_MAP[kind]][fieldName]; + return field?.type; }, defaultSelectorOptions ), diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index a413de38aea..c22c0d95057 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -2,7 +2,8 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { getNeedsUpdate } from './useNodeVersion'; +import { getNeedsUpdate } from '../store/util/nodeUpdate'; +import { isInvocationNode } from '../types/invocation'; const selector = createSelector( stateSelector, @@ -10,8 +11,11 @@ const selector = createSelector( const nodes = state.nodes.nodes; const templates = state.nodes.nodeTemplates; - const needsUpdate = nodes.some((node) => { + const needsUpdate = nodes.filter(isInvocationNode).some((node) => { const template = templates[node.data.type]; + if (!template) { + return false; + } return getNeedsUpdate(node, template); }); return needsUpdate; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts index 111e48a45f2..6b99d75ef07 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts @@ -4,8 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { some } from 'lodash-es'; import { useMemo } from 'react'; -import { IMAGE_FIELDS } from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useHasImageOutput = (nodeId: string) => { const selector = useMemo( @@ -20,8 +19,8 @@ export const useHasImageOutput = (nodeId: string) => { return some( node.data.outputs, (output) => - IMAGE_FIELDS.includes(output.type) && - // the image primitive node does not actually save the image, do not show the image-saving checkboxes + output.type.name === 'ImageField' && + // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes node.data.type !== 'image' ); }, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts index 86b9371b034..167610c14fa 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useIsIntermediate = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index c88d4758af1..028b238c7b3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -4,7 +4,7 @@ import { useCallback } from 'react'; import { Connection, Node, useReactFlow } from 'reactflow'; import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes'; import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic'; -import { InvocationNodeData } from '../types/types'; +import { InvocationNodeData } from '../types/invocation'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -34,10 +34,10 @@ export const useIsValidConnection = () => { return false; } - const sourceType = sourceNode.data.outputs[sourceHandle]?.type; - const targetType = targetNode.data.inputs[targetHandle]?.type; + const sourceField = sourceNode.data.outputs[sourceHandle]; + const targetField = targetNode.data.inputs[targetHandle]; - if (!sourceType || !targetType) { + if (!sourceField || !targetField) { // something has gone terribly awry return false; } @@ -70,12 +70,13 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetType !== 'CollectionItem' + targetField.type.name !== 'CollectionItemField' ) { return false; } - if (!validateSourceAndTargetTypes(sourceType, targetType)) { + // Must use the originalType here if it exists + if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx b/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx index 890fa7a72d8..3646e8dc585 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx +++ b/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx @@ -1,17 +1,15 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react'; import { useLogger } from 'app/logging/useLogger'; import { useAppDispatch } from 'app/store/storeHooks'; -import { parseify } from 'common/util/serialize'; -import { zWorkflow } from 'features/nodes/types/types'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; -import { memo, useCallback } from 'react'; +import { RefObject, memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; import { ZodError } from 'zod'; -import { fromZodError, fromZodIssue } from 'zod-validation-error'; +import { fromZodIssue } from 'zod-validation-error'; import { workflowLoadRequested } from '../store/actions'; -import { useTranslation } from 'react-i18next'; -export const useLoadWorkflowFromFile = () => { +export const useLoadWorkflowFromFile = (resetRef: RefObject<() => void>) => { const dispatch = useAppDispatch(); const logger = useLogger('nodes'); const { t } = useTranslation(); @@ -26,33 +24,10 @@ export const useLoadWorkflowFromFile = () => { try { const parsedJSON = JSON.parse(String(rawJSON)); - const result = zWorkflow.safeParse(parsedJSON); - - if (!result.success) { - const { message } = fromZodError(result.error, { - prefix: t('nodes.workflowValidation'), - }); - - logger.error({ error: parseify(result.error) }, message); - - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - duration: 5000, - }) - ) - ); - reader.abort(); - return; - } - - dispatch(workflowLoadRequested(result.data)); - - reader.abort(); - } catch { - // file reader error + dispatch(workflowLoadRequested(parsedJSON)); + } catch (e) { + // There was a problem reading the file + logger.error(t('nodes.unableToLoadWorkflow')); dispatch( addToast( makeToast({ @@ -61,12 +36,15 @@ export const useLoadWorkflowFromFile = () => { }) ) ); + reader.abort(); } }; reader.readAsText(file); + // Reset the file picker internal state so that the same file can be loaded again + resetRef.current?.(); }, - [dispatch, logger, t] + [dispatch, logger, resetRef, t] ); return loadWorkflowFromFile; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index f9bbe4cc1d1..edce18b52b7 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts new file mode 100644 index 00000000000..99a7c47170e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -0,0 +1,35 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useMemo } from 'react'; +import { isInvocationNode } from '../types/invocation'; +import { getNeedsUpdate } from '../store/util/nodeUpdate'; + +export const useNodeNeedsUpdate = (nodeId: string) => { + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const node = nodes.nodes.find((node) => node.id === nodeId); + const template = nodes.nodeTemplates[node?.data.type ?? '']; + return { node, template }; + }, + defaultSelectorOptions + ), + [nodeId] + ); + + const { node, template } = useAppSelector(selector); + + const needsUpdate = useMemo( + () => + isInvocationNode(node) && template + ? getNeedsUpdate(node, template) + : false, + [node, template] + ); + + return needsUpdate; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index 6fd06155633..83012d0830a 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -3,16 +3,14 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { AnyInvocationType } from 'services/events/types'; +import { InvocationTemplate } from '../types/invocation'; -export const useNodeTemplateByType = ( - type: AnyInvocationType | 'current_image' | 'notes' -) => { +export const useNodeTemplateByType = (type: string) => { const selector = useMemo( () => createSelector( stateSelector, - ({ nodes }) => { + ({ nodes }): InvocationTemplate | undefined => { const nodeTemplate = nodes.nodeTemplates[type]; return nodeTemplate; }, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 4ef3eed5d92..c3dc1507359 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useNodeTemplateTitle = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts deleted file mode 100644 index 1f213d64810..00000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts +++ /dev/null @@ -1,119 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { satisfies } from 'compare-versions'; -import { cloneDeep, defaultsDeep } from 'lodash-es'; -import { useCallback, useMemo } from 'react'; -import { Node } from 'reactflow'; -import { AnyInvocationType } from 'services/events/types'; -import { nodeReplaced } from '../store/nodesSlice'; -import { buildNodeData } from '../store/util/buildNodeData'; -import { - InvocationNodeData, - InvocationTemplate, - NodeData, - isInvocationNode, - zParsedSemver, -} from '../types/types'; -import { useAppToaster } from 'app/components/Toaster'; -import { useTranslation } from 'react-i18next'; - -export const getNeedsUpdate = ( - node?: Node, - template?: InvocationTemplate -) => { - if (!isInvocationNode(node) || !template) { - return false; - } - return node.data.version !== template.version; -}; - -export const getMayUpdateNode = ( - node?: Node, - template?: InvocationTemplate -) => { - const needsUpdate = getNeedsUpdate(node, template); - if ( - !needsUpdate || - !isInvocationNode(node) || - !template || - !node.data.version - ) { - return false; - } - const templateMajor = zParsedSemver.parse(template.version).major; - - return satisfies(node.data.version, `^${templateMajor}`); -}; - -export const updateNode = ( - node?: Node, - template?: InvocationTemplate -) => { - const mayUpdate = getMayUpdateNode(node, template); - if ( - !mayUpdate || - !isInvocationNode(node) || - !template || - !node.data.version - ) { - return; - } - - const defaults = buildNodeData( - node.data.type as AnyInvocationType, - node.position, - template - ) as Node; - - const clone = cloneDeep(node); - clone.data.version = template.version; - defaultsDeep(clone, defaults); - return clone; -}; - -export const useNodeVersion = (nodeId: string) => { - const dispatch = useAppDispatch(); - const toast = useAppToaster(); - const { t } = useTranslation(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ nodes }) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; - return { node, nodeTemplate }; - }, - defaultSelectorOptions - ), - [nodeId] - ); - - const { node, nodeTemplate } = useAppSelector(selector); - - const needsUpdate = useMemo( - () => getNeedsUpdate(node, nodeTemplate), - [node, nodeTemplate] - ); - - const mayUpdate = useMemo( - () => getMayUpdateNode(node, nodeTemplate), - [node, nodeTemplate] - ); - - const _updateNode = useCallback(() => { - const needsUpdate = getNeedsUpdate(node, nodeTemplate); - const updatedNode = updateNode(node, nodeTemplate); - if (!updatedNode) { - if (needsUpdate) { - toast({ title: t('nodes.unableToUpdateNodes', { count: 1 }) }); - } - return; - } - dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); - }, [dispatch, node, nodeTemplate, t, toast]); - - return { needsUpdate, mayUpdate, updateNode: _updateNode }; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e0a1e5433ed..93e4ccb8334 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; export const useOutputFieldNames = (nodeId: string) => { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts new file mode 100644 index 00000000000..bff58738649 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts @@ -0,0 +1,23 @@ +import { useTranslation } from 'react-i18next'; +import { FieldType } from '../types/field'; +import { useMemo } from 'react'; + +export const useFieldTypeName = (fieldType?: FieldType): string => { + const { t } = useTranslation(); + + const name = useMemo(() => { + if (!fieldType) { + return ''; + } + const { name } = fieldType; + if (fieldType.isCollection) { + return t('nodes.collectionFieldType', { name }); + } + if (fieldType.isPolymorphic) { + return t('nodes.polymorphicFieldType', { name }); + } + return name; + }, [fieldType, t]); + + return name; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts index 7416d7e66eb..e05cbd818fe 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useUseCache = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts index 3c83e01731c..c495c54974f 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useWithWorkflow = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 0d75e6934db..5dd5344a99c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,6 +1,5 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; import { Graph } from 'services/api/types'; -import { Workflow } from '../types/types'; export const textToImageGraphBuilt = createAction( 'nodes/textToImageGraphBuilt' @@ -18,7 +17,7 @@ export const isAnyGraphBuilt = isAnyOf( nodesGraphBuilt ); -export const workflowLoadRequested = createAction( +export const workflowLoadRequested = createAction( 'nodes/workflowLoadRequested' ); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts index 64fee2293f9..1322bafa431 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts @@ -6,7 +6,7 @@ import { NodesState } from './types'; export const nodesPersistDenylist: (keyof NodesState)[] = [ 'nodeTemplates', 'connectionStartParams', - 'currentConnectionFieldType', + 'connectionStartFieldType', 'selectedNodes', 'selectedEdges', 'isReady', diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 3acef5978f2..0c21d02feda 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -20,7 +20,6 @@ import { XYPosition, } from 'reactflow'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; -import { ImageField } from 'services/api/types'; import { appSocketGeneratorProgress, appSocketInvocationComplete, @@ -31,60 +30,58 @@ import { import { v4 as uuidv4 } from 'uuid'; import { DRAG_HANDLE_CLASSNAME } from '../types/constants'; import { - BoardInputFieldValue, - BooleanInputFieldValue, - ColorInputFieldValue, - ControlNetModelInputFieldValue, - CurrentImageNodeData, - EnumInputFieldValue, + BoardFieldValue, + BooleanFieldValue, + ColorFieldValue, + ControlNetModelFieldValue, + EnumFieldValue, FieldIdentifier, - FloatInputFieldValue, - ImageInputFieldValue, - InputFieldValue, - IntegerInputFieldValue, - InvocationNodeData, + FieldValue, + FloatFieldValue, + ImageFieldValue, + IntegerFieldValue, + IPAdapterModelFieldValue, + LoRAModelFieldValue, + MainModelFieldValue, + SchedulerFieldValue, + SDXLRefinerModelFieldValue, + StringFieldValue, + T2IAdapterModelFieldValue, + VAEModelFieldValue, +} from '../types/field'; +import { + AnyNodeData, InvocationTemplate, - IPAdapterModelInputFieldValue, isInvocationNode, isNotesNode, - LoRAModelInputFieldValue, - MainModelInputFieldValue, NodeExecutionState, - NodeStatus, - NotesNodeData, - SchedulerInputFieldValue, - SDXLRefinerModelInputFieldValue, - StringInputFieldValue, - T2IAdapterModelInputFieldValue, - VaeModelInputFieldValue, - Workflow, -} from '../types/types'; + zNodeStatus, +} from '../types/invocation'; +import { WorkflowV2 } from '../types/workflow'; import { NodesState } from './types'; -import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; import { findConnectionToValidHandle } from './util/findConnectionToValidHandle'; - -export const WORKFLOW_FORMAT_VERSION = '1.0.0'; +import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; const initialNodeExecutionState: Omit = { - status: NodeStatus.PENDING, + status: zNodeStatus.enum.PENDING, error: null, progress: null, progressImage: null, outputs: [], }; -export const initialWorkflow = { - meta: { - version: WORKFLOW_FORMAT_VERSION, - }, +const INITIAL_WORKFLOW: WorkflowV2 = { name: '', author: '', description: '', - notes: '', - tags: '', - contact: '', version: '', + contact: '', + tags: '', + notes: '', + nodes: [], + edges: [], exposedFields: [], + meta: { version: '2.0.0' }, }; export const initialNodesState: NodesState = { @@ -93,11 +90,10 @@ export const initialNodesState: NodesState = { nodeTemplates: {}, isReady: false, connectionStartParams: null, - currentConnectionFieldType: null, + connectionStartFieldType: null, connectionMade: false, modifyingEdge: false, addNewNodePosition: null, - shouldShowFieldTypeLegend: false, shouldShowMinimapPanel: true, shouldValidateGraph: true, shouldAnimateEdges: true, @@ -107,7 +103,7 @@ export const initialNodesState: NodesState = { nodeOpacity: 1, selectedNodes: [], selectedEdges: [], - workflow: initialWorkflow, + workflow: INITIAL_WORKFLOW, nodeExecutionStates: {}, viewport: { x: 0, y: 0, zoom: 1 }, mouseOverField: null, @@ -117,13 +113,13 @@ export const initialNodesState: NodesState = { selectionMode: SelectionMode.Partial, }; -type FieldValueAction = PayloadAction<{ +type FieldValueAction = PayloadAction<{ nodeId: string; fieldName: string; - value: T['value']; + value: T; }>; -const fieldValueReducer = ( +const fieldValueReducer = ( state: NodesState, action: FieldValueAction ) => { @@ -161,12 +157,7 @@ const nodesSlice = createSlice({ } state.nodes[nodeIndex] = action.payload.node; }, - nodeAdded: ( - state, - action: PayloadAction< - Node - > - ) => { + nodeAdded: (state, action: PayloadAction>) => { const node = action.payload; const position = findUnoccupiedPosition( state.nodes, @@ -203,7 +194,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.currentConnectionFieldType + state.connectionStartFieldType ) { const newConnection = findConnectionToValidHandle( node, @@ -212,7 +203,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.currentConnectionFieldType + state.connectionStartFieldType ); if (newConnection) { state.edges = addEdge( @@ -224,7 +215,7 @@ const nodesSlice = createSlice({ } state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; }, edgeChangeStarted: (state) => { state.modifyingEdge = true; @@ -258,10 +249,10 @@ const nodesSlice = createSlice({ handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; - state.currentConnectionFieldType = field?.type ?? null; + state.connectionStartFieldType = field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { - const fieldType = state.currentConnectionFieldType; + const fieldType = state.connectionStartFieldType; if (!fieldType) { return; } @@ -286,7 +277,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.currentConnectionFieldType + state.connectionStartFieldType ) { const newConnection = findConnectionToValidHandle( mouseOverNode, @@ -295,7 +286,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.currentConnectionFieldType + state.connectionStartFieldType ); if (newConnection) { state.edges = addEdge( @@ -306,14 +297,14 @@ const nodesSlice = createSlice({ } } state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; } else { state.addNewNodePosition = action.payload.cursorPosition; state.isAddNodePopoverOpen = true; } } else { state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; } state.modifyingEdge = false; }, @@ -529,12 +520,7 @@ const nodesSlice = createSlice({ state.edges = applyEdgeChanges(edgeChanges, state.edges); } }, - nodesDeleted: ( - state, - action: PayloadAction< - Node[] - > - ) => { + nodesDeleted: (state, action: PayloadAction[]>) => { action.payload.forEach((node) => { state.workflow.exposedFields = state.workflow.exposedFields.filter( (f) => f.nodeId !== node.id @@ -588,132 +574,94 @@ const nodesSlice = createSlice({ }, fieldStringValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldNumberValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldBooleanValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldBoardValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldImageValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldColorValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldMainModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldRefinerModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldVaeModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldLoRAModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldControlNetModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldIPAdapterModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldT2IAdapterModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldEnumModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldSchedulerValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, - imageCollectionFieldValueChanged: ( - state, - action: PayloadAction<{ - nodeId: string; - fieldName: string; - value: ImageField[]; - }> - ) => { - const { nodeId, fieldName, value } = action.payload; - const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - - if (nodeIndex === -1) { - return; - } - - const node = state.nodes?.[nodeIndex]; - - if (!isInvocationNode(node)) { - return; - } - - const input = node.data?.inputs[fieldName]; - if (!input) { - return; - } - - const currentValue = cloneDeep(input.value); - - if (!currentValue) { - input.value = value; - return; - } - - input.value = uniqBy( - (currentValue as ImageField[]).concat(value), - 'image_name' - ); - }, notesNodeValueChanged: ( state, action: PayloadAction<{ nodeId: string; value: string }> @@ -726,12 +674,6 @@ const nodesSlice = createSlice({ } node.data.notes = value; }, - shouldShowFieldTypeLegendChanged: ( - state, - action: PayloadAction - ) => { - state.shouldShowFieldTypeLegend = action.payload; - }, shouldShowMinimapPanelChanged: (state, action: PayloadAction) => { state.shouldShowMinimapPanel = action.payload; }, @@ -745,7 +687,7 @@ const nodesSlice = createSlice({ nodeEditorReset: (state) => { state.nodes = []; state.edges = []; - state.workflow = cloneDeep(initialWorkflow); + state.workflow = cloneDeep(INITIAL_WORKFLOW); }, shouldValidateGraphChanged: (state, action: PayloadAction) => { state.shouldValidateGraph = action.payload; @@ -783,7 +725,7 @@ const nodesSlice = createSlice({ workflowContactChanged: (state, action: PayloadAction) => { state.workflow.contact = action.payload; }, - workflowLoaded: (state, action: PayloadAction) => { + workflowLoaded: (state, action: PayloadAction) => { const { nodes, edges, ...workflow } = action.payload; state.workflow = workflow; @@ -810,7 +752,7 @@ const nodesSlice = createSlice({ }, {}); }, workflowReset: (state) => { - state.workflow = cloneDeep(initialWorkflow); + state.workflow = cloneDeep(INITIAL_WORKFLOW); }, viewportChanged: (state, action: PayloadAction) => { state.viewport = action.payload; @@ -942,7 +884,7 @@ const nodesSlice = createSlice({ //Make sure these get reset if we close the popover and haven't selected a node state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; }, addNodePopoverToggled: (state) => { state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen; @@ -961,14 +903,14 @@ const nodesSlice = createSlice({ const { source_node_id } = action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = NodeStatus.IN_PROGRESS; + node.status = zNodeStatus.enum.IN_PROGRESS; } }); builder.addCase(appSocketInvocationComplete, (state, action) => { const { source_node_id, result } = action.payload.data; const nes = state.nodeExecutionStates[source_node_id]; if (nes) { - nes.status = NodeStatus.COMPLETED; + nes.status = zNodeStatus.enum.COMPLETED; if (nes.progress !== null) { nes.progress = 1; } @@ -979,7 +921,7 @@ const nodesSlice = createSlice({ const { source_node_id } = action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = NodeStatus.FAILED; + node.status = zNodeStatus.enum.FAILED; node.error = action.payload.data.error; node.progress = null; node.progressImage = null; @@ -990,7 +932,7 @@ const nodesSlice = createSlice({ action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = NodeStatus.IN_PROGRESS; + node.status = zNodeStatus.enum.IN_PROGRESS; node.progress = (step + 1) / total_steps; node.progressImage = progress_image ?? null; } @@ -998,7 +940,7 @@ const nodesSlice = createSlice({ builder.addCase(appSocketQueueItemStatusChanged, (state, action) => { if (['in_progress'].includes(action.payload.data.queue_item.status)) { forEach(state.nodeExecutionStates, (nes) => { - nes.status = NodeStatus.PENDING; + nes.status = zNodeStatus.enum.PENDING; nes.error = null; nes.progress = null; nes.progressImage = null; @@ -1037,7 +979,6 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - imageCollectionFieldValueChanged, mouseOverFieldChanged, mouseOverNodeChanged, nodeAdded, @@ -1063,7 +1004,6 @@ export const { selectionPasted, shouldAnimateEdgesChanged, shouldColorEdgesChanged, - shouldShowFieldTypeLegendChanged, shouldShowMinimapPanelChanged, shouldSnapToGridChanged, shouldValidateGraphChanged, diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index f6bfa7cad8b..b865b9d3a10 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -6,25 +6,23 @@ import { Viewport, XYPosition, } from 'reactflow'; +import { FieldIdentifier, FieldType } from '../types/field'; import { - FieldIdentifier, - FieldType, + AnyNodeData, InvocationEdgeExtra, InvocationTemplate, - NodeData, NodeExecutionState, - Workflow, -} from '../types/types'; +} from '../types/invocation'; +import { WorkflowV2 } from '../types/workflow'; export type NodesState = { - nodes: Node[]; + nodes: Node[]; edges: Edge[]; nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; - currentConnectionFieldType: FieldType | null; + connectionStartFieldType: FieldType | null; connectionMade: boolean; modifyingEdge: boolean; - shouldShowFieldTypeLegend: boolean; shouldShowMinimapPanel: boolean; shouldValidateGraph: boolean; shouldAnimateEdges: boolean; @@ -33,13 +31,13 @@ export type NodesState = { shouldColorEdges: boolean; selectedNodes: string[]; selectedEdges: string[]; - workflow: Omit; + workflow: Omit; nodeExecutionStates: Record; viewport: Viewport; isReady: boolean; mouseOverField: FieldIdentifier | null; mouseOverNode: string | null; - nodesToCopy: Node[]; + nodesToCopy: Node[]; edgesToCopy: Edge[]; isAddNodePopoverOpen: boolean; addNewNodePosition: XYPosition | null; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts index 6cecc8c4098..5328f789ad2 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts @@ -1,78 +1,73 @@ import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + FieldInputInstance, + FieldOutputInstance, +} from 'features/nodes/types/field'; import { CurrentImageNodeData, - InputFieldValue, InvocationNodeData, InvocationTemplate, NotesNodeData, - OutputFieldValue, -} from 'features/nodes/types/types'; -import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders'; +} from 'features/nodes/types/invocation'; +import { buildFieldInputInstance } from 'features/nodes/util/buildFieldInputInstance'; import { reduce } from 'lodash-es'; import { Node, XYPosition } from 'reactflow'; -import { AnyInvocationType } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; export const SHARED_NODE_PROPERTIES: Partial = { dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, }; -export const buildNodeData = ( - type: AnyInvocationType | 'current_image' | 'notes', - position: XYPosition, - template?: InvocationTemplate -): - | Node - | Node - | Node - | undefined => { - const nodeId = uuidv4(); - - if (type === 'current_image') { - const node: Node = { - ...SHARED_NODE_PROPERTIES, - id: nodeId, - type: 'current_image', - position, - data: { - id: nodeId, - type: 'current_image', - isOpen: true, - label: 'Current Image', - }, - }; - return node; - } - - if (type === 'notes') { - const node: Node = { - ...SHARED_NODE_PROPERTIES, +export const buildNotesNode = (position: XYPosition): Node => { + const nodeId = uuidv4(); + const node: Node = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'notes', + position, + data: { id: nodeId, + isOpen: true, + label: 'Notes', + notes: '', type: 'notes', - position, - data: { - id: nodeId, - isOpen: true, - label: 'Notes', - notes: '', - type: 'notes', - }, - }; + }, + }; + return node; +}; - return node; - } +export const buildCurrentImageNode = ( + position: XYPosition +): Node => { + const nodeId = uuidv4(); + const node: Node = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'current_image', + position, + data: { + id: nodeId, + type: 'current_image', + isOpen: true, + label: 'Current Image', + }, + }; + return node; +}; - if (template === undefined) { - console.error(`Unable to find template ${type}.`); - return; - } +export const buildInvocationNode = ( + position: XYPosition, + template: InvocationTemplate +): Node => { + const nodeId = uuidv4(); + const { type } = template; const inputs = reduce( template.inputs, (inputsAccumulator, inputTemplate, inputName) => { const fieldId = uuidv4(); - const inputFieldValue: InputFieldValue = buildInputFieldValue( + const inputFieldValue: FieldInputInstance = buildFieldInputInstance( fieldId, inputTemplate ); @@ -81,7 +76,7 @@ export const buildNodeData = ( return inputsAccumulator; }, - {} as Record + {} as Record ); const outputs = reduce( @@ -89,7 +84,7 @@ export const buildNodeData = ( (outputsAccumulator, outputTemplate, outputName) => { const fieldId = uuidv4(); - const outputFieldValue: OutputFieldValue = { + const outputFieldValue: FieldOutputInstance = { id: fieldId, name: outputName, type: outputTemplate.type, @@ -100,10 +95,10 @@ export const buildNodeData = ( return outputsAccumulator; }, - {} as Record + {} as Record ); - const invocation: Node = { + const node: Node = { ...SHARED_NODE_PROPERTIES, id: nodeId, type: 'invocation', @@ -117,11 +112,11 @@ export const buildNodeData = ( isOpen: true, embedWorkflow: false, isIntermediate: type === 'save_image' ? false : true, + useCache: template.useCache, inputs, outputs, - useCache: template.useCache, }, }; - return invocation; + return node; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 69386c1f23c..0a7adf77cbc 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,20 +1,19 @@ -import { Connection, HandleType } from 'reactflow'; -import { Node, Edge } from 'reactflow'; +import { Connection, Edge, HandleType, Node } from 'reactflow'; + import { + FieldInputInstance, + FieldOutputInstance, FieldType, - InputFieldValue, - OutputFieldValue, -} from 'features/nodes/types/types'; - -import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; +} from 'features/nodes/types/field'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; +import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; const isValidConnection = ( edges: Edge[], handleCurrentType: HandleType, handleCurrentFieldType: FieldType, node: Node, - handle: InputFieldValue | OutputFieldValue + handle: FieldInputInstance | FieldOutputInstance ) => { let isValidConnection = true; if (handleCurrentType === 'source') { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 57dd284b88d..de795612919 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -1,9 +1,9 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { FieldType } from 'features/nodes/types/types'; +import { FieldType } from 'features/nodes/types/field'; import i18n from 'i18next'; import { HandleType } from 'reactflow'; +import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; /** @@ -17,15 +17,15 @@ export const makeConnectionErrorSelector = ( handleType: HandleType, fieldType?: FieldType ) => { - return createSelector(stateSelector, (state) => { + return createSelector(stateSelector, (state): string | undefined => { if (!fieldType) { return i18n.t('nodes.noFieldType'); } - const { currentConnectionFieldType, connectionStartParams, nodes, edges } = + const { connectionStartFieldType, connectionStartParams, nodes, edges } = state.nodes; - if (!connectionStartParams || !currentConnectionFieldType) { + if (!connectionStartParams || !connectionStartFieldType) { return i18n.t('nodes.noConnectionInProgress'); } @@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = ( } const targetType = - handleType === 'target' ? fieldType : currentConnectionFieldType; + handleType === 'target' ? fieldType : connectionStartFieldType; const sourceType = - handleType === 'source' ? fieldType : currentConnectionFieldType; + handleType === 'source' ? fieldType : connectionStartFieldType; if (nodeId === connectionNodeId) { return i18n.t('nodes.cannotConnectToSelf'); @@ -80,7 +80,7 @@ export const makeConnectionErrorSelector = ( return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetType !== 'CollectionItem' + targetType.name !== 'CollectionItemField' ) { return i18n.t('nodes.inputMayOnlyHaveOneConnection'); } @@ -100,6 +100,6 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.connectionWouldCreateCycle'); } - return null; + return; }); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts new file mode 100644 index 00000000000..e9e24823f9a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts @@ -0,0 +1,68 @@ +import { satisfies } from 'compare-versions'; +import { NodeUpdateError } from 'features/nodes/types/error'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/invocation'; +import { zParsedSemver } from 'features/nodes/types/semver'; +import { cloneDeep, defaultsDeep } from 'lodash-es'; +import { Node } from 'reactflow'; +import { buildInvocationNode } from './buildNodeData'; + +export const getNeedsUpdate = ( + node: Node, + template: InvocationTemplate +): boolean => { + if (node.data.type !== template.type) { + return true; + } + return node.data.version !== template.version; +}; /** + * Checks if a node may be updated by comparing its major version with the template's major version. + * @param node The node to check. + * @param template The invocation template to check against. + */ + +export const getMayUpdateNode = ( + node: Node, + template: InvocationTemplate +): boolean => { + const needsUpdate = getNeedsUpdate(node, template); + if (!needsUpdate || node.data.type !== template.type) { + return false; + } + const templateMajor = zParsedSemver.parse(template.version).major; + + return satisfies(node.data.version, `^${templateMajor}`); +}; /** + * Updates a node to the latest version of its template: + * - Create a new node data object with the latest version of the template. + * - Recursively merge new node data object into the node to be updated. + * + * @param node The node to updated. + * @param template The invocation template to update to. + * @throws {NodeUpdateError} If the node is not an invocation node. + */ + +export const updateNode = ( + node: Node, + template: InvocationTemplate +): Node => { + const mayUpdate = getMayUpdateNode(node, template); + + if (!mayUpdate || node.data.type !== template.type) { + throw new NodeUpdateError(`Unable to update node ${node.id}`); + } + + // Start with a "fresh" node - just as if the user created a new node of this type + const defaults = buildInvocationNode(node.position, template); + + // The updateability of a node, via semver comparison, relies on the this kind of recursive merge + // being valid. We rely on the template's major version to be majorly incremented if this kind of + // merge would result in an invalid node. + const clone = cloneDeep(node); + clone.data.version = template.version; + defaultsDeep(clone, defaults); // mutates! + + return clone; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index 2f47e47a787..2770af19e35 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -1,11 +1,12 @@ -import { - COLLECTION_MAP, - COLLECTION_TYPES, - POLYMORPHIC_TO_SINGLE_MAP, - POLYMORPHIC_TYPES, -} from 'features/nodes/types/constants'; -import { FieldType } from 'features/nodes/types/types'; +import { FieldType } from 'features/nodes/types/field'; +import { isEqual } from 'lodash-es'; +/** + * Validates that the source and target types are compatible for a connection. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the connection is valid, false otherwise. + */ export const validateSourceAndTargetTypes = ( sourceType: FieldType, targetType: FieldType @@ -13,11 +14,14 @@ export const validateSourceAndTargetTypes = ( // TODO: There's a bug with Collect -> Iterate nodes: // https://github.com/invoke-ai/InvokeAI/issues/3956 // Once this is resolved, we can remove this check. - if (sourceType === 'Collection' && targetType === 'Collection') { + if ( + sourceType.name === 'CollectionField' && + targetType.name === 'CollectionField' + ) { return false; } - if (sourceType === targetType) { + if (isEqual(sourceType, targetType)) { return true; } @@ -31,46 +35,42 @@ export const validateSourceAndTargetTypes = ( */ const isCollectionItemToNonCollection = - sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType); + sourceType.name === 'CollectionItemField' && !targetType.isCollection; const isNonCollectionToCollectionItem = - targetType === 'CollectionItem' && - !COLLECTION_TYPES.includes(sourceType) && - !POLYMORPHIC_TYPES.includes(sourceType); + targetType.name === 'CollectionItemField' && + !sourceType.isCollection && + !sourceType.isPolymorphic; const isAnythingToPolymorphicOfSameBaseType = - POLYMORPHIC_TYPES.includes(targetType) && - (() => { - if (!POLYMORPHIC_TYPES.includes(targetType)) { - return false; - } - const baseType = - POLYMORPHIC_TO_SINGLE_MAP[ - targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP - ]; - - const collectionType = - COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; - - return sourceType === baseType || sourceType === collectionType; - })(); + targetType.isPolymorphic && sourceType.name === targetType.name; const isGenericCollectionToAnyCollectionOrPolymorphic = - sourceType === 'Collection' && - (COLLECTION_TYPES.includes(targetType) || - POLYMORPHIC_TYPES.includes(targetType)); + sourceType.name === 'CollectionField' && + (targetType.isCollection || targetType.isPolymorphic); const isCollectionToGenericCollection = - targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + targetType.name === 'CollectionField' && sourceType.isCollection; + + const areBothTypesSingle = + !sourceType.isCollection && + !sourceType.isPolymorphic && + !targetType.isCollection && + !targetType.isPolymorphic; - const isIntToFloat = sourceType === 'integer' && targetType === 'float'; + const isIntToFloat = + areBothTypesSingle && + sourceType.name === 'IntegerField' && + targetType.name === 'FloatField'; const isIntOrFloatToString = - (sourceType === 'integer' || sourceType === 'float') && - targetType === 'string'; + areBothTypesSingle && + (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && + targetType.name === 'StringField'; - const isTargetAnyType = targetType === 'Any'; + const isTargetAnyType = targetType.name === 'AnyField'; + // One of these must be true for the connection to be valid return ( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts new file mode 100644 index 00000000000..0cab248c80e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -0,0 +1,216 @@ +import { z } from 'zod'; + +// #region Field data schemas +export const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); +export type ImageField = z.infer; + +export const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); +export type BoardField = z.infer; + +export const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); +export type ColorField = z.infer; + +export const zSchedulerField = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +export type SchedulerField = z.infer; +// #endregion + +// #region Model-related schemas +export const zBaseModel = z.enum([ + 'any', + 'sd-1', + 'sd-2', + 'sdxl', + 'sdxl-refiner', +]); +export const zModelType = z.enum([ + 'onnx', + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', +]); +export const zModelName = z.string().trim().min(1); +export const zModelIdentifier = z.object({ + model_name: zModelName, + base_model: zBaseModel, +}); +export type BaseModel = z.infer; +export type ModelType = z.infer; +export type ModelIdentifier = z.infer; + +export const zMainModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('main'), +}); +export const zONNXModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('onnx'), +}); +export const zMainOrONNXModelField = z.union([ + zMainModelField, + zONNXModelField, +]); +export const zSDXLRefinerModelField = z.object({ + model_name: z.string().min(1), + base_model: z.literal('sdxl-refiner'), + model_type: z.literal('main'), +}); +export type MainModelField = z.infer; +export type ONNXModelField = z.infer; +export type MainOrONNXModelField = z.infer; +export type SDXLRefinerModelField = z.infer; + +export const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); +export type SubModelType = z.infer; + +export const zVAEModelField = zModelIdentifier; + +export const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); +export type ModelInfo = z.infer; + +export const zLoRAModelField = zModelIdentifier; +export type LoRAModelField = z.infer; + +export const zControlNetModelField = zModelIdentifier; +export type ControlNetModelField = z.infer; + +export const zIPAdapterModelField = zModelIdentifier; +export type IPAdapterModelField = z.infer; + +export const zT2IAdapterModelField = zModelIdentifier; +export type T2IAdapterModelField = z.infer; + +export const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); +export type LoraInfo = z.infer; + +export const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); +export type UNetField = z.infer; + +export const zCLIPField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); +export type CLIPField = z.infer; + +export const zVAEField = z.object({ + vae: zModelInfo, +}); +export type VAEField = z.infer; +// #endregion + +// #region Control Adapters +export const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModelField, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z + .enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']) + .optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); +export type ControlField = z.infer; + +export const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModelField, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); +export type IPAdapterField = z.infer; + +export const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModelField, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); +export type T2IAdapterField = z.infer; +// #endregion + +// #region ProgressImage +export const zProgressImage = z.object({ + dataURL: z.string(), + width: z.number().int(), + height: z.number().int(), +}); +export type ProgressImage = z.infer; +// #endregion + +// #region ImageOutput +export const zImageOutput = z.object({ + image: zImageField, + width: z.number().int(), + height: z.number().int(), + type: z.literal('image_output'), +}); +export type ImageOutput = z.infer; +export const isImageOutput = (output: unknown): output is ImageOutput => + zImageOutput.safeParse(output).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index c6eec736da0..a97899de91f 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -1,58 +1,31 @@ -import { - FieldType, - FieldTypeMap, - FieldTypeMapWithNumber, - FieldUIConfig, -} from './types'; -import { t } from 'i18next'; - +/** + * How long to wait before showing a tooltip when hovering a field handle. + */ export const HANDLE_TOOLTIP_OPEN_DELAY = 500; -export const COLOR_TOKEN_VALUE = 500; + +/** + * The width of a node in the UI in pixels. + */ export const NODE_WIDTH = 320; -export const NODE_MIN_WIDTH = 320; -export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; -export const IMAGE_FIELDS = ['ImageField', 'ImageCollection']; -export const FOOTER_FIELDS = IMAGE_FIELDS; +/** + * This class name is special - reactflow uses it to identify the drag handle of a node, + * applying the appropriate listeners to it. + */ +export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; +/** + * Helper for getting the kind of a field. + */ export const KIND_MAP = { input: 'inputs' as const, output: 'outputs' as const, }; -export const COLLECTION_TYPES: FieldType[] = [ - 'Collection', - 'IntegerCollection', - 'BooleanCollection', - 'FloatCollection', - 'StringCollection', - 'ImageCollection', - 'LatentsCollection', - 'ConditioningCollection', - 'ControlCollection', - 'ColorCollection', - 'T2IAdapterCollection', - 'IPAdapterCollection', - 'MetadataItemCollection', - 'MetadataCollection', -]; - -export const POLYMORPHIC_TYPES: FieldType[] = [ - 'IntegerPolymorphic', - 'BooleanPolymorphic', - 'FloatPolymorphic', - 'StringPolymorphic', - 'ImagePolymorphic', - 'LatentsPolymorphic', - 'ConditioningPolymorphic', - 'ControlPolymorphic', - 'ColorPolymorphic', - 'T2IAdapterPolymorphic', - 'IPAdapterPolymorphic', - 'MetadataItemPolymorphic', -]; - -export const MODEL_TYPES: FieldType[] = [ +/** + * Model types' handles are rendered as squares in the UI. + */ +export const MODEL_TYPES = [ 'IPAdapterModelField', 'ControlNetModelField', 'LoRAModelField', @@ -68,373 +41,33 @@ export const MODEL_TYPES: FieldType[] = [ 'IPAdapterModelField', ]; -export const COLLECTION_MAP: FieldTypeMapWithNumber = { - integer: 'IntegerCollection', - boolean: 'BooleanCollection', - number: 'FloatCollection', - float: 'FloatCollection', - string: 'StringCollection', - ImageField: 'ImageCollection', - LatentsField: 'LatentsCollection', - ConditioningField: 'ConditioningCollection', - ControlField: 'ControlCollection', - ColorField: 'ColorCollection', - T2IAdapterField: 'T2IAdapterCollection', - IPAdapterField: 'IPAdapterCollection', - MetadataItemField: 'MetadataItemCollection', - MetadataField: 'MetadataCollection', -}; -export const isCollectionItemType = ( - itemType: string | undefined -): itemType is keyof typeof COLLECTION_MAP => - Boolean(itemType && itemType in COLLECTION_MAP); - -export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = { - integer: 'IntegerPolymorphic', - boolean: 'BooleanPolymorphic', - number: 'FloatPolymorphic', - float: 'FloatPolymorphic', - string: 'StringPolymorphic', - ImageField: 'ImagePolymorphic', - LatentsField: 'LatentsPolymorphic', - ConditioningField: 'ConditioningPolymorphic', - ControlField: 'ControlPolymorphic', - ColorField: 'ColorPolymorphic', - T2IAdapterField: 'T2IAdapterPolymorphic', - IPAdapterField: 'IPAdapterPolymorphic', - MetadataItemField: 'MetadataItemPolymorphic', -}; - -export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = { - IntegerPolymorphic: 'integer', - BooleanPolymorphic: 'boolean', - FloatPolymorphic: 'float', - StringPolymorphic: 'string', - ImagePolymorphic: 'ImageField', - LatentsPolymorphic: 'LatentsField', - ConditioningPolymorphic: 'ConditioningField', - ControlPolymorphic: 'ControlField', - ColorPolymorphic: 'ColorField', - T2IAdapterPolymorphic: 'T2IAdapterField', - IPAdapterPolymorphic: 'IPAdapterField', - MetadataItemPolymorphic: 'MetadataItemField', -}; - -export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [ - 'string', - 'StringPolymorphic', - 'boolean', - 'BooleanPolymorphic', - 'integer', - 'float', - 'FloatPolymorphic', - 'IntegerPolymorphic', - 'enum', - 'ImageField', - 'ImagePolymorphic', - 'MainModelField', - 'SDXLRefinerModelField', - 'VaeModelField', - 'LoRAModelField', - 'ControlNetModelField', - 'ColorField', - 'SDXLMainModelField', - 'Scheduler', - 'IPAdapterModelField', - 'BoardField', - 'T2IAdapterModelField', -]; - -export const isPolymorphicItemType = ( - itemType: string | undefined -): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP => - Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP); - -export const FIELDS: Record = { - Any: { - color: 'gray.500', - description: 'Any field type is accepted.', - title: 'Any', - }, - MetadataField: { - color: 'gray.500', - description: 'A metadata dict.', - title: 'Metadata Dict', - }, - MetadataCollection: { - color: 'gray.500', - description: 'A collection of metadata dicts.', - title: 'Metadata Dict Collection', - }, - MetadataItemField: { - color: 'gray.500', - description: 'A metadata item.', - title: 'Metadata Item', - }, - MetadataItemCollection: { - color: 'gray.500', - description: 'Any field type is accepted.', - title: 'Metadata Item Collection', - }, - MetadataItemPolymorphic: { - color: 'gray.500', - description: - 'MetadataItem or MetadataItemCollection field types are accepted.', - title: 'Metadata Item Polymorphic', - }, - boolean: { - color: 'green.500', - description: t('nodes.booleanDescription'), - title: t('nodes.boolean'), - }, - BooleanCollection: { - color: 'green.500', - description: t('nodes.booleanCollectionDescription'), - title: t('nodes.booleanCollection'), - }, - BooleanPolymorphic: { - color: 'green.500', - description: t('nodes.booleanPolymorphicDescription'), - title: t('nodes.booleanPolymorphic'), - }, - ClipField: { - color: 'green.500', - description: t('nodes.clipFieldDescription'), - title: t('nodes.clipField'), - }, - Collection: { - color: 'base.500', - description: t('nodes.collectionDescription'), - title: t('nodes.collection'), - }, - CollectionItem: { - color: 'base.500', - description: t('nodes.collectionItemDescription'), - title: t('nodes.collectionItem'), - }, - ColorCollection: { - color: 'pink.300', - description: t('nodes.colorCollectionDescription'), - title: t('nodes.colorCollection'), - }, - ColorField: { - color: 'pink.300', - description: t('nodes.colorFieldDescription'), - title: t('nodes.colorField'), - }, - ColorPolymorphic: { - color: 'pink.300', - description: t('nodes.colorPolymorphicDescription'), - title: t('nodes.colorPolymorphic'), - }, - ConditioningCollection: { - color: 'cyan.500', - description: t('nodes.conditioningCollectionDescription'), - title: t('nodes.conditioningCollection'), - }, - ConditioningField: { - color: 'cyan.500', - description: t('nodes.conditioningFieldDescription'), - title: t('nodes.conditioningField'), - }, - ConditioningPolymorphic: { - color: 'cyan.500', - description: t('nodes.conditioningPolymorphicDescription'), - title: t('nodes.conditioningPolymorphic'), - }, - ControlCollection: { - color: 'teal.500', - description: t('nodes.controlCollectionDescription'), - title: t('nodes.controlCollection'), - }, - ControlField: { - color: 'teal.500', - description: t('nodes.controlFieldDescription'), - title: t('nodes.controlField'), - }, - ControlNetModelField: { - color: 'teal.500', - description: 'TODO', - title: 'ControlNet', - }, - ControlPolymorphic: { - color: 'teal.500', - description: 'Control info passed between nodes.', - title: 'Control Polymorphic', - }, - DenoiseMaskField: { - color: 'blue.300', - description: t('nodes.denoiseMaskFieldDescription'), - title: t('nodes.denoiseMaskField'), - }, - enum: { - color: 'blue.500', - description: t('nodes.enumDescription'), - title: t('nodes.enum'), - }, - float: { - color: 'orange.500', - description: t('nodes.floatDescription'), - title: t('nodes.float'), - }, - FloatCollection: { - color: 'orange.500', - description: t('nodes.floatCollectionDescription'), - title: t('nodes.floatCollection'), - }, - FloatPolymorphic: { - color: 'orange.500', - description: t('nodes.floatPolymorphicDescription'), - title: t('nodes.floatPolymorphic'), - }, - ImageCollection: { - color: 'purple.500', - description: t('nodes.imageCollectionDescription'), - title: t('nodes.imageCollection'), - }, - ImageField: { - color: 'purple.500', - description: t('nodes.imageFieldDescription'), - title: t('nodes.imageField'), - }, - BoardField: { - color: 'purple.500', - description: t('nodes.imageFieldDescription'), - title: t('nodes.imageField'), - }, - ImagePolymorphic: { - color: 'purple.500', - description: t('nodes.imagePolymorphicDescription'), - title: t('nodes.imagePolymorphic'), - }, - integer: { - color: 'red.500', - description: t('nodes.integerDescription'), - title: t('nodes.integer'), - }, - IntegerCollection: { - color: 'red.500', - description: t('nodes.integerCollectionDescription'), - title: t('nodes.integerCollection'), - }, - IntegerPolymorphic: { - color: 'red.500', - description: t('nodes.integerPolymorphicDescription'), - title: t('nodes.integerPolymorphic'), - }, - IPAdapterCollection: { - color: 'teal.500', - description: t('nodes.ipAdapterCollectionDescription'), - title: t('nodes.ipAdapterCollection'), - }, - IPAdapterField: { - color: 'teal.500', - description: t('nodes.ipAdapterDescription'), - title: t('nodes.ipAdapter'), - }, - IPAdapterModelField: { - color: 'teal.500', - description: t('nodes.ipAdapterModelDescription'), - title: t('nodes.ipAdapterModel'), - }, - IPAdapterPolymorphic: { - color: 'teal.500', - description: t('nodes.ipAdapterPolymorphicDescription'), - title: t('nodes.ipAdapterPolymorphic'), - }, - LatentsCollection: { - color: 'pink.500', - description: t('nodes.latentsCollectionDescription'), - title: t('nodes.latentsCollection'), - }, - LatentsField: { - color: 'pink.500', - description: t('nodes.latentsFieldDescription'), - title: t('nodes.latentsField'), - }, - LatentsPolymorphic: { - color: 'pink.500', - description: t('nodes.latentsPolymorphicDescription'), - title: t('nodes.latentsPolymorphic'), - }, - LoRAModelField: { - color: 'teal.500', - description: t('nodes.loRAModelFieldDescription'), - title: t('nodes.loRAModelField'), - }, - MainModelField: { - color: 'teal.500', - description: t('nodes.mainModelFieldDescription'), - title: t('nodes.mainModelField'), - }, - ONNXModelField: { - color: 'teal.500', - description: t('nodes.oNNXModelFieldDescription'), - title: t('nodes.oNNXModelField'), - }, - Scheduler: { - color: 'base.500', - description: t('nodes.schedulerDescription'), - title: t('nodes.scheduler'), - }, - SDXLMainModelField: { - color: 'teal.500', - description: t('nodes.sDXLMainModelFieldDescription'), - title: t('nodes.sDXLMainModelField'), - }, - SDXLRefinerModelField: { - color: 'teal.500', - description: t('nodes.sDXLRefinerModelFieldDescription'), - title: t('nodes.sDXLRefinerModelField'), - }, - string: { - color: 'yellow.500', - description: t('nodes.stringDescription'), - title: t('nodes.string'), - }, - StringCollection: { - color: 'yellow.500', - description: t('nodes.stringCollectionDescription'), - title: t('nodes.stringCollection'), - }, - StringPolymorphic: { - color: 'yellow.500', - description: t('nodes.stringPolymorphicDescription'), - title: t('nodes.stringPolymorphic'), - }, - T2IAdapterCollection: { - color: 'teal.500', - description: t('nodes.t2iAdapterCollectionDescription'), - title: t('nodes.t2iAdapterCollection'), - }, - T2IAdapterField: { - color: 'teal.500', - description: t('nodes.t2iAdapterFieldDescription'), - title: t('nodes.t2iAdapterField'), - }, - T2IAdapterModelField: { - color: 'teal.500', - description: 'TODO', - title: 'T2I-Adapter', - }, - T2IAdapterPolymorphic: { - color: 'teal.500', - description: 'T2I-Adapter info passed between nodes.', - title: 'T2I-Adapter Polymorphic', - }, - UNetField: { - color: 'red.500', - description: t('nodes.uNetFieldDescription'), - title: t('nodes.uNetField'), - }, - VaeField: { - color: 'blue.500', - description: t('nodes.vaeFieldDescription'), - title: t('nodes.vaeField'), - }, - VaeModelField: { - color: 'teal.500', - description: t('nodes.vaeModelFieldDescription'), - title: t('nodes.vaeModelField'), - }, +/** + * Colors for each field type - applies to their handles and edges. + */ +export const FIELD_COLORS: { [key: string]: string } = { + BoardField: 'purple.500', + BooleanField: 'green.500', + ClipField: 'green.500', + ColorField: 'pink.300', + ConditioningField: 'cyan.500', + ControlField: 'teal.500', + ControlNetModelField: 'teal.500', + EnumField: 'blue.500', + FloatField: 'orange.500', + ImageField: 'purple.500', + IntegerField: 'red.500', + IPAdapterField: 'teal.500', + IPAdapterModelField: 'teal.500', + LatentsField: 'pink.500', + LoRAModelField: 'teal.500', + MainModelField: 'teal.500', + ONNXModelField: 'teal.500', + SDXLMainModelField: 'teal.500', + SDXLRefinerModelField: 'teal.500', + StringField: 'yellow.500', + T2IAdapterField: 'teal.500', + T2IAdapterModelField: 'teal.500', + UNetField: 'red.500', + VaeField: 'blue.500', + VaeModelField: 'teal.500', }; diff --git a/invokeai/frontend/web/src/features/nodes/types/error.ts b/invokeai/frontend/web/src/features/nodes/types/error.ts new file mode 100644 index 00000000000..b7ffb753bc8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/error.ts @@ -0,0 +1,59 @@ +/** + * Invalid Workflow Version Error + * Raised when a workflow version is not recognized. + */ +export class WorkflowVersionError extends Error { + /** + * Create WorkflowVersionError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Unable to Update Node Error + * Raised when a node cannot be updated. + */ +export class NodeUpdateError extends Error { + /** + * Create NodeUpdateError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * FieldTypeParseError + * Raised when a field cannot be parsed from a field schema. + */ +export class FieldTypeParseError extends Error { + /** + * Create FieldTypeParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * UnsupportedFieldTypeError + * Raised when an unsupported field type is parsed. + */ +export class UnsupportedFieldTypeError extends Error { + /** + * Create UnsupportedFieldTypeError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts new file mode 100644 index 00000000000..dd1c50f6e30 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -0,0 +1,1114 @@ +import { z } from 'zod'; +import { + zBoardField, + zColorField, + zControlNetModelField, + zIPAdapterModelField, + zImageField, + zLoRAModelField, + zMainOrONNXModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from './common'; + +/** + * zod schemas & inferred types for input field values. + * + * These schemas and types are only required for field types that have UI components and allow the + * user to directly provide values. + * + * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. + * + * If a field type does not have a UI component, then it does not need to be included here, because + * we never store its value. Such field types will be handled via the "StatelessField" logic. + * + * Fields require: + * - zFieldType - zod schema for the field type + * - zFieldValue - zod schema for the field value + * - zFieldInputInstance - zod schema for the field's input instance + * - zFieldOutputInstance - zod schema for the field's output instance + * - zFieldInputTemplate - zod schema for the field's input template + * - zFieldOutputTemplate - zod schema for the field's output template + * + * These then must be added to the unions at the bottom of this file. + */ + +/** */ + +// #region Base schemas & misc +export const zFieldInput = z.enum(['connection', 'direct', 'any']); +export type FieldInput = z.infer; + +export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); +export type FieldUIComponent = z.infer; + +export const zFieldInstanceBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), +}); +export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('input'), + label: z.string().nullish(), +}); +export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldInstanceBase = z.infer; +export type FieldInputInstanceBase = z.infer; +export type FieldOutputInstanceBase = z.infer; + +export const zFieldTemplateBase = z.object({ + name: z.string().min(1), + title: z.string().min(1), + description: z.string().nullish(), + ui_hidden: z.boolean(), + ui_type: z.string().nullish(), + ui_order: z.number().int().nullish(), +}); +export const zFieldInputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('input'), + input: zFieldInput, + required: z.boolean(), + ui_component: zFieldUIComponent.nullish(), + ui_choice_labels: z.record(z.string()).nullish(), +}); +export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldTemplateBase = z.infer; +export type FieldInputTemplateBase = z.infer; +export type FieldOutputTemplateBase = z.infer; + +export const zFieldTypeBase = z.object({ + isCollection: z.boolean(), + isPolymorphic: z.boolean(), +}); + +export const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); +export type FieldIdentifier = z.infer; +export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => + zFieldIdentifier.safeParse(val).success; +// #endregion + +// #region IntegerField +export const zIntegerFieldType = zFieldTypeBase.extend({ + name: z.literal('IntegerField'), +}); +export const zIntegerFieldValue = z.number().int(); +export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIntegerFieldType, + value: zIntegerFieldValue, +}); +export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIntegerFieldType, +}); +export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIntegerFieldType, + default: zIntegerFieldValue, + multipleOf: z.number().int().optional(), + maximum: z.number().int().optional(), + exclusiveMaximum: z.number().int().optional(), + minimum: z.number().int().optional(), + exclusiveMinimum: z.number().int().optional(), +}); +export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIntegerFieldType, +}); +export type IntegerFieldType = z.infer; +export type IntegerFieldValue = z.infer; +export type IntegerFieldInputInstance = z.infer< + typeof zIntegerFieldInputInstance +>; +export type IntegerFieldInputTemplate = z.infer< + typeof zIntegerFieldInputTemplate +>; +export const isIntegerFieldInputInstance = ( + val: unknown +): val is IntegerFieldInputInstance => + zIntegerFieldInputInstance.safeParse(val).success; +export const isIntegerFieldInputTemplate = ( + val: unknown +): val is IntegerFieldInputTemplate => + zIntegerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region FloatField +export const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), +}); +export const zFloatFieldValue = z.number(); +export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFloatFieldType, + value: zFloatFieldValue, +}); +export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFloatFieldType, +}); +export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFloatFieldType, + default: zFloatFieldValue, + multipleOf: z.number().optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.number().optional(), +}); +export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFloatFieldType, +}); +export type FloatFieldType = z.infer; +export type FloatFieldValue = z.infer; +export type FloatFieldInputInstance = z.infer; +export type FloatFieldOutputInstance = z.infer< + typeof zFloatFieldOutputInstance +>; +export type FloatFieldInputTemplate = z.infer; +export type FloatFieldOutputTemplate = z.infer< + typeof zFloatFieldOutputTemplate +>; +export const isFloatFieldInputInstance = ( + val: unknown +): val is FloatFieldInputInstance => + zFloatFieldInputInstance.safeParse(val).success; +export const isFloatFieldInputTemplate = ( + val: unknown +): val is FloatFieldInputTemplate => + zFloatFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StringField +export const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), +}); +export const zStringFieldValue = z.string(); +export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStringFieldType, + value: zStringFieldValue, +}); +export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStringFieldType, +}); +export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStringFieldType, + default: zStringFieldValue, + maxLength: z.number().int().optional(), + minLength: z.number().int().optional(), +}); +export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStringFieldType, +}); + +export type StringFieldType = z.infer; +export type StringFieldValue = z.infer; +export type StringFieldInputInstance = z.infer< + typeof zStringFieldInputInstance +>; +export type StringFieldOutputInstance = z.infer< + typeof zStringFieldOutputInstance +>; +export type StringFieldInputTemplate = z.infer< + typeof zStringFieldInputTemplate +>; +export type StringFieldOutputTemplate = z.infer< + typeof zStringFieldOutputTemplate +>; +export const isStringFieldInputInstance = ( + val: unknown +): val is StringFieldInputInstance => + zStringFieldInputInstance.safeParse(val).success; +export const isStringFieldInputTemplate = ( + val: unknown +): val is StringFieldInputTemplate => + zStringFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BooleanField +export const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), +}); +export const zBooleanFieldValue = z.boolean(); +export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBooleanFieldType, + value: zBooleanFieldValue, +}); +export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBooleanFieldType, +}); +export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBooleanFieldType, + default: zBooleanFieldValue, +}); +export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBooleanFieldType, +}); +export type BooleanFieldType = z.infer; +export type BooleanFieldValue = z.infer; +export type BooleanFieldInputInstance = z.infer< + typeof zBooleanFieldInputInstance +>; +export type BooleanFieldOutputInstance = z.infer< + typeof zBooleanFieldOutputInstance +>; +export type BooleanFieldInputTemplate = z.infer< + typeof zBooleanFieldInputTemplate +>; +export type BooleanFieldOutputTemplate = z.infer< + typeof zBooleanFieldOutputTemplate +>; +export const isBooleanFieldInputInstance = ( + val: unknown +): val is BooleanFieldInputInstance => + zBooleanFieldInputInstance.safeParse(val).success; +export const isBooleanFieldInputTemplate = ( + val: unknown +): val is BooleanFieldInputTemplate => + zBooleanFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region EnumField +export const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), +}); +export const zEnumFieldValue = z.string(); +export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zEnumFieldType, + value: zEnumFieldValue, +}); +export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zEnumFieldType, +}); +export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zEnumFieldType, + default: zEnumFieldValue, + options: z.array(z.string()), + labels: z.record(z.string()).optional(), +}); +export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zEnumFieldType, +}); +export type EnumFieldType = z.infer; +export type EnumFieldValue = z.infer; +export type EnumFieldInputInstance = z.infer; +export type EnumFieldOutputInstance = z.infer; +export type EnumFieldInputTemplate = z.infer; +export type EnumFieldOutputTemplate = z.infer; +export const isEnumFieldInputInstance = ( + val: unknown +): val is EnumFieldInputInstance => + zEnumFieldInputInstance.safeParse(val).success; +export const isEnumFieldInputTemplate = ( + val: unknown +): val is EnumFieldInputTemplate => + zEnumFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ImageField +export const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), +}); +export const zImageFieldValue = zImageField.optional(); +export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zImageFieldType, + value: zImageFieldValue, +}); +export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zImageFieldType, +}); +export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zImageFieldType, + default: zImageFieldValue, +}); +export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zImageFieldType, +}); +export type ImageFieldType = z.infer; +export type ImageFieldValue = z.infer; +export type ImageFieldInputInstance = z.infer; +export type ImageFieldOutputInstance = z.infer< + typeof zImageFieldOutputInstance +>; +export type ImageFieldInputTemplate = z.infer; +export type ImageFieldOutputTemplate = z.infer< + typeof zImageFieldOutputTemplate +>; +export const isImageFieldInputInstance = ( + val: unknown +): val is ImageFieldInputInstance => + zImageFieldInputInstance.safeParse(val).success; +export const isImageFieldInputTemplate = ( + val: unknown +): val is ImageFieldInputTemplate => + zImageFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BoardField +export const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), +}); +export const zBoardFieldValue = zBoardField.optional(); +export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBoardFieldType, + value: zBoardFieldValue, +}); +export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBoardFieldType, +}); +export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBoardFieldType, + default: zBoardFieldValue, +}); +export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBoardFieldType, +}); +export type BoardFieldType = z.infer; +export type BoardFieldValue = z.infer; +export type BoardFieldInputInstance = z.infer; +export type BoardFieldOutputInstance = z.infer< + typeof zBoardFieldOutputInstance +>; +export type BoardFieldInputTemplate = z.infer; +export type BoardFieldOutputTemplate = z.infer< + typeof zBoardFieldOutputTemplate +>; +export const isBoardFieldInputInstance = ( + val: unknown +): val is BoardFieldInputInstance => + zBoardFieldInputInstance.safeParse(val).success; +export const isBoardFieldInputTemplate = ( + val: unknown +): val is BoardFieldInputTemplate => + zBoardFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ColorField +export const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), +}); +export const zColorFieldValue = zColorField.optional(); +export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zColorFieldType, + value: zColorFieldValue, +}); +export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zColorFieldType, +}); +export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zColorFieldType, + default: zColorFieldValue, +}); +export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zColorFieldType, +}); +export type ColorFieldType = z.infer; +export type ColorFieldValue = z.infer; +export type ColorFieldInputInstance = z.infer; +export type ColorFieldOutputInstance = z.infer< + typeof zColorFieldOutputInstance +>; +export type ColorFieldInputTemplate = z.infer; +export type ColorFieldOutputTemplate = z.infer< + typeof zColorFieldOutputTemplate +>; +export const isColorFieldInputInstance = ( + val: unknown +): val is ColorFieldInputInstance => + zColorFieldInputInstance.safeParse(val).success; +export const isColorFieldInputTemplate = ( + val: unknown +): val is ColorFieldInputTemplate => + zColorFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region MainModelField +export const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), +}); +export const zMainModelFieldValue = zMainOrONNXModelField.optional(); +export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zMainModelFieldType, + value: zMainModelFieldValue, +}); +export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zMainModelFieldType, +}); +export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zMainModelFieldType, + default: zMainModelFieldValue, +}); +export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zMainModelFieldType, +}); +export type MainModelFieldType = z.infer; +export type MainModelFieldValue = z.infer; +export type MainModelFieldInputInstance = z.infer< + typeof zMainModelFieldInputInstance +>; +export type MainModelFieldOutputInstance = z.infer< + typeof zMainModelFieldOutputInstance +>; +export type MainModelFieldInputTemplate = z.infer< + typeof zMainModelFieldInputTemplate +>; +export type MainModelFieldOutputTemplate = z.infer< + typeof zMainModelFieldOutputTemplate +>; +export const isMainModelFieldInputInstance = ( + val: unknown +): val is MainModelFieldInputInstance => + zMainModelFieldInputInstance.safeParse(val).success; +export const isMainModelFieldInputTemplate = ( + val: unknown +): val is MainModelFieldInputTemplate => + zMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField +export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), +}); +export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + value: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + }); +export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + default: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + }); +export type SDXLMainModelFieldType = z.infer; +export type SDXLMainModelFieldValue = z.infer; +export type SDXLMainModelFieldInputInstance = z.infer< + typeof zSDXLMainModelFieldInputInstance +>; +export type SDXLMainModelFieldOutputInstance = z.infer< + typeof zSDXLMainModelFieldOutputInstance +>; +export type SDXLMainModelFieldInputTemplate = z.infer< + typeof zSDXLMainModelFieldInputTemplate +>; +export type SDXLMainModelFieldOutputTemplate = z.infer< + typeof zSDXLMainModelFieldOutputTemplate +>; +export const isSDXLMainModelFieldInputInstance = ( + val: unknown +): val is SDXLMainModelFieldInputInstance => + zSDXLMainModelFieldInputInstance.safeParse(val).success; +export const isSDXLMainModelFieldInputTemplate = ( + val: unknown +): val is SDXLMainModelFieldInputTemplate => + zSDXLMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLRefinerModelField +export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), +}); +export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. +export const zSDXLRefinerModelFieldInputInstance = + zFieldInputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + value: zSDXLRefinerModelFieldValue, + }); +export const zSDXLRefinerModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + }); +export const zSDXLRefinerModelFieldInputTemplate = + zFieldInputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + default: zSDXLRefinerModelFieldValue, + }); +export const zSDXLRefinerModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + }); +export type SDXLRefinerModelFieldType = z.infer< + typeof zSDXLRefinerModelFieldType +>; +export type SDXLRefinerModelFieldValue = z.infer< + typeof zSDXLRefinerModelFieldValue +>; +export type SDXLRefinerModelFieldInputInstance = z.infer< + typeof zSDXLRefinerModelFieldInputInstance +>; +export type SDXLRefinerModelFieldOutputInstance = z.infer< + typeof zSDXLRefinerModelFieldOutputInstance +>; +export type SDXLRefinerModelFieldInputTemplate = z.infer< + typeof zSDXLRefinerModelFieldInputTemplate +>; +export type SDXLRefinerModelFieldOutputTemplate = z.infer< + typeof zSDXLRefinerModelFieldOutputTemplate +>; +export const isSDXLRefinerModelFieldInputInstance = ( + val: unknown +): val is SDXLRefinerModelFieldInputInstance => + zSDXLRefinerModelFieldInputInstance.safeParse(val).success; +export const isSDXLRefinerModelFieldInputTemplate = ( + val: unknown +): val is SDXLRefinerModelFieldInputTemplate => + zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region VAEModelField +export const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), +}); +export const zVAEModelFieldValue = zVAEModelField.optional(); +export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zVAEModelFieldType, + value: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zVAEModelFieldType, +}); +export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zVAEModelFieldType, + default: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zVAEModelFieldType, +}); +export type VAEModelFieldType = z.infer; +export type VAEModelFieldValue = z.infer; +export type VAEModelFieldInputInstance = z.infer< + typeof zVAEModelFieldInputInstance +>; +export type VAEModelFieldOutputInstance = z.infer< + typeof zVAEModelFieldOutputInstance +>; +export type VAEModelFieldInputTemplate = z.infer< + typeof zVAEModelFieldInputTemplate +>; +export type VAEModelFieldOutputTemplate = z.infer< + typeof zVAEModelFieldOutputTemplate +>; +export const isVAEModelFieldInputInstance = ( + val: unknown +): val is VAEModelFieldInputInstance => + zVAEModelFieldInputInstance.safeParse(val).success; +export const isVAEModelFieldInputTemplate = ( + val: unknown +): val is VAEModelFieldInputTemplate => + zVAEModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region LoRAModelField +export const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), +}); +export const zLoRAModelFieldValue = zLoRAModelField.optional(); +export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zLoRAModelFieldType, + value: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zLoRAModelFieldType, +}); +export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zLoRAModelFieldType, + default: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zLoRAModelFieldType, +}); +export type LoRAModelFieldType = z.infer; +export type LoRAModelFieldValue = z.infer; +export type LoRAModelFieldInputInstance = z.infer< + typeof zLoRAModelFieldInputInstance +>; +export type LoRAModelFieldOutputInstance = z.infer< + typeof zLoRAModelFieldOutputInstance +>; +export type LoRAModelFieldInputTemplate = z.infer< + typeof zLoRAModelFieldInputTemplate +>; +export type LoRAModelFieldOutputTemplate = z.infer< + typeof zLoRAModelFieldOutputTemplate +>; +export const isLoRAModelFieldInputInstance = ( + val: unknown +): val is LoRAModelFieldInputInstance => + zLoRAModelFieldInputInstance.safeParse(val).success; +export const isLoRAModelFieldInputTemplate = ( + val: unknown +): val is LoRAModelFieldInputTemplate => + zLoRAModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ControlNetModelField +export const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), +}); +export const zControlNetModelFieldValue = zControlNetModelField.optional(); +export const zControlNetModelFieldInputInstance = + zFieldInputInstanceBase.extend({ + type: zControlNetModelFieldType, + value: zControlNetModelFieldValue, + }); +export const zControlNetModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zControlNetModelFieldType, + }); +export const zControlNetModelFieldInputTemplate = + zFieldInputTemplateBase.extend({ + type: zControlNetModelFieldType, + default: zControlNetModelFieldValue, + }); +export const zControlNetModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zControlNetModelFieldType, + }); +export type ControlNetModelFieldType = z.infer< + typeof zControlNetModelFieldType +>; +export type ControlNetModelFieldValue = z.infer< + typeof zControlNetModelFieldValue +>; +export type ControlNetModelFieldInputInstance = z.infer< + typeof zControlNetModelFieldInputInstance +>; +export type ControlNetModelFieldOutputInstance = z.infer< + typeof zControlNetModelFieldOutputInstance +>; +export type ControlNetModelFieldInputTemplate = z.infer< + typeof zControlNetModelFieldInputTemplate +>; +export type ControlNetModelFieldOutputTemplate = z.infer< + typeof zControlNetModelFieldOutputTemplate +>; +export const isControlNetModelFieldInputInstance = ( + val: unknown +): val is ControlNetModelFieldInputInstance => + zControlNetModelFieldInputInstance.safeParse(val).success; +export const isControlNetModelFieldInputTemplate = ( + val: unknown +): val is ControlNetModelFieldInputTemplate => + zControlNetModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region IPAdapterModelField +export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), +}); +export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); +export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend( + { + type: zIPAdapterModelFieldType, + value: zIPAdapterModelFieldValue, + } +); +export const zIPAdapterModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zIPAdapterModelFieldType, + }); +export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend( + { type: zIPAdapterModelFieldType, default: zIPAdapterModelFieldValue } +); +export const zIPAdapterModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zIPAdapterModelFieldType, + }); +export type IPAdapterModelFieldType = z.infer; +export type IPAdapterModelFieldValue = z.infer< + typeof zIPAdapterModelFieldValue +>; +export type IPAdapterModelFieldInputInstance = z.infer< + typeof zIPAdapterModelFieldInputInstance +>; +export type IPAdapterModelFieldOutputInstance = z.infer< + typeof zIPAdapterModelFieldOutputInstance +>; +export type IPAdapterModelFieldInputTemplate = z.infer< + typeof zIPAdapterModelFieldInputTemplate +>; +export type IPAdapterModelFieldOutputTemplate = z.infer< + typeof zIPAdapterModelFieldOutputTemplate +>; +export const isIPAdapterModelFieldInputInstance = ( + val: unknown +): val is IPAdapterModelFieldInputInstance => + zIPAdapterModelFieldInputInstance.safeParse(val).success; +export const isIPAdapterModelFieldInputTemplate = ( + val: unknown +): val is IPAdapterModelFieldInputTemplate => + zIPAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region T2IAdapterField +export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), +}); +export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); +export const zT2IAdapterModelFieldInputInstance = + zFieldInputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + value: zT2IAdapterModelFieldValue, + }); +export const zT2IAdapterModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + }); +export const zT2IAdapterModelFieldInputTemplate = + zFieldInputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + default: zT2IAdapterModelFieldValue, + }); +export const zT2IAdapterModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + }); +export type T2IAdapterModelFieldType = z.infer< + typeof zT2IAdapterModelFieldType +>; +export type T2IAdapterModelFieldValue = z.infer< + typeof zT2IAdapterModelFieldValue +>; +export type T2IAdapterModelFieldInputInstance = z.infer< + typeof zT2IAdapterModelFieldInputInstance +>; +export type T2IAdapterModelFieldOutputInstance = z.infer< + typeof zT2IAdapterModelFieldOutputInstance +>; +export type T2IAdapterModelFieldInputTemplate = z.infer< + typeof zT2IAdapterModelFieldInputTemplate +>; +export type T2IAdapterModelFieldOutputTemplate = z.infer< + typeof zT2IAdapterModelFieldOutputTemplate +>; +export const isT2IAdapterModelFieldInputInstance = ( + val: unknown +): val is T2IAdapterModelFieldInputInstance => + zT2IAdapterModelFieldInputInstance.safeParse(val).success; +export const isT2IAdapterModelFieldInputTemplate = ( + val: unknown +): val is T2IAdapterModelFieldInputTemplate => + zT2IAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SchedulerField +export const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), +}); +export const zSchedulerFieldValue = zSchedulerField.optional(); +export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSchedulerFieldType, + value: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSchedulerFieldType, +}); +export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSchedulerFieldType, + default: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSchedulerFieldType, +}); +export type SchedulerFieldType = z.infer; +export type SchedulerFieldValue = z.infer; +export type SchedulerFieldInputInstance = z.infer< + typeof zSchedulerFieldInputInstance +>; +export type SchedulerFieldOutputInstance = z.infer< + typeof zSchedulerFieldOutputInstance +>; +export type SchedulerFieldInputTemplate = z.infer< + typeof zSchedulerFieldInputTemplate +>; +export type SchedulerFieldOutputTemplate = z.infer< + typeof zSchedulerFieldOutputTemplate +>; +export const isSchedulerFieldInputInstance = ( + val: unknown +): val is SchedulerFieldInputInstance => + zSchedulerFieldInputInstance.safeParse(val).success; +export const isSchedulerFieldInputTemplate = ( + val: unknown +): val is SchedulerFieldInputTemplate => + zSchedulerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatelessField +/** + * StatelessField is a catchall for stateless fields with no UI input components. They do not + * do not support "direct" input, instead only accepting connections from other fields. + * + * This field type serves as a "generic" field type. + * + * Examples include: + * - Fields like UNetField or LatentsField where we do not allow direct UI input + * - Reserved fields like IsIntermediate + * - Any other field we don't have full-on schemas for + */ +export const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); +export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling +export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStatelessFieldType, + value: zStatelessFieldValue, +}); +export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStatelessFieldType, +}); +export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStatelessFieldType, + default: zStatelessFieldValue, + input: z.literal('connection'), // stateless --> only accepts connection inputs +}); +export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStatelessFieldType, +}); + +export type StatelessFieldType = z.infer; +export type StatelessFieldValue = z.infer; +export type StatelessFieldInputInstance = z.infer< + typeof zStatelessFieldInputInstance +>; +export type StatelessFieldOutputInstance = z.infer< + typeof zStatelessFieldOutputInstance +>; +export type StatelessFieldInputTemplate = z.infer< + typeof zStatelessFieldInputTemplate +>; +export type StatelessFieldOutputTemplate = z.infer< + typeof zStatelessFieldOutputTemplate +>; +// #endregion + +/** + * Here we define the main field unions: + * - FieldType + * - FieldValue + * - FieldInputInstance + * - FieldOutputInstance + * - FieldInputTemplate + * - FieldOutputTemplate + * + * All stateful fields are unioned together, and then that union is unioned with StatelessField. + * + * This allows us to interact with stateful fields without needing to worry about "generic" handling + * for all other StatelessFields. + */ + +// #region StatefulFieldType & FieldType +export const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => + zStatefulFieldType.safeParse(val).success; + +export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +export const isFieldType = (val: unknown): val is FieldType => + zFieldType.safeParse(val).success; +// #endregion + +// #region StatefulFieldValue & FieldValue +export const zStatefulFieldValue = z.union([ + zIntegerFieldValue, + zFloatFieldValue, + zStringFieldValue, + zBooleanFieldValue, + zEnumFieldValue, + zImageFieldValue, + zBoardFieldValue, + zMainModelFieldValue, + zSDXLMainModelFieldValue, + zSDXLRefinerModelFieldValue, + zVAEModelFieldValue, + zLoRAModelFieldValue, + zControlNetModelFieldValue, + zIPAdapterModelFieldValue, + zT2IAdapterModelFieldValue, + zColorFieldValue, + zSchedulerFieldValue, +]); +export type StatefulFieldValue = z.infer; +export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue => + zStatefulFieldValue.safeParse(val).success; + +export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]); +export type FieldValue = z.infer; +export const isFieldValue = (val: unknown): val is FieldValue => + zFieldValue.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputInstance & FieldInputInstance +export const zStatefulFieldInputInstance = z.union([ + zIntegerFieldInputInstance, + zFloatFieldInputInstance, + zStringFieldInputInstance, + zBooleanFieldInputInstance, + zEnumFieldInputInstance, + zImageFieldInputInstance, + zBoardFieldInputInstance, + zMainModelFieldInputInstance, + zSDXLMainModelFieldInputInstance, + zSDXLRefinerModelFieldInputInstance, + zVAEModelFieldInputInstance, + zLoRAModelFieldInputInstance, + zControlNetModelFieldInputInstance, + zIPAdapterModelFieldInputInstance, + zT2IAdapterModelFieldInputInstance, + zColorFieldInputInstance, + zSchedulerFieldInputInstance, +]); +export type StatefulFieldInputInstance = z.infer< + typeof zStatefulFieldInputInstance +>; +export const isStatefulFieldInputInstance = ( + val: unknown +): val is StatefulFieldInputInstance => + zStatefulFieldInputInstance.safeParse(val).success; + +export const zFieldInputInstance = z.union([ + zStatefulFieldInputInstance, + zStatelessFieldInputInstance, +]); +export type FieldInputInstance = z.infer; +export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => + zFieldInputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputInstance & FieldOutputInstance +export const zStatefulFieldOutputInstance = z.union([ + zIntegerFieldOutputInstance, + zFloatFieldOutputInstance, + zStringFieldOutputInstance, + zBooleanFieldOutputInstance, + zEnumFieldOutputInstance, + zImageFieldOutputInstance, + zBoardFieldOutputInstance, + zMainModelFieldOutputInstance, + zSDXLMainModelFieldOutputInstance, + zSDXLRefinerModelFieldOutputInstance, + zVAEModelFieldOutputInstance, + zLoRAModelFieldOutputInstance, + zControlNetModelFieldOutputInstance, + zIPAdapterModelFieldOutputInstance, + zT2IAdapterModelFieldOutputInstance, + zColorFieldOutputInstance, + zSchedulerFieldOutputInstance, +]); +export type StatefulFieldOutputInstance = z.infer< + typeof zStatefulFieldOutputInstance +>; +export const isStatefulFieldOutputInstance = ( + val: unknown +): val is StatefulFieldOutputInstance => + zStatefulFieldOutputInstance.safeParse(val).success; + +export const zFieldOutputInstance = z.union([ + zStatefulFieldOutputInstance, + zStatelessFieldOutputInstance, +]); +export type FieldOutputInstance = z.infer; +export const isFieldOutputInstance = ( + val: unknown +): val is FieldOutputInstance => zFieldOutputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputTemplate & FieldInputTemplate +export const zStatefulFieldInputTemplate = z.union([ + zIntegerFieldInputTemplate, + zFloatFieldInputTemplate, + zStringFieldInputTemplate, + zBooleanFieldInputTemplate, + zEnumFieldInputTemplate, + zImageFieldInputTemplate, + zBoardFieldInputTemplate, + zMainModelFieldInputTemplate, + zSDXLMainModelFieldInputTemplate, + zSDXLRefinerModelFieldInputTemplate, + zVAEModelFieldInputTemplate, + zLoRAModelFieldInputTemplate, + zControlNetModelFieldInputTemplate, + zIPAdapterModelFieldInputTemplate, + zT2IAdapterModelFieldInputTemplate, + zColorFieldInputTemplate, + zSchedulerFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type StatefulFieldInputTemplate = z.infer; +export const isStatefulFieldInputTemplate = ( + val: unknown +): val is StatefulFieldInputTemplate => + zStatefulFieldInputTemplate.safeParse(val).success; + +export const zFieldInputTemplate = z.union([ + zStatefulFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type FieldInputTemplate = z.infer; +export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate => + zFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputTemplate & FieldOutputTemplate +export const zStatefulFieldOutputTemplate = z.union([ + zIntegerFieldOutputTemplate, + zFloatFieldOutputTemplate, + zStringFieldOutputTemplate, + zBooleanFieldOutputTemplate, + zEnumFieldOutputTemplate, + zImageFieldOutputTemplate, + zBoardFieldOutputTemplate, + zMainModelFieldOutputTemplate, + zSDXLMainModelFieldOutputTemplate, + zSDXLRefinerModelFieldOutputTemplate, + zVAEModelFieldOutputTemplate, + zLoRAModelFieldOutputTemplate, + zControlNetModelFieldOutputTemplate, + zIPAdapterModelFieldOutputTemplate, + zT2IAdapterModelFieldOutputTemplate, + zColorFieldOutputTemplate, + zSchedulerFieldOutputTemplate, +]); +export type StatefulFieldOutputTemplate = z.infer< + typeof zStatefulFieldOutputTemplate +>; +export const isStatefulFieldOutputTemplate = ( + val: unknown +): val is StatefulFieldOutputTemplate => + zStatefulFieldOutputTemplate.safeParse(val).success; + +export const zFieldOutputTemplate = z.union([ + zStatefulFieldOutputTemplate, + zStatelessFieldOutputTemplate, +]); +export type FieldOutputTemplate = z.infer; +export const isFieldOutputTemplate = ( + val: unknown +): val is FieldOutputTemplate => zFieldOutputTemplate.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts new file mode 100644 index 00000000000..216db437b92 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -0,0 +1,108 @@ +import { Node } from 'reactflow'; +import { z } from 'zod'; +import { zProgressImage } from './common'; +import { + zFieldInputInstance, + zFieldInputTemplate, + zFieldOutputInstance, + zFieldOutputTemplate, +} from './field'; +import { zSemVer } from './semver'; + +// #region InvocationTemplate +export const zInvocationTemplate = z.object({ + type: z.string(), + title: z.string(), + description: z.string(), + tags: z.array(z.string().min(1)), + inputs: z.record(zFieldInputTemplate), + outputs: z.record(zFieldOutputTemplate), + outputType: z.string().min(1), + withWorkflow: z.boolean(), + version: zSemVer, + useCache: z.boolean(), +}); +export type InvocationTemplate = z.infer; +// #endregion + +// #region NodeData +export const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + type: z.string().trim().min(1), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + embedWorkflow: z.boolean(), + isIntermediate: z.boolean(), + useCache: z.boolean(), + version: zSemVer, + inputs: z.record(zFieldInputInstance), + outputs: z.record(zFieldOutputInstance), +}); + +export const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); +export const zCurrentImageNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('current_image'), + label: z.string(), + isOpen: z.boolean(), +}); +export const zAnyNodeData = z.union([ + zInvocationNodeData, + zNotesNodeData, + zCurrentImageNodeData, +]); + +export type NotesNodeData = z.infer; +export type InvocationNodeData = z.infer; +export type CurrentImageNodeData = z.infer; +export type AnyNodeData = z.infer; + +export const isInvocationNode = ( + node?: Node +): node is Node => + Boolean(node && node.type === 'invocation'); +export const isNotesNode = ( + node?: Node +): node is Node => Boolean(node && node.type === 'notes'); +export const isProgressImageNode = ( + node?: Node +): node is Node => + Boolean(node && node.type === 'current_image'); +export const isInvocationNodeData = ( + node?: AnyNodeData +): node is InvocationNodeData => + Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type +// #endregion + +// #region NodeExecutionState +export const zNodeStatus = z.enum([ + 'PENDING', + 'IN_PROGRESS', + 'COMPLETED', + 'FAILED', +]); +export const zNodeExecutionState = z.object({ + nodeId: z.string().trim().min(1), + status: zNodeStatus, + progress: z.number().nullable(), + progressImage: zProgressImage.nullable(), + error: z.string().nullable(), + outputs: z.array(z.any()), +}); +export type NodeExecutionState = z.infer; +export type NodeStatus = z.infer; +// #endregion + +// #region Edges +export const zInvocationEdgeExtra = z.object({ + type: z.union([z.literal('default'), z.literal('collapsed')]), +}); +export type InvocationEdgeExtra = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts new file mode 100644 index 00000000000..a22b8aed0ea --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts @@ -0,0 +1,81 @@ +import { z } from 'zod'; +import { + zControlField, + zIPAdapterField, + zLoRAModelField, + zMainModelField, + zONNXModelField, + zSDXLRefinerModelField, + zT2IAdapterField, + zVAEModelField, +} from './common'; + +// #region Metadata-optimized versions of schemas +// TODO: It's possible that `deepPartial` will be deprecated: +// - https://github.com/colinhacks/zod/issues/2106 +// - https://github.com/colinhacks/zod/issues/2854 +export const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); +const zControlNetMetadataItem = zControlField.deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); +const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); +const zModelMetadataitem = z.union([ + zMainModelField.deepPartial(), + zONNXModelField.deepPartial(), +]); +const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +export type LoRAMetadataItem = z.infer; +export type ControlNetMetadataItem = z.infer; +export type IPAdapterMetadataItem = z.infer; +export type T2IAdapterMetadataItem = z.infer; +export type SDXLRefinerModelMetadataItem = z.infer< + typeof zSDXLRefinerModelMetadataItem +>; +export type ModelMetadataitem = z.infer; +export type VAEModelMetadataItem = z.infer; +// #endregion + +// #region CoreMetadata +export const zCoreMetadata = z + .object({ + app_version: z.string().nullish().catch(null), + generation_mode: z.string().nullish().catch(null), + created_by: z.string().nullish().catch(null), + positive_prompt: z.string().nullish().catch(null), + negative_prompt: z.string().nullish().catch(null), + width: z.number().int().nullish().catch(null), + height: z.number().int().nullish().catch(null), + seed: z.number().int().nullish().catch(null), + rand_device: z.string().nullish().catch(null), + cfg_scale: z.number().nullish().catch(null), + steps: z.number().int().nullish().catch(null), + scheduler: z.string().nullish().catch(null), + clip_skip: z.number().int().nullish().catch(null), + model: zModelMetadataitem.nullish().catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), + vae: zVAEModelMetadataItem.nullish().catch(null), + strength: z.number().nullish().catch(null), + hrf_enabled: z.boolean().nullish().catch(null), + hrf_strength: z.number().nullish().catch(null), + hrf_method: z.string().nullish().catch(null), + init_image: z.string().nullish().catch(null), + positive_style_prompt: z.string().nullish().catch(null), + negative_style_prompt: z.string().nullish().catch(null), + refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null), + refiner_cfg_scale: z.number().nullish().catch(null), + refiner_steps: z.number().int().nullish().catch(null), + refiner_scheduler: z.string().nullish().catch(null), + refiner_positive_aesthetic_score: z.number().nullish().catch(null), + refiner_negative_aesthetic_score: z.number().nullish().catch(null), + refiner_start: z.number().nullish().catch(null), + }) + .passthrough(); +export type CoreMetadata = z.infer; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts b/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts new file mode 100644 index 00000000000..45c38524930 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts @@ -0,0 +1,69 @@ +import { forEach, isString } from 'lodash-es'; +import { z } from 'zod'; +import { WorkflowVersionError } from '../error'; +import { zSemVer } from '../semver'; +import { WorkflowV2, zWorkflowV2 } from '../workflow'; +import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from './v1/fieldTypeMap'; +import { WorkflowV1, zWorkflowV1 } from './v1/workflowV1'; +import { t } from 'i18next'; + +/** + * Helper schema to extract the version from a workflow. + * + * All properties except for the version are ignored in this schema. + */ +const zWorkflowMetaVersion = z.object({ + meta: z.object({ version: zSemVer }), +}); + +/** + * Migrates a workflow from V1 to V2. + */ +const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { + workflowToMigrate.nodes.forEach((node) => { + if (node.type === 'invocation') { + forEach(node.data.inputs, (input) => { + if (!isString(input.type)) { + return; + } + (input.type as unknown) = + FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type]; + }); + forEach(node.data.outputs, (output) => { + if (!isString(output.type)) { + return; + } + (output.type as unknown) = + FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type]; + }); + } + }); + (workflowToMigrate.meta.version as WorkflowV2['meta']['version']) = '2.0.0'; + return zWorkflowV2.parse(workflowToMigrate); +}; + +/** + * Parses a workflow and migrates it to the latest version if necessary. + */ +export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => { + const workflowVersionResult = zWorkflowMetaVersion.safeParse(data); + + if (!workflowVersionResult.success) { + throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion')); + } + + const { version } = workflowVersionResult.data.meta; + + if (version === '1.0.0') { + const v1 = zWorkflowV1.parse(data); + return migrateV1toV2(v1); + } + + if (version === '2.0.0') { + return zWorkflowV2.parse(data); + } + + throw new WorkflowVersionError( + t('nodes.unrecognizedWorkflowVersion', { version }) + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts new file mode 100644 index 00000000000..facf015b026 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts @@ -0,0 +1,270 @@ +import { FieldType, StatefulFieldType } from '../../field'; +import { FieldTypeV1 } from './workflowV1'; + +/** + * Mapping of V1 field type strings to their *stateful* V2 field type counterparts. + */ +const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { + [key in FieldTypeV1]?: StatefulFieldType; +} = { + BoardField: { name: 'BoardField', isCollection: false, isPolymorphic: false }, + boolean: { name: 'BooleanField', isCollection: false, isPolymorphic: false }, + BooleanCollection: { + name: 'BooleanField', + isCollection: true, + isPolymorphic: false, + }, + BooleanPolymorphic: { + name: 'BooleanField', + isCollection: false, + isPolymorphic: true, + }, + ColorField: { name: 'ColorField', isCollection: false, isPolymorphic: false }, + ColorCollection: { + name: 'ColorField', + isCollection: true, + isPolymorphic: false, + }, + ColorPolymorphic: { + name: 'ColorField', + isCollection: false, + isPolymorphic: true, + }, + ControlNetModelField: { + name: 'ControlNetModelField', + isCollection: false, + isPolymorphic: false, + }, + enum: { name: 'EnumField', isCollection: false, isPolymorphic: false }, + float: { name: 'FloatField', isCollection: false, isPolymorphic: false }, + FloatCollection: { + name: 'FloatField', + isCollection: true, + isPolymorphic: false, + }, + FloatPolymorphic: { + name: 'FloatField', + isCollection: false, + isPolymorphic: true, + }, + ImageCollection: { + name: 'ImageField', + isCollection: true, + isPolymorphic: false, + }, + ImageField: { name: 'ImageField', isCollection: false, isPolymorphic: false }, + ImagePolymorphic: { + name: 'ImageField', + isCollection: false, + isPolymorphic: true, + }, + integer: { name: 'IntegerField', isCollection: false, isPolymorphic: false }, + IntegerCollection: { + name: 'IntegerField', + isCollection: true, + isPolymorphic: false, + }, + IntegerPolymorphic: { + name: 'IntegerField', + isCollection: false, + isPolymorphic: true, + }, + IPAdapterModelField: { + name: 'IPAdapterModelField', + isCollection: false, + isPolymorphic: false, + }, + LoRAModelField: { + name: 'LoRAModelField', + isCollection: false, + isPolymorphic: false, + }, + MainModelField: { + name: 'MainModelField', + isCollection: false, + isPolymorphic: false, + }, + Scheduler: { + name: 'SchedulerField', + isCollection: false, + isPolymorphic: false, + }, + SDXLMainModelField: { + name: 'SDXLMainModelField', + isCollection: false, + isPolymorphic: false, + }, + SDXLRefinerModelField: { + name: 'SDXLRefinerModelField', + isCollection: false, + isPolymorphic: false, + }, + string: { name: 'StringField', isCollection: false, isPolymorphic: false }, + StringCollection: { + name: 'StringField', + isCollection: true, + isPolymorphic: false, + }, + StringPolymorphic: { + name: 'StringField', + isCollection: false, + isPolymorphic: true, + }, + T2IAdapterModelField: { + name: 'T2IAdapterModelField', + isCollection: false, + isPolymorphic: false, + }, + VaeModelField: { + name: 'VAEModelField', + isCollection: false, + isPolymorphic: false, + }, +}; + +/** + * Mapping of V1 field type strings to their *stateless* V2 field type counterparts. + * + * The type doesn't do what I want it to do. + * + * Ideally, the value of each propery would be a `FieldType` where `FieldType['name']` is not in + * `StatefulFieldType['name']`, but this is hard to represent. That's because `FieldType['name']` is + * actually widened to `string`, and TS's `Exclude` doesn't work on `string`. + * + * There's probably some way to do it with conditionals and intersections but I can't figure it out. + * + * Thus, this object was manually edited to ensure it is correct. + */ +const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: { + [key in FieldTypeV1]?: FieldType; +} = { + Any: { name: 'AnyField', isCollection: false, isPolymorphic: false }, + ClipField: { name: 'ClipField', isCollection: false, isPolymorphic: false }, + Collection: { + name: 'CollectionField', + isCollection: true, + isPolymorphic: false, + }, + CollectionItem: { + name: 'CollectionItemField', + isCollection: false, + isPolymorphic: false, + }, + ConditioningCollection: { + name: 'ConditioningField', + isCollection: true, + isPolymorphic: false, + }, + ConditioningField: { + name: 'ConditioningField', + isCollection: false, + isPolymorphic: false, + }, + ConditioningPolymorphic: { + name: 'ConditioningField', + isCollection: false, + isPolymorphic: true, + }, + ControlCollection: { + name: 'ControlField', + isCollection: true, + isPolymorphic: false, + }, + ControlField: { + name: 'ControlField', + isCollection: false, + isPolymorphic: false, + }, + ControlPolymorphic: { + name: 'ControlField', + isCollection: false, + isPolymorphic: true, + }, + DenoiseMaskField: { + name: 'DenoiseMaskField', + isCollection: false, + isPolymorphic: false, + }, + IPAdapterField: { + name: 'IPAdapterField', + isCollection: false, + isPolymorphic: false, + }, + IPAdapterCollection: { + name: 'IPAdapterField', + isCollection: true, + isPolymorphic: false, + }, + IPAdapterPolymorphic: { + name: 'IPAdapterField', + isCollection: false, + isPolymorphic: true, + }, + LatentsField: { + name: 'LatentsField', + isCollection: false, + isPolymorphic: false, + }, + LatentsCollection: { + name: 'LatentsField', + isCollection: true, + isPolymorphic: false, + }, + LatentsPolymorphic: { + name: 'LatentsField', + isCollection: false, + isPolymorphic: true, + }, + MetadataField: { + name: 'MetadataField', + isCollection: false, + isPolymorphic: false, + }, + MetadataCollection: { + name: 'MetadataField', + isCollection: true, + isPolymorphic: false, + }, + MetadataItemField: { + name: 'MetadataItemField', + isCollection: false, + isPolymorphic: false, + }, + MetadataItemCollection: { + name: 'MetadataItemField', + isCollection: true, + isPolymorphic: false, + }, + MetadataItemPolymorphic: { + name: 'MetadataItemField', + isCollection: false, + isPolymorphic: true, + }, + ONNXModelField: { + name: 'ONNXModelField', + isCollection: false, + isPolymorphic: false, + }, + T2IAdapterField: { + name: 'T2IAdapterField', + isCollection: false, + isPolymorphic: false, + }, + T2IAdapterCollection: { + name: 'T2IAdapterField', + isCollection: true, + isPolymorphic: false, + }, + T2IAdapterPolymorphic: { + name: 'T2IAdapterField', + isCollection: false, + isPolymorphic: true, + }, + UNetField: { name: 'UNetField', isCollection: false, isPolymorphic: false }, + VaeField: { name: 'VaeField', isCollection: false, isPolymorphic: false }, +}; + +export const FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING = { + ...FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2, + ...FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2, +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts new file mode 100644 index 00000000000..98e4158f9ad --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts @@ -0,0 +1,711 @@ +import { z } from 'zod'; + +// WorkflowV1 Schema + +const zScheduler = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zMainModel = z.object({ + model_name: z.string().min(1), + base_model: zBaseModel, + model_type: z.literal('main'), +}); +const zOnnxModel = z.object({ + model_name: z.string().min(1), + base_model: zBaseModel, + model_type: z.literal('onnx'), +}); + +const zMainOrOnnxModel = z.union([zMainModel, zOnnxModel]); + +// TODO: Get this from the OpenAPI schema? may be tricky... +const zFieldTypeV1 = z.enum([ + 'Any', + 'BoardField', + 'boolean', + 'BooleanCollection', + 'BooleanPolymorphic', + 'ClipField', + 'Collection', + 'CollectionItem', + 'ColorCollection', + 'ColorField', + 'ColorPolymorphic', + 'ConditioningCollection', + 'ConditioningField', + 'ConditioningPolymorphic', + 'ControlCollection', + 'ControlField', + 'ControlNetModelField', + 'ControlPolymorphic', + 'DenoiseMaskField', + 'enum', + 'float', + 'FloatCollection', + 'FloatPolymorphic', + 'ImageCollection', + 'ImageField', + 'ImagePolymorphic', + 'integer', + 'IntegerCollection', + 'IntegerPolymorphic', + 'IPAdapterCollection', + 'IPAdapterField', + 'IPAdapterModelField', + 'IPAdapterPolymorphic', + 'LatentsCollection', + 'LatentsField', + 'LatentsPolymorphic', + 'LoRAModelField', + 'MainModelField', + 'MetadataField', + 'MetadataCollection', + 'MetadataItemField', + 'MetadataItemCollection', + 'MetadataItemPolymorphic', + 'ONNXModelField', + 'Scheduler', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'string', + 'StringCollection', + 'StringPolymorphic', + 'T2IAdapterCollection', + 'T2IAdapterField', + 'T2IAdapterModelField', + 'T2IAdapterPolymorphic', + 'UNetField', + 'VaeField', + 'VaeModelField', +]); +export type FieldTypeV1 = z.infer; + +const zFieldValueBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), + type: zFieldTypeV1, +}); + +/** + * An output field is persisted across as part of the user's local state. + * + * An output field has two properties: + * - `id` a unique identifier + * - `name` the name of the field, which comes from the python dataclass + */ + +const zOutputFieldValue = zFieldValueBase.extend({ + fieldKind: z.literal('output'), +}); + +const zInputFieldValueBase = zFieldValueBase.extend({ + fieldKind: z.literal('input'), + label: z.string(), +}); + +const zModelIdentifier = z.object({ + model_name: z.string().trim().min(1), + base_model: zBaseModel, +}); + +const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); + +const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); + +const zLatentsField = z.object({ + latents_name: z.string().trim().min(1), + seed: z.number().int().optional(), +}); + +const zConditioningField = z.object({ + conditioning_name: z.string().trim().min(1), +}); + +const zDenoiseMaskField = z.object({ + mask_name: z.string().trim().min(1), + masked_latents_name: z.string().trim().min(1).optional(), +}); + +const zIntegerInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('integer'), + value: z.number().int().optional(), +}); + +const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IntegerCollection'), + value: z.array(z.number().int()).optional(), +}); + +const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IntegerPolymorphic'), + value: z.number().int().optional(), +}); + +const zFloatInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('float'), + value: z.number().optional(), +}); + +const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('FloatCollection'), + value: z.array(z.number()).optional(), +}); + +const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('FloatPolymorphic'), + value: z.number().optional(), +}); + +const zStringInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('string'), + value: z.string().optional(), +}); + +const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('StringCollection'), + value: z.array(z.string()).optional(), +}); + +const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('StringPolymorphic'), + value: z.string().optional(), +}); + +const zBooleanInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('boolean'), + value: z.boolean().optional(), +}); + +const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BooleanCollection'), + value: z.array(z.boolean()).optional(), +}); + +const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BooleanPolymorphic'), + value: z.boolean().optional(), +}); + +const zEnumInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('enum'), + value: z.string().optional(), +}); + +const zLatentsInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsField'), + value: zLatentsField.optional(), +}); + +const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsCollection'), + value: z.array(zLatentsField).optional(), +}); + +const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsPolymorphic'), + value: z.union([zLatentsField, z.array(zLatentsField)]).optional(), +}); + +const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('DenoiseMaskField'), + value: zDenoiseMaskField.optional(), +}); + +const zConditioningInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ConditioningField'), + value: zConditioningField.optional(), +}); + +const zConditioningCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ConditioningCollection'), + value: z.array(zConditioningField).optional(), +}); + +const zConditioningPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ConditioningPolymorphic'), + value: z.union([zConditioningField, z.array(zConditioningField)]).optional(), +}); + +const zControlNetModel = zModelIdentifier; + +const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModel, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z + .enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']) + .optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); + +const zControlInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlField'), + value: zControlField.optional(), +}); + +const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlPolymorphic'), + value: z.union([zControlField, z.array(zControlField)]).optional(), +}); + +const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlCollection'), + value: z.array(zControlField).optional(), +}); + +const zIPAdapterModel = zModelIdentifier; + +const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModel, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); + +const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterField'), + value: zIPAdapterField.optional(), +}); + +const zIPAdapterPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterPolymorphic'), + value: z.union([zIPAdapterField, z.array(zIPAdapterField)]).optional(), +}); + +const zIPAdapterCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterCollection'), + value: z.array(zIPAdapterField).optional(), +}); + +const zT2IAdapterModel = zModelIdentifier; + +const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModel, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); + +const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterField'), + value: zT2IAdapterField.optional(), +}); + +const zT2IAdapterPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterPolymorphic'), + value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(), +}); + +const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterCollection'), + value: z.array(zT2IAdapterField).optional(), +}); + +const zModelType = z.enum([ + 'onnx', + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', +]); + +const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); + +const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); + +const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); + +const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); + +const zUNetInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('UNetField'), + value: zUNetField.optional(), +}); + +const zClipField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); + +const zClipInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ClipField'), + value: zClipField.optional(), +}); + +const zVaeField = z.object({ + vae: zModelInfo, +}); + +const zVaeInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('VaeField'), + value: zVaeField.optional(), +}); + +const zImageInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ImageField'), + value: zImageField.optional(), +}); + +const zBoardInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BoardField'), + value: zBoardField.optional(), +}); + +const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ImagePolymorphic'), + value: zImageField.optional(), +}); + +const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ImageCollection'), + value: z.array(zImageField).optional(), +}); + +const zMainModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MainModelField'), + value: zMainOrOnnxModel.optional(), +}); + +const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('SDXLMainModelField'), + value: zMainOrOnnxModel.optional(), +}); + +const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('SDXLRefinerModelField'), + value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model +}); + +const zVaeModelField = zModelIdentifier; + +const zVaeModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('VaeModelField'), + value: zVaeModelField.optional(), +}); + +const zLoRAModelField = zModelIdentifier; + +const zLoRAModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LoRAModelField'), + value: zLoRAModelField.optional(), +}); + +const zControlNetModelField = zModelIdentifier; + +const zControlNetModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlNetModelField'), + value: zControlNetModelField.optional(), +}); + +const zIPAdapterModelField = zModelIdentifier; + +const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterModelField'), + value: zIPAdapterModelField.optional(), +}); + +const zT2IAdapterModelField = zModelIdentifier; + +const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterModelField'), + value: zT2IAdapterModelField.optional(), +}); + +const zCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Collection'), + value: z.array(z.any()).optional(), // TODO: should this field ever have a value? +}); + +const zCollectionItemInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('CollectionItem'), + value: z.any().optional(), // TODO: should this field ever have a value? +}); + +const zMetadataItemField = z.object({ + label: z.string(), + value: z.any(), +}); + +const zMetadataItemInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataItemField'), + value: zMetadataItemField.optional(), +}); + +const zMetadataItemCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataItemCollection'), + value: z.array(zMetadataItemField).optional(), +}); + +const zMetadataItemPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataItemPolymorphic'), + value: z.union([zMetadataItemField, z.array(zMetadataItemField)]).optional(), +}); + +const zMetadataField = z.record(z.any()); + +const zMetadataInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataField'), + value: zMetadataField.optional(), +}); + +const zMetadataCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataCollection'), + value: z.array(zMetadataField).optional(), +}); + +const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); + +const zColorInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorField'), + value: zColorField.optional(), +}); + +const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorCollection'), + value: z.array(zColorField).optional(), +}); + +const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorPolymorphic'), + value: z.union([zColorField, z.array(zColorField)]).optional(), +}); + +const zSchedulerInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Scheduler'), + value: zScheduler.optional(), +}); + +const zAnyInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Any'), + value: z.any().optional(), +}); + +const zInputFieldValue = z.discriminatedUnion('type', [ + zAnyInputFieldValue, + zBoardInputFieldValue, + zBooleanCollectionInputFieldValue, + zBooleanInputFieldValue, + zBooleanPolymorphicInputFieldValue, + zClipInputFieldValue, + zCollectionInputFieldValue, + zCollectionItemInputFieldValue, + zColorInputFieldValue, + zColorCollectionInputFieldValue, + zColorPolymorphicInputFieldValue, + zConditioningInputFieldValue, + zConditioningCollectionInputFieldValue, + zConditioningPolymorphicInputFieldValue, + zControlInputFieldValue, + zControlNetModelInputFieldValue, + zControlCollectionInputFieldValue, + zControlPolymorphicInputFieldValue, + zDenoiseMaskInputFieldValue, + zEnumInputFieldValue, + zFloatCollectionInputFieldValue, + zFloatInputFieldValue, + zFloatPolymorphicInputFieldValue, + zImageCollectionInputFieldValue, + zImagePolymorphicInputFieldValue, + zImageInputFieldValue, + zIntegerCollectionInputFieldValue, + zIntegerPolymorphicInputFieldValue, + zIntegerInputFieldValue, + zIPAdapterInputFieldValue, + zIPAdapterModelInputFieldValue, + zIPAdapterCollectionInputFieldValue, + zIPAdapterPolymorphicInputFieldValue, + zLatentsInputFieldValue, + zLatentsCollectionInputFieldValue, + zLatentsPolymorphicInputFieldValue, + zLoRAModelInputFieldValue, + zMainModelInputFieldValue, + zSchedulerInputFieldValue, + zSDXLMainModelInputFieldValue, + zSDXLRefinerModelInputFieldValue, + zStringCollectionInputFieldValue, + zStringPolymorphicInputFieldValue, + zStringInputFieldValue, + zT2IAdapterInputFieldValue, + zT2IAdapterModelInputFieldValue, + zT2IAdapterCollectionInputFieldValue, + zT2IAdapterPolymorphicInputFieldValue, + zUNetInputFieldValue, + zVaeInputFieldValue, + zVaeModelInputFieldValue, + zMetadataItemInputFieldValue, + zMetadataItemCollectionInputFieldValue, + zMetadataItemPolymorphicInputFieldValue, + zMetadataInputFieldValue, + zMetadataCollectionInputFieldValue, +]); + +const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + major !== undefined && + Number.isInteger(Number(major)) && + minor !== undefined && + Number.isInteger(Number(minor)) && + patch !== undefined && + Number.isInteger(Number(patch)) + ); +}); + +const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + // no easy way to build this dynamically, and we don't want to anyways, because this will be used + // to validate incoming workflows, and we want to allow community nodes. + type: z.string().trim().min(1), + inputs: z.record(zInputFieldValue), + outputs: z.record(zOutputFieldValue), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + embedWorkflow: z.boolean(), + isIntermediate: z.boolean(), + useCache: z.boolean().default(true), + version: zSemVer.optional(), +}); + +const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); + +const zPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); + +const zDimension = z.number().gt(0).nullish(); + +const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zPosition, +}); + +const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zPosition, +}); + +const zWorkflowNode = z.discriminatedUnion('type', [ + zWorkflowInvocationNode, + zWorkflowNotesNode, +]); + +const zDefaultWorkflowEdge = z.object({ + source: z.string().trim().min(1), + sourceHandle: z.string().trim().min(1), + target: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), + id: z.string().trim().min(1), + type: z.literal('default'), +}); +const zCollapsedWorkflowEdge = z.object({ + source: z.string().trim().min(1), + target: z.string().trim().min(1), + id: z.string().trim().min(1), + type: z.literal('collapsed'), +}); + +const zWorkflowEdge = z.union([zDefaultWorkflowEdge, zCollapsedWorkflowEdge]); + +const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); + +export const zWorkflowV1 = z.object({ + name: z.string().default(''), + author: z.string().default(''), + description: z.string().default(''), + version: z.string().default(''), + contact: z.string().default(''), + tags: z.string().default(''), + notes: z.string().default(''), + nodes: z.array(zWorkflowNode).default([]), + edges: z.array(zWorkflowEdge).default([]), + exposedFields: z.array(zFieldIdentifier).default([]), + meta: z.object({ + version: z.literal('1.0.0'), + }), +}); +export type WorkflowV1 = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/types/openapi.ts b/invokeai/frontend/web/src/features/nodes/types/openapi.ts new file mode 100644 index 00000000000..0d8ffeb9204 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/openapi.ts @@ -0,0 +1,108 @@ +import { OpenAPIV3_1 } from 'openapi-types'; +import { + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, +} from 'services/api/types'; + +// Janky customization of OpenAPI Schema :/ + +export type InvocationSchemaExtra = { + output: OpenAPIV3_1.ReferenceObject; // the output of the invocation + title: string; + category?: string; + tags?: string[]; + version: string; + properties: Omit< + NonNullable & + (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra), + 'type' + > & { + type: Omit & { + default: string; + }; + use_cache: Omit & { + default: boolean; + }; + }; +}; + +export type InvocationSchemaType = { + default: string; // the type of the invocation +}; + +export type InvocationBaseSchemaObject = Omit< + OpenAPIV3_1.BaseSchemaObject, + 'title' | 'type' | 'properties' +> & + InvocationSchemaExtra; + +export type InvocationOutputSchemaObject = Omit< + OpenAPIV3_1.SchemaObject, + 'properties' +> & { + properties: OpenAPIV3_1.SchemaObject['properties'] & { + type: Omit & { + default: string; + }; + } & { + class: 'output'; + }; +}; + +export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & + InputFieldJSONSchemaExtra; + +export type OpenAPIV3_1SchemaOrRef = + | OpenAPIV3_1.ReferenceObject + | OpenAPIV3_1.SchemaObject; + +export interface ArraySchemaObject extends InvocationBaseSchemaObject { + type: OpenAPIV3_1.ArraySchemaObjectType; + items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; +} +export interface NonArraySchemaObject extends InvocationBaseSchemaObject { + type?: OpenAPIV3_1.NonArraySchemaObjectType; +} + +export type InvocationSchemaObject = ( + | ArraySchemaObject + | NonArraySchemaObject +) & { class: 'invocation' }; + +export const isSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ArraySchemaObject => + Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.NonArraySchemaObject => + Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); + +export const isInvocationSchemaObject = ( + obj: + | OpenAPIV3_1.ReferenceObject + | OpenAPIV3_1.SchemaObject + | InvocationSchemaObject +): obj is InvocationSchemaObject => + 'class' in obj && obj.class === 'invocation'; + +export const isInvocationOutputSchemaObject = ( + obj: + | OpenAPIV3_1.ReferenceObject + | OpenAPIV3_1.SchemaObject + | InvocationOutputSchemaObject +): obj is InvocationOutputSchemaObject => + 'class' in obj && obj.class === 'output'; + +export const isInvocationFieldSchema = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject +): obj is InvocationFieldSchema => !('$ref' in obj); diff --git a/invokeai/frontend/web/src/features/nodes/types/semver.ts b/invokeai/frontend/web/src/features/nodes/types/semver.ts new file mode 100644 index 00000000000..70dc2288193 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/semver.ts @@ -0,0 +1,23 @@ +import { z } from 'zod'; + +// Schemas and types for working with semver + +const zVersionInt = z.coerce.number().int().min(0); + +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + zVersionInt.safeParse(major).success && + zVersionInt.safeParse(minor).success && + zVersionInt.safeParse(patch).success + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts deleted file mode 100644 index c55d114dcf6..00000000000 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ /dev/null @@ -1,1742 +0,0 @@ -import { $store } from 'app/store/nanostores/store'; -import { - SchedulerParam, - zBaseModel, - zMainModel, - zMainOrOnnxModel, - zOnnxModel, - zSDXLRefinerModel, - zScheduler, -} from 'features/parameters/types/parameterSchemas'; -import i18n from 'i18next'; -import { has, keyBy } from 'lodash-es'; -import { OpenAPIV3_1 } from 'openapi-types'; -import { RgbaColor } from 'react-colorful'; -import { Node } from 'reactflow'; -import { Graph, _InputField, _OutputField } from 'services/api/types'; -import { - AnyInvocationType, - AnyResult, - ProgressImage, -} from 'services/events/types'; -import { O } from 'ts-toolbelt'; -import { JsonObject } from 'type-fest'; -import { z } from 'zod'; - -export type NonNullableGraph = O.Required; - -export type InvocationTemplate = { - /** - * Unique type of the invocation - */ - type: AnyInvocationType; - /** - * Display name of the invocation - */ - title: string; - /** - * Description of the invocation - */ - description: string; - /** - * Invocation tags - */ - tags: string[]; - /** - * Array of invocation inputs - */ - inputs: Record; - /** - * Array of the invocation outputs - */ - outputs: Record; - /** - * The type of this node's output - */ - outputType: string; // TODO: generate a union of output types - /** - * Whether or not this invocation supports workflows - */ - withWorkflow: boolean; - /** - * The invocation's version. - */ - version?: string; - /** - * Whether or not this node should use the cache - */ - useCache: boolean; -}; - -export type FieldUIConfig = { - title: string; - description: string; - color: string; -}; - -// TODO: Get this from the OpenAPI schema? may be tricky... -export const zFieldType = z.enum([ - 'Any', - 'BoardField', - 'boolean', - 'BooleanCollection', - 'BooleanPolymorphic', - 'ClipField', - 'Collection', - 'CollectionItem', - 'ColorCollection', - 'ColorField', - 'ColorPolymorphic', - 'ConditioningCollection', - 'ConditioningField', - 'ConditioningPolymorphic', - 'ControlCollection', - 'ControlField', - 'ControlNetModelField', - 'ControlPolymorphic', - 'DenoiseMaskField', - 'enum', - 'float', - 'FloatCollection', - 'FloatPolymorphic', - 'ImageCollection', - 'ImageField', - 'ImagePolymorphic', - 'integer', - 'IntegerCollection', - 'IntegerPolymorphic', - 'IPAdapterCollection', - 'IPAdapterField', - 'IPAdapterModelField', - 'IPAdapterPolymorphic', - 'LatentsCollection', - 'LatentsField', - 'LatentsPolymorphic', - 'LoRAModelField', - 'MainModelField', - 'MetadataField', - 'MetadataCollection', - 'MetadataItemField', - 'MetadataItemCollection', - 'MetadataItemPolymorphic', - 'ONNXModelField', - 'Scheduler', - 'SDXLMainModelField', - 'SDXLRefinerModelField', - 'string', - 'StringCollection', - 'StringPolymorphic', - 'T2IAdapterCollection', - 'T2IAdapterField', - 'T2IAdapterModelField', - 'T2IAdapterPolymorphic', - 'UNetField', - 'VaeField', - 'VaeModelField', -]); - -export type FieldType = z.infer; -export type FieldTypeMap = { [key in FieldType]?: FieldType }; -export type FieldTypeMapWithNumber = { - [key in FieldType | 'number']?: FieldType; -}; - -export const zReservedFieldType = z.enum([ - 'WorkflowField', - 'IsIntermediate', - 'MetadataField', -]); - -export type ReservedFieldType = z.infer; - -export const isFieldType = (value: unknown): value is FieldType => - zFieldType.safeParse(value).success || - zReservedFieldType.safeParse(value).success; - -/** - * Indicates the kind of input(s) this field may have. - */ -export const zInputKind = z.enum(['connection', 'direct', 'any']); -export type InputKind = z.infer; - -export const zFieldValueBase = z.object({ - id: z.string().trim().min(1), - name: z.string().trim().min(1), - type: zFieldType, -}); -export type FieldValueBase = z.infer; - -/** - * An output field is persisted across as part of the user's local state. - * - * An output field has two properties: - * - `id` a unique identifier - * - `name` the name of the field, which comes from the python dataclass - */ - -export const zOutputFieldValue = zFieldValueBase.extend({ - fieldKind: z.literal('output'), -}); -export type OutputFieldValue = z.infer; - -/** - * An output field template is generated on each page load from the OpenAPI schema. - * - * The template provides the output field's name, type, title, and description. - */ -export type OutputFieldTemplate = { - fieldKind: 'output'; - name: string; - type: FieldType; - title: string; - description: string; -} & _OutputField; - -export const zInputFieldValueBase = zFieldValueBase.extend({ - fieldKind: z.literal('input'), - label: z.string(), -}); -export type InputFieldValueBase = z.infer; - -export const zModelIdentifier = z.object({ - model_name: z.string().trim().min(1), - base_model: zBaseModel, -}); - -export const zImageField = z.object({ - image_name: z.string().trim().min(1), -}); -export type ImageField = z.infer; - -export const zBoardField = z.object({ - board_id: z.string().trim().min(1), -}); -export type BoardField = z.infer; - -export const zLatentsField = z.object({ - latents_name: z.string().trim().min(1), - seed: z.number().int().optional(), -}); -export type LatentsField = z.infer; - -export const zConditioningField = z.object({ - conditioning_name: z.string().trim().min(1), -}); -export type ConditioningField = z.infer; - -export const zDenoiseMaskField = z.object({ - mask_name: z.string().trim().min(1), - masked_latents_name: z.string().trim().min(1).optional(), -}); -export type DenoiseMaskFieldValue = z.infer; - -export const zIntegerInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('integer'), - value: z.number().int().optional(), -}); -export type IntegerInputFieldValue = z.infer; - -export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IntegerCollection'), - value: z.array(z.number().int()).optional(), -}); -export type IntegerCollectionInputFieldValue = z.infer< - typeof zIntegerCollectionInputFieldValue ->; - -export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IntegerPolymorphic'), - value: z.number().int().optional(), -}); -export type IntegerPolymorphicInputFieldValue = z.infer< - typeof zIntegerPolymorphicInputFieldValue ->; - -export const zFloatInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('float'), - value: z.number().optional(), -}); -export type FloatInputFieldValue = z.infer; - -export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('FloatCollection'), - value: z.array(z.number()).optional(), -}); -export type FloatCollectionInputFieldValue = z.infer< - typeof zFloatCollectionInputFieldValue ->; - -export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('FloatPolymorphic'), - value: z.number().optional(), -}); -export type FloatPolymorphicInputFieldValue = z.infer< - typeof zFloatPolymorphicInputFieldValue ->; - -export const zStringInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('string'), - value: z.string().optional(), -}); -export type StringInputFieldValue = z.infer; - -export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('StringCollection'), - value: z.array(z.string()).optional(), -}); -export type StringCollectionInputFieldValue = z.infer< - typeof zStringCollectionInputFieldValue ->; - -export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('StringPolymorphic'), - value: z.string().optional(), -}); -export type StringPolymorphicInputFieldValue = z.infer< - typeof zStringPolymorphicInputFieldValue ->; - -export const zBooleanInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('boolean'), - value: z.boolean().optional(), -}); -export type BooleanInputFieldValue = z.infer; - -export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('BooleanCollection'), - value: z.array(z.boolean()).optional(), -}); -export type BooleanCollectionInputFieldValue = z.infer< - typeof zBooleanCollectionInputFieldValue ->; - -export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('BooleanPolymorphic'), - value: z.boolean().optional(), -}); -export type BooleanPolymorphicInputFieldValue = z.infer< - typeof zBooleanPolymorphicInputFieldValue ->; - -export const zEnumInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('enum'), - value: z.string().optional(), -}); -export type EnumInputFieldValue = z.infer; - -export const zLatentsInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LatentsField'), - value: zLatentsField.optional(), -}); -export type LatentsInputFieldValue = z.infer; - -export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LatentsCollection'), - value: z.array(zLatentsField).optional(), -}); -export type LatentsCollectionInputFieldValue = z.infer< - typeof zLatentsCollectionInputFieldValue ->; - -export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LatentsPolymorphic'), - value: z.union([zLatentsField, z.array(zLatentsField)]).optional(), -}); -export type LatentsPolymorphicInputFieldValue = z.infer< - typeof zLatentsPolymorphicInputFieldValue ->; - -export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('DenoiseMaskField'), - value: zDenoiseMaskField.optional(), -}); -export type DenoiseMaskInputFieldValue = z.infer< - typeof zDenoiseMaskInputFieldValue ->; - -export const zConditioningInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ConditioningField'), - value: zConditioningField.optional(), -}); -export type ConditioningInputFieldValue = z.infer< - typeof zConditioningInputFieldValue ->; - -export const zConditioningCollectionInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('ConditioningCollection'), - value: z.array(zConditioningField).optional(), - }); -export type ConditioningCollectionInputFieldValue = z.infer< - typeof zConditioningCollectionInputFieldValue ->; - -export const zConditioningPolymorphicInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('ConditioningPolymorphic'), - value: z - .union([zConditioningField, z.array(zConditioningField)]) - .optional(), - }); -export type ConditioningPolymorphicInputFieldValue = z.infer< - typeof zConditioningPolymorphicInputFieldValue ->; - -export const zControlNetModel = zModelIdentifier; -export type ControlNetModel = z.infer; - -export const zControlField = z.object({ - image: zImageField, - control_model: zControlNetModel, - control_weight: z.union([z.number(), z.array(z.number())]).optional(), - begin_step_percent: z.number().optional(), - end_step_percent: z.number().optional(), - control_mode: z - .enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']) - .optional(), - resize_mode: z - .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) - .optional(), -}); -export type ControlField = z.infer; - -export const zControlInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlField'), - value: zControlField.optional(), -}); -export type ControlInputFieldValue = z.infer; - -export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlPolymorphic'), - value: z.union([zControlField, z.array(zControlField)]).optional(), -}); -export type ControlPolymorphicInputFieldValue = z.infer< - typeof zControlPolymorphicInputFieldValue ->; - -export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlCollection'), - value: z.array(zControlField).optional(), -}); -export type ControlCollectionInputFieldValue = z.infer< - typeof zControlCollectionInputFieldValue ->; - -export const zIPAdapterModel = zModelIdentifier; -export type IPAdapterModel = z.infer; - -export const zIPAdapterField = z.object({ - image: zImageField, - ip_adapter_model: zIPAdapterModel, - weight: z.number(), - begin_step_percent: z.number().optional(), - end_step_percent: z.number().optional(), -}); -export type IPAdapterField = z.infer; - -export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IPAdapterField'), - value: zIPAdapterField.optional(), -}); -export type IPAdapterInputFieldValue = z.infer< - typeof zIPAdapterInputFieldValue ->; - -export const zIPAdapterPolymorphicInputFieldValue = zInputFieldValueBase.extend( - { - type: z.literal('IPAdapterPolymorphic'), - value: z.union([zIPAdapterField, z.array(zIPAdapterField)]).optional(), - } -); -export type IPAdapterPolymorphicInputFieldValue = z.infer< - typeof zT2IAdapterPolymorphicInputFieldValue ->; - -export const zIPAdapterCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IPAdapterCollection'), - value: z.array(zIPAdapterField).optional(), -}); -export type IPAdapterCollectionInputFieldValue = z.infer< - typeof zIPAdapterCollectionInputFieldValue ->; - -export const zT2IAdapterModel = zModelIdentifier; -export type T2IAdapterModel = z.infer; - -export const zT2IAdapterField = z.object({ - image: zImageField, - t2i_adapter_model: zT2IAdapterModel, - weight: z.union([z.number(), z.array(z.number())]).optional(), - begin_step_percent: z.number().optional(), - end_step_percent: z.number().optional(), - resize_mode: z - .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) - .optional(), -}); -export type T2IAdapterField = z.infer; - -export const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('T2IAdapterField'), - value: zT2IAdapterField.optional(), -}); -export type T2IAdapterInputFieldValue = z.infer< - typeof zT2IAdapterInputFieldValue ->; - -export const zT2IAdapterPolymorphicInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('T2IAdapterPolymorphic'), - value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(), - }); -export type T2IAdapterPolymorphicInputFieldValue = z.infer< - typeof zT2IAdapterPolymorphicInputFieldValue ->; - -export const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend( - { - type: z.literal('T2IAdapterCollection'), - value: z.array(zT2IAdapterField).optional(), - } -); -export type T2IAdapterCollectionInputFieldValue = z.infer< - typeof zT2IAdapterCollectionInputFieldValue ->; - -export const zModelType = z.enum([ - 'onnx', - 'main', - 'vae', - 'lora', - 'controlnet', - 'embedding', -]); -export type ModelType = z.infer; - -export const zSubModelType = z.enum([ - 'unet', - 'text_encoder', - 'text_encoder_2', - 'tokenizer', - 'tokenizer_2', - 'vae', - 'vae_decoder', - 'vae_encoder', - 'scheduler', - 'safety_checker', -]); -export type SubModelType = z.infer; - -export const zModelInfo = zModelIdentifier.extend({ - model_type: zModelType, - submodel: zSubModelType.optional(), -}); -export type ModelInfo = z.infer; - -export const zLoraInfo = zModelInfo.extend({ - weight: z.number().optional(), -}); -export type LoraInfo = z.infer; - -export const zUNetField = z.object({ - unet: zModelInfo, - scheduler: zModelInfo, - loras: z.array(zLoraInfo), -}); -export type UNetField = z.infer; - -export const zUNetInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('UNetField'), - value: zUNetField.optional(), -}); -export type UNetInputFieldValue = z.infer; - -export const zClipField = z.object({ - tokenizer: zModelInfo, - text_encoder: zModelInfo, - skipped_layers: z.number(), - loras: z.array(zLoraInfo), -}); -export type ClipField = z.infer; - -export const zClipInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ClipField'), - value: zClipField.optional(), -}); -export type ClipInputFieldValue = z.infer; - -export const zVaeField = z.object({ - vae: zModelInfo, -}); -export type VaeField = z.infer; - -export const zVaeInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('VaeField'), - value: zVaeField.optional(), -}); -export type VaeInputFieldValue = z.infer; - -export const zImageInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ImageField'), - value: zImageField.optional(), -}); -export type ImageInputFieldValue = z.infer; - -export const zBoardInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('BoardField'), - value: zBoardField.optional(), -}); -export type BoardInputFieldValue = z.infer; - -export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ImagePolymorphic'), - value: zImageField.optional(), -}); -export type ImagePolymorphicInputFieldValue = z.infer< - typeof zImagePolymorphicInputFieldValue ->; - -export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ImageCollection'), - value: z.array(zImageField).optional(), -}); -export type ImageCollectionInputFieldValue = z.infer< - typeof zImageCollectionInputFieldValue ->; - -export const zMainModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MainModelField'), - value: zMainOrOnnxModel.optional(), -}); -export type MainModelInputFieldValue = z.infer< - typeof zMainModelInputFieldValue ->; - -export const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('SDXLMainModelField'), - value: zMainOrOnnxModel.optional(), -}); -export type SDXLMainModelInputFieldValue = z.infer< - typeof zSDXLMainModelInputFieldValue ->; - -export const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('SDXLRefinerModelField'), - value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model -}); -export type SDXLRefinerModelInputFieldValue = z.infer< - typeof zSDXLRefinerModelInputFieldValue ->; - -export const zVaeModelField = zModelIdentifier; - -export const zVaeModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('VaeModelField'), - value: zVaeModelField.optional(), -}); -export type VaeModelInputFieldValue = z.infer; - -export const zLoRAModelField = zModelIdentifier; -export type LoRAModelField = z.infer; - -export const zLoRAModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LoRAModelField'), - value: zLoRAModelField.optional(), -}); -export type LoRAModelInputFieldValue = z.infer< - typeof zLoRAModelInputFieldValue ->; - -export const zControlNetModelField = zModelIdentifier; -export type ControlNetModelField = z.infer; - -export const zControlNetModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlNetModelField'), - value: zControlNetModelField.optional(), -}); -export type ControlNetModelInputFieldValue = z.infer< - typeof zControlNetModelInputFieldValue ->; - -export const zIPAdapterModelField = zModelIdentifier; -export type IPAdapterModelField = z.infer; - -export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IPAdapterModelField'), - value: zIPAdapterModelField.optional(), -}); -export type IPAdapterModelInputFieldValue = z.infer< - typeof zIPAdapterModelInputFieldValue ->; - -export const zT2IAdapterModelField = zModelIdentifier; -export type T2IAdapterModelField = z.infer; - -export const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('T2IAdapterModelField'), - value: zT2IAdapterModelField.optional(), -}); -export type T2IAdapterModelInputFieldValue = z.infer< - typeof zT2IAdapterModelInputFieldValue ->; - -export const zCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('Collection'), - value: z.array(z.any()).optional(), // TODO: should this field ever have a value? -}); -export type CollectionInputFieldValue = z.infer< - typeof zCollectionInputFieldValue ->; - -export const zCollectionItemInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('CollectionItem'), - value: z.any().optional(), // TODO: should this field ever have a value? -}); -export type CollectionItemInputFieldValue = z.infer< - typeof zCollectionItemInputFieldValue ->; - -export const zMetadataItemField = z.object({ - label: z.string(), - value: z.any(), -}); -export type MetadataItemField = z.infer; - -export const zMetadataItemInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MetadataItemField'), - value: zMetadataItemField.optional(), -}); -export type MetadataItemInputFieldValue = z.infer< - typeof zMetadataItemInputFieldValue ->; - -export const zMetadataItemCollectionInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('MetadataItemCollection'), - value: z.array(zMetadataItemField).optional(), - }); -export type MetadataItemCollectionInputFieldValue = z.infer< - typeof zMetadataItemCollectionInputFieldValue ->; - -export const zMetadataItemPolymorphicInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('MetadataItemPolymorphic'), - value: z - .union([zMetadataItemField, z.array(zMetadataItemField)]) - .optional(), - }); -export type MetadataItemPolymorphicInputFieldValue = z.infer< - typeof zMetadataItemPolymorphicInputFieldValue ->; - -export const zMetadataField = z.record(z.any()); -export type MetadataField = z.infer; - -export const zMetadataInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MetadataField'), - value: zMetadataField.optional(), -}); -export type MetadataInputFieldValue = z.infer; - -export const zMetadataCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MetadataCollection'), - value: z.array(zMetadataField).optional(), -}); -export type MetadataCollectionInputFieldValue = z.infer< - typeof zMetadataCollectionInputFieldValue ->; - -export const zColorField = z.object({ - r: z.number().int().min(0).max(255), - g: z.number().int().min(0).max(255), - b: z.number().int().min(0).max(255), - a: z.number().int().min(0).max(255), -}); -export type ColorField = z.infer; - -export const zColorInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ColorField'), - value: zColorField.optional(), -}); -export type ColorInputFieldValue = z.infer; - -export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ColorCollection'), - value: z.array(zColorField).optional(), -}); -export type ColorCollectionInputFieldValue = z.infer< - typeof zColorCollectionInputFieldValue ->; - -export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ColorPolymorphic'), - value: z.union([zColorField, z.array(zColorField)]).optional(), -}); -export type ColorPolymorphicInputFieldValue = z.infer< - typeof zColorPolymorphicInputFieldValue ->; - -export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('Scheduler'), - value: zScheduler.optional(), -}); -export type SchedulerInputFieldValue = z.infer< - typeof zSchedulerInputFieldValue ->; - -export const zAnyInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('Any'), - value: z.any().optional(), -}); - -export const zInputFieldValue = z.discriminatedUnion('type', [ - zAnyInputFieldValue, - zBoardInputFieldValue, - zBooleanCollectionInputFieldValue, - zBooleanInputFieldValue, - zBooleanPolymorphicInputFieldValue, - zClipInputFieldValue, - zCollectionInputFieldValue, - zCollectionItemInputFieldValue, - zColorInputFieldValue, - zColorCollectionInputFieldValue, - zColorPolymorphicInputFieldValue, - zConditioningInputFieldValue, - zConditioningCollectionInputFieldValue, - zConditioningPolymorphicInputFieldValue, - zControlInputFieldValue, - zControlNetModelInputFieldValue, - zControlCollectionInputFieldValue, - zControlPolymorphicInputFieldValue, - zDenoiseMaskInputFieldValue, - zEnumInputFieldValue, - zFloatCollectionInputFieldValue, - zFloatInputFieldValue, - zFloatPolymorphicInputFieldValue, - zImageCollectionInputFieldValue, - zImagePolymorphicInputFieldValue, - zImageInputFieldValue, - zIntegerCollectionInputFieldValue, - zIntegerPolymorphicInputFieldValue, - zIntegerInputFieldValue, - zIPAdapterInputFieldValue, - zIPAdapterModelInputFieldValue, - zIPAdapterCollectionInputFieldValue, - zIPAdapterPolymorphicInputFieldValue, - zLatentsInputFieldValue, - zLatentsCollectionInputFieldValue, - zLatentsPolymorphicInputFieldValue, - zLoRAModelInputFieldValue, - zMainModelInputFieldValue, - zSchedulerInputFieldValue, - zSDXLMainModelInputFieldValue, - zSDXLRefinerModelInputFieldValue, - zStringCollectionInputFieldValue, - zStringPolymorphicInputFieldValue, - zStringInputFieldValue, - zT2IAdapterInputFieldValue, - zT2IAdapterModelInputFieldValue, - zT2IAdapterCollectionInputFieldValue, - zT2IAdapterPolymorphicInputFieldValue, - zUNetInputFieldValue, - zVaeInputFieldValue, - zVaeModelInputFieldValue, - zMetadataItemInputFieldValue, - zMetadataItemCollectionInputFieldValue, - zMetadataItemPolymorphicInputFieldValue, - zMetadataInputFieldValue, - zMetadataCollectionInputFieldValue, -]); - -export type InputFieldValue = z.infer; - -export type InputFieldTemplateBase = { - name: string; - title: string; - description: string; - required: boolean; - fieldKind: 'input'; -} & _InputField; - -export type AnyInputFieldTemplate = InputFieldTemplateBase & { - type: 'Any'; - default: undefined; -}; - -export type IntegerInputFieldTemplate = InputFieldTemplateBase & { - type: 'integer'; - default: number; - multipleOf?: number; - maximum?: number; - exclusiveMaximum?: number; - minimum?: number; - exclusiveMinimum?: number; -}; - -export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'IntegerCollection'; - default: number[]; - item_default?: number; -}; - -export type IntegerPolymorphicInputFieldTemplate = Omit< - IntegerInputFieldTemplate, - 'type' -> & { - type: 'IntegerPolymorphic'; -}; - -export type FloatInputFieldTemplate = InputFieldTemplateBase & { - type: 'float'; - default: number; - multipleOf?: number; - maximum?: number; - exclusiveMaximum?: number; - minimum?: number; - exclusiveMinimum?: number; -}; - -export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'FloatCollection'; - default: number[]; - item_default?: number; -}; - -export type FloatPolymorphicInputFieldTemplate = Omit< - FloatInputFieldTemplate, - 'type' -> & { - type: 'FloatPolymorphic'; -}; - -export type StringInputFieldTemplate = InputFieldTemplateBase & { - type: 'string'; - default: string; - maxLength?: number; - minLength?: number; - pattern?: string; -}; - -export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'StringCollection'; - default: string[]; - item_default?: string; -}; - -export type StringPolymorphicInputFieldTemplate = Omit< - StringInputFieldTemplate, - 'type' -> & { - type: 'StringPolymorphic'; -}; - -export type BooleanInputFieldTemplate = InputFieldTemplateBase & { - default: boolean; - type: 'boolean'; -}; - -export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'BooleanCollection'; - default: boolean[]; - item_default?: boolean; -}; - -export type BooleanPolymorphicInputFieldTemplate = Omit< - BooleanInputFieldTemplate, - 'type' -> & { - type: 'BooleanPolymorphic'; -}; - -export type BoardInputFieldTemplate = InputFieldTemplateBase & { - default: BoardField; - type: 'BoardField'; -}; - -export type ImageInputFieldTemplate = InputFieldTemplateBase & { - default: ImageField; - type: 'ImageField'; -}; - -export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: ImageField[]; - type: 'ImageCollection'; - item_default?: ImageField; -}; - -export type ImagePolymorphicInputFieldTemplate = Omit< - ImageInputFieldTemplate, - 'type' -> & { - type: 'ImagePolymorphic'; -}; - -export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'DenoiseMaskField'; -}; - -export type LatentsInputFieldTemplate = InputFieldTemplateBase & { - default: LatentsField; - type: 'LatentsField'; -}; - -export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: LatentsField[]; - type: 'LatentsCollection'; - item_default?: LatentsField; -}; - -export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & { - default: LatentsField; - type: 'LatentsPolymorphic'; -}; - -export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ConditioningField'; -}; - -export type ConditioningCollectionInputFieldTemplate = - InputFieldTemplateBase & { - default: ConditioningField[]; - type: 'ConditioningCollection'; - item_default?: ConditioningField; - }; - -export type ConditioningPolymorphicInputFieldTemplate = Omit< - ConditioningInputFieldTemplate, - 'type' -> & { - type: 'ConditioningPolymorphic'; -}; - -export type UNetInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'UNetField'; -}; - -export type MetadataItemFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataItemField'; -}; - -export type ClipInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ClipField'; -}; - -export type VaeInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'VaeField'; -}; - -export type ControlInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ControlField'; -}; - -export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ControlCollection'; - item_default?: ControlField; -}; - -export type ControlPolymorphicInputFieldTemplate = Omit< - ControlInputFieldTemplate, - 'type' -> & { - type: 'ControlPolymorphic'; -}; - -export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'IPAdapterField'; -}; - -export type IPAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'IPAdapterCollection'; - item_default?: IPAdapterField; -}; - -export type IPAdapterPolymorphicInputFieldTemplate = Omit< - IPAdapterInputFieldTemplate, - 'type' -> & { - type: 'IPAdapterPolymorphic'; -}; - -export type T2IAdapterInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'T2IAdapterField'; -}; - -export type T2IAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'T2IAdapterCollection'; - item_default?: T2IAdapterField; -}; - -export type T2IAdapterPolymorphicInputFieldTemplate = Omit< - T2IAdapterInputFieldTemplate, - 'type' -> & { - type: 'T2IAdapterPolymorphic'; -}; - -export type EnumInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'enum'; - options: string[]; - labels?: { [key: string]: string }; -}; - -export type MainModelInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MainModelField'; -}; - -export type SDXLMainModelInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'SDXLMainModelField'; -}; - -export type SDXLRefinerModelInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'SDXLRefinerModelField'; -}; - -export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'VaeModelField'; -}; - -export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'LoRAModelField'; -}; - -export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'ControlNetModelField'; -}; - -export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'IPAdapterModelField'; -}; - -export type T2IAdapterModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'T2IAdapterModelField'; -}; - -export type CollectionInputFieldTemplate = InputFieldTemplateBase & { - default: []; - type: 'Collection'; -}; - -export type CollectionItemInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'CollectionItem'; -}; - -export type ColorInputFieldTemplate = InputFieldTemplateBase & { - default: RgbaColor; - type: 'ColorField'; -}; - -export type ColorPolymorphicInputFieldTemplate = Omit< - ColorInputFieldTemplate, - 'type' -> & { - type: 'ColorPolymorphic'; -}; - -export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: []; - type: 'ColorCollection'; -}; - -export type SchedulerInputFieldTemplate = InputFieldTemplateBase & { - default: SchedulerParam; - type: 'Scheduler'; -}; - -export type WorkflowInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'WorkflowField'; -}; - -export type MetadataItemInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataItemField'; -}; - -export type MetadataItemCollectionInputFieldTemplate = - InputFieldTemplateBase & { - default: undefined; - type: 'MetadataItemCollection'; - }; - -export type MetadataItemPolymorphicInputFieldTemplate = Omit< - MetadataItemInputFieldTemplate, - 'type' -> & { - type: 'MetadataItemPolymorphic'; -}; - -export type MetadataInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataField'; -}; - -export type MetadataCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataCollection'; -}; - -/** - * An input field template is generated on each page load from the OpenAPI schema. - * - * The template provides the field type and other field metadata (e.g. title, description, - * maximum length, pattern to match, etc). - */ -export type InputFieldTemplate = - | AnyInputFieldTemplate - | BoardInputFieldTemplate - | BooleanCollectionInputFieldTemplate - | BooleanPolymorphicInputFieldTemplate - | BooleanInputFieldTemplate - | ClipInputFieldTemplate - | CollectionInputFieldTemplate - | CollectionItemInputFieldTemplate - | ColorInputFieldTemplate - | ColorCollectionInputFieldTemplate - | ColorPolymorphicInputFieldTemplate - | ConditioningInputFieldTemplate - | ConditioningCollectionInputFieldTemplate - | ConditioningPolymorphicInputFieldTemplate - | ControlInputFieldTemplate - | ControlCollectionInputFieldTemplate - | ControlNetModelInputFieldTemplate - | ControlPolymorphicInputFieldTemplate - | DenoiseMaskInputFieldTemplate - | EnumInputFieldTemplate - | FloatCollectionInputFieldTemplate - | FloatInputFieldTemplate - | FloatPolymorphicInputFieldTemplate - | ImageCollectionInputFieldTemplate - | ImagePolymorphicInputFieldTemplate - | ImageInputFieldTemplate - | IntegerCollectionInputFieldTemplate - | IntegerPolymorphicInputFieldTemplate - | IntegerInputFieldTemplate - | IPAdapterInputFieldTemplate - | IPAdapterCollectionInputFieldTemplate - | IPAdapterModelInputFieldTemplate - | IPAdapterPolymorphicInputFieldTemplate - | LatentsInputFieldTemplate - | LatentsCollectionInputFieldTemplate - | LatentsPolymorphicInputFieldTemplate - | LoRAModelInputFieldTemplate - | MainModelInputFieldTemplate - | SchedulerInputFieldTemplate - | SDXLMainModelInputFieldTemplate - | SDXLRefinerModelInputFieldTemplate - | StringCollectionInputFieldTemplate - | StringPolymorphicInputFieldTemplate - | StringInputFieldTemplate - | T2IAdapterInputFieldTemplate - | T2IAdapterCollectionInputFieldTemplate - | T2IAdapterModelInputFieldTemplate - | T2IAdapterPolymorphicInputFieldTemplate - | UNetInputFieldTemplate - | VaeInputFieldTemplate - | VaeModelInputFieldTemplate - | MetadataItemInputFieldTemplate - | MetadataItemCollectionInputFieldTemplate - | MetadataInputFieldTemplate - | MetadataItemPolymorphicInputFieldTemplate - | MetadataCollectionInputFieldTemplate; - -export const isInputFieldValue = ( - field?: InputFieldValue | OutputFieldValue -): field is InputFieldValue => Boolean(field && field.fieldKind === 'input'); - -export const isInputFieldTemplate = ( - fieldTemplate?: InputFieldTemplate | OutputFieldTemplate -): fieldTemplate is InputFieldTemplate => - Boolean(fieldTemplate && fieldTemplate.fieldKind === 'input'); - -/** - * JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES - */ - -export type TypeHints = { - [fieldName: string]: FieldType; -}; - -export type InvocationSchemaExtra = { - output: OpenAPIV3_1.ReferenceObject; // the output of the invocation - title: string; - category?: string; - tags?: string[]; - version?: string; - properties: Omit< - NonNullable & - (_InputField | _OutputField), - 'type' - > & { - type: Omit & { - default: AnyInvocationType; - }; - use_cache: Omit & { - default: boolean; - }; - }; -}; - -export type InvocationSchemaType = { - default: string; // the type of the invocation -}; - -export type InvocationBaseSchemaObject = Omit< - OpenAPIV3_1.BaseSchemaObject, - 'title' | 'type' | 'properties' -> & - InvocationSchemaExtra; - -export type InvocationOutputSchemaObject = Omit< - OpenAPIV3_1.SchemaObject, - 'properties' -> & { - properties: OpenAPIV3_1.SchemaObject['properties'] & { - type: Omit & { - default: string; - }; - } & { - class: 'output'; - }; -}; - -export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & _InputField; - -export type OpenAPIV3_1SchemaOrRef = - | OpenAPIV3_1.ReferenceObject - | OpenAPIV3_1.SchemaObject; - -export interface ArraySchemaObject extends InvocationBaseSchemaObject { - type: OpenAPIV3_1.ArraySchemaObjectType; - items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; -} -export interface NonArraySchemaObject extends InvocationBaseSchemaObject { - type?: OpenAPIV3_1.NonArraySchemaObjectType; -} - -export type InvocationSchemaObject = ( - | ArraySchemaObject - | NonArraySchemaObject -) & { class: 'invocation' }; - -export const isSchemaObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); - -export const isArraySchemaObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.ArraySchemaObject => - Boolean(obj && !('$ref' in obj) && obj.type === 'array'); - -export const isNonArraySchemaObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.NonArraySchemaObject => - Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); - -export const isRefObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); - -export const isInvocationSchemaObject = ( - obj: - | OpenAPIV3_1.ReferenceObject - | OpenAPIV3_1.SchemaObject - | InvocationSchemaObject -): obj is InvocationSchemaObject => - 'class' in obj && obj.class === 'invocation'; - -export const isInvocationOutputSchemaObject = ( - obj: - | OpenAPIV3_1.ReferenceObject - | OpenAPIV3_1.SchemaObject - | InvocationOutputSchemaObject -): obj is InvocationOutputSchemaObject => - 'class' in obj && obj.class === 'output'; - -export const isInvocationFieldSchema = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject -): obj is InvocationFieldSchema => !('$ref' in obj); - -export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; - -export const zLoRAMetadataItem = z.object({ - lora: zLoRAModelField.deepPartial(), - weight: z.number(), -}); - -export type LoRAMetadataItem = z.infer; - -const zControlNetMetadataItem = zControlField.deepPartial(); - -export type ControlNetMetadataItem = z.infer; - -const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); - -export type IPAdapterMetadataItem = z.infer; - -const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); - -export type T2IAdapterMetadataItem = z.infer; - -export const zCoreMetadata = z - .object({ - app_version: z.string().nullish().catch(null), - generation_mode: z.string().nullish().catch(null), - created_by: z.string().nullish().catch(null), - positive_prompt: z.string().nullish().catch(null), - negative_prompt: z.string().nullish().catch(null), - width: z.number().int().nullish().catch(null), - height: z.number().int().nullish().catch(null), - seed: z.number().int().nullish().catch(null), - rand_device: z.string().nullish().catch(null), - cfg_scale: z.number().nullish().catch(null), - steps: z.number().int().nullish().catch(null), - scheduler: z.string().nullish().catch(null), - clip_skip: z.number().int().nullish().catch(null), - model: z - .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) - .nullish() - .catch(null), - controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), - ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), - t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), - loras: z.array(zLoRAMetadataItem).nullish().catch(null), - vae: zVaeModelField.nullish().catch(null), - strength: z.number().nullish().catch(null), - hrf_enabled: z.boolean().nullish().catch(null), - hrf_strength: z.number().nullish().catch(null), - hrf_method: z.string().nullish().catch(null), - init_image: z.string().nullish().catch(null), - positive_style_prompt: z.string().nullish().catch(null), - negative_style_prompt: z.string().nullish().catch(null), - refiner_model: zSDXLRefinerModel.deepPartial().nullish().catch(null), - refiner_cfg_scale: z.number().nullish().catch(null), - refiner_steps: z.number().int().nullish().catch(null), - refiner_scheduler: z.string().nullish().catch(null), - refiner_positive_aesthetic_score: z.number().nullish().catch(null), - refiner_negative_aesthetic_score: z.number().nullish().catch(null), - refiner_start: z.number().nullish().catch(null), - }) - .passthrough(); - -export type CoreMetadata = z.infer; - -export const zSemVer = z.string().refine((val) => { - const [major, minor, patch] = val.split('.'); - return ( - major !== undefined && - Number.isInteger(Number(major)) && - minor !== undefined && - Number.isInteger(Number(minor)) && - patch !== undefined && - Number.isInteger(Number(patch)) - ); -}); - -export const zParsedSemver = zSemVer.transform((val) => { - const [major, minor, patch] = val.split('.'); - return { - major: Number(major), - minor: Number(minor), - patch: Number(patch), - }; -}); - -export type SemVer = z.infer; - -export const zInvocationNodeData = z.object({ - id: z.string().trim().min(1), - // no easy way to build this dynamically, and we don't want to anyways, because this will be used - // to validate incoming workflows, and we want to allow community nodes. - type: z.string().trim().min(1), - inputs: z.record(zInputFieldValue), - outputs: z.record(zOutputFieldValue), - label: z.string(), - isOpen: z.boolean(), - notes: z.string(), - embedWorkflow: z.boolean(), - isIntermediate: z.boolean(), - useCache: z.boolean().optional(), - version: zSemVer.optional(), -}); - -export const zInvocationNodeDataV2 = z.preprocess( - (arg) => { - try { - const data = zInvocationNodeData.parse(arg); - if (!has(data, 'useCache')) { - const nodeTemplates = $store.get()?.getState().nodes.nodeTemplates as - | Record - | undefined; - - const template = nodeTemplates?.[data.type]; - - let useCache = true; - if (template) { - useCache = template.useCache; - } - - Object.assign(data, { useCache }); - } - return data; - } catch { - return arg; - } - }, - zInvocationNodeData.extend({ - useCache: z.boolean(), - }) -); - -// Massage this to get better type safety while developing -export type InvocationNodeData = Omit< - z.infer, - 'type' -> & { - type: AnyInvocationType; -}; - -export const zNotesNodeData = z.object({ - id: z.string().trim().min(1), - type: z.literal('notes'), - label: z.string(), - isOpen: z.boolean(), - notes: z.string(), -}); - -export type NotesNodeData = z.infer; - -const zPosition = z - .object({ - x: z.number(), - y: z.number(), - }) - .default({ x: 0, y: 0 }); - -const zDimension = z.number().gt(0).nullish(); - -export const zWorkflowInvocationNode = z.object({ - id: z.string().trim().min(1), - type: z.literal('invocation'), - data: zInvocationNodeDataV2, - width: zDimension, - height: zDimension, - position: zPosition, -}); - -export type WorkflowInvocationNode = z.infer; - -export const isWorkflowInvocationNode = ( - val: unknown -): val is WorkflowInvocationNode => - zWorkflowInvocationNode.safeParse(val).success; - -export const zWorkflowNotesNode = z.object({ - id: z.string().trim().min(1), - type: z.literal('notes'), - data: zNotesNodeData, - width: zDimension, - height: zDimension, - position: zPosition, -}); - -export const zWorkflowNode = z.discriminatedUnion('type', [ - zWorkflowInvocationNode, - zWorkflowNotesNode, -]); - -export type WorkflowNode = z.infer; - -export const zDefaultWorkflowEdge = z.object({ - source: z.string().trim().min(1), - sourceHandle: z.string().trim().min(1), - target: z.string().trim().min(1), - targetHandle: z.string().trim().min(1), - id: z.string().trim().min(1), - type: z.literal('default'), -}); -export const zCollapsedWorkflowEdge = z.object({ - source: z.string().trim().min(1), - target: z.string().trim().min(1), - id: z.string().trim().min(1), - type: z.literal('collapsed'), -}); - -export const zWorkflowEdge = z.union([ - zDefaultWorkflowEdge, - zCollapsedWorkflowEdge, -]); - -export const zFieldIdentifier = z.object({ - nodeId: z.string().trim().min(1), - fieldName: z.string().trim().min(1), -}); - -export type FieldIdentifier = z.infer; - -export type WorkflowWarning = { - message: string; - issues: string[]; - data: JsonObject; -}; - -const CURRENT_WORKFLOW_VERSION = '1.0.0'; - -export const zWorkflow = z.object({ - name: z.string().default(''), - author: z.string().default(''), - description: z.string().default(''), - version: z.string().default(''), - contact: z.string().default(''), - tags: z.string().default(''), - notes: z.string().default(''), - nodes: z.array(zWorkflowNode).default([]), - edges: z.array(zWorkflowEdge).default([]), - exposedFields: z.array(zFieldIdentifier).default([]), - meta: z - .object({ - version: zSemVer, - }) - .default({ version: CURRENT_WORKFLOW_VERSION }), -}); - -export const zValidatedWorkflow = zWorkflow.transform((workflow) => { - const { nodes, edges } = workflow; - const warnings: WorkflowWarning[] = []; - const invocationNodes = nodes.filter(isWorkflowInvocationNode); - const keyedNodes = keyBy(invocationNodes, 'id'); - edges.forEach((edge, i) => { - const sourceNode = keyedNodes[edge.source]; - const targetNode = keyedNodes[edge.target]; - const issues: string[] = []; - if (!sourceNode) { - issues.push( - `${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t( - 'nodes.doesNotExist' - )}` - ); - } else if ( - edge.type === 'default' && - !(edge.sourceHandle in sourceNode.data.outputs) - ) { - issues.push( - `${i18n.t('nodes.outputField')}"${edge.source}.${ - edge.sourceHandle - }" ${i18n.t('nodes.doesNotExist')}` - ); - } - if (!targetNode) { - issues.push( - `${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t( - 'nodes.doesNotExist' - )}` - ); - } else if ( - edge.type === 'default' && - !(edge.targetHandle in targetNode.data.inputs) - ) { - issues.push( - `${i18n.t('nodes.inputField')} "${edge.target}.${ - edge.targetHandle - }" ${i18n.t('nodes.doesNotExist')}` - ); - } - if (issues.length) { - delete edges[i]; - const src = edge.type === 'default' ? edge.sourceHandle : edge.source; - const tgt = edge.type === 'default' ? edge.targetHandle : edge.target; - warnings.push({ - message: `${i18n.t('nodes.edge')} "${src} -> ${tgt}" ${i18n.t( - 'nodes.skipped' - )}`, - issues, - data: edge, - }); - } - }); - return { workflow, warnings }; -}); - -export type Workflow = z.infer; - -export type ImageMetadataAndWorkflow = { - metadata?: CoreMetadata; - workflow?: Workflow; -}; - -export type CurrentImageNodeData = { - id: string; - type: 'current_image'; - isOpen: boolean; - label: string; -}; - -export type NodeData = - | InvocationNodeData - | NotesNodeData - | CurrentImageNodeData; - -export const isInvocationNode = ( - node?: Node -): node is Node => - Boolean(node && node.type === 'invocation'); - -export const isInvocationNodeData = ( - node?: NodeData -): node is InvocationNodeData => - Boolean(node && !['notes', 'current_image'].includes(node.type)); - -export const isNotesNode = ( - node?: Node -): node is Node => Boolean(node && node.type === 'notes'); - -export const isProgressImageNode = ( - node?: Node -): node is Node => - Boolean(node && node.type === 'current_image'); - -export enum NodeStatus { - PENDING, - IN_PROGRESS, - COMPLETED, - FAILED, -} - -export type NodeExecutionState = { - nodeId: string; - status: NodeStatus; - progress: number | null; - progressImage: ProgressImage | null; - error: string | null; - outputs: AnyResult[]; -}; - -export type FieldComponentProps< - V extends InputFieldValue, - T extends InputFieldTemplate, -> = { - nodeId: string; - field: V; - fieldTemplate: T; -}; diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts new file mode 100644 index 00000000000..7af8a2dd728 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -0,0 +1,91 @@ +import { z } from 'zod'; +import { zFieldIdentifier } from './field'; +import { zInvocationNodeData, zNotesNodeData } from './invocation'; + +// #region Workflow misc +export const zXYPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); +export type XYPosition = z.infer; + +export const zDimension = z.number().gt(0).nullish(); +export type Dimension = z.infer; +// #endregion + +// #region Workflow Nodes +export const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNode = z.union([ + zWorkflowInvocationNode, + zWorkflowNotesNode, +]); + +export type WorkflowInvocationNode = z.infer; +export type WorkflowNotesNode = z.infer; +export type WorkflowNode = z.infer; + +export const isWorkflowInvocationNode = ( + val: unknown +): val is WorkflowInvocationNode => + zWorkflowInvocationNode.safeParse(val).success; +// #endregion + +// #region Workflow Edges +export const zWorkflowEdgeBase = z.object({ + id: z.string().trim().min(1), + source: z.string().trim().min(1), + target: z.string().trim().min(1), +}); +export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ + type: z.literal('default'), + sourceHandle: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), +}); +export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ + type: z.literal('collapsed'), +}); +export const zWorkflowEdge = z.union([ + zWorkflowEdgeDefault, + zWorkflowEdgeCollapsed, +]); + +export type WorkflowEdgeDefault = z.infer; +export type WorkflowEdgeCollapsed = z.infer; +export type WorkflowEdge = z.infer; +// #endregion + +// #region Workflow +export const zWorkflowV2 = z.object({ + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + version: z.literal('2.0.0'), + }), +}); +export type WorkflowV2 = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts new file mode 100644 index 00000000000..200bd98e86f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts @@ -0,0 +1,42 @@ +import { get } from 'lodash-es'; +import { FieldInputInstance, FieldInputTemplate } from '../types/field'; + +const FIELD_VALUE_FALLBACK_MAP = { + EnumField: '', + BoardField: undefined, + BooleanField: false, + ClipField: undefined, + ColorField: { r: 0, g: 0, b: 0, a: 1 }, + FloatField: 0, + ImageField: undefined, + IntegerField: 0, + IPAdapterModelField: undefined, + LoRAModelField: undefined, + MainModelField: undefined, + ONNXModelField: undefined, + SchedulerField: 'euler', + SDXLMainModelField: undefined, + SDXLRefinerModelField: undefined, + StringField: '', + T2IAdapterModelField: undefined, + T2IAdapterPolymorphic: undefined, + VAEModelField: undefined, + ControlNetModelField: undefined, +}; + +export const buildFieldInputInstance = ( + id: string, + template: FieldInputTemplate +): FieldInputInstance => { + const fieldInstance: FieldInputInstance = { + id, + name: template.name, + type: template.type, + label: '', + fieldKind: 'input' as const, + value: + template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name), + }; + + return fieldInstance; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts new file mode 100644 index 00000000000..8d11ac25b98 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts @@ -0,0 +1,376 @@ +import { isNumber, startCase } from 'lodash-es'; +import { + BoardFieldInputTemplate, + BooleanFieldInputTemplate, + ColorFieldInputTemplate, + ControlNetModelFieldInputTemplate, + EnumFieldInputTemplate, + FieldInputTemplate, + FieldType, + FloatFieldInputTemplate, + IPAdapterModelFieldInputTemplate, + ImageFieldInputTemplate, + IntegerFieldInputTemplate, + LoRAModelFieldInputTemplate, + MainModelFieldInputTemplate, + SDXLMainModelFieldInputTemplate, + SDXLRefinerModelFieldInputTemplate, + SchedulerFieldInputTemplate, + StatefulFieldType, + StatelessFieldInputTemplate, + StringFieldInputTemplate, + T2IAdapterModelFieldInputTemplate, + VAEModelFieldInputTemplate, + isStatefulFieldType, +} from '../types/field'; +import { InvocationFieldSchema } from '../types/openapi'; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type FieldInputTemplateBuilder = // valid `any`! + (arg: { + schemaObject: InvocationFieldSchema; + baseField: Omit; + isCollection: boolean; + isPolymorphic: boolean; + }) => T; + +const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder< + IntegerFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: IntegerFieldInputTemplate = { + ...baseField, + type: { name: 'IntegerField', isCollection, isPolymorphic }, + default: schemaObject.default ?? 0, + }; + + if (schemaObject.multipleOf !== undefined) { + template.multipleOf = schemaObject.multipleOf; + } + + if (schemaObject.maximum !== undefined) { + template.maximum = schemaObject.maximum; + } + + if ( + schemaObject.exclusiveMaximum !== undefined && + isNumber(schemaObject.exclusiveMaximum) + ) { + template.exclusiveMaximum = schemaObject.exclusiveMaximum; + } + + if (schemaObject.minimum !== undefined) { + template.minimum = schemaObject.minimum; + } + + if ( + schemaObject.exclusiveMinimum !== undefined && + isNumber(schemaObject.exclusiveMinimum) + ) { + template.exclusiveMinimum = schemaObject.exclusiveMinimum; + } + + return template; +}; + +const buildFloatFieldInputTemplate: FieldInputTemplateBuilder< + FloatFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: FloatFieldInputTemplate = { + ...baseField, + type: { name: 'FloatField', isCollection, isPolymorphic }, + default: schemaObject.default ?? 0, + }; + + if (schemaObject.multipleOf !== undefined) { + template.multipleOf = schemaObject.multipleOf; + } + + if (schemaObject.maximum !== undefined) { + template.maximum = schemaObject.maximum; + } + + if ( + schemaObject.exclusiveMaximum !== undefined && + isNumber(schemaObject.exclusiveMaximum) + ) { + template.exclusiveMaximum = schemaObject.exclusiveMaximum; + } + + if (schemaObject.minimum !== undefined) { + template.minimum = schemaObject.minimum; + } + + if ( + schemaObject.exclusiveMinimum !== undefined && + isNumber(schemaObject.exclusiveMinimum) + ) { + template.exclusiveMinimum = schemaObject.exclusiveMinimum; + } + + return template; +}; + +const buildStringFieldInputTemplate: FieldInputTemplateBuilder< + StringFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: StringFieldInputTemplate = { + ...baseField, + type: { name: 'StringField', isCollection, isPolymorphic }, + default: schemaObject.default ?? '', + }; + + if (schemaObject.minLength !== undefined) { + template.minLength = schemaObject.minLength; + } + + if (schemaObject.maxLength !== undefined) { + template.maxLength = schemaObject.maxLength; + } + + return template; +}; + +const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder< + BooleanFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: BooleanFieldInputTemplate = { + ...baseField, + type: { name: 'BooleanField', isCollection, isPolymorphic }, + default: schemaObject.default ?? false, + }; + + return template; +}; + +const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder< + MainModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: MainModelFieldInputTemplate = { + ...baseField, + type: { name: 'MainModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder< + SDXLMainModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: SDXLMainModelFieldInputTemplate = { + ...baseField, + type: { name: 'SDXLMainModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder< + SDXLRefinerModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: SDXLRefinerModelFieldInputTemplate = { + ...baseField, + type: { name: 'SDXLRefinerModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder< + VAEModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: VAEModelFieldInputTemplate = { + ...baseField, + type: { name: 'VAEModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder< + LoRAModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: LoRAModelFieldInputTemplate = { + ...baseField, + type: { name: 'LoRAModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder< + ControlNetModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: ControlNetModelFieldInputTemplate = { + ...baseField, + type: { name: 'ControlNetModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder< + IPAdapterModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: IPAdapterModelFieldInputTemplate = { + ...baseField, + type: { name: 'IPAdapterModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder< + T2IAdapterModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: T2IAdapterModelFieldInputTemplate = { + ...baseField, + type: { name: 'T2IAdapterModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildBoardFieldInputTemplate: FieldInputTemplateBuilder< + BoardFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: BoardFieldInputTemplate = { + ...baseField, + type: { name: 'BoardField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildImageFieldInputTemplate: FieldInputTemplateBuilder< + ImageFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: ImageFieldInputTemplate = { + ...baseField, + type: { name: 'ImageField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildEnumFieldInputTemplate: FieldInputTemplateBuilder< + EnumFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const options = schemaObject.enum ?? []; + const template: EnumFieldInputTemplate = { + ...baseField, + type: { name: 'EnumField', isCollection, isPolymorphic }, + options, + ui_choice_labels: schemaObject.ui_choice_labels, + default: schemaObject.default ?? options[0], + }; + + return template; +}; + +const buildColorFieldInputTemplate: FieldInputTemplateBuilder< + ColorFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: ColorFieldInputTemplate = { + ...baseField, + type: { name: 'ColorField', isCollection, isPolymorphic }, + default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, + }; + + return template; +}; + +const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder< + SchedulerFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: SchedulerFieldInputTemplate = { + ...baseField, + type: { name: 'SchedulerField', isCollection, isPolymorphic }, + default: schemaObject.default ?? 'euler', + }; + + return template; +}; + +export const TEMPLATE_BUILDER_MAP: Record< + StatefulFieldType['name'], + FieldInputTemplateBuilder +> = { + BoardField: buildBoardFieldInputTemplate, + BooleanField: buildBooleanFieldInputTemplate, + ColorField: buildColorFieldInputTemplate, + ControlNetModelField: buildControlNetModelFieldInputTemplate, + EnumField: buildEnumFieldInputTemplate, + FloatField: buildFloatFieldInputTemplate, + ImageField: buildImageFieldInputTemplate, + IntegerField: buildIntegerFieldInputTemplate, + IPAdapterModelField: buildIPAdapterModelFieldInputTemplate, + LoRAModelField: buildLoRAModelFieldInputTemplate, + MainModelField: buildMainModelFieldInputTemplate, + SchedulerField: buildSchedulerFieldInputTemplate, + SDXLMainModelField: buildSDXLMainModelFieldInputTemplate, + SDXLRefinerModelField: buildRefinerModelFieldInputTemplate, + StringField: buildStringFieldInputTemplate, + T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate, + VAEModelField: buildVAEModelFieldInputTemplate, +}; + +export const buildFieldInputTemplate = ( + fieldSchema: InvocationFieldSchema, + name: string, + fieldType: FieldType +): FieldInputTemplate => { + const { + input, + ui_hidden, + ui_component, + ui_type, + ui_order, + ui_choice_labels, + orig_required: required, + } = fieldSchema; + + // This is the base field template that is common to all fields. The builder function will add all other + // properties to this template. + const baseField: Omit = { + name, + title: fieldSchema.title ?? (name ? startCase(name) : ''), + required, + description: fieldSchema.description ?? '', + fieldKind: 'input' as const, + input, + ui_hidden, + ui_component, + ui_type, + ui_order, + ui_choice_labels, + }; + + if (isStatefulFieldType(fieldType)) { + const builder = TEMPLATE_BUILDER_MAP[fieldType.name]; + return builder({ + schemaObject: fieldSchema, + baseField, + isCollection: fieldType.isCollection, + isPolymorphic: fieldType.isPolymorphic, + }); + } + + // This is a StatelessField, create it directly. + const template: StatelessFieldInputTemplate = { + ...baseField, + input: 'connection', // stateless --> connection only inputs + type: fieldType, + default: undefined, // stateless --> no default value + }; + return template; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts index 43ee75b735a..7e49be4068f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts @@ -1,13 +1,13 @@ import { logger } from 'app/logging/logger'; import { NodesState } from '../store/types'; -import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types'; +import { WorkflowV2, zWorkflowEdge, zWorkflowNode } from '../types/workflow'; import { fromZodError } from 'zod-validation-error'; import { parseify } from 'common/util/serialize'; import i18n from 'i18next'; -export const buildWorkflow = (nodesState: NodesState): Workflow => { +export const buildWorkflow = (nodesState: NodesState): WorkflowV2 => { const { workflow: workflowMeta, nodes, edges } = nodesState; - const workflow: Workflow = { + const workflow: WorkflowV2 = { ...workflowMeta, nodes: [], edges: [], diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts deleted file mode 100644 index 92e44e9ab2c..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ /dev/null @@ -1,1210 +0,0 @@ -import { - isArray, - isBoolean, - isInteger, - isNumber, - isString, - startCase, -} from 'lodash-es'; -import { OpenAPIV3_1 } from 'openapi-types'; -import { ControlField } from 'services/api/types'; -import { - COLLECTION_MAP, - POLYMORPHIC_TYPES, - SINGLE_TO_POLYMORPHIC_MAP, - isCollectionItemType, - isPolymorphicItemType, -} from '../types/constants'; -import { - AnyInputFieldTemplate, - BoardInputFieldTemplate, - BooleanCollectionInputFieldTemplate, - BooleanInputFieldTemplate, - BooleanPolymorphicInputFieldTemplate, - ClipInputFieldTemplate, - CollectionInputFieldTemplate, - CollectionItemInputFieldTemplate, - ColorCollectionInputFieldTemplate, - ColorInputFieldTemplate, - ColorPolymorphicInputFieldTemplate, - ConditioningCollectionInputFieldTemplate, - ConditioningField, - ConditioningInputFieldTemplate, - ConditioningPolymorphicInputFieldTemplate, - ControlCollectionInputFieldTemplate, - ControlInputFieldTemplate, - ControlNetModelInputFieldTemplate, - ControlPolymorphicInputFieldTemplate, - DenoiseMaskInputFieldTemplate, - EnumInputFieldTemplate, - FieldType, - FloatCollectionInputFieldTemplate, - FloatInputFieldTemplate, - FloatPolymorphicInputFieldTemplate, - IPAdapterCollectionInputFieldTemplate, - IPAdapterField, - IPAdapterInputFieldTemplate, - IPAdapterModelInputFieldTemplate, - IPAdapterPolymorphicInputFieldTemplate, - ImageCollectionInputFieldTemplate, - ImageField, - ImageInputFieldTemplate, - ImagePolymorphicInputFieldTemplate, - InputFieldTemplate, - InputFieldTemplateBase, - IntegerCollectionInputFieldTemplate, - IntegerInputFieldTemplate, - IntegerPolymorphicInputFieldTemplate, - InvocationFieldSchema, - InvocationSchemaObject, - LatentsCollectionInputFieldTemplate, - LatentsField, - LatentsInputFieldTemplate, - LatentsPolymorphicInputFieldTemplate, - LoRAModelInputFieldTemplate, - MainModelInputFieldTemplate, - MetadataCollectionInputFieldTemplate, - MetadataInputFieldTemplate, - MetadataItemCollectionInputFieldTemplate, - MetadataItemInputFieldTemplate, - MetadataItemPolymorphicInputFieldTemplate, - OpenAPIV3_1SchemaOrRef, - SDXLMainModelInputFieldTemplate, - SDXLRefinerModelInputFieldTemplate, - SchedulerInputFieldTemplate, - StringCollectionInputFieldTemplate, - StringInputFieldTemplate, - StringPolymorphicInputFieldTemplate, - T2IAdapterCollectionInputFieldTemplate, - T2IAdapterField, - T2IAdapterInputFieldTemplate, - T2IAdapterModelInputFieldTemplate, - T2IAdapterPolymorphicInputFieldTemplate, - UNetInputFieldTemplate, - VaeInputFieldTemplate, - VaeModelInputFieldTemplate, - isArraySchemaObject, - isNonArraySchemaObject, - isRefObject, - isSchemaObject, -} from '../types/types'; - -export type BaseFieldProperties = 'name' | 'title' | 'description'; - -export type BuildInputFieldArg = { - schemaObject: InvocationFieldSchema; - baseField: Omit; -}; - -/** - * Transforms an invocation output ref object to field type. - * @param ref The ref string to transform - * @returns The field type. - * - * @example - * refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField' - */ -export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) => - refObject.$ref.split('/').slice(-1)[0]; - -const buildIntegerInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IntegerInputFieldTemplate => { - const template: IntegerInputFieldTemplate = { - ...baseField, - type: 'integer', - default: schemaObject.default ?? 0, - }; - - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - - return template; -}; - -const buildIntegerPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => { - const template: IntegerPolymorphicInputFieldTemplate = { - ...baseField, - type: 'IntegerPolymorphic', - default: schemaObject.default ?? 0, - }; - - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - - return template; -}; - -const buildIntegerCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => { - const item_default = - isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default) - ? schemaObject.item_default - : 0; - const template: IntegerCollectionInputFieldTemplate = { - ...baseField, - type: 'IntegerCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildFloatInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): FloatInputFieldTemplate => { - const template: FloatInputFieldTemplate = { - ...baseField, - type: 'float', - default: schemaObject.default ?? 0, - }; - - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - - return template; -}; - -const buildFloatPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => { - const template: FloatPolymorphicInputFieldTemplate = { - ...baseField, - type: 'FloatPolymorphic', - default: schemaObject.default ?? 0, - }; - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - return template; -}; - -const buildFloatCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => { - const item_default = isNumber(schemaObject.item_default) - ? schemaObject.item_default - : 0; - const template: FloatCollectionInputFieldTemplate = { - ...baseField, - type: 'FloatCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildStringInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): StringInputFieldTemplate => { - const template: StringInputFieldTemplate = { - ...baseField, - type: 'string', - default: schemaObject.default ?? '', - }; - - if (schemaObject.minLength !== undefined) { - template.minLength = schemaObject.minLength; - } - - if (schemaObject.maxLength !== undefined) { - template.maxLength = schemaObject.maxLength; - } - - if (schemaObject.pattern !== undefined) { - template.pattern = schemaObject.pattern; - } - - return template; -}; - -const buildStringPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => { - const template: StringPolymorphicInputFieldTemplate = { - ...baseField, - type: 'StringPolymorphic', - default: schemaObject.default ?? '', - }; - - if (schemaObject.minLength !== undefined) { - template.minLength = schemaObject.minLength; - } - - if (schemaObject.maxLength !== undefined) { - template.maxLength = schemaObject.maxLength; - } - - if (schemaObject.pattern !== undefined) { - template.pattern = schemaObject.pattern; - } - - return template; -}; - -const buildStringCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): StringCollectionInputFieldTemplate => { - const item_default = isString(schemaObject.item_default) - ? schemaObject.item_default - : ''; - const template: StringCollectionInputFieldTemplate = { - ...baseField, - type: 'StringCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildBooleanInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BooleanInputFieldTemplate => { - const template: BooleanInputFieldTemplate = { - ...baseField, - type: 'boolean', - default: schemaObject.default ?? false, - }; - - return template; -}; - -const buildBooleanPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => { - const template: BooleanPolymorphicInputFieldTemplate = { - ...baseField, - type: 'BooleanPolymorphic', - default: schemaObject.default ?? false, - }; - - return template; -}; - -const buildBooleanCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => { - const item_default = - schemaObject.item_default && isBoolean(schemaObject.item_default) - ? schemaObject.item_default - : false; - const template: BooleanCollectionInputFieldTemplate = { - ...baseField, - type: 'BooleanCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildMainModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): MainModelInputFieldTemplate => { - const template: MainModelInputFieldTemplate = { - ...baseField, - type: 'MainModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildSDXLMainModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): SDXLMainModelInputFieldTemplate => { - const template: SDXLMainModelInputFieldTemplate = { - ...baseField, - type: 'SDXLMainModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildRefinerModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): SDXLRefinerModelInputFieldTemplate => { - const template: SDXLRefinerModelInputFieldTemplate = { - ...baseField, - type: 'SDXLRefinerModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildVaeModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): VaeModelInputFieldTemplate => { - const template: VaeModelInputFieldTemplate = { - ...baseField, - type: 'VaeModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLoRAModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LoRAModelInputFieldTemplate => { - const template: LoRAModelInputFieldTemplate = { - ...baseField, - type: 'LoRAModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlNetModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => { - const template: ControlNetModelInputFieldTemplate = { - ...baseField, - type: 'ControlNetModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildIPAdapterModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => { - const template: IPAdapterModelInputFieldTemplate = { - ...baseField, - type: 'IPAdapterModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterModelInputFieldTemplate => { - const template: T2IAdapterModelInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildBoardInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BoardInputFieldTemplate => { - const template: BoardInputFieldTemplate = { - ...baseField, - type: 'BoardField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildImageInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ImageInputFieldTemplate => { - const template: ImageInputFieldTemplate = { - ...baseField, - type: 'ImageField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildImagePolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => { - const template: ImagePolymorphicInputFieldTemplate = { - ...baseField, - type: 'ImagePolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildImageCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ImageCollectionInputFieldTemplate => { - const template: ImageCollectionInputFieldTemplate = { - ...baseField, - type: 'ImageCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as ImageField) ?? undefined, - }; - - return template; -}; - -const buildDenoiseMaskInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): DenoiseMaskInputFieldTemplate => { - const template: DenoiseMaskInputFieldTemplate = { - ...baseField, - type: 'DenoiseMaskField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLatentsInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LatentsInputFieldTemplate => { - const template: LatentsInputFieldTemplate = { - ...baseField, - type: 'LatentsField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLatentsPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => { - const template: LatentsPolymorphicInputFieldTemplate = { - ...baseField, - type: 'LatentsPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLatentsCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => { - const template: LatentsCollectionInputFieldTemplate = { - ...baseField, - type: 'LatentsCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as LatentsField) ?? undefined, - }; - - return template; -}; - -const buildConditioningInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ConditioningInputFieldTemplate => { - const template: ConditioningInputFieldTemplate = { - ...baseField, - type: 'ConditioningField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildConditioningPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => { - const template: ConditioningPolymorphicInputFieldTemplate = { - ...baseField, - type: 'ConditioningPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildConditioningCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => { - const template: ConditioningCollectionInputFieldTemplate = { - ...baseField, - type: 'ConditioningCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as ConditioningField) ?? undefined, - }; - - return template; -}; - -const buildUNetInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): UNetInputFieldTemplate => { - const template: UNetInputFieldTemplate = { - ...baseField, - type: 'UNetField', - - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildClipInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ClipInputFieldTemplate => { - const template: ClipInputFieldTemplate = { - ...baseField, - type: 'ClipField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildVaeInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): VaeInputFieldTemplate => { - const template: VaeInputFieldTemplate = { - ...baseField, - type: 'VaeField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlInputFieldTemplate => { - const template: ControlInputFieldTemplate = { - ...baseField, - type: 'ControlField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => { - const template: ControlPolymorphicInputFieldTemplate = { - ...baseField, - type: 'ControlPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => { - const template: ControlCollectionInputFieldTemplate = { - ...baseField, - type: 'ControlCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as ControlField) ?? undefined, - }; - - return template; -}; - -const buildIPAdapterInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterInputFieldTemplate => { - const template: IPAdapterInputFieldTemplate = { - ...baseField, - type: 'IPAdapterField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildIPAdapterPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterPolymorphicInputFieldTemplate => { - const template: IPAdapterPolymorphicInputFieldTemplate = { - ...baseField, - type: 'IPAdapterPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildIPAdapterCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterCollectionInputFieldTemplate => { - const template: IPAdapterCollectionInputFieldTemplate = { - ...baseField, - type: 'IPAdapterCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as IPAdapterField) ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterInputFieldTemplate => { - const template: T2IAdapterInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterPolymorphicInputFieldTemplate => { - const template: T2IAdapterPolymorphicInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterCollectionInputFieldTemplate => { - const template: T2IAdapterCollectionInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as T2IAdapterField) ?? undefined, - }; - - return template; -}; - -const buildEnumInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): EnumInputFieldTemplate => { - const options = schemaObject.enum ?? []; - const template: EnumInputFieldTemplate = { - ...baseField, - type: 'enum', - options, - ui_choice_labels: schemaObject.ui_choice_labels, - default: schemaObject.default ?? options[0], - }; - - return template; -}; - -const buildCollectionInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): CollectionInputFieldTemplate => { - const template: CollectionInputFieldTemplate = { - ...baseField, - type: 'Collection', - default: [], - }; - - return template; -}; - -const buildCollectionItemInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): CollectionItemInputFieldTemplate => { - const template: CollectionItemInputFieldTemplate = { - ...baseField, - type: 'CollectionItem', - default: undefined, - }; - - return template; -}; - -const buildAnyInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): AnyInputFieldTemplate => { - const template: AnyInputFieldTemplate = { - ...baseField, - type: 'Any', - default: undefined, - }; - - return template; -}; - -const buildMetadataItemInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataItemInputFieldTemplate => { - const template: MetadataItemInputFieldTemplate = { - ...baseField, - type: 'MetadataItemField', - default: undefined, - }; - - return template; -}; - -const buildMetadataItemCollectionInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataItemCollectionInputFieldTemplate => { - const template: MetadataItemCollectionInputFieldTemplate = { - ...baseField, - type: 'MetadataItemCollection', - default: undefined, - }; - - return template; -}; - -const buildMetadataItemPolymorphicInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataItemPolymorphicInputFieldTemplate => { - const template: MetadataItemPolymorphicInputFieldTemplate = { - ...baseField, - type: 'MetadataItemPolymorphic', - default: undefined, - }; - - return template; -}; - -const buildMetadataDictInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataInputFieldTemplate => { - const template: MetadataInputFieldTemplate = { - ...baseField, - type: 'MetadataField', - default: undefined, - }; - - return template; -}; - -const buildMetadataCollectionInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataCollectionInputFieldTemplate => { - const template: MetadataCollectionInputFieldTemplate = { - ...baseField, - type: 'MetadataCollection', - default: undefined, - }; - - return template; -}; - -const buildColorInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ColorInputFieldTemplate => { - const template: ColorInputFieldTemplate = { - ...baseField, - type: 'ColorField', - default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, - }; - - return template; -}; - -const buildColorPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => { - const template: ColorPolymorphicInputFieldTemplate = { - ...baseField, - type: 'ColorPolymorphic', - default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, - }; - - return template; -}; - -const buildColorCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => { - const template: ColorCollectionInputFieldTemplate = { - ...baseField, - type: 'ColorCollection', - default: schemaObject.default ?? [], - }; - - return template; -}; - -const buildSchedulerInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): SchedulerInputFieldTemplate => { - const template: SchedulerInputFieldTemplate = { - ...baseField, - type: 'Scheduler', - default: schemaObject.default ?? 'euler', - }; - - return template; -}; - -export const getFieldType = ( - schemaObject: OpenAPIV3_1SchemaOrRef -): string | undefined => { - if (isSchemaObject(schemaObject)) { - if (!schemaObject.type) { - // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf - - if (schemaObject.allOf) { - const allOf = schemaObject.allOf; - if (allOf && allOf[0] && isRefObject(allOf[0])) { - return refObjectToSchemaName(allOf[0]); - } - } else if (schemaObject.anyOf) { - // ignore null types - const anyOf = schemaObject.anyOf.filter((i) => { - if (isSchemaObject(i)) { - if (i.type === 'null') { - return false; - } - } - return true; - }); - if (anyOf.length === 1) { - if (isRefObject(anyOf[0])) { - return refObjectToSchemaName(anyOf[0]); - } else if (isSchemaObject(anyOf[0])) { - return getFieldType(anyOf[0]); - } - } - /** - * Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is: - * - an `anyOf` with two items - * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items` - * - the other is a `SchemaObject` or `ReferenceObject` of type T - * - * Any other cases we ignore. - */ - - let firstType: string | undefined; - let secondType: string | undefined; - - if (isArraySchemaObject(anyOf[0])) { - // first is array, second is not - const first = anyOf[0].items; - const second = anyOf[1]; - if (isRefObject(first) && isRefObject(second)) { - firstType = refObjectToSchemaName(first); - secondType = refObjectToSchemaName(second); - } else if ( - isNonArraySchemaObject(first) && - isNonArraySchemaObject(second) - ) { - firstType = first.type; - secondType = second.type; - } - } else if (isArraySchemaObject(anyOf[1])) { - // first is not array, second is - const first = anyOf[0]; - const second = anyOf[1].items; - if (isRefObject(first) && isRefObject(second)) { - firstType = refObjectToSchemaName(first); - secondType = refObjectToSchemaName(second); - } else if ( - isNonArraySchemaObject(first) && - isNonArraySchemaObject(second) - ) { - firstType = first.type; - secondType = second.type; - } - } - if (firstType === secondType && isPolymorphicItemType(firstType)) { - return SINGLE_TO_POLYMORPHIC_MAP[firstType]; - } - } - } else if (schemaObject.enum) { - return 'enum'; - } else if (schemaObject.type) { - if (schemaObject.type === 'number') { - // floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them - return 'float'; - } else if (schemaObject.type === 'array') { - const itemType = isSchemaObject(schemaObject.items) - ? schemaObject.items.type - : refObjectToSchemaName(schemaObject.items); - - if (isArray(itemType)) { - // This is a nested array, which we don't support - return; - } - - if (isCollectionItemType(itemType)) { - return COLLECTION_MAP[itemType]; - } - - return; - } else if (!isArray(schemaObject.type)) { - return schemaObject.type; - } - } - } else if (isRefObject(schemaObject)) { - return refObjectToSchemaName(schemaObject); - } - return; -}; - -const TEMPLATE_BUILDER_MAP: { - [key in FieldType]?: (arg: BuildInputFieldArg) => InputFieldTemplate; -} = { - BoardField: buildBoardInputFieldTemplate, - Any: buildAnyInputFieldTemplate, - boolean: buildBooleanInputFieldTemplate, - BooleanCollection: buildBooleanCollectionInputFieldTemplate, - BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate, - ClipField: buildClipInputFieldTemplate, - Collection: buildCollectionInputFieldTemplate, - CollectionItem: buildCollectionItemInputFieldTemplate, - ColorCollection: buildColorCollectionInputFieldTemplate, - ColorField: buildColorInputFieldTemplate, - ColorPolymorphic: buildColorPolymorphicInputFieldTemplate, - ConditioningCollection: buildConditioningCollectionInputFieldTemplate, - ConditioningField: buildConditioningInputFieldTemplate, - ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate, - ControlCollection: buildControlCollectionInputFieldTemplate, - ControlField: buildControlInputFieldTemplate, - ControlNetModelField: buildControlNetModelInputFieldTemplate, - ControlPolymorphic: buildControlPolymorphicInputFieldTemplate, - DenoiseMaskField: buildDenoiseMaskInputFieldTemplate, - enum: buildEnumInputFieldTemplate, - float: buildFloatInputFieldTemplate, - FloatCollection: buildFloatCollectionInputFieldTemplate, - FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate, - ImageCollection: buildImageCollectionInputFieldTemplate, - ImageField: buildImageInputFieldTemplate, - ImagePolymorphic: buildImagePolymorphicInputFieldTemplate, - integer: buildIntegerInputFieldTemplate, - IntegerCollection: buildIntegerCollectionInputFieldTemplate, - IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate, - IPAdapterCollection: buildIPAdapterCollectionInputFieldTemplate, - IPAdapterField: buildIPAdapterInputFieldTemplate, - IPAdapterModelField: buildIPAdapterModelInputFieldTemplate, - IPAdapterPolymorphic: buildIPAdapterPolymorphicInputFieldTemplate, - LatentsCollection: buildLatentsCollectionInputFieldTemplate, - LatentsField: buildLatentsInputFieldTemplate, - LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate, - LoRAModelField: buildLoRAModelInputFieldTemplate, - MetadataItemField: buildMetadataItemInputFieldTemplate, - MetadataItemCollection: buildMetadataItemCollectionInputFieldTemplate, - MetadataItemPolymorphic: buildMetadataItemPolymorphicInputFieldTemplate, - MetadataField: buildMetadataDictInputFieldTemplate, - MetadataCollection: buildMetadataCollectionInputFieldTemplate, - MainModelField: buildMainModelInputFieldTemplate, - Scheduler: buildSchedulerInputFieldTemplate, - SDXLMainModelField: buildSDXLMainModelInputFieldTemplate, - SDXLRefinerModelField: buildRefinerModelInputFieldTemplate, - string: buildStringInputFieldTemplate, - StringCollection: buildStringCollectionInputFieldTemplate, - StringPolymorphic: buildStringPolymorphicInputFieldTemplate, - T2IAdapterCollection: buildT2IAdapterCollectionInputFieldTemplate, - T2IAdapterField: buildT2IAdapterInputFieldTemplate, - T2IAdapterModelField: buildT2IAdapterModelInputFieldTemplate, - T2IAdapterPolymorphic: buildT2IAdapterPolymorphicInputFieldTemplate, - UNetField: buildUNetInputFieldTemplate, - VaeField: buildVaeInputFieldTemplate, - VaeModelField: buildVaeModelInputFieldTemplate, -}; - -const isTemplatedFieldType = ( - fieldType: string | undefined -): fieldType is keyof typeof TEMPLATE_BUILDER_MAP => - Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP); - -/** - * Builds an input field from an invocation schema property. - * @param fieldSchema The schema object - * @returns An input field - */ -export const buildInputFieldTemplate = ( - nodeSchema: InvocationSchemaObject, - fieldSchema: InvocationFieldSchema, - name: string, - fieldType: FieldType -) => { - const { - input, - ui_hidden, - ui_component, - ui_type, - ui_order, - ui_choice_labels, - item_default, - } = fieldSchema; - - const extra = { - // TODO: Can we support polymorphic inputs in the UI? - input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input, - ui_hidden, - ui_component, - ui_type, - required: nodeSchema.required?.includes(name) ?? false, - ui_order, - ui_choice_labels, - item_default, - }; - - const baseField = { - name, - title: fieldSchema.title ?? (name ? startCase(name) : ''), - description: fieldSchema.description ?? '', - fieldKind: 'input' as const, - ...extra, - }; - - if (!isTemplatedFieldType(fieldType)) { - return; - } - - const builder = TEMPLATE_BUILDER_MAP[fieldType]; - - if (!builder) { - return; - } - - return builder({ - schemaObject: fieldSchema, - baseField, - }); -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts deleted file mode 100644 index ca2513649d3..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { FieldType, InputFieldTemplate, InputFieldValue } from '../types/types'; - -const FIELD_VALUE_FALLBACK_MAP: { - [key in FieldType]: InputFieldValue['value']; -} = { - Any: undefined, - enum: '', - BoardField: undefined, - boolean: false, - BooleanCollection: [], - BooleanPolymorphic: false, - ClipField: undefined, - Collection: [], - CollectionItem: undefined, - ColorCollection: [], - ColorField: undefined, - ColorPolymorphic: undefined, - ConditioningCollection: [], - ConditioningField: undefined, - ConditioningPolymorphic: undefined, - ControlCollection: [], - ControlField: undefined, - ControlNetModelField: undefined, - ControlPolymorphic: undefined, - DenoiseMaskField: undefined, - float: 0, - FloatCollection: [], - FloatPolymorphic: 0, - ImageCollection: [], - ImageField: undefined, - ImagePolymorphic: undefined, - integer: 0, - IntegerCollection: [], - IntegerPolymorphic: 0, - IPAdapterCollection: [], - IPAdapterField: undefined, - IPAdapterModelField: undefined, - IPAdapterPolymorphic: undefined, - LatentsCollection: [], - LatentsField: undefined, - LatentsPolymorphic: undefined, - MetadataItemField: undefined, - MetadataItemCollection: [], - MetadataItemPolymorphic: undefined, - MetadataField: undefined, - MetadataCollection: [], - LoRAModelField: undefined, - MainModelField: undefined, - ONNXModelField: undefined, - Scheduler: 'euler', - SDXLMainModelField: undefined, - SDXLRefinerModelField: undefined, - string: '', - StringCollection: [], - StringPolymorphic: '', - T2IAdapterCollection: [], - T2IAdapterField: undefined, - T2IAdapterModelField: undefined, - T2IAdapterPolymorphic: undefined, - UNetField: undefined, - VaeField: undefined, - VaeModelField: undefined, -}; - -export const buildInputFieldValue = ( - id: string, - template: InputFieldTemplate -): InputFieldValue => { - // TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't - // resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both - // `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the - // `InputFieldValue` union, but TS doesn't seem to like it... - const fieldValue = { - id, - name: template.name, - type: template.type, - label: '', - fieldKind: 'input', - } as InputFieldValue; - - fieldValue.value = - template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type]; - - return fieldValue; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts b/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts index b235fe8a07f..2ed5faca293 100644 --- a/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts @@ -1,8 +1,8 @@ import { isNil } from 'lodash-es'; -import { InputFieldTemplate, OutputFieldTemplate } from '../types/types'; +import { FieldInputTemplate, FieldOutputTemplate } from '../types/field'; export const getSortedFilteredFieldNames = ( - fields: InputFieldTemplate[] | OutputFieldTemplate[] + fields: FieldInputTemplate[] | FieldOutputTemplate[] ): string[] => { const visibleFields = fields.filter((field) => !field.ui_hidden); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts index 60d4e36dcab..ff6028b38eb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts @@ -6,8 +6,8 @@ import { ControlField, ControlNetInvocation, CoreMetadataInvocation, + NonNullableGraph, } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, CONTROL_NET_COLLECT, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts index 9825ce754e2..edbba61b734 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts @@ -1,13 +1,13 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, ESRGANInvocation, Edge, LatentsToImageInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { DENOISE_LATENTS, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts index 93c6cdb284d..9dd8b253683 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts @@ -6,8 +6,8 @@ import { CoreMetadataInvocation, IPAdapterInvocation, IPAdapterMetadataField, + NonNullableGraph, } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER_COLLECT, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts index 926fa3a8f3f..1676a5f53d4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts @@ -1,7 +1,6 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { LinearUIOutputInvocation } from 'services/api/types'; +import { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; import { CANVAS_OUTPUT, LATENTS_TO_IMAGE, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts index 66c2bd04445..acbe53e6114 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -1,9 +1,9 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { forEach, size } from 'lodash-es'; import { CoreMetadataInvocation, LoraLoaderInvocation, + NonNullableGraph, } from 'services/api/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts index 94fddccc8f1..d4cd5e83cb2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts @@ -1,8 +1,8 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageNSFWBlurInvocation, LatentsToImageInvocation, + NonNullableGraph, } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts index 04841f0def3..544958c39dd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts @@ -1,11 +1,10 @@ import { RootState } from 'app/store/store'; import { LoRAMetadataItem, - NonNullableGraph, zLoRAMetadataItem, -} from 'features/nodes/types/types'; +} from 'features/nodes/types/metadata'; import { forEach, size } from 'lodash-es'; -import { SDXLLoraLoaderInvocation } from 'services/api/types'; +import { NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, LORA_LOADER, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts index 136263f63ed..8976d7ed5fa 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts @@ -2,9 +2,9 @@ import { RootState } from 'app/store/store'; import { CreateDenoiseMaskInvocation, ImageDTO, + NonNullableGraph, SeamlessModeInvocation, } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; import { CANVAS_OUTPUT, INPAINT_IMAGE_RESIZE_UP, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts index ba341a8a3da..d062b25309b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts @@ -1,7 +1,5 @@ import { RootState } from 'app/store/store'; -import { SeamlessModeInvocation } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; -import { upsertMetadata } from './metadata'; +import { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_INPAINT_GRAPH, @@ -16,6 +14,7 @@ import { SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; +import { upsertMetadata } from './metadata'; export const addSeamlessToLinearGraph = ( state: RootState, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts index 71c2aaeede5..550f9ba5f3c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts @@ -4,9 +4,10 @@ import { omit } from 'lodash-es'; import { CollectInvocation, CoreMetadataInvocation, + NonNullableGraph, + T2IAdapterField, T2IAdapterInvocation, } from 'services/api/types'; -import { NonNullableGraph, T2IAdapterField } from '../../types/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, T2I_ADAPTER_COLLECT, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts index f049a89e362..438bbfd8922 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts @@ -1,5 +1,5 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; +import { NonNullableGraph } from 'services/api/types'; import { CANVAS_COHERENCE_INPAINT_CREATE_MASK, CANVAS_IMAGE_TO_IMAGE_GRAPH, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts index c43437e4fcf..f553e6d0f90 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts @@ -1,10 +1,10 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { ImageNSFWBlurInvocation, ImageWatermarkInvocation, LatentsToImageInvocation, + NonNullableGraph, } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts index 8331c81eb37..60143252ab6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts @@ -1,10 +1,10 @@ import { BoardId } from 'features/gallery/store/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice'; import { ESRGANInvocation, Graph, LinearUIOutputInvocation, + NonNullableGraph, } from 'services/api/types'; import { ESRGAN, LINEAR_UI_OUTPUT } from './constants'; import { addCoreMetadataNode, upsertMetadata } from './metadata'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts index d268a3990d9..66500c9ce53 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts @@ -1,6 +1,5 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { ImageDTO } from 'services/api/types'; +import { ImageDTO, NonNullableGraph } from 'services/api/types'; import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph'; import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph'; import { buildCanvasOutpaintGraph } from './buildCanvasOutpaintGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index a86fdb4ce6f..2866aeff075 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -1,12 +1,15 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types'; +import { + ImageDTO, + ImageToLatentsInvocation, + NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 48052e2a94b..6253ce1f18a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -1,6 +1,5 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -8,12 +7,13 @@ import { ImageToLatentsInvocation, MaskEdgeInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts index 31cf5ca7e86..11573c21c28 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts @@ -1,18 +1,18 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, ImageToLatentsInvocation, InfillPatchMatchInvocation, InfillTileInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts index 8281c9c2489..f579a7d9e7e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts @@ -1,14 +1,18 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types'; +import { + ImageDTO, + ImageToLatentsInvocation, + NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; +import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { @@ -26,7 +30,6 @@ import { SEAMLESS, } from './constants'; import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt'; -import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addCoreMetadataNode } from './metadata'; /** diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts index 40626e289a7..de1dd0dfd2c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts @@ -1,6 +1,5 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -8,13 +7,14 @@ import { ImageToLatentsInvocation, MaskEdgeInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts index c7302cd56da..2f8b4fd6531 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts @@ -1,19 +1,19 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, ImageToLatentsInvocation, InfillPatchMatchInvocation, InfillTileInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts index 2a712f2ef31..0a456aaccfb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts @@ -1,16 +1,16 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, + NonNullableGraph, ONNXTextToLatentsInvocation, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index 5c0c91ca71a..72d7c1e460c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -1,15 +1,15 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, + NonNullableGraph, ONNXTextToLatentsInvocation, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts index 59f8d4123fe..865a4535aec 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts @@ -1,10 +1,9 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { RootState } from 'app/store/store'; import { generateSeeds } from 'common/util/generateSeeds'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { range } from 'lodash-es'; import { components } from 'services/api/schema'; -import { Batch, BatchConfig } from 'services/api/types'; +import { Batch, BatchConfig, NonNullableGraph } from 'services/api/types'; import { CANVAS_COHERENCE_NOISE, METADATA, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index d1897ad9ee7..8d4d0ae35fb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -1,15 +1,15 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageResizeInvocation, ImageToLatentsInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts index 0b57dcd5bf4..621219eb679 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts @@ -1,16 +1,16 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageResizeInvocation, ImageToLatentsInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts index 37e9b293c6c..df28fbbd620 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts @@ -1,17 +1,16 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; +import { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; -import { addCoreMetadataNode } from './metadata'; import { LATENTS_TO_IMAGE, NEGATIVE_CONDITIONING, @@ -24,6 +23,7 @@ import { SEAMLESS, } from './constants'; import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt'; +import { addCoreMetadataNode } from './metadata'; export const buildLinearSDXLTextToImageGraph = ( state: RootState diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index f097cf0c421..ffec2a409f1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,21 +1,20 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, + NonNullableGraph, ONNXTextToLatentsInvocation, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addHrfToGraph } from './addHrfToGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; -import { addCoreMetadataNode } from './metadata'; import { CLIP_SKIP, DENOISE_LATENTS, @@ -28,6 +27,7 @@ import { SEAMLESS, TEXT_TO_IMAGE_GRAPH, } from './constants'; +import { addCoreMetadataNode } from './metadata'; export const buildLinearTextToImageGraph = ( state: RootState diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index eb782f456a1..9ed0eb1d32c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,16 +1,20 @@ import { NodesState } from 'features/nodes/store/types'; -import { InputFieldValue, isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; import { buildWorkflow } from '../buildWorkflow'; +import { + FieldInputInstance, + isColorFieldInputInstance, +} from 'features/nodes/types/field'; /** * We need to do special handling for some fields */ -export const parseFieldValue = (field: InputFieldValue) => { - if (field.type === 'ColorField') { +export const parseFieldValue = (field: FieldInputInstance) => { + if (isColorFieldInputInstance(field)) { if (field.value) { const clonedValue = cloneDeep(field.value); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts index c80e1c80c69..f78a0af035f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts @@ -1,5 +1,4 @@ -import { NonNullableGraph } from 'features/nodes/types/types'; -import { CoreMetadataInvocation } from 'services/api/types'; +import { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types'; import { JsonObject } from 'type-fest'; import { METADATA } from './constants'; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts new file mode 100644 index 00000000000..133a3d11c9a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts @@ -0,0 +1,233 @@ +import { t } from 'i18next'; +import { isArray } from 'lodash-es'; +import { OpenAPIV3_1 } from 'openapi-types'; +import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; +import { FieldType } from '../types/field'; +import { + OpenAPIV3_1SchemaOrRef, + isArraySchemaObject, + isInvocationFieldSchema, + isNonArraySchemaObject, + isRefObject, + isSchemaObject, +} from '../types/openapi'; + +/** + * Transforms an invocation output ref object to field type. + * @param ref The ref string to transform + * @returns The field type. + * + * @example + * refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField' + */ +export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) => + refObject.$ref.split('/').slice(-1)[0]; + +const OPENAPI_TO_FIELD_TYPE_MAP: Record = { + integer: 'IntegerField', + number: 'FloatField', + string: 'StringField', + boolean: 'BooleanField', +}; + +const isCollectionFieldType = (fieldType: string) => { + /** + * CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between + * it and other `list[Any]` fields, due to its special internal handling. + * + * In pydantic, it gets an explicit field type of `CollectionField`. + */ + if (fieldType === 'CollectionField') { + return true; + } + return false; +}; + +export const parseFieldType = ( + schemaObject: OpenAPIV3_1SchemaOrRef +): FieldType => { + if (isInvocationFieldSchema(schemaObject)) { + // Check if this field has an explicit type provided by the node schema + const { ui_type } = schemaObject; + if (ui_type) { + return { + name: ui_type, + isCollection: isCollectionFieldType(ui_type), + isPolymorphic: false, + }; + } + } + if (isSchemaObject(schemaObject)) { + if (!schemaObject.type) { + // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf + + if (schemaObject.allOf) { + const allOf = schemaObject.allOf; + if (allOf && allOf[0] && isRefObject(allOf[0])) { + // This is a single ref type + const name = refObjectToSchemaName(allOf[0]); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } + } else if (schemaObject.anyOf) { + // ignore null types + const filteredAnyOf = schemaObject.anyOf.filter((i) => { + if (isSchemaObject(i)) { + if (i.type === 'null') { + return false; + } + } + return true; + }); + if (filteredAnyOf.length === 1) { + // This is a single ref type + if (isRefObject(filteredAnyOf[0])) { + const name = refObjectToSchemaName(filteredAnyOf[0]); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } else if (isSchemaObject(filteredAnyOf[0])) { + return parseFieldType(filteredAnyOf[0]); + } + } + /** + * Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is: + * - an `anyOf` with two items + * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items` + * - the other is a `SchemaObject` or `ReferenceObject` of type T + * + * Any other cases we ignore. + */ + + let firstType: string | undefined; + let secondType: string | undefined; + + if (isArraySchemaObject(filteredAnyOf[0])) { + // first is array, second is not + const first = filteredAnyOf[0].items; + const second = filteredAnyOf[1]; + if (isRefObject(first) && isRefObject(second)) { + firstType = refObjectToSchemaName(first); + secondType = refObjectToSchemaName(second); + } else if ( + isNonArraySchemaObject(first) && + isNonArraySchemaObject(second) + ) { + firstType = first.type; + secondType = second.type; + } + } else if (isArraySchemaObject(filteredAnyOf[1])) { + // first is not array, second is + const first = filteredAnyOf[0]; + const second = filteredAnyOf[1].items; + if (isRefObject(first) && isRefObject(second)) { + firstType = refObjectToSchemaName(first); + secondType = refObjectToSchemaName(second); + } else if ( + isNonArraySchemaObject(first) && + isNonArraySchemaObject(second) + ) { + firstType = first.type; + secondType = second.type; + } + } + if (firstType && firstType === secondType) { + return { + name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType, + isCollection: false, + isPolymorphic: true, // <-- don't forget, polymorphic! + }; + } + } + } else if (schemaObject.enum) { + return { name: 'EnumField', isCollection: false, isPolymorphic: false }; + } else if (schemaObject.type) { + if (schemaObject.type === 'array') { + // We need to get the type of the items + if (isSchemaObject(schemaObject.items)) { + const itemType = schemaObject.items.type; + if (!itemType || isArray(itemType)) { + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedArrayItemType', { + type: itemType, + }) + ); + } + // This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean' + const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType]; + if (!name) { + // it's 'null', 'object', or 'array' - skip + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedArrayItemType', { + type: itemType, + }) + ); + } + return { + name, + isCollection: true, // <-- don't forget, collection! + isPolymorphic: false, + }; + } + + // This is a ref object, extract the type name + const name = refObjectToSchemaName(schemaObject.items); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + return { + name, + isCollection: true, // <-- don't forget, collection! + isPolymorphic: false, + }; + } else if (!isArray(schemaObject.type)) { + // This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean' + const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type]; + if (!name) { + // it's 'null', 'object', or 'array' - skip + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedArrayItemType', { + type: schemaObject.type, + }) + ); + } + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } + } + } else if (isRefObject(schemaObject)) { + const name = refObjectToSchemaName(schemaObject); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } + throw new FieldTypeParseError(t('nodes.unableToParseFieldType')); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 8737fc52b9c..2c59b6cb14f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -2,24 +2,24 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import { reduce, startCase } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; -import { AnyInvocationType } from 'services/events/types'; +import { FieldInputTemplate, FieldOutputTemplate } from '../types/field'; +import { InvocationTemplate } from '../types/invocation'; import { - InputFieldTemplate, InvocationSchemaObject, - InvocationTemplate, - OutputFieldTemplate, - isFieldType, isInvocationFieldSchema, isInvocationOutputSchemaObject, isInvocationSchemaObject, -} from '../types/types'; -import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders'; +} from '../types/openapi'; +import { buildFieldInputTemplate } from './buildFieldInputTemplate'; +import { parseFieldType } from './parseFieldType'; +import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; +import { t } from 'i18next'; const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache']; const RESERVED_OUTPUT_FIELD_NAMES = ['type']; const RESERVED_FIELD_TYPES = ['IsIntermediate']; -const invocationDenylist: AnyInvocationType[] = ['graph', 'linear_ui_output']; +const invocationDenylist: string[] = ['graph', 'linear_ui_output']; const isReservedInputField = (nodeType: string, fieldName: string) => { if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) { @@ -83,13 +83,13 @@ export const parseSchema = ( const inputs = reduce( schema.properties, ( - inputsAccumulator: Record, + inputsAccumulator: Record, property, propertyName ) => { if (isReservedInputField(type, propertyName)) { logger('nodes').trace( - { node: type, fieldName: propertyName, field: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Skipped reserved input field' ); return inputsAccumulator; @@ -97,79 +97,53 @@ export const parseSchema = ( if (!isInvocationFieldSchema(property)) { logger('nodes').warn( - { node: type, propertyName, property: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Unhandled input property' ); return inputsAccumulator; } - const fieldType = property.ui_type ?? getFieldType(property); + try { + const fieldType = parseFieldType(property); - if (!fieldType) { - logger('nodes').warn( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - 'Missing input field type' - ); - return inputsAccumulator; - } - - if (fieldType === 'WorkflowField') { - withWorkflow = true; - return inputsAccumulator; - } + if (fieldType.name === 'WorkflowField') { + // This supports workflows, set the flag and skip to next field + withWorkflow = true; + return inputsAccumulator; + } - if (isReservedFieldType(fieldType)) { - logger('nodes').trace( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - `Skipping reserved input field type: ${fieldType}` - ); - return inputsAccumulator; - } + if (isReservedFieldType(fieldType.name)) { + // Skip processing this reserved field + return inputsAccumulator; + } - if (!isFieldType(fieldType)) { - logger('nodes').warn( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - `Skipping unknown input field type: ${fieldType}` + const fieldInputTemplate = buildFieldInputTemplate( + property, + propertyName, + fieldType ); - return inputsAccumulator; - } - - const field = buildInputFieldTemplate( - schema, - property, - propertyName, - fieldType - ); - if (!field) { - logger('nodes').warn( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - 'Skipping input field with no template' - ); - return inputsAccumulator; + inputsAccumulator[propertyName] = fieldInputTemplate; + } catch (e) { + if ( + e instanceof FieldTypeParseError || + e instanceof UnsupportedFieldTypeError + ) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + t('nodes.inputFieldTypeParseError', { + node: type, + field: propertyName, + message: e.message, + }) + ); + } } - inputsAccumulator[propertyName] = field; return inputsAccumulator; }, {} @@ -206,7 +180,7 @@ export const parseSchema = ( (outputsAccumulator, property, propertyName) => { if (!isAllowedOutputField(type, propertyName)) { logger('nodes').trace( - { type, propertyName, property: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Skipped reserved output field' ); return outputsAccumulator; @@ -214,37 +188,62 @@ export const parseSchema = ( if (!isInvocationFieldSchema(property)) { logger('nodes').warn( - { type, propertyName, property: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Unhandled output property' ); return outputsAccumulator; } - const fieldType = property.ui_type ?? getFieldType(property); - - if (!isFieldType(fieldType)) { - logger('nodes').warn( - { fieldName: propertyName, fieldType, field: parseify(property) }, - 'Skipping unknown output field type' - ); - return outputsAccumulator; + try { + const fieldType = parseFieldType(property); + + if (!fieldType) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + 'Missing output field type' + ); + return outputsAccumulator; + } + + const fieldOutputTemplate: FieldOutputTemplate = { + fieldKind: 'output', + name: propertyName, + title: + property.title ?? (propertyName ? startCase(propertyName) : ''), + description: property.description ?? '', + type: fieldType, + ui_hidden: property.ui_hidden ?? false, + ui_type: property.ui_type, + ui_order: property.ui_order, + }; + + outputsAccumulator[propertyName] = fieldOutputTemplate; + } catch (e) { + if ( + e instanceof FieldTypeParseError || + e instanceof UnsupportedFieldTypeError + ) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + t('nodes.outputFieldTypeParseError', { + node: type, + field: propertyName, + message: e.message, + }) + ); + } } - - outputsAccumulator[propertyName] = { - fieldKind: 'output', - name: propertyName, - title: - property.title ?? (propertyName ? startCase(propertyName) : ''), - description: property.description ?? '', - type: fieldType, - ui_hidden: property.ui_hidden ?? false, - ui_type: property.ui_type, - ui_order: property.ui_order, - }; - return outputsAccumulator; }, - {} as Record + {} as Record ); const useCache = schema.properties.use_cache.default; diff --git a/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts index 9e5cea13f60..6d2ee13cf26 100644 --- a/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts @@ -1,123 +1,159 @@ -import { compareVersions } from 'compare-versions'; -import { cloneDeep, keyBy } from 'lodash-es'; -import { - InvocationTemplate, - Workflow, - WorkflowWarning, - isWorkflowInvocationNode, -} from '../types/types'; import { parseify } from 'common/util/serialize'; -import i18n from 'i18next'; +import { t } from 'i18next'; +import { keyBy } from 'lodash-es'; +import { JsonObject } from 'type-fest'; +import { getNeedsUpdate } from '../store/util/nodeUpdate'; +import { InvocationTemplate } from '../types/invocation'; +import { parseAndMigrateWorkflow } from '../types/migration/migrations'; +import { WorkflowV2, isWorkflowInvocationNode } from '../types/workflow'; +type WorkflowWarning = { + message: string; + issues?: string[]; + data: JsonObject; +}; + +type ValidateWorkflowResult = { + workflow: WorkflowV2; + warnings: WorkflowWarning[]; +}; + +/** + * Parses and validates a workflow: + * - Parses the workflow schema, and migrates it to the latest version if necessary. + * - Validates the workflow against the node templates, warning if the template is not known. + * - Attempts to update nodes which have a mismatched version. + * - Removes edges which are invalid. + * @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow)) + * @param invocationTemplates The node templates to validate against. + * @throws {WorkflowVersionError} If the workflow version is not recognized. + * @throws {z.ZodError} If there is a validation error. + */ export const validateWorkflow = ( - workflow: Workflow, - nodeTemplates: Record -) => { - const clone = cloneDeep(workflow); - const { nodes, edges } = clone; - const errors: WorkflowWarning[] = []; + workflow: unknown, + invocationTemplates: Record +): ValidateWorkflowResult => { + // Parse the raw workflow data & migrate it to the latest version + const _workflow = parseAndMigrateWorkflow(workflow); + + // Now we can validate the graph + const { nodes, edges } = _workflow; + const warnings: WorkflowWarning[] = []; + + // We don't need to validate Note nodes or CurrentImage nodes - only Invocation nodes const invocationNodes = nodes.filter(isWorkflowInvocationNode); const keyedNodes = keyBy(invocationNodes, 'id'); - nodes.forEach((node) => { - if (!isWorkflowInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates[node.data.type]; - if (!nodeTemplate) { - errors.push({ - message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t( - 'nodes.skipped' - )}`, - issues: [ - `${i18n.t('nodes.nodeType')}"${node.data.type}" ${i18n.t( - 'nodes.doesNotExist' - )}`, - ], - data: node, + invocationNodes.forEach((node) => { + const template = invocationTemplates[node.data.type]; + if (!template) { + // This node's type template does not exist + const message = t('nodes.missingTemplate', { + node: node.id, + type: node.data.type, + }); + warnings.push({ + message, + data: parseify(node), }); return; } - if ( - nodeTemplate.version && - node.data.version && - compareVersions(nodeTemplate.version, node.data.version) !== 0 - ) { - errors.push({ - message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t( - 'nodes.mismatchedVersion' - )}`, - issues: [ - `${i18n.t('nodes.node')} "${node.data.type}" v${ - node.data.version - } ${i18n.t('nodes.maybeIncompatible')} v${nodeTemplate.version}`, - ], - data: { node, nodeTemplate: parseify(nodeTemplate) }, + if (getNeedsUpdate(node, template)) { + // This node needs to be updated, based on comparison of its version to the template version + const message = t('nodes.mismatchedVersion', { + node: node.id, + type: node.data.type, + }); + warnings.push({ + message, + data: parseify({ node, nodeTemplate: template }), }); return; } }); edges.forEach((edge, i) => { + // Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow. const sourceNode = keyedNodes[edge.source]; const targetNode = keyedNodes[edge.target]; const issues: string[] = []; + if (!sourceNode) { + // The edge's source/output node does not exist issues.push( - `${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t( - 'nodes.doesNotExist' - )}` + t('nodes.sourceNodeDoesNotExist', { + node: edge.source, + }) ); } else if ( edge.type === 'default' && !(edge.sourceHandle in sourceNode.data.outputs) ) { + // The edge's source/output node field does not exist issues.push( - `${i18n.t('nodes.outputNode')} "${edge.source}.${ - edge.sourceHandle - }" ${i18n.t('nodes.doesNotExist')}` + t('nodes.sourceNodeFieldDoesNotExist', { + node: edge.source, + field: edge.sourceHandle, + }) ); } + if (!targetNode) { + // The edge's target/input node does not exist issues.push( - `${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t( - 'nodes.doesNotExist' - )}` + t('nodes.targetNodeDoesNotExist', { + node: edge.target, + }) ); } else if ( edge.type === 'default' && !(edge.targetHandle in targetNode.data.inputs) ) { + // The edge's target/input node field does not exist issues.push( - `${i18n.t('nodes.inputField')} "${edge.target}.${ - edge.targetHandle - }" ${i18n.t('nodes.doesNotExist')}` + t('nodes.targetNodeFieldDoesNotExist', { + node: edge.target, + field: edge.targetHandle, + }) ); } - if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) { + + if (!sourceNode?.data.type || !invocationTemplates[sourceNode.data.type]) { + // The edge's source/output node template does not exist issues.push( - `${i18n.t('nodes.sourceNode')} "${edge.source}" ${i18n.t( - 'nodes.missingTemplate' - )} "${sourceNode?.data.type}"` + t('nodes.missingTemplate', { + node: edge.source, + type: sourceNode?.data.type, + }) ); } - if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) { + if (!targetNode?.data.type || !invocationTemplates[targetNode?.data.type]) { + // The edge's target/input node template does not exist issues.push( - `${i18n.t('nodes.sourceNode')}"${edge.target}" ${i18n.t( - 'nodes.missingTemplate' - )} "${targetNode?.data.type}"` + t('nodes.missingTemplate', { + node: edge.target, + type: targetNode?.data.type, + }) ); } + if (issues.length) { + // This edge has some issues. Remove it. delete edges[i]; - const src = edge.type === 'default' ? edge.sourceHandle : edge.source; - const tgt = edge.type === 'default' ? edge.targetHandle : edge.target; - errors.push({ - message: `Edge "${src} -> ${tgt}" skipped`, + const source = + edge.type === 'default' + ? `${edge.source}.${edge.sourceHandle}` + : edge.source; + const target = + edge.type === 'default' + ? `${edge.source}.${edge.targetHandle}` + : edge.target; + warnings.push({ + message: t('nodes.deletedInvalidEdge', { source, target }), issues, data: edge, }); } }); - return { workflow: clone, errors }; + return { workflow: _workflow, warnings }; }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx index bff8120b7bb..49dd60beb58 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx @@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAISlider from 'common/components/IAISlider'; import { setClipSkip } from 'features/parameters/store/generationSlice'; -import { clipSkipMap } from 'features/parameters/types/constants'; +import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -30,16 +30,16 @@ export default function ParamClipSkip() { const max = useMemo(() => { if (!model) { - return clipSkipMap['sd-1'].maxClip; + return CLIP_SKIP_MAP['sd-1'].maxClip; } - return clipSkipMap[model.base_model].maxClip; + return CLIP_SKIP_MAP[model.base_model].maxClip; }, [model]); const sliderMarks = useMemo(() => { if (!model) { - return clipSkipMap['sd-1'].markers; + return CLIP_SKIP_MAP['sd-1'].markers; } - return clipSkipMap[model.base_model].markers; + return CLIP_SKIP_MAP[model.base_model].markers; }, [model]); if (model?.base_model === 'sdxl') { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx index 1196719af35..1fe4f95c3ba 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx @@ -4,7 +4,7 @@ import IAIInformationalPopover from 'common/components/IAIInformationalPopover/I import { IAISelectDataType } from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { setCanvasCoherenceMode } from 'features/parameters/store/generationSlice'; -import { CanvasCoherenceModeParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterCanvasCoherenceMode } from 'features/parameters/types/parameterSchemas'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -30,7 +30,7 @@ const ParamCanvasCoherenceMode = () => { return; } - dispatch(setCanvasCoherenceMode(v as CanvasCoherenceModeParam)); + dispatch(setCanvasCoherenceMode(v as ParameterCanvasCoherenceMode)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index a44e6fb551b..8f277474973 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -5,10 +5,8 @@ import IAIInformationalPopover from 'common/components/IAIInformationalPopover/I import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setScheduler } from 'features/parameters/store/generationSlice'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { uiSelector } from 'features/ui/store/uiSelectors'; import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; @@ -23,7 +21,7 @@ const selector = createSelector( const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ value: name, label: label, - group: enabledSchedulers.includes(name as SchedulerParam) + group: enabledSchedulers.includes(name as ParameterScheduler) ? 'Favorites' : undefined, })).sort((a, b) => a.label.localeCompare(b.label)); @@ -46,7 +44,7 @@ const ParamScheduler = () => { if (!v) { return; } - dispatch(setScheduler(v as SchedulerParam)); + dispatch(setScheduler(v as ParameterScheduler)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx index 403d2268c10..89c4f51356f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx @@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { setHrfMethod } from 'features/parameters/store/generationSlice'; -import { HrfMethodParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterHRFMethod } from 'features/parameters/types/parameterSchemas'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -26,7 +26,7 @@ const ParamHrfMethodSelect = () => { const { hrfMethod, hrfEnabled } = useAppSelector(selector); const handleChange = useCallback( - (v: HrfMethodParam | null) => { + (v: ParameterHRFMethod | null) => { if (!v) { return; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx index 723e57a2881..ad75fa0b7b7 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx @@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { vaePrecisionChanged } from 'features/parameters/store/generationSlice'; -import { PrecisionParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterPrecision } from 'features/parameters/types/parameterSchemas'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -31,7 +31,7 @@ const ParamVAEModelSelect = () => { return; } - dispatch(vaePrecisionChanged(v as PrecisionParam)); + dispatch(vaePrecisionChanged(v as ParameterPrecision)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/util/useCoreParametersCollapseLabel.ts b/invokeai/frontend/web/src/features/parameters/hooks/useCoreParametersCollapseLabel.ts similarity index 100% rename from invokeai/frontend/web/src/features/parameters/util/useCoreParametersCollapseLabel.ts rename to invokeai/frontend/web/src/features/parameters/hooks/useCoreParametersCollapseLabel.ts diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 5cecd03753e..898fe136189 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -24,7 +24,7 @@ import { IPAdapterMetadataItem, LoRAMetadataItem, T2IAdapterMetadataItem, -} from 'features/nodes/types/types'; +} from 'features/nodes/types/metadata'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -69,28 +69,28 @@ import { vaeSelected, } from '../store/generationSlice'; import { - isValidBoolean, - isValidCfgScale, - isValidControlNetModel, - isValidHeight, - isValidHrfMethod, - isValidIPAdapterModel, - isValidLoRAModel, - isValidMainModel, - isValidNegativePrompt, - isValidPositivePrompt, - isValidSDXLNegativeStylePrompt, - isValidSDXLPositiveStylePrompt, - isValidSDXLRefinerModel, - isValidSDXLRefinerNegativeAestheticScore, - isValidSDXLRefinerPositiveAestheticScore, - isValidSDXLRefinerStart, - isValidScheduler, - isValidSeed, - isValidSteps, - isValidStrength, - isValidVaeModel, - isValidWidth, + isParameterHRFEnabled, + isParameterCFGScale, + isParameterControlNetModel, + isParameterHeight, + isParameterHRFMethod, + isParameterIPAdapterModel, + isParameterLoRAModel, + isParameterModel, + isParameterNegativePrompt, + isParameterPositivePrompt, + isParameterNegativeStylePromptSDXL, + isParameterPositiveStylePromptSDXL, + isParameterSDXLRefinerModel, + isParameterSDXLRefinerNegativeAestheticScore, + isParameterSDXLRefinerPositiveAestheticScore, + isParameterSDXLRefinerStart, + isParameterScheduler, + isParameterSeed, + isParameterSteps, + isParameterStrength, + isParameterVAEModel, + isParameterWidth, } from '../types/parameterSchemas'; const selector = createSelector( @@ -160,24 +160,24 @@ export const useRecallParameters = () => { negativeStylePrompt: unknown ) => { if ( - isValidPositivePrompt(positivePrompt) || - isValidNegativePrompt(negativePrompt) || - isValidSDXLPositiveStylePrompt(positiveStylePrompt) || - isValidSDXLNegativeStylePrompt(negativeStylePrompt) + isParameterPositivePrompt(positivePrompt) || + isParameterNegativePrompt(negativePrompt) || + isParameterPositiveStylePromptSDXL(positiveStylePrompt) || + isParameterNegativeStylePromptSDXL(negativeStylePrompt) ) { - if (isValidPositivePrompt(positivePrompt)) { + if (isParameterPositivePrompt(positivePrompt)) { dispatch(setPositivePrompt(positivePrompt)); } - if (isValidNegativePrompt(negativePrompt)) { + if (isParameterNegativePrompt(negativePrompt)) { dispatch(setNegativePrompt(negativePrompt)); } - if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) { + if (isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); } - if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) { + if (isParameterPositiveStylePromptSDXL(negativeStylePrompt)) { dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); } @@ -194,7 +194,7 @@ export const useRecallParameters = () => { */ const recallPositivePrompt = useCallback( (positivePrompt: unknown) => { - if (!isValidPositivePrompt(positivePrompt)) { + if (!isParameterPositivePrompt(positivePrompt)) { parameterNotSetToast(); return; } @@ -209,7 +209,7 @@ export const useRecallParameters = () => { */ const recallNegativePrompt = useCallback( (negativePrompt: unknown) => { - if (!isValidNegativePrompt(negativePrompt)) { + if (!isParameterNegativePrompt(negativePrompt)) { parameterNotSetToast(); return; } @@ -224,7 +224,7 @@ export const useRecallParameters = () => { */ const recallSDXLPositiveStylePrompt = useCallback( (positiveStylePrompt: unknown) => { - if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) { + if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { parameterNotSetToast(); return; } @@ -239,7 +239,7 @@ export const useRecallParameters = () => { */ const recallSDXLNegativeStylePrompt = useCallback( (negativeStylePrompt: unknown) => { - if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) { + if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) { parameterNotSetToast(); return; } @@ -254,7 +254,7 @@ export const useRecallParameters = () => { */ const recallSeed = useCallback( (seed: unknown) => { - if (!isValidSeed(seed)) { + if (!isParameterSeed(seed)) { parameterNotSetToast(); return; } @@ -269,7 +269,7 @@ export const useRecallParameters = () => { */ const recallCfgScale = useCallback( (cfgScale: unknown) => { - if (!isValidCfgScale(cfgScale)) { + if (!isParameterCFGScale(cfgScale)) { parameterNotSetToast(); return; } @@ -284,7 +284,7 @@ export const useRecallParameters = () => { */ const recallModel = useCallback( (model: unknown) => { - if (!isValidMainModel(model)) { + if (!isParameterModel(model)) { parameterNotSetToast(); return; } @@ -299,7 +299,7 @@ export const useRecallParameters = () => { */ const recallScheduler = useCallback( (scheduler: unknown) => { - if (!isValidScheduler(scheduler)) { + if (!isParameterScheduler(scheduler)) { parameterNotSetToast(); return; } @@ -314,7 +314,7 @@ export const useRecallParameters = () => { */ const recallVaeModel = useCallback( (vae: unknown) => { - if (!isValidVaeModel(vae) && !isNil(vae)) { + if (!isParameterVAEModel(vae) && !isNil(vae)) { parameterNotSetToast(); return; } @@ -333,7 +333,7 @@ export const useRecallParameters = () => { */ const recallSteps = useCallback( (steps: unknown) => { - if (!isValidSteps(steps)) { + if (!isParameterSteps(steps)) { parameterNotSetToast(); return; } @@ -348,7 +348,7 @@ export const useRecallParameters = () => { */ const recallWidth = useCallback( (width: unknown) => { - if (!isValidWidth(width)) { + if (!isParameterWidth(width)) { parameterNotSetToast(); return; } @@ -363,7 +363,7 @@ export const useRecallParameters = () => { */ const recallHeight = useCallback( (height: unknown) => { - if (!isValidHeight(height)) { + if (!isParameterHeight(height)) { parameterNotSetToast(); return; } @@ -378,11 +378,11 @@ export const useRecallParameters = () => { */ const recallWidthAndHeight = useCallback( (width: unknown, height: unknown) => { - if (!isValidWidth(width)) { + if (!isParameterWidth(width)) { allParameterNotSetToast(); return; } - if (!isValidHeight(height)) { + if (!isParameterHeight(height)) { allParameterNotSetToast(); return; } @@ -398,7 +398,7 @@ export const useRecallParameters = () => { */ const recallStrength = useCallback( (strength: unknown) => { - if (!isValidStrength(strength)) { + if (!isParameterStrength(strength)) { parameterNotSetToast(); return; } @@ -413,7 +413,7 @@ export const useRecallParameters = () => { */ const recallHrfEnabled = useCallback( (hrfEnabled: unknown) => { - if (!isValidBoolean(hrfEnabled)) { + if (!isParameterHRFEnabled(hrfEnabled)) { parameterNotSetToast(); return; } @@ -428,7 +428,7 @@ export const useRecallParameters = () => { */ const recallHrfStrength = useCallback( (hrfStrength: unknown) => { - if (!isValidStrength(hrfStrength)) { + if (!isParameterStrength(hrfStrength)) { parameterNotSetToast(); return; } @@ -443,7 +443,7 @@ export const useRecallParameters = () => { */ const recallHrfMethod = useCallback( (hrfMethod: unknown) => { - if (!isValidHrfMethod(hrfMethod)) { + if (!isParameterHRFMethod(hrfMethod)) { parameterNotSetToast(); return; } @@ -461,7 +461,7 @@ export const useRecallParameters = () => { const prepareLoRAMetadataItem = useCallback( (loraMetadataItem: LoRAMetadataItem) => { - if (!isValidLoRAModel(loraMetadataItem.lora)) { + if (!isParameterLoRAModel(loraMetadataItem.lora)) { return { lora: null, error: 'Invalid LoRA model' }; } @@ -518,7 +518,7 @@ export const useRecallParameters = () => { const prepareControlNetMetadataItem = useCallback( (controlnetMetadataItem: ControlNetMetadataItem) => { - if (!isValidControlNetModel(controlnetMetadataItem.control_model)) { + if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -613,7 +613,9 @@ export const useRecallParameters = () => { const prepareT2IAdapterMetadataItem = useCallback( (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { - if (!isValidControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) { + if ( + !isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model) + ) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -703,7 +705,7 @@ export const useRecallParameters = () => { const prepareIPAdapterMetadataItem = useCallback( (ipAdapterMetadataItem: IPAdapterMetadataItem) => { - if (!isValidIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { + if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { return { ipAdapter: null, error: 'Invalid IP Adapter model' }; } @@ -822,26 +824,26 @@ export const useRecallParameters = () => { t2iAdapters, } = metadata; - if (isValidCfgScale(cfg_scale)) { + if (isParameterCFGScale(cfg_scale)) { dispatch(setCfgScale(cfg_scale)); } - if (isValidMainModel(model)) { + if (isParameterModel(model)) { dispatch(modelSelected(model)); } - if (isValidPositivePrompt(positive_prompt)) { + if (isParameterPositivePrompt(positive_prompt)) { dispatch(setPositivePrompt(positive_prompt)); } - if (isValidNegativePrompt(negative_prompt)) { + if (isParameterNegativePrompt(negative_prompt)) { dispatch(setNegativePrompt(negative_prompt)); } - if (isValidScheduler(scheduler)) { + if (isParameterScheduler(scheduler)) { dispatch(setScheduler(scheduler)); } - if (isValidVaeModel(vae) || isNil(vae)) { + if (isParameterVAEModel(vae) || isNil(vae)) { if (isNil(vae)) { dispatch(vaeSelected(null)); } else { @@ -849,64 +851,64 @@ export const useRecallParameters = () => { } } - if (isValidSeed(seed)) { + if (isParameterSeed(seed)) { dispatch(setSeed(seed)); } - if (isValidSteps(steps)) { + if (isParameterSteps(steps)) { dispatch(setSteps(steps)); } - if (isValidWidth(width)) { + if (isParameterWidth(width)) { dispatch(setWidth(width)); } - if (isValidHeight(height)) { + if (isParameterHeight(height)) { dispatch(setHeight(height)); } - if (isValidStrength(strength)) { + if (isParameterStrength(strength)) { dispatch(setImg2imgStrength(strength)); } - if (isValidBoolean(hrf_enabled)) { + if (isParameterHRFEnabled(hrf_enabled)) { dispatch(setHrfEnabled(hrf_enabled)); } - if (isValidStrength(hrf_strength)) { + if (isParameterStrength(hrf_strength)) { dispatch(setHrfStrength(hrf_strength)); } - if (isValidHrfMethod(hrf_method)) { + if (isParameterHRFMethod(hrf_method)) { dispatch(setHrfMethod(hrf_method)); } - if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) { + if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) { dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); } - if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) { + if (isParameterNegativeStylePromptSDXL(negative_style_prompt)) { dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); } - if (isValidSDXLRefinerModel(refiner_model)) { + if (isParameterSDXLRefinerModel(refiner_model)) { dispatch(refinerModelChanged(refiner_model)); } - if (isValidSteps(refiner_steps)) { + if (isParameterSteps(refiner_steps)) { dispatch(setRefinerSteps(refiner_steps)); } - if (isValidCfgScale(refiner_cfg_scale)) { + if (isParameterCFGScale(refiner_cfg_scale)) { dispatch(setRefinerCFGScale(refiner_cfg_scale)); } - if (isValidScheduler(refiner_scheduler)) { + if (isParameterScheduler(refiner_scheduler)) { dispatch(setRefinerScheduler(refiner_scheduler)); } if ( - isValidSDXLRefinerPositiveAestheticScore( + isParameterSDXLRefinerPositiveAestheticScore( refiner_positive_aesthetic_score ) ) { @@ -916,7 +918,7 @@ export const useRecallParameters = () => { } if ( - isValidSDXLRefinerNegativeAestheticScore( + isParameterSDXLRefinerNegativeAestheticScore( refiner_negative_aesthetic_score ) ) { @@ -925,7 +927,7 @@ export const useRecallParameters = () => { ); } - if (isValidSDXLRefinerStart(refiner_start)) { + if (isParameterSDXLRefinerStart(refiner_start)) { dispatch(setRefinerStart(refiner_start)); } diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 8fbdfafbde1..e23747c921f 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -6,63 +6,62 @@ import { clamp } from 'lodash-es'; import { ImageDTO } from 'services/api/types'; import { isAnyControlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { clipSkipMap } from '../types/constants'; +import { CLIP_SKIP_MAP } from '../types/constants'; import { - CanvasCoherenceModeParam, - CfgScaleParam, - HeightParam, - HrfMethodParam, - MainModelParam, - MaskBlurMethodParam, - NegativePromptParam, - OnnxModelParam, - PositivePromptParam, - PrecisionParam, - SchedulerParam, - SeedParam, - StepsParam, - StrengthParam, - VaeModelParam, - WidthParam, - zMainModel, + ParameterCanvasCoherenceMode, + ParameterCFGScale, + ParameterHeight, + ParameterHRFMethod, + ParameterModel, + ParameterMaskBlurMethod, + ParameterNegativePrompt, + ParameterPositivePrompt, + ParameterPrecision, + ParameterScheduler, + ParameterSeed, + ParameterSteps, + ParameterStrength, + ParameterVAEModel, + ParameterWidth, + zParameterModel, } from '../types/parameterSchemas'; export interface GenerationState { hrfEnabled: boolean; - hrfStrength: StrengthParam; - hrfMethod: HrfMethodParam; - cfgScale: CfgScaleParam; - height: HeightParam; - img2imgStrength: StrengthParam; + hrfStrength: ParameterStrength; + hrfMethod: ParameterHRFMethod; + cfgScale: ParameterCFGScale; + height: ParameterHeight; + img2imgStrength: ParameterStrength; infillMethod: string; initialImage?: { imageName: string; width: number; height: number }; iterations: number; perlin: number; - positivePrompt: PositivePromptParam; - negativePrompt: NegativePromptParam; - scheduler: SchedulerParam; + positivePrompt: ParameterPositivePrompt; + negativePrompt: ParameterNegativePrompt; + scheduler: ParameterScheduler; maskBlur: number; - maskBlurMethod: MaskBlurMethodParam; - canvasCoherenceMode: CanvasCoherenceModeParam; + maskBlurMethod: ParameterMaskBlurMethod; + canvasCoherenceMode: ParameterCanvasCoherenceMode; canvasCoherenceSteps: number; - canvasCoherenceStrength: StrengthParam; - seed: SeedParam; + canvasCoherenceStrength: ParameterStrength; + seed: ParameterSeed; seedWeights: string; shouldFitToWidthHeight: boolean; shouldGenerateVariations: boolean; shouldRandomizeSeed: boolean; - steps: StepsParam; + steps: ParameterSteps; threshold: number; infillTileSize: number; infillPatchmatchDownscaleSize: number; variationAmount: number; - width: WidthParam; + width: ParameterWidth; shouldUseSymmetry: boolean; horizontalSymmetrySteps: number; verticalSymmetrySteps: number; - model: MainModelParam | OnnxModelParam | null; - vae: VaeModelParam | null; - vaePrecision: PrecisionParam; + model: ParameterModel | null; + vae: ParameterVAEModel | null; + vaePrecision: ParameterPrecision; seamlessXAxis: boolean; seamlessYAxis: boolean; clipSkip: number; @@ -166,7 +165,7 @@ export const generationSlice = createSlice({ state.width = height; state.height = width; }, - setScheduler: (state, action: PayloadAction) => { + setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, setSeed: (state, action: PayloadAction) => { @@ -214,12 +213,15 @@ export const generationSlice = createSlice({ setMaskBlur: (state, action: PayloadAction) => { state.maskBlur = action.payload; }, - setMaskBlurMethod: (state, action: PayloadAction) => { + setMaskBlurMethod: ( + state, + action: PayloadAction + ) => { state.maskBlurMethod = action.payload; }, setCanvasCoherenceMode: ( state, - action: PayloadAction + action: PayloadAction ) => { state.canvasCoherenceMode = action.payload; }, @@ -254,10 +256,7 @@ export const generationSlice = createSlice({ const { image_name, width, height } = action.payload; state.initialImage = { imageName: image_name, width, height }; }, - modelChanged: ( - state, - action: PayloadAction - ) => { + modelChanged: (state, action: PayloadAction) => { state.model = action.payload; if (state.model === null) { @@ -265,14 +264,14 @@ export const generationSlice = createSlice({ } // Clamp ClipSkip Based On Selected Model - const { maxClip } = clipSkipMap[state.model.base_model]; + const { maxClip } = CLIP_SKIP_MAP[state.model.base_model]; state.clipSkip = clamp(state.clipSkip, 0, maxClip); }, - vaeSelected: (state, action: PayloadAction) => { + vaeSelected: (state, action: PayloadAction) => { // null is a valid VAE! state.vae = action.payload; }, - vaePrecisionChanged: (state, action: PayloadAction) => { + vaePrecisionChanged: (state, action: PayloadAction) => { state.vaePrecision = action.payload; }, setClipSkip: (state, action: PayloadAction) => { @@ -284,7 +283,7 @@ export const generationSlice = createSlice({ setHrfEnabled: (state, action: PayloadAction) => { state.hrfEnabled = action.payload; }, - setHrfMethod: (state, action: PayloadAction) => { + setHrfMethod: (state, action: PayloadAction) => { state.hrfMethod = action.payload; }, shouldUseCpuNoiseChanged: (state, action: PayloadAction) => { @@ -308,7 +307,7 @@ export const generationSlice = createSlice({ if (defaultModel && !state.model) { const [base_model, model_type, model_name] = defaultModel.split('/'); - const result = zMainModel.safeParse({ + const result = zParameterModel.safeParse({ model_name, base_model, model_type, diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index 4494d235afb..2d9fa62a794 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -1,5 +1,9 @@ -import { components } from 'services/api/schema'; +import { SchedulerField } from 'features/nodes/types/common'; +import { LoRAModelFormat } from 'services/api/types'; +/** + * Mapping of model type to human readable name + */ export const MODEL_TYPE_MAP = { any: 'Any', 'sd-1': 'Stable Diffusion 1.x', @@ -8,6 +12,9 @@ export const MODEL_TYPE_MAP = { 'sdxl-refiner': 'Stable Diffusion XL Refiner', }; +/** + * Mapping of model type to (short) human readable name + */ export const MODEL_TYPE_SHORT_MAP = { any: 'Any', 'sd-1': 'SD1', @@ -16,7 +23,10 @@ export const MODEL_TYPE_SHORT_MAP = { 'sdxl-refiner': 'SDXLR', }; -export const clipSkipMap = { +/** + * Mapping of model type to CLIP skip parameter constraints + */ +export const CLIP_SKIP_MAP = { any: { maxClip: 0, markers: [], @@ -39,11 +49,41 @@ export const clipSkipMap = { }, }; -type LoRAModelFormatMap = { - [key in components['schemas']['LoRAModelFormat']]: string; -}; - -export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = { +/** + * Mapping of LoRA format to human readable name + */ +export const LORA_MODEL_FORMAT_MAP: { + [key in LoRAModelFormat]: string; +} = { lycoris: 'LyCORIS', diffusers: 'Diffusers', }; + +/** + * Mapping of schedulers to human readable name + */ +export const SCHEDULER_LABEL_MAP: Record = { + euler: 'Euler', + deis: 'DEIS', + ddim: 'DDIM', + ddpm: 'DDPM', + dpmpp_sde: 'DPM++ SDE', + dpmpp_2s: 'DPM++ 2S', + dpmpp_2m: 'DPM++ 2M', + dpmpp_2m_sde: 'DPM++ 2M SDE', + heun: 'Heun', + kdpm_2: 'KDPM 2', + lms: 'LMS', + pndm: 'PNDM', + unipc: 'UniPC', + euler_k: 'Euler Karras', + dpmpp_sde_k: 'DPM++ SDE Karras', + dpmpp_2s_k: 'DPM++ 2S Karras', + dpmpp_2m_k: 'DPM++ 2M Karras', + dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras', + heun_k: 'Heun Karras', + lms_k: 'LMS Karras', + euler_a: 'Euler Ancestral', + kdpm_2_a: 'KDPM 2 Ancestral', + lcm: 'LCM', +}; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index ec3f9baba1c..a96e8af002f 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,522 +1,269 @@ import { NUMPY_RAND_MAX } from 'app/constants'; +import { + zControlNetModelField, + zIPAdapterModelField, + zLoRAModelField, + zMainOrONNXModelField, + zSDXLRefinerModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from 'features/nodes/types/common'; import { z } from 'zod'; /** - * These zod schemas should match the pydantic node schemas. + * Schemas, types and type guards for parameters. * - * Parameters only need schemas if we want to recall them from metadata. + * Parameters need schemas if we want to recall them from metadata or some untrusted source. * * Each parameter needs: * - a zod schema * - a type alias, inferred from the zod schema - * - a combo validation/type guard function, which returns true if the value is valid + * - a combo validation/type guard function, which returns true if the value is valid, should + * simply be the zod schema's safeParse function */ -/** - * Zod schema for positive prompt parameter - */ -export const zPositivePrompt = z.string(); -/** - * Type alias for positive prompt parameter, inferred from its zod schema - */ -export type PositivePromptParam = z.infer; -/** - * Validates/type-guards a value as a positive prompt parameter - */ -export const isValidPositivePrompt = ( +// #region Positive prompt +export const zParameterPositivePrompt = z.string(); +export type ParameterPositivePrompt = z.infer; +export const isParameterPositivePrompt = ( val: unknown -): val is PositivePromptParam => zPositivePrompt.safeParse(val).success; +): val is ParameterPositivePrompt => + zParameterPositivePrompt.safeParse(val).success; +// #endregion -/** - * Zod schema for negative prompt parameter - */ -export const zNegativePrompt = z.string(); -/** - * Type alias for negative prompt parameter, inferred from its zod schema - */ -export type NegativePromptParam = z.infer; -/** - * Validates/type-guards a value as a negative prompt parameter - */ -export const isValidNegativePrompt = ( +// #region Negative prompt +export const zParameterNegativePrompt = z.string(); +export type ParameterNegativePrompt = z.infer; +export const isParameterNegativePrompt = ( val: unknown -): val is NegativePromptParam => zNegativePrompt.safeParse(val).success; +): val is ParameterNegativePrompt => + zParameterNegativePrompt.safeParse(val).success; +// #endregion -/** - * Zod schema for SDXL positive style prompt parameter - */ -export const zPositiveStylePromptSDXL = z.string(); -/** - * Type alias for SDXL positive style prompt parameter, inferred from its zod schema - */ -export type PositiveStylePromptSDXLParam = z.infer< - typeof zPositiveStylePromptSDXL +// #region Positive style prompt (SDXL) +export const zParameterPositiveStylePromptSDXL = z.string(); +export type ParameterPositiveStylePromptSDXL = z.infer< + typeof zParameterPositiveStylePromptSDXL >; -/** - * Validates/type-guards a value as a SDXL positive style prompt parameter - */ -export const isValidSDXLPositiveStylePrompt = ( +export const isParameterPositiveStylePromptSDXL = ( val: unknown -): val is PositiveStylePromptSDXLParam => - zPositiveStylePromptSDXL.safeParse(val).success; +): val is ParameterPositiveStylePromptSDXL => + zParameterPositiveStylePromptSDXL.safeParse(val).success; +// #endregion -/** - * Zod schema for SDXL negative style prompt parameter - */ -export const zNegativeStylePromptSDXL = z.string(); -/** - * Type alias for SDXL negative style prompt parameter, inferred from its zod schema - */ -export type NegativeStylePromptSDXLParam = z.infer< - typeof zNegativeStylePromptSDXL +// #region Positive style prompt (SDXL) +export const zParameterNegativeStylePromptSDXL = z.string(); +export type ParameterNegativeStylePromptSDXL = z.infer< + typeof zParameterNegativeStylePromptSDXL >; -/** - * Validates/type-guards a value as a SDXL negative style prompt parameter - */ -export const isValidSDXLNegativeStylePrompt = ( +export const isParameterNegativeStylePromptSDXL = ( val: unknown -): val is NegativeStylePromptSDXLParam => - zNegativeStylePromptSDXL.safeParse(val).success; +): val is ParameterNegativeStylePromptSDXL => + zParameterNegativeStylePromptSDXL.safeParse(val).success; +// #endregion -/** - * Zod schema for steps parameter - */ -export const zSteps = z.number().int().min(1); -/** - * Type alias for steps parameter, inferred from its zod schema - */ -export type StepsParam = z.infer; -/** - * Validates/type-guards a value as a steps parameter - */ -export const isValidSteps = (val: unknown): val is StepsParam => - zSteps.safeParse(val).success; +// #region Steps +export const zParameterSteps = z.number().int().min(1); +export type ParameterSteps = z.infer; +export const isParameterSteps = (val: unknown): val is ParameterSteps => + zParameterSteps.safeParse(val).success; +// #endregion -/** - * Zod schema for CFG scale parameter - */ -export const zCfgScale = z.number().min(1); -/** - * Type alias for CFG scale parameter, inferred from its zod schema - */ -export type CfgScaleParam = z.infer; -/** - * Validates/type-guards a value as a CFG scale parameter - */ -export const isValidCfgScale = (val: unknown): val is CfgScaleParam => - zCfgScale.safeParse(val).success; +// #region CFG scale parameter +export const zParameterCFGScale = z.number().min(1); +export type ParameterCFGScale = z.infer; +export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale => + zParameterCFGScale.safeParse(val).success; +// #endregion -/** - * Zod schema for scheduler parameter - */ -export const zScheduler = z.enum([ - 'euler', - 'deis', - 'ddim', - 'ddpm', - 'dpmpp_2s', - 'dpmpp_2m', - 'dpmpp_2m_sde', - 'dpmpp_sde', - 'heun', - 'kdpm_2', - 'lms', - 'pndm', - 'unipc', - 'euler_k', - 'dpmpp_2s_k', - 'dpmpp_2m_k', - 'dpmpp_2m_sde_k', - 'dpmpp_sde_k', - 'heun_k', - 'lms_k', - 'euler_a', - 'kdpm_2_a', - 'lcm', -]); -/** - * Type alias for scheduler parameter, inferred from its zod schema - */ -export type SchedulerParam = z.infer; -/** - * Validates/type-guards a value as a scheduler parameter - */ -export const isValidScheduler = (val: unknown): val is SchedulerParam => - zScheduler.safeParse(val).success; +// #region Scheduler +export const zParameterScheduler = zSchedulerField; +export type ParameterScheduler = z.infer; +export const isParameterScheduler = (val: unknown): val is ParameterScheduler => + zParameterScheduler.safeParse(val).success; +// #endregion -export const SCHEDULER_LABEL_MAP: Record = { - euler: 'Euler', - deis: 'DEIS', - ddim: 'DDIM', - ddpm: 'DDPM', - dpmpp_sde: 'DPM++ SDE', - dpmpp_2s: 'DPM++ 2S', - dpmpp_2m: 'DPM++ 2M', - dpmpp_2m_sde: 'DPM++ 2M SDE', - heun: 'Heun', - kdpm_2: 'KDPM 2', - lms: 'LMS', - pndm: 'PNDM', - unipc: 'UniPC', - euler_k: 'Euler Karras', - dpmpp_sde_k: 'DPM++ SDE Karras', - dpmpp_2s_k: 'DPM++ 2S Karras', - dpmpp_2m_k: 'DPM++ 2M Karras', - dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras', - heun_k: 'Heun Karras', - lms_k: 'LMS Karras', - euler_a: 'Euler Ancestral', - kdpm_2_a: 'KDPM 2 Ancestral', - lcm: 'LCM', -}; +// #region seed +export const zParameterSeed = z.number().int().min(0).max(NUMPY_RAND_MAX); +export type ParameterSeed = z.infer; +export const isParameterSeed = (val: unknown): val is ParameterSeed => + zParameterSeed.safeParse(val).success; +// #endregion -/** - * Zod schema for seed parameter - */ -export const zSeed = z.number().int().min(0).max(NUMPY_RAND_MAX); -/** - * Type alias for seed parameter, inferred from its zod schema - */ -export type SeedParam = z.infer; -/** - * Validates/type-guards a value as a seed parameter - */ -export const isValidSeed = (val: unknown): val is SeedParam => - zSeed.safeParse(val).success; - -/** - * Zod schema for width parameter - */ -export const zWidth = z.number().multipleOf(8).min(64); -/** - * Type alias for width parameter, inferred from its zod schema - */ -export type WidthParam = z.infer; -/** - * Validates/type-guards a value as a width parameter - */ -export const isValidWidth = (val: unknown): val is WidthParam => - zWidth.safeParse(val).success; +// #region Width +export const zParameterWidth = z.number().multipleOf(8).min(64); +export type ParameterWidth = z.infer; +export const isParameterWidth = (val: unknown): val is ParameterWidth => + zParameterWidth.safeParse(val).success; +// #endregion -/** - * Zod schema for height parameter - */ -export const zHeight = z.number().multipleOf(8).min(64); -/** - * Type alias for height parameter, inferred from its zod schema - */ -export type HeightParam = z.infer; -/** - * Validates/type-guards a value as a height parameter - */ -export const isValidHeight = (val: unknown): val is HeightParam => - zHeight.safeParse(val).success; - -/** - * Zod schema for resolution parameter - */ -export const zResolution = z.tuple([zWidth, zHeight]); -/** - * Type alias for resolution parameter, inferred from its zod schema - */ -export type ResolutionParam = z.infer; +// #region Height +export const zParameterHeight = zParameterWidth; +export type ParameterHeight = z.infer; +export const isParameterHeight = (val: unknown): val is ParameterHeight => + zParameterHeight.safeParse(val).success; +// #endregion -export const zBaseModel = z.enum([ - 'any', - 'sd-1', - 'sd-2', - 'sdxl', - 'sdxl-refiner', +// #region Resolution +export const zParameterResolution = z.tuple([ + zParameterWidth, + zParameterHeight, ]); +export type ParameterResolution = z.infer; +export const iParameterResolution = ( + val: unknown +): val is ParameterResolution => zParameterResolution.safeParse(val).success; +// #endregion -export type BaseModelParam = z.infer; +// #region Model +export const zParameterModel = zMainOrONNXModelField; +export type ParameterModel = z.infer; +export const isParameterModel = (val: unknown): val is ParameterModel => + zParameterModel.safeParse(val).success; +// #endregion -/** - * Zod schema for main model parameter - * TODO: Make this a dynamically generated enum? - */ -export const zMainModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, - model_type: z.literal('main'), -}); -/** - * Type alias for main model parameter, inferred from its zod schema - */ -export type MainModelParam = z.infer; -/** - * Validates/type-guards a value as a main model parameter - */ -export const isValidMainModel = (val: unknown): val is MainModelParam => - zMainModel.safeParse(val).success; - -/** - * Zod schema for SDXL refiner model parameter - * TODO: Make this a dynamically generated enum? - */ -export const zSDXLRefinerModel = z.object({ - model_name: z.string().min(1), - base_model: z.literal('sdxl-refiner'), - model_type: z.literal('main'), -}); -/** - * Type alias for SDXL refiner model parameter, inferred from its zod schema - */ -export type SDXLRefinerModelParam = z.infer; -/** - * Validates/type-guards a value as a SDXL refiner model parameter - */ -export const isValidSDXLRefinerModel = ( +// #region SDXL Refiner Model +export const zParameterSDXLRefinerModel = zSDXLRefinerModelField; +export type ParameterSDXLRefinerModel = z.infer< + typeof zParameterSDXLRefinerModel +>; +export const isParameterSDXLRefinerModel = ( val: unknown -): val is SDXLRefinerModelParam => zSDXLRefinerModel.safeParse(val).success; +): val is ParameterSDXLRefinerModel => + zParameterSDXLRefinerModel.safeParse(val).success; +// #endregion -/** - * Zod schema for Onnx model parameter - * TODO: Make this a dynamically generated enum? - */ -export const zOnnxModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, - model_type: z.literal('onnx'), -}); -/** - * Type alias for Onnx model parameter, inferred from its zod schema - */ -export type OnnxModelParam = z.infer; -/** - * Validates/type-guards a value as a Onnx model parameter - */ -export const isValidOnnxModel = (val: unknown): val is OnnxModelParam => - zOnnxModel.safeParse(val).success; +// #region VAE Model +export const zParameterVAEModel = zVAEModelField; +export type ParameterVAEModel = z.infer; +export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => + zParameterVAEModel.safeParse(val).success; +// #endregion -export const zMainOrOnnxModel = z.union([zMainModel, zOnnxModel]); +// #region LoRA Model +export const zParameterLoRAModel = zLoRAModelField; +export type ParameterLoRAModel = z.infer; +export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => + zParameterLoRAModel.safeParse(val).success; +// #endregion -/** - * Zod schema for VAE parameter - */ -export const zVaeModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type VaeModelParam = z.infer; -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidVaeModel = (val: unknown): val is VaeModelParam => - zVaeModel.safeParse(val).success; -/** - * Zod schema for LoRA - */ -export const zLoRAModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type LoRAModelParam = z.infer; -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidLoRAModel = (val: unknown): val is LoRAModelParam => - zLoRAModel.safeParse(val).success; -/** - * Zod schema for ControlNet models - */ -export const zControlNetModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type ControlNetModelParam = z.infer; -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidControlNetModel = ( +// #region ControlNet Model +export const zParameterControlNetModel = zControlNetModelField; +export type ParameterControlNetModel = z.infer; +export const isParameterControlNetModel = ( val: unknown -): val is ControlNetModelParam => zControlNetModel.safeParse(val).success; -/** - * Zod schema for IP-Adapter models - */ -export const zIPAdapterModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type IPAdapterModelParam = z.infer; -/** - * Zod schema for T2I-Adapter models - */ -export const zT2IAdapterModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -export const isValidT2IAdapterModel = ( +): val is ParameterControlNetModel => + zParameterControlNetModel.safeParse(val).success; +// #endregion + +// #region IP Adapter Model +export const zParameterIPAdapterModel = zIPAdapterModelField; +export type ParameterIPAdapterModel = z.infer; +export const isParameterIPAdapterModel = ( val: unknown -): val is T2IAdapterModelParam => zT2IAdapterModel.safeParse(val).success; +): val is ParameterIPAdapterModel => + zParameterIPAdapterModel.safeParse(val).success; +// #endregion -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type T2IAdapterModelParam = z.infer; -/** - * Zod schema for l2l strength parameter - */ -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidIPAdapterModel = ( +// #region T2I Adapter Model +export const zParameterT2IAdapterModel = zT2IAdapterModelField; +export type ParameterT2IAdapterModel = z.infer< + typeof zParameterT2IAdapterModel +>; +export const isParameterT2IAdapterModel = ( val: unknown -): val is IPAdapterModelParam => zIPAdapterModel.safeParse(val).success; -export const zStrength = z.number().min(0).max(1); -/** - * Type alias for l2l strength parameter, inferred from its zod schema - */ -export type StrengthParam = z.infer; -/** - * Validates/type-guards a value as a l2l strength parameter - */ -export const isValidStrength = (val: unknown): val is StrengthParam => - zStrength.safeParse(val).success; +): val is ParameterT2IAdapterModel => + zParameterT2IAdapterModel.safeParse(val).success; +// #endregion -/** - * Zod schema for a precision parameter - */ -export const zPrecision = z.enum(['fp16', 'fp32']); -/** - * Type alias for precision parameter, inferred from its zod schema - */ -export type PrecisionParam = z.infer; -/** - * Validates/type-guards a value as a precision parameter - */ -export const isValidPrecision = (val: unknown): val is PrecisionParam => - zPrecision.safeParse(val).success; +// #region Strength (l2l strength) +export const zParameterStrength = z.number().min(0).max(1); +export type ParameterStrength = z.infer; +export const isParameterStrength = (val: unknown): val is ParameterStrength => + zParameterStrength.safeParse(val).success; +// #endregion -/** - * Zod schema for a high resolution fix method parameter. - */ -export const zHrfMethod = z.enum(['ESRGAN', 'bilinear']); -/** - * Type alias for high resolution fix method parameter, inferred from its zod schema - */ -export type HrfMethodParam = z.infer; -/** - * Validates/type-guards a value as a high resolution fix method parameter - */ -export const isValidHrfMethod = (val: unknown): val is HrfMethodParam => - zHrfMethod.safeParse(val).success; +// #region Precision +export const zParameterPrecision = z.enum(['fp16', 'fp32']); +export type ParameterPrecision = z.infer; +export const isParameterPrecision = (val: unknown): val is ParameterPrecision => + zParameterPrecision.safeParse(val).success; +// #endregion -/** - * Zod schema for SDXL refiner positive aesthetic score parameter - */ -export const zSDXLRefinerPositiveAestheticScore = z.number().min(1).max(10); -/** - * Type alias for SDXL refiner aesthetic positive score parameter, inferred from its zod schema - */ -export type SDXLRefinerPositiveAestheticScoreParam = z.infer< - typeof zSDXLRefinerPositiveAestheticScore +// #region HRF Method +export const zParameterHRFMethod = z.enum(['ESRGAN', 'bilinear']); +export type ParameterHRFMethod = z.infer; +export const isParameterHRFMethod = (val: unknown): val is ParameterHRFMethod => + zParameterHRFMethod.safeParse(val).success; +// #endregion + +// #region HRF Enabled +export const zParameterHRFEnabled = z.boolean(); +export type ParameterHRFEnabled = z.infer; +export const isParameterHRFEnabled = (val: unknown): val is boolean => + zParameterHRFEnabled.safeParse(val).success && + val !== null && + val !== undefined; +// #endregion + +// #region SDXL Refiner Positive Aesthetic Score +export const zParameterSDXLRefinerPositiveAestheticScore = z + .number() + .min(1) + .max(10); +export type ParameterSDXLRefinerPositiveAestheticScore = z.infer< + typeof zParameterSDXLRefinerPositiveAestheticScore >; -/** - * Validates/type-guards a value as a SDXL refiner positive aesthetic score parameter - */ -export const isValidSDXLRefinerPositiveAestheticScore = ( +export const isParameterSDXLRefinerPositiveAestheticScore = ( val: unknown -): val is SDXLRefinerPositiveAestheticScoreParam => - zSDXLRefinerPositiveAestheticScore.safeParse(val).success; +): val is ParameterSDXLRefinerPositiveAestheticScore => + zParameterSDXLRefinerPositiveAestheticScore.safeParse(val).success; +// #endregion -/** - * Zod schema for SDXL refiner negative aesthetic score parameter - */ -export const zSDXLRefinerNegativeAestheticScore = z.number().min(1).max(10); -/** - * Type alias for SDXL refiner aesthetic negative score parameter, inferred from its zod schema - */ -export type SDXLRefinerNegativeAestheticScoreParam = z.infer< - typeof zSDXLRefinerNegativeAestheticScore +// #region SDXL Refiner Negative Aesthetic Score +export const zParameterSDXLRefinerNegativeAestheticScore = + zParameterSDXLRefinerPositiveAestheticScore; +export type ParameterSDXLRefinerNegativeAestheticScore = z.infer< + typeof zParameterSDXLRefinerNegativeAestheticScore >; -/** - * Validates/type-guards a value as a SDXL refiner negative aesthetic score parameter - */ -export const isValidSDXLRefinerNegativeAestheticScore = ( +export const isParameterSDXLRefinerNegativeAestheticScore = ( val: unknown -): val is SDXLRefinerNegativeAestheticScoreParam => - zSDXLRefinerNegativeAestheticScore.safeParse(val).success; +): val is ParameterSDXLRefinerNegativeAestheticScore => + zParameterSDXLRefinerNegativeAestheticScore.safeParse(val).success; +// #endregion -/** - * Zod schema for SDXL start parameter - */ -export const zSDXLRefinerstart = z.number().min(0).max(1); -/** - * Type alias for SDXL start, inferred from its zod schema - */ -export type SDXLRefinerStartParam = z.infer; -/** - * Validates/type-guards a value as a SDXL refiner aesthetic score parameter - */ -export const isValidSDXLRefinerStart = ( +// #region SDXL Refiner Start +export const zParameterSDXLRefinerStart = z.number().min(0).max(1); +export type ParameterSDXLRefinerStart = z.infer< + typeof zParameterSDXLRefinerStart +>; +export const isParameterSDXLRefinerStart = ( val: unknown -): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success; +): val is ParameterSDXLRefinerStart => + zParameterSDXLRefinerStart.safeParse(val).success; +// #endregion -/** - * Zod schema for a mask blur method parameter - */ -export const zMaskBlurMethod = z.enum(['box', 'gaussian']); -/** - * Type alias for mask blur method parameter, inferred from its zod schema - */ -export type MaskBlurMethodParam = z.infer; -/** - * Validates/type-guards a value as a mask blur method parameter - */ -export const isValidMaskBlurMethod = ( +// #region Mask Blur Method +export const zParameterMaskBlurMethod = z.enum(['box', 'gaussian']); +export type ParameterMaskBlurMethod = z.infer; +export const isParameterMaskBlurMethod = ( val: unknown -): val is MaskBlurMethodParam => zMaskBlurMethod.safeParse(val).success; +): val is ParameterMaskBlurMethod => + zParameterMaskBlurMethod.safeParse(val).success; +// #endregion -/** - * Zod schema for a Canvas Coherence Mode method parameter - */ -export const zCanvasCoherenceMode = z.enum(['unmasked', 'mask', 'edge']); -/** - * Type alias for Canvas Coherence Mode parameter, inferred from its zod schema - */ -export type CanvasCoherenceModeParam = z.infer; -/** - * Validates/type-guards a value as a mask blur method parameter - */ -export const isValidCoherenceModeParam = ( +// #region Canvas Coherence Mode +export const zParameterCanvasCoherenceMode = z.enum([ + 'unmasked', + 'mask', + 'edge', +]); +export type ParameterCanvasCoherenceMode = z.infer< + typeof zParameterCanvasCoherenceMode +>; +export const isParameterCanvasCoherenceMode = ( val: unknown -): val is CanvasCoherenceModeParam => - zCanvasCoherenceMode.safeParse(val).success; - -/** - * Zod schema for a boolean. - */ -export const zBoolean = z.boolean(); - -/** - * Validates/type-guards a value as a boolean parameter - */ -export const isValidBoolean = (val: unknown): val is boolean => - zBoolean.safeParse(val).success && val !== null && val !== undefined; - -// /** -// * Zod schema for BaseModelType -// */ -// export const zBaseModelType = z.enum(['sd-1', 'sd-2']); -// /** -// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI. -// */ -// export type BaseModelType = z.infer; -// /** -// * Validates/type-guards a value as a base model type -// */ -// export const isValidBaseModelType = (val: unknown): val is BaseModelType => -// zBaseModelType.safeParse(val).success; +): val is ParameterCanvasCoherenceMode => + zParameterCanvasCoherenceMode.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts index 30e6fdcd3dc..d823edbce22 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts @@ -1,5 +1,5 @@ import { logger } from 'app/logging/logger'; -import { zControlNetModel } from 'features/parameters/types/parameterSchemas'; +import { zParameterControlNetModel } from 'features/parameters/types/parameterSchemas'; import { ControlNetModelField } from 'services/api/types'; export const modelIdToControlNetModelParam = ( @@ -8,7 +8,7 @@ export const modelIdToControlNetModelParam = ( const log = logger('models'); const [base_model, _model_type, model_name] = controlNetModelId.split('/'); - const result = zControlNetModel.safeParse({ + const result = zParameterControlNetModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts index 4d580465453..f3ccce47dfd 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts @@ -1,5 +1,5 @@ import { logger } from 'app/logging/logger'; -import { zIPAdapterModel } from 'features/parameters/types/parameterSchemas'; +import { zParameterIPAdapterModel } from 'features/parameters/types/parameterSchemas'; import { IPAdapterModelField } from 'services/api/types'; export const modelIdToIPAdapterModelParam = ( @@ -8,7 +8,7 @@ export const modelIdToIPAdapterModelParam = ( const log = logger('models'); const [base_model, _model_type, model_name] = ipAdapterModelId.split('/'); - const result = zIPAdapterModel.safeParse({ + const result = zParameterIPAdapterModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts index bf4c6454fbc..abe0e9e58f8 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts @@ -1,14 +1,17 @@ import { logger } from 'app/logging/logger'; -import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; +import { + ParameterLoRAModel, + zParameterLoRAModel, +} from '../types/parameterSchemas'; export const modelIdToLoRAModelParam = ( loraModelId: string -): LoRAModelParam | undefined => { +): ParameterLoRAModel | undefined => { const log = logger('models'); const [base_model, _model_type, model_name] = loraModelId.split('/'); - const result = zLoRAModel.safeParse({ + const result = zParameterLoRAModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts index 78a3bcc515e..9500546f84f 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts @@ -1,17 +1,16 @@ import { logger } from 'app/logging/logger'; import { - MainModelParam, - OnnxModelParam, - zMainOrOnnxModel, + ParameterModel, + zParameterModel, } from 'features/parameters/types/parameterSchemas'; export const modelIdToMainModelParam = ( mainModelId: string -): OnnxModelParam | MainModelParam | undefined => { +): ParameterModel | undefined => { const log = logger('models'); const [base_model, model_type, model_name] = mainModelId.split('/'); - const result = zMainOrOnnxModel.safeParse({ + const result = zParameterModel.safeParse({ base_model, model_name, model_type, diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts index 780ac564593..5ed185ef8ed 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts @@ -1,16 +1,16 @@ import { logger } from 'app/logging/logger'; import { - SDXLRefinerModelParam, - zSDXLRefinerModel, + ParameterSDXLRefinerModel, + zParameterSDXLRefinerModel, } from 'features/parameters/types/parameterSchemas'; export const modelIdToSDXLRefinerModelParam = ( mainModelId: string -): SDXLRefinerModelParam | undefined => { +): ParameterSDXLRefinerModel | undefined => { const log = logger('models'); const [base_model, model_type, model_name] = mainModelId.split('/'); - const result = zSDXLRefinerModel.safeParse({ + const result = zParameterSDXLRefinerModel.safeParse({ base_model, model_name, model_type, diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts index 95f1a3f25af..3d66ef66e8c 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts @@ -1,5 +1,5 @@ import { logger } from 'app/logging/logger'; -import { zT2IAdapterModel } from 'features/parameters/types/parameterSchemas'; +import { zParameterT2IAdapterModel } from 'features/parameters/types/parameterSchemas'; import { T2IAdapterModelField } from 'services/api/types'; export const modelIdToT2IAdapterModelParam = ( @@ -8,7 +8,7 @@ export const modelIdToT2IAdapterModelParam = ( const log = logger('models'); const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/'); - const result = zT2IAdapterModel.safeParse({ + const result = zParameterT2IAdapterModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts index 1f3908dd47a..a30dbcc12f8 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts @@ -1,13 +1,16 @@ import { logger } from 'app/logging/logger'; -import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; +import { + ParameterVAEModel, + zParameterVAEModel, +} from '../types/parameterSchemas'; export const modelIdToVAEModelParam = ( vaeModelId: string -): VaeModelParam | undefined => { +): ParameterVAEModel | undefined => { const log = logger('models'); const [base_model, _model_type, model_name] = vaeModelId.split('/'); - const result = zVaeModel.safeParse({ + const result = zParameterVAEModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx index 50400aef9fd..90a3e6eeed7 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx @@ -3,10 +3,8 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice'; import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; @@ -22,7 +20,7 @@ const selector = createSelector( const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ value: name, label: label, - group: enabledSchedulers.includes(name as SchedulerParam) + group: enabledSchedulers.includes(name as ParameterScheduler) ? 'Favorites' : undefined, })).sort((a, b) => a.label.localeCompare(b.label)); @@ -45,7 +43,7 @@ const ParamSDXLRefinerScheduler = () => { if (!v) { return; } - dispatch(setRefinerScheduler(v as SchedulerParam)); + dispatch(setRefinerScheduler(v as ParameterScheduler)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts index 73f5779a52f..861fa91e23b 100644 --- a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts +++ b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts @@ -1,21 +1,21 @@ import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { - NegativeStylePromptSDXLParam, - PositiveStylePromptSDXLParam, - SDXLRefinerModelParam, - SchedulerParam, + ParameterNegativeStylePromptSDXL, + ParameterPositiveStylePromptSDXL, + ParameterSDXLRefinerModel, + ParameterScheduler, } from 'features/parameters/types/parameterSchemas'; type SDXLState = { - positiveStylePrompt: PositiveStylePromptSDXLParam; - negativeStylePrompt: NegativeStylePromptSDXLParam; + positiveStylePrompt: ParameterPositiveStylePromptSDXL; + negativeStylePrompt: ParameterNegativeStylePromptSDXL; shouldConcatSDXLStylePrompt: boolean; shouldUseSDXLRefiner: boolean; sdxlImg2ImgDenoisingStrength: number; - refinerModel: SDXLRefinerModelParam | null; + refinerModel: ParameterSDXLRefinerModel | null; refinerSteps: number; refinerCFGScale: number; - refinerScheduler: SchedulerParam; + refinerScheduler: ParameterScheduler; refinerPositiveAestheticScore: number; refinerNegativeAestheticScore: number; refinerStart: number; @@ -57,7 +57,7 @@ const sdxlSlice = createSlice({ }, refinerModelChanged: ( state, - action: PayloadAction + action: PayloadAction ) => { state.refinerModel = action.payload; }, @@ -67,7 +67,7 @@ const sdxlSlice = createSlice({ setRefinerCFGScale: (state, action: PayloadAction) => { state.refinerCFGScale = action.payload; }, - setRefinerScheduler: (state, action: PayloadAction) => { + setRefinerScheduler: (state, action: PayloadAction) => { state.refinerScheduler = action.payload; }, setRefinerPositiveAestheticScore: ( diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index 0be58a8815f..270d9aed2c7 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,10 +1,8 @@ import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice'; import { map } from 'lodash-es'; import { useCallback } from 'react'; @@ -26,7 +24,7 @@ export default function SettingsSchedulers() { const handleChange = useCallback( (v: string[]) => { - dispatch(favoriteSchedulersChanged(v as SchedulerParam[])); + dispatch(favoriteSchedulersChanged(v as ParameterScheduler[])); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx index 2668062da63..46ea9fb0512 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx @@ -9,7 +9,7 @@ import ParamSteps from 'features/parameters/components/Parameters/Core/ParamStep import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { useCoreParametersCollapseLabel } from 'features/parameters/util/useCoreParametersCollapseLabel'; +import { useCoreParametersCollapseLabel } from 'features/parameters/hooks/useCoreParametersCollapseLabel'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx index 3f3cf2db050..29ab63cb1c5 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx @@ -7,7 +7,7 @@ import ParamModelandVAEandScheduler from 'features/parameters/components/Paramet import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize'; import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { useCoreParametersCollapseLabel } from 'features/parameters/util/useCoreParametersCollapseLabel'; +import { useCoreParametersCollapseLabel } from 'features/parameters/hooks/useCoreParametersCollapseLabel'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx index 40a5026d09c..bc86386515b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx @@ -8,7 +8,7 @@ import ParamModelandVAEandScheduler from 'features/parameters/components/Paramet import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { useCoreParametersCollapseLabel } from 'features/parameters/util/useCoreParametersCollapseLabel'; +import { useCoreParametersCollapseLabel } from 'features/parameters/hooks/useCoreParametersCollapseLabel'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 9782d0bfac2..69cfe428271 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,7 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; -import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import { InvokeTabName } from './tabMap'; import { UIState } from './uiTypes'; @@ -50,7 +50,7 @@ export const uiSlice = createSlice({ }, favoriteSchedulersChanged: ( state, - action: PayloadAction + action: PayloadAction ) => { state.favoriteSchedulers = action.payload; }, diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index 1b9fee6989e..b5320430546 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -1,4 +1,4 @@ -import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import { InvokeTabName } from './tabMap'; export type Coordinates = { @@ -23,7 +23,7 @@ export interface UIState { shouldShowProgressInViewer: boolean; shouldShowEmbeddingPicker: boolean; shouldAutoChangeDimensions: boolean; - favoriteSchedulers: SchedulerParam[]; + favoriteSchedulers: ParameterScheduler[]; globalContextMenuCloseTrigger: number; panels: Record; } diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 166d00a3dbc..97473b21f25 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -7,7 +7,7 @@ import { IMAGE_CATEGORIES, IMAGE_LIMIT, } from 'features/gallery/store/types'; -import { CoreMetadata, zCoreMetadata } from 'features/nodes/types/types'; +import { CoreMetadata, zCoreMetadata } from 'features/nodes/types/metadata'; import { keyBy } from 'lodash-es'; import { ApiTagDescription, LIST_TAG, api } from '..'; import { components, paths } from '../schema'; diff --git a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts index 1792788d575..b7611cb397f 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts @@ -1,11 +1,11 @@ import { logger } from 'app/logging/logger'; -import { Workflow, zWorkflow } from 'features/nodes/types/types'; +import { WorkflowV2, zWorkflowV2 } from 'features/nodes/types/workflow'; import { api } from '..'; import { paths } from '../schema'; export const workflowsApi = api.injectEndpoints({ endpoints: (build) => ({ - getWorkflow: build.query({ + getWorkflow: build.query({ query: (workflow_id) => `workflows/i/${workflow_id}`, providesTags: (result, error, workflow_id) => [ { type: 'Workflow', id: workflow_id }, @@ -14,7 +14,7 @@ export const workflowsApi = api.injectEndpoints({ response: paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'] ) => { if (response) { - const result = zWorkflow.safeParse(response); + const result = zWorkflowV2.safeParse(response); if (result.success) { return result.data; } else { diff --git a/invokeai/frontend/web/src/services/api/guards.ts b/invokeai/frontend/web/src/services/api/guards.ts deleted file mode 100644 index 2893d88e07e..00000000000 --- a/invokeai/frontend/web/src/services/api/guards.ts +++ /dev/null @@ -1,67 +0,0 @@ -import { get, isObject, isString } from 'lodash-es'; -import { - GraphExecutionState, - GraphInvocationOutput, - ImageOutput, - IterateInvocationOutput, - CollectInvocationOutput, - ImageField, - LatentsOutput, - ResourceOrigin, - ImageDTO, - BoardDTO, -} from 'services/api/types'; - -export const isImageDTO = (obj: unknown): obj is ImageDTO => { - return ( - isObject(obj) && - 'image_name' in obj && - isString(obj?.image_name) && - 'thumbnail_url' in obj && - isString(obj?.thumbnail_url) && - 'image_url' in obj && - isString(obj?.image_url) && - 'image_origin' in obj && - isString(obj?.image_origin) && - 'created_at' in obj && - isString(obj?.created_at) - ); -}; - -export const isBoardDTO = (obj: unknown): obj is BoardDTO => { - return ( - isObject(obj) && - 'board_id' in obj && - isString(obj?.board_id) && - 'board_name' in obj && - isString(obj?.board_name) - ); -}; - -export const isImageOutput = ( - output: GraphExecutionState['results'][string] -): output is ImageOutput => output?.type === 'image_output'; - -export const isLatentsOutput = ( - output: GraphExecutionState['results'][string] -): output is LatentsOutput => output?.type === 'latents_output'; - -export const isGraphOutput = ( - output: GraphExecutionState['results'][string] -): output is GraphInvocationOutput => output?.type === 'graph_output'; - -export const isIterateOutput = ( - output: GraphExecutionState['results'][string] -): output is IterateInvocationOutput => output?.type === 'iterate_output'; - -export const isCollectOutput = ( - output: GraphExecutionState['results'][string] -): output is CollectInvocationOutput => output?.type === 'collect_output'; - -export const isResourceOrigin = (t: unknown): t is ResourceOrigin => - isString(t) && ['internal', 'external'].includes(t); - -export const isImageField = (imageField: unknown): imageField is ImageField => - isObject(imageField) && - isString(get(imageField, 'image_name')) && - isResourceOrigin(get(imageField, 'image_origin')); diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index c82a1950287..5d5b382c4c6 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -930,6 +930,7 @@ export type components = { /** * Collection * @description The collection of boolean values + * @default [] */ collection?: boolean[]; /** @@ -1310,6 +1311,7 @@ export type components = { /** * Collection * @description The collection, will be provided on execution + * @default [] */ collection?: unknown[]; /** @@ -1581,6 +1583,7 @@ export type components = { /** * Collection * @description The collection of conditioning tensors + * @default [] */ collection?: components["schemas"]["ConditioningField"][]; /** @@ -2893,6 +2896,7 @@ export type components = { /** * Collection * @description The collection of float values + * @default [] */ collection?: number[]; /** @@ -3216,7 +3220,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"]; + [key: string]: components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StepParamEasingInvocation"]; }; /** * Edges @@ -3253,7 +3257,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["String2Output"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"]; + [key: string]: components["schemas"]["SeamlessModeOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["SchedulerOutput"]; }; /** * Errors @@ -4727,6 +4731,7 @@ export type components = { /** * Seed * @description The seed to use for tile generation (omit for random) + * @default 0 */ seed?: number; /** @@ -4761,6 +4766,7 @@ export type components = { /** * Collection * @description The collection of integer values + * @default [] */ collection?: number[]; /** @@ -4940,6 +4946,7 @@ export type components = { /** * Collection * @description The list of items to iterate over + * @default [] */ collection?: unknown[]; /** @@ -6342,6 +6349,7 @@ export type components = { /** * Seed * @description Seed for random number generation + * @default 0 */ seed?: number; /** @@ -7194,6 +7202,7 @@ export type components = { /** * Seed * @description The seed for the RNG (omit for random) + * @default 0 */ seed?: number; /** @@ -8532,6 +8541,7 @@ export type components = { /** * Collection * @description The collection of string values + * @default [] */ collection?: string[]; /** @@ -9529,6 +9539,24 @@ export type components = { * @enum {string} */ invokeai__backend__model_manager__config__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; + /** + * FieldKind + * @description The kind of field. + * - `Input`: An input field on a node. + * - `Output`: An output field on a node. + * - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + * one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + * "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + * allowing "metadata" for that field. + * - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + * but which are used to store information about the node. For example, the `id` and `type` fields are node + * attributes. + * + * The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + * startup, and when generating the OpenAPI schema for the workflow editor. + * @enum {string} + */ + FieldKind: "input" | "output" | "internal" | "node_attribute"; /** * Input * @description The type of input a field accepts. @@ -9538,9 +9566,65 @@ export type components = { * @enum {string} */ Input: "connection" | "direct" | "any"; + /** + * InputFieldJSONSchemaExtra + * @description Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + * and by the workflow editor during schema parsing and UI rendering. + */ + InputFieldJSONSchemaExtra: { + input: components["schemas"]["Input"]; + /** Orig Required */ + orig_required: boolean; + field_kind: components["schemas"]["FieldKind"]; + /** + * Default + * @default null + */ + default: unknown; + /** + * Orig Default + * @default null + */ + orig_default: unknown; + /** + * Ui Hidden + * @default false + */ + ui_hidden: boolean; + /** @default null */ + ui_type: components["schemas"]["UIType"] | null; + /** @default null */ + ui_component: components["schemas"]["UIComponent"] | null; + /** + * Ui Order + * @default null + */ + ui_order: number | null; + /** + * Ui Choice Labels + * @default null + */ + ui_choice_labels: { + [key: string]: string; + } | null; + }; + /** + * OutputFieldJSONSchemaExtra + * @description Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + * during schema parsing and UI rendering. + */ + OutputFieldJSONSchemaExtra: { + field_kind: components["schemas"]["FieldKind"]; + /** Ui Hidden */ + ui_hidden: boolean; + ui_type: components["schemas"]["UIType"] | null; + /** Ui Order */ + ui_order: number | null; + }; /** * UIComponent - * @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type. + * @description The type of UI component to use for a field, used to override the default components, which are + * inferred from the field type. * @enum {string} */ UIComponent: "none" | "textarea" | "slider"; @@ -9570,71 +9654,55 @@ export type components = { /** * Version * @description The node's version. Should be a valid semver string e.g. "1.0.0" or "3.8.13". - * @default null */ - version: string | null; + version: string; + /** + * Is Custom + * @description Whether or not this is a custom node + * @default false + */ + is_custom: boolean; }; /** * UIType - * @description Type hints for the UI. - * If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes. + * @description Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + * + * - Model Fields + * The most common node-author-facing use will be for model fields. Internally, there is no difference + * between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + * base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + * the field is an SDXL main model field. + * + * - Any Field + * We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + * indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + * + * - Scheduler Field + * Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. + * + * - Internal Fields + * Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + * handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + * should not be used by node authors. + * + * - DEPRECATED Fields + * These types are deprecated and should not be used by node authors. A warning will be logged if one is + * used, and the type will be ignored. They are included here for backwards compatibility. * @enum {string} */ - UIType: "boolean" | "ColorField" | "ConditioningField" | "ControlField" | "float" | "ImageField" | "integer" | "LatentsField" | "string" | "BooleanCollection" | "ColorCollection" | "ConditioningCollection" | "ControlCollection" | "FloatCollection" | "ImageCollection" | "IntegerCollection" | "LatentsCollection" | "StringCollection" | "BooleanPolymorphic" | "ColorPolymorphic" | "ConditioningPolymorphic" | "ControlPolymorphic" | "FloatPolymorphic" | "ImagePolymorphic" | "IntegerPolymorphic" | "LatentsPolymorphic" | "StringPolymorphic" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "BoardField" | "Any" | "MetadataItem" | "MetadataItemCollection" | "MetadataItemPolymorphic" | "MetadataDict"; - /** - * _InputField - * @description *DO NOT USE* - * This helper class is used to tell the client about our custom field attributes via OpenAPI - * schema generation, and Typescript type generation from that schema. It serves no functional - * purpose in the backend. - */ - _InputField: { - input: components["schemas"]["Input"]; - /** Ui Hidden */ - ui_hidden: boolean; - ui_type: components["schemas"]["UIType"] | null; - ui_component: components["schemas"]["UIComponent"] | null; - /** Ui Order */ - ui_order: number | null; - /** Ui Choice Labels */ - ui_choice_labels: { - [key: string]: string; - } | null; - /** Item Default */ - item_default: unknown; - }; - /** - * _OutputField - * @description *DO NOT USE* - * This helper class is used to tell the client about our custom field attributes via OpenAPI - * schema generation, and Typescript type generation from that schema. It serves no functional - * purpose in the backend. - */ - _OutputField: { - /** Ui Hidden */ - ui_hidden: boolean; - ui_type: components["schemas"]["UIType"] | null; - /** Ui Order */ - ui_order: number | null; - }; + UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * IPAdapterModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + T2IAdapterModelFormat: "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion2ModelFormat * @description An enumeration. @@ -9642,17 +9710,17 @@ export type components = { */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** - * T2IAdapterModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * StableDiffusionOnnxModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** * CLIPVisionModelFormat * @description An enumeration. @@ -9665,6 +9733,12 @@ export type components = { * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * IPAdapterModelFormat + * @description An enumeration. + * @enum {string} + */ + IPAdapterModelFormat: "invokeai"; }; responses: never; parameters: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index ce3e75a5843..3c5e54536e4 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -1,6 +1,7 @@ import { UseToastOptions } from '@chakra-ui/react'; import { EntityState } from '@reduxjs/toolkit'; import { components, paths } from './schema'; +import { O } from 'ts-toolbelt'; type s = components['schemas']; @@ -27,8 +28,8 @@ export type BatchConfig = export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult']; -export type _InputField = s['_InputField']; -export type _OutputField = s['_OutputField']; +export type InputFieldJSONSchemaExtra = s['InputFieldJSONSchemaExtra']; +export type OutputFieldJSONSchemaExtra = s['OutputFieldJSONSchemaExtra']; // App Info export type AppVersion = s['AppVersion']; @@ -57,6 +58,7 @@ export type MainModelField = s['MainModelField']; export type OnnxModelField = s['OnnxModelField']; export type VAEModelField = s['VAEModelField']; export type LoRAModelField = s['LoRAModelField']; +export type LoRAModelFormat = s['LoRAModelFormat']; export type ControlNetModelField = s['ControlNetModelField']; export type IPAdapterModelField = s['IPAdapterModelField']; export type T2IAdapterModelField = s['T2IAdapterModelField']; @@ -105,6 +107,7 @@ export type ImportModelConfig = s['Body_import_model']; // Graphs export type Graph = s['Graph']; +export type NonNullableGraph = O.Required; export type Edge = s['Edge']; export type GraphExecutionState = s['GraphExecutionState']; export type Batch = s['Batch']; diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 543107bb132..b1d7f147316 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -1,5 +1,4 @@ import { components } from 'services/api/schema'; -import { O } from 'ts-toolbelt'; import { BaseModelType, Graph, @@ -17,11 +16,6 @@ export type ProgressImage = { height: number; }; -export type AnyInvocationType = O.Required< - NonNullable[string]>, - 'type' ->['type']; - export type AnyInvocation = NonNullable[string]>; export type AnyResult = NonNullable;