diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 7d854484e02..fd8cd7ccbda 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -73,9 +73,13 @@ const AddNodePopover = () => { return some(handles, (handle) => { const sourceType = - handleFilter == 'source' ? fieldFilter : handle.type; + handleFilter == 'source' + ? fieldFilter + : handle.originalType ?? handle.type; const targetType = - handleFilter == 'target' ? fieldFilter : handle.type; + handleFilter == 'target' + ? fieldFilter + : handle.originalType ?? handle.type; return validateSourceAndTargetTypes(sourceType, targetType); }); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx index 724494502c6..a14b7b23c6d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx @@ -4,14 +4,14 @@ import { useAppSelector } from 'app/store/storeHooks'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { memo } from 'react'; import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; -import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor'; +import { getFieldColor } from '../edges/util/getEdgeColor'; const selector = createSelector(stateSelector, ({ nodes }) => { - const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } = + const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = nodes; const stroke = shouldColorEdges - ? getFieldColor(connectionStartFieldType) + ? getFieldColor(currentConnectionFieldType) : colorTokenToCssVar('base.500'); let className = 'react-flow__custom_connection-path'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts index 15c63b0bae8..99ada97de14 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -1,12 +1,12 @@ import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { FIELD_COLORS } from 'features/nodes/types/constants'; -import { FieldType } from 'features/nodes/types/field'; +import { FIELDS } from 'features/nodes/types/constants'; +import { FieldType } from 'features/nodes/types/types'; -export const getFieldColor = (fieldType: FieldType | null): string => { +export const getFieldColor = (fieldType: FieldType | string | null): string => { if (!fieldType) { return colorTokenToCssVar('base.500'); } - const color = FIELD_COLORS[fieldType.name]; + const color = FIELDS[fieldType]?.color; return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 73d3d5dc4d7..a6a409e1ad2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/types'; import { getFieldColor } from './getEdgeColor'; export const makeEdgeSelector = ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index b458f2ca255..849003ffbeb 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -1,8 +1,6 @@ import { Tooltip } from '@chakra-ui/react'; -import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { COLLECTION_TYPES, - FIELDS, HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES, POLYMORPHIC_TYPES, @@ -13,6 +11,7 @@ import { } from 'features/nodes/types/types'; import { CSSProperties, memo, useMemo } from 'react'; import { Handle, HandleType, Position } from 'reactflow'; +import { getFieldColor } from '../../../edges/util/getEdgeColor'; export const handleBaseStyles: CSSProperties = { position: 'absolute', @@ -47,14 +46,14 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionStartField, connectionError, } = props; - const { name, type, originalType } = fieldTemplate; - const { color: typeColor } = FIELDS[type]; + const { name } = fieldTemplate; + const type = fieldTemplate.originalType ?? fieldTemplate.type; const styles: CSSProperties = useMemo(() => { - const isCollectionType = COLLECTION_TYPES.includes(type); - const isPolymorphicType = POLYMORPHIC_TYPES.includes(type); - const isModelType = MODEL_TYPES.includes(type); - const color = colorTokenToCssVar(typeColor); + const isCollectionType = COLLECTION_TYPES.some((t) => t === type); + const isPolymorphicType = POLYMORPHIC_TYPES.some((t) => t === type); + const isModelType = MODEL_TYPES.some((t) => t === type); + const color = getFieldColor(type); const s: CSSProperties = { backgroundColor: isCollectionType || isPolymorphicType @@ -97,23 +96,14 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionInProgress, isConnectionStartField, type, - typeColor, ]); const tooltip = useMemo(() => { - if (isConnectionInProgress && isConnectionStartField) { - return originalType; - } if (isConnectionInProgress && connectionError) { - return connectionError ?? originalType; + return connectionError; } - return originalType; - }, [ - connectionError, - isConnectionInProgress, - isConnectionStartField, - originalType, - ]); + return type; + }, [connectionError, isConnectionInProgress, type]); return ( = { - status: zNodeStatus.enum.PENDING, + status: NodeStatus.PENDING, error: null, progress: null, progressImage: null, outputs: [], }; -const INITIAL_WORKFLOW: WorkflowV2 = { +export const initialWorkflow = { + meta: { + version: WORKFLOW_FORMAT_VERSION, + }, name: '', author: '', description: '', - version: '', - contact: '', - tags: '', notes: '', - nodes: [], - edges: [], + tags: '', + contact: '', + version: '', exposedFields: [], - meta: { version: '2.0.0' }, }; export const initialNodesState: NodesState = { @@ -90,10 +93,11 @@ export const initialNodesState: NodesState = { nodeTemplates: {}, isReady: false, connectionStartParams: null, - connectionStartFieldType: null, + currentConnectionFieldType: null, connectionMade: false, modifyingEdge: false, addNewNodePosition: null, + shouldShowFieldTypeLegend: false, shouldShowMinimapPanel: true, shouldValidateGraph: true, shouldAnimateEdges: true, @@ -103,7 +107,7 @@ export const initialNodesState: NodesState = { nodeOpacity: 1, selectedNodes: [], selectedEdges: [], - workflow: INITIAL_WORKFLOW, + workflow: initialWorkflow, nodeExecutionStates: {}, viewport: { x: 0, y: 0, zoom: 1 }, mouseOverField: null, @@ -113,13 +117,13 @@ export const initialNodesState: NodesState = { selectionMode: SelectionMode.Partial, }; -type FieldValueAction = PayloadAction<{ +type FieldValueAction = PayloadAction<{ nodeId: string; fieldName: string; - value: T; + value: T['value']; }>; -const fieldValueReducer = ( +const fieldValueReducer = ( state: NodesState, action: FieldValueAction ) => { @@ -157,7 +161,12 @@ const nodesSlice = createSlice({ } state.nodes[nodeIndex] = action.payload.node; }, - nodeAdded: (state, action: PayloadAction) => { + nodeAdded: ( + state, + action: PayloadAction< + Node + > + ) => { const node = action.payload; const position = findUnoccupiedPosition( state.nodes, @@ -194,7 +203,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.connectionStartFieldType + state.currentConnectionFieldType ) { const newConnection = findConnectionToValidHandle( node, @@ -203,7 +212,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.connectionStartFieldType + state.currentConnectionFieldType ); if (newConnection) { state.edges = addEdge( @@ -215,7 +224,7 @@ const nodesSlice = createSlice({ } state.connectionStartParams = null; - state.connectionStartFieldType = null; + state.currentConnectionFieldType = null; }, edgeChangeStarted: (state) => { state.modifyingEdge = true; @@ -249,10 +258,11 @@ const nodesSlice = createSlice({ handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; - state.connectionStartFieldType = field?.type ?? null; + state.currentConnectionFieldType = + field?.originalType ?? field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { - const fieldType = state.connectionStartFieldType; + const fieldType = state.currentConnectionFieldType; if (!fieldType) { return; } @@ -277,7 +287,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.connectionStartFieldType + state.currentConnectionFieldType ) { const newConnection = findConnectionToValidHandle( mouseOverNode, @@ -286,7 +296,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.connectionStartFieldType + state.currentConnectionFieldType ); if (newConnection) { state.edges = addEdge( @@ -297,14 +307,14 @@ const nodesSlice = createSlice({ } } state.connectionStartParams = null; - state.connectionStartFieldType = null; + state.currentConnectionFieldType = null; } else { state.addNewNodePosition = action.payload.cursorPosition; state.isAddNodePopoverOpen = true; } } else { state.connectionStartParams = null; - state.connectionStartFieldType = null; + state.currentConnectionFieldType = null; } state.modifyingEdge = false; }, @@ -520,7 +530,12 @@ const nodesSlice = createSlice({ state.edges = applyEdgeChanges(edgeChanges, state.edges); } }, - nodesDeleted: (state, action: PayloadAction) => { + nodesDeleted: ( + state, + action: PayloadAction< + Node[] + > + ) => { action.payload.forEach((node) => { state.workflow.exposedFields = state.workflow.exposedFields.filter( (f) => f.nodeId !== node.id @@ -574,94 +589,132 @@ const nodesSlice = createSlice({ }, fieldStringValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldNumberValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldBooleanValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldBoardValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldImageValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldColorValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldMainModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldRefinerModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldVaeModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldLoRAModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldControlNetModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldIPAdapterModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldT2IAdapterModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldEnumModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldSchedulerValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, + imageCollectionFieldValueChanged: ( + state, + action: PayloadAction<{ + nodeId: string; + fieldName: string; + value: ImageField[]; + }> + ) => { + const { nodeId, fieldName, value } = action.payload; + const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); + + if (nodeIndex === -1) { + return; + } + + const node = state.nodes?.[nodeIndex]; + + if (!isInvocationNode(node)) { + return; + } + + const input = node.data?.inputs[fieldName]; + if (!input) { + return; + } + + const currentValue = cloneDeep(input.value); + + if (!currentValue) { + input.value = value; + return; + } + + input.value = uniqBy( + (currentValue as ImageField[]).concat(value), + 'image_name' + ); + }, notesNodeValueChanged: ( state, action: PayloadAction<{ nodeId: string; value: string }> @@ -674,6 +727,12 @@ const nodesSlice = createSlice({ } node.data.notes = value; }, + shouldShowFieldTypeLegendChanged: ( + state, + action: PayloadAction + ) => { + state.shouldShowFieldTypeLegend = action.payload; + }, shouldShowMinimapPanelChanged: (state, action: PayloadAction) => { state.shouldShowMinimapPanel = action.payload; }, @@ -687,7 +746,7 @@ const nodesSlice = createSlice({ nodeEditorReset: (state) => { state.nodes = []; state.edges = []; - state.workflow = cloneDeep(INITIAL_WORKFLOW); + state.workflow = cloneDeep(initialWorkflow); }, shouldValidateGraphChanged: (state, action: PayloadAction) => { state.shouldValidateGraph = action.payload; @@ -725,13 +784,13 @@ const nodesSlice = createSlice({ workflowContactChanged: (state, action: PayloadAction) => { state.workflow.contact = action.payload; }, - workflowLoaded: (state, action: PayloadAction) => { + workflowLoaded: (state, action: PayloadAction) => { const { nodes, edges, ...workflow } = action.payload; state.workflow = workflow; state.nodes = applyNodeChanges( nodes.map((node) => ({ - item: { ...node, ...SHARED_NODE_PROPERTIES }, + item: { ...node, dragHandle: `.${DRAG_HANDLE_CLASSNAME}` }, type: 'add', })), [] @@ -752,7 +811,7 @@ const nodesSlice = createSlice({ }, {}); }, workflowReset: (state) => { - state.workflow = cloneDeep(INITIAL_WORKFLOW); + state.workflow = cloneDeep(initialWorkflow); }, viewportChanged: (state, action: PayloadAction) => { state.viewport = action.payload; @@ -884,7 +943,7 @@ const nodesSlice = createSlice({ //Make sure these get reset if we close the popover and haven't selected a node state.connectionStartParams = null; - state.connectionStartFieldType = null; + state.currentConnectionFieldType = null; }, addNodePopoverToggled: (state) => { state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen; @@ -903,14 +962,14 @@ const nodesSlice = createSlice({ const { source_node_id } = action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = zNodeStatus.enum.IN_PROGRESS; + node.status = NodeStatus.IN_PROGRESS; } }); builder.addCase(appSocketInvocationComplete, (state, action) => { const { source_node_id, result } = action.payload.data; const nes = state.nodeExecutionStates[source_node_id]; if (nes) { - nes.status = zNodeStatus.enum.COMPLETED; + nes.status = NodeStatus.COMPLETED; if (nes.progress !== null) { nes.progress = 1; } @@ -921,7 +980,7 @@ const nodesSlice = createSlice({ const { source_node_id } = action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = zNodeStatus.enum.FAILED; + node.status = NodeStatus.FAILED; node.error = action.payload.data.error; node.progress = null; node.progressImage = null; @@ -932,7 +991,7 @@ const nodesSlice = createSlice({ action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = zNodeStatus.enum.IN_PROGRESS; + node.status = NodeStatus.IN_PROGRESS; node.progress = (step + 1) / total_steps; node.progressImage = progress_image ?? null; } @@ -940,7 +999,7 @@ const nodesSlice = createSlice({ builder.addCase(appSocketQueueItemStatusChanged, (state, action) => { if (['in_progress'].includes(action.payload.data.queue_item.status)) { forEach(state.nodeExecutionStates, (nes) => { - nes.status = zNodeStatus.enum.PENDING; + nes.status = NodeStatus.PENDING; nes.error = null; nes.progress = null; nes.progressImage = null; @@ -979,6 +1038,7 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, + imageCollectionFieldValueChanged, mouseOverFieldChanged, mouseOverNodeChanged, nodeAdded, @@ -1004,6 +1064,7 @@ export const { selectionPasted, shouldAnimateEdgesChanged, shouldColorEdgesChanged, + shouldShowFieldTypeLegendChanged, shouldShowMinimapPanelChanged, shouldSnapToGridChanged, shouldValidateGraphChanged, diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index bfd351ac5d5..b81dd286d72 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -1,26 +1,30 @@ import { + Edge, + Node, OnConnectStartParams, SelectionMode, Viewport, XYPosition, } from 'reactflow'; -import { FieldIdentifier, FieldType } from 'features/nodes/types/field'; import { - AnyNode, - InvocationNodeEdge, + FieldIdentifier, + FieldType, + InvocationEdgeExtra, InvocationTemplate, + NodeData, NodeExecutionState, -} from 'features/nodes/types/invocation'; -import { WorkflowV2 } from 'features/nodes/types/workflow'; + Workflow, +} from '../types/types'; export type NodesState = { - nodes: AnyNode[]; - edges: InvocationNodeEdge[]; + nodes: Node[]; + edges: Edge[]; nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; - connectionStartFieldType: FieldType | null; + currentConnectionFieldType: FieldType | string | null; connectionMade: boolean; modifyingEdge: boolean; + shouldShowFieldTypeLegend: boolean; shouldShowMinimapPanel: boolean; shouldValidateGraph: boolean; shouldAnimateEdges: boolean; @@ -29,14 +33,14 @@ export type NodesState = { shouldColorEdges: boolean; selectedNodes: string[]; selectedEdges: string[]; - workflow: Omit; + workflow: Omit; nodeExecutionStates: Record; viewport: Viewport; isReady: boolean; mouseOverField: FieldIdentifier | null; mouseOverNode: string | null; - nodesToCopy: AnyNode[]; - edgesToCopy: InvocationNodeEdge[]; + nodesToCopy: Node[]; + edgesToCopy: Edge[]; isAddNodePopoverOpen: boolean; addNewNodePosition: XYPosition | null; selectionMode: SelectionMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts new file mode 100644 index 00000000000..0efd3d17c6a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts @@ -0,0 +1,128 @@ +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + CurrentImageNodeData, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, + NotesNodeData, + OutputFieldValue, +} from 'features/nodes/types/types'; +import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders'; +import { reduce } from 'lodash-es'; +import { Node, XYPosition } from 'reactflow'; +import { AnyInvocationType } from 'services/events/types'; +import { v4 as uuidv4 } from 'uuid'; + +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; +export const buildNodeData = ( + type: AnyInvocationType | 'current_image' | 'notes', + position: XYPosition, + template?: InvocationTemplate +): + | Node + | Node + | Node + | undefined => { + const nodeId = uuidv4(); + + if (type === 'current_image') { + const node: Node = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'current_image', + position, + data: { + id: nodeId, + type: 'current_image', + isOpen: true, + label: 'Current Image', + }, + }; + + return node; + } + + if (type === 'notes') { + const node: Node = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'notes', + position, + data: { + id: nodeId, + isOpen: true, + label: 'Notes', + notes: '', + type: 'notes', + }, + }; + + return node; + } + + if (template === undefined) { + console.error(`Unable to find template ${type}.`); + return; + } + + const inputs = reduce( + template.inputs, + (inputsAccumulator, inputTemplate, inputName) => { + const fieldId = uuidv4(); + + const inputFieldValue: InputFieldValue = buildInputFieldValue( + fieldId, + inputTemplate + ); + + inputsAccumulator[inputName] = inputFieldValue; + + return inputsAccumulator; + }, + {} as Record + ); + + const outputs = reduce( + template.outputs, + (outputsAccumulator, outputTemplate, outputName) => { + const fieldId = uuidv4(); + + const outputFieldValue: OutputFieldValue = { + id: fieldId, + name: outputName, + type: outputTemplate.type, + fieldKind: 'output', + originalType: outputTemplate.originalType, + }; + + outputsAccumulator[outputName] = outputFieldValue; + + return outputsAccumulator; + }, + {} as Record + ); + + const invocation: Node = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'invocation', + position, + data: { + id: nodeId, + type, + version: template.version, + label: '', + notes: '', + isOpen: true, + embedWorkflow: false, + isIntermediate: type === 'save_image' ? false : true, + inputs, + outputs, + useCache: template.useCache, + }, + }; + + return invocation; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 0a7adf77cbc..290a5714444 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -11,7 +11,7 @@ import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; const isValidConnection = ( edges: Edge[], handleCurrentType: HandleType, - handleCurrentFieldType: FieldType, + handleCurrentFieldType: FieldType | string, node: Node, handle: FieldInputInstance | FieldOutputInstance ) => { @@ -34,7 +34,12 @@ const isValidConnection = ( } } - if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) { + if ( + !validateSourceAndTargetTypes( + handleCurrentFieldType, + handle.originalType ?? handle.type + ) + ) { isValidConnection = false; } @@ -48,7 +53,7 @@ export const findConnectionToValidHandle = ( handleCurrentNodeId: string, handleCurrentName: string, handleCurrentType: HandleType, - handleCurrentFieldType: FieldType + handleCurrentFieldType: FieldType | string ): Connection | null => { if (node.id === handleCurrentNodeId) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index de795612919..cb7886e57e1 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -1,6 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { FieldType } from 'features/nodes/types/field'; +import { FieldType } from 'features/nodes/types/types'; import i18n from 'i18next'; import { HandleType } from 'reactflow'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; @@ -15,17 +15,17 @@ export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, handleType: HandleType, - fieldType?: FieldType + fieldType?: FieldType | string ) => { - return createSelector(stateSelector, (state): string | undefined => { + return createSelector(stateSelector, (state) => { if (!fieldType) { return i18n.t('nodes.noFieldType'); } - const { connectionStartFieldType, connectionStartParams, nodes, edges } = + const { currentConnectionFieldType, connectionStartParams, nodes, edges } = state.nodes; - if (!connectionStartParams || !connectionStartFieldType) { + if (!connectionStartParams || !currentConnectionFieldType) { return i18n.t('nodes.noConnectionInProgress'); } @@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = ( } const targetType = - handleType === 'target' ? fieldType : connectionStartFieldType; + handleType === 'target' ? fieldType : currentConnectionFieldType; const sourceType = - handleType === 'source' ? fieldType : connectionStartFieldType; + handleType === 'source' ? fieldType : currentConnectionFieldType; if (nodeId === connectionNodeId) { return i18n.t('nodes.cannotConnectToSelf'); @@ -80,7 +80,7 @@ export const makeConnectionErrorSelector = ( return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetType.name !== 'CollectionItemField' + targetType !== 'CollectionItem' ) { return i18n.t('nodes.inputMayOnlyHaveOneConnection'); } @@ -100,6 +100,6 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.connectionWouldCreateCycle'); } - return; + return null; }); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index e3ad0b96213..123cda8e044 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -1,27 +1,23 @@ -import { FieldType } from 'features/nodes/types/field'; -import { isEqual } from 'lodash-es'; +import { + COLLECTION_MAP, + COLLECTION_TYPES, + POLYMORPHIC_TO_SINGLE_MAP, + POLYMORPHIC_TYPES, +} from 'features/nodes/types/constants'; +import { FieldType } from 'features/nodes/types/types'; -/** - * Validates that the source and target types are compatible for a connection. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. - * @returns True if the connection is valid, false otherwise. - */ export const validateSourceAndTargetTypes = ( - sourceType: FieldType, - targetType: FieldType + sourceType: FieldType | string, + targetType: FieldType | string ) => { // TODO: There's a bug with Collect -> Iterate nodes: // https://github.com/invoke-ai/InvokeAI/issues/3956 // Once this is resolved, we can remove this check. - if ( - sourceType.name === 'CollectionField' && - targetType.name === 'CollectionField' - ) { + if (sourceType === 'Collection' && targetType === 'Collection') { return false; } - if (isEqual(sourceType, targetType)) { + if (sourceType === targetType) { return true; } @@ -29,53 +25,59 @@ export const validateSourceAndTargetTypes = ( * Connection types must be the same for a connection, with exceptions: * - CollectionItem can connect to any non-Collection * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type - * - Generic Collection can connect to any other Collection or CollectionOrScalar + * - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type + * - Generic Collection can connect to any other Collection or Polymorphic * - Any Collection can connect to a Generic Collection */ const isCollectionItemToNonCollection = - sourceType.name === 'CollectionItemField' && !targetType.isCollection; + sourceType === 'CollectionItem' && + !COLLECTION_TYPES.some((t) => t === targetType); const isNonCollectionToCollectionItem = - targetType.name === 'CollectionItemField' && - !sourceType.isCollection && - !sourceType.isCollectionOrScalar; + targetType === 'CollectionItem' && + !COLLECTION_TYPES.some((t) => t === sourceType) && + !POLYMORPHIC_TYPES.some((t) => t === sourceType); - const isAnythingToCollectionOrScalarOfSameBaseType = - targetType.isCollectionOrScalar && sourceType.name === targetType.name; + const isAnythingToPolymorphicOfSameBaseType = + POLYMORPHIC_TYPES.some((t) => t === targetType) && + (() => { + if (!POLYMORPHIC_TYPES.some((t) => t === targetType)) { + return false; + } + const baseType = + POLYMORPHIC_TO_SINGLE_MAP[ + targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP + ]; - const isGenericCollectionToAnyCollectionOrCollectionOrScalar = - sourceType.name === 'CollectionField' && - (targetType.isCollection || targetType.isCollectionOrScalar); + const collectionType = + COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; - const isCollectionToGenericCollection = - targetType.name === 'CollectionField' && sourceType.isCollection; + return sourceType === baseType || sourceType === collectionType; + })(); + + const isGenericCollectionToAnyCollectionOrPolymorphic = + sourceType === 'Collection' && + (COLLECTION_TYPES.some((t) => t === targetType) || + POLYMORPHIC_TYPES.some((t) => t === targetType)); - const areBothTypesSingle = - !sourceType.isCollection && - !sourceType.isCollectionOrScalar && - !targetType.isCollection && - !targetType.isCollectionOrScalar; + const isCollectionToGenericCollection = + targetType === 'Collection' && + COLLECTION_TYPES.some((t) => t === sourceType); - const isIntToFloat = - areBothTypesSingle && - sourceType.name === 'IntegerField' && - targetType.name === 'FloatField'; + const isIntToFloat = sourceType === 'integer' && targetType === 'float'; const isIntOrFloatToString = - areBothTypesSingle && - (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && - targetType.name === 'StringField'; + (sourceType === 'integer' || sourceType === 'float') && + targetType === 'string'; - const isTargetAnyType = targetType.name === 'AnyField'; + const isTargetAnyType = targetType === 'Any'; - // One of these must be true for the connection to be valid return ( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || - isAnythingToCollectionOrScalarOfSameBaseType || - isGenericCollectionToAnyCollectionOrCollectionOrScalar || + isAnythingToPolymorphicOfSameBaseType || + isGenericCollectionToAnyCollectionOrPolymorphic || isCollectionToGenericCollection || isIntToFloat || isIntOrFloatToString || diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 89d6729f4c3..f8db78ecc39 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -60,7 +60,7 @@ const FIELD_VALUE_FALLBACK_MAP: { UNetField: undefined, VaeField: undefined, VaeModelField: undefined, - Unknown: undefined, + Custom: undefined, }; export const buildInputFieldValue = ( @@ -77,10 +77,9 @@ export const buildInputFieldValue = ( type: template.type, label: '', fieldKind: 'input', + originalType: template.originalType, + value: template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type], } as InputFieldValue; - fieldValue.value = - template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type]; - return fieldValue; };