Skip to content

Commit

Permalink
feat(ui): custom field types connection validation
Browse files Browse the repository at this point in the history
In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic.

*Actually, it was `"Unknown"`, but I changed it to custom for clarity.

Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields.

To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`.

This ended up needing a bit of fanagling:

- If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property.

While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer.

(Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.)

- Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future.

- We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`.

Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`.

This causes a TS error. This behaviour is somewhat controversial (see microsoft/TypeScript#14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.
  • Loading branch information
psychedelicious committed Nov 17, 2023
1 parent 7b93b5e commit 98a0ce0
Show file tree
Hide file tree
Showing 17 changed files with 98 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ const AddNodePopover = () => {

return some(handles, (handle) => {
const sourceType =
handleFilter == 'source' ? fieldFilter : handle.type;
handleFilter == 'source'
? fieldFilter
: handle.originalType ?? handle.type;
const targetType =
handleFilter == 'target' ? fieldFilter : handle.type;
handleFilter == 'target'
? fieldFilter
: handle.originalType ?? handle.type;

return validateSourceAndTargetTypes(sourceType, targetType);
});
Expand Down Expand Up @@ -111,7 +115,7 @@ const AddNodePopover = () => {

data.sort((a, b) => a.label.localeCompare(b.label));

return { data, t };
return { data };
},
defaultSelectorOptions
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { memo } from 'react';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { getFieldColor } from '../edges/util/getEdgeColor';

const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
nodes;

const stroke =
currentConnectionFieldType && shouldColorEdges
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
: colorTokenToCssVar('base.500');
const stroke = shouldColorEdges
? getFieldColor(currentConnectionFieldType)
: colorTokenToCssVar('base.500');

let className = 'react-flow__custom_connection-path';

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';

export const getFieldColor = (fieldType: FieldType | string | null): string => {
if (!fieldType) {
return colorTokenToCssVar('base.500');
}
const color = FIELDS[fieldType]?.color;

return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { isInvocationNode } from 'features/nodes/types/types';
import { getFieldColor } from './getEdgeColor';

export const makeEdgeSelector = (
source: string,
Expand All @@ -29,7 +29,7 @@ export const makeEdgeSelector = (

const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
? getFieldColor(sourceType)
: colorTokenToCssVar('base.500');

return {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import {
COLLECTION_TYPES,
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
Expand All @@ -13,6 +11,7 @@ import {
} from 'features/nodes/types/types';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, Position } from 'reactflow';
import { getFieldColor } from '../../../edges/util/getEdgeColor';

export const handleBaseStyles: CSSProperties = {
position: 'absolute',
Expand Down Expand Up @@ -47,14 +46,14 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionStartField,
connectionError,
} = props;
const { name, type, originalType } = fieldTemplate;
const { color: typeColor } = FIELDS[type];
const { name } = fieldTemplate;
const type = fieldTemplate.originalType ?? fieldTemplate.type;

const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const isCollectionType = COLLECTION_TYPES.some((t) => t === type);
const isPolymorphicType = POLYMORPHIC_TYPES.some((t) => t === type);
const isModelType = MODEL_TYPES.some((t) => t === type);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor:
isCollectionType || isPolymorphicType
Expand Down Expand Up @@ -97,23 +96,14 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);

const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return originalType;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? originalType;
return connectionError;
}
return originalType;
}, [
connectionError,
isConnectionInProgress,
isConnectionStartField,
originalType,
]);
return type;
}, [connectionError, isConnectionInProgress, type]);

return (
<Tooltip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ export const useFieldType = (
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
const field = node.data[KIND_MAP[kind]][fieldName];
return field?.originalType ?? field?.type;
},
defaultSelectorOptions
),
Expand Down
3 changes: 2 additions & 1 deletion invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ const nodesSlice = createSlice({
handleType === 'source'
? node.data.outputs[handleId]
: node.data.inputs[handleId];
state.currentConnectionFieldType = field?.type ?? null;
state.currentConnectionFieldType =
field?.originalType ?? field?.type ?? null;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
const fieldType = state.currentConnectionFieldType;
Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/src/features/nodes/store/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export type NodesState = {
edges: Edge<InvocationEdgeExtra>[];
nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
currentConnectionFieldType: FieldType | null;
currentConnectionFieldType: FieldType | string | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowFieldTypeLegend: boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export const buildNodeData = (
name: outputName,
type: outputTemplate.type,
fieldKind: 'output',
originalType: outputTemplate.originalType,
};

outputsAccumulator[outputName] = outputFieldValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { getIsGraphAcyclic } from './getIsGraphAcyclic';
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
handleCurrentFieldType: FieldType | string,
node: Node,
handle: InputFieldValue | OutputFieldValue
) => {
Expand All @@ -35,7 +35,12 @@ const isValidConnection = (
}
}

if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
if (
!validateSourceAndTargetTypes(
handleCurrentFieldType,
handle.originalType ?? handle.type
)
) {
isValidConnection = false;
}

Expand All @@ -49,7 +54,7 @@ export const findConnectionToValidHandle = (
handleCurrentNodeId: string,
handleCurrentName: string,
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType
handleCurrentFieldType: FieldType | string
): Connection | null => {
if (node.id === handleCurrentNodeId) {
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { FieldType } from 'features/nodes/types/types';
import i18n from 'i18next';
import { HandleType } from 'reactflow';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';

/**
Expand All @@ -15,7 +15,7 @@ export const makeConnectionErrorSelector = (
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType?: FieldType
fieldType?: FieldType | string
) => {
return createSelector(stateSelector, (state) => {
if (!fieldType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import {
import { FieldType } from 'features/nodes/types/types';

export const validateSourceAndTargetTypes = (
sourceType: FieldType,
targetType: FieldType
sourceType: FieldType | string,
targetType: FieldType | string
) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
Expand All @@ -31,17 +31,18 @@ export const validateSourceAndTargetTypes = (
*/

const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType);
sourceType === 'CollectionItem' &&
!COLLECTION_TYPES.some((t) => t === targetType);

const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
!COLLECTION_TYPES.some((t) => t === sourceType) &&
!POLYMORPHIC_TYPES.some((t) => t === sourceType);

const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
POLYMORPHIC_TYPES.some((t) => t === targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
if (!POLYMORPHIC_TYPES.some((t) => t === targetType)) {
return false;
}
const baseType =
Expand All @@ -57,11 +58,12 @@ export const validateSourceAndTargetTypes = (

const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
(COLLECTION_TYPES.some((t) => t === targetType) ||
POLYMORPHIC_TYPES.some((t) => t === targetType));

const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
targetType === 'Collection' &&
COLLECTION_TYPES.some((t) => t === sourceType);

const isIntToFloat = sourceType === 'integer' && targetType === 'float';

Expand Down
8 changes: 4 additions & 4 deletions invokeai/frontend/web/src/features/nodes/types/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,16 @@ export const isPolymorphicItemType = (
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);

export const FIELDS: Record<FieldType, FieldUIConfig> = {
export const FIELDS: Record<FieldType | string, FieldUIConfig> = {
Any: {
color: 'gray.500',
description: 'Any field type is accepted.',
title: 'Any',
},
Unknown: {
Custom: {
color: 'gray.500',
description: 'Unknown field type is accepted.',
title: 'Unknown',
description: 'A custom field, provided by an external node.',
title: 'Custom',
},
MetadataField: {
color: 'gray.500',
Expand Down
19 changes: 10 additions & 9 deletions invokeai/frontend/web/src/features/nodes/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ export const zFieldType = z.enum([
'UNetField',
'VaeField',
'VaeModelField',
'Unknown',
'Custom',
]);

export type FieldType = z.infer<typeof zFieldType>;
Expand Down Expand Up @@ -164,6 +164,7 @@ export const zFieldValueBase = z.object({
id: z.string().trim().min(1),
name: z.string().trim().min(1),
type: zFieldType,
originalType: z.string().optional(),
});
export type FieldValueBase = z.infer<typeof zFieldValueBase>;

Expand Down Expand Up @@ -191,7 +192,7 @@ export type OutputFieldTemplate = {
type: FieldType;
title: string;
description: string;
originalType: string; // used for custom types
originalType?: string; // used for custom types
} & _OutputField;

export const zInputFieldValueBase = zFieldValueBase.extend({
Expand Down Expand Up @@ -791,8 +792,8 @@ export const zAnyInputFieldValue = zInputFieldValueBase.extend({
value: z.any().optional(),
});

export const zUnknownInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Unknown'),
export const zCustomInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Custom'),
value: z.any().optional(),
});

Expand Down Expand Up @@ -853,7 +854,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zMetadataItemPolymorphicInputFieldValue,
zMetadataInputFieldValue,
zMetadataCollectionInputFieldValue,
zUnknownInputFieldValue,
zCustomInputFieldValue,
]);

export type InputFieldValue = z.infer<typeof zInputFieldValue>;
Expand All @@ -864,16 +865,16 @@ export type InputFieldTemplateBase = {
description: string;
required: boolean;
fieldKind: 'input';
originalType: string; // used for custom types
originalType?: string; // used for custom types
} & _InputField;

export type AnyInputFieldTemplate = InputFieldTemplateBase & {
type: 'Any';
default: undefined;
};

export type UnknownInputFieldTemplate = InputFieldTemplateBase & {
type: 'Unknown';
export type CustomInputFieldTemplate = InputFieldTemplateBase & {
type: 'Custom';
default: undefined;
};

Expand Down Expand Up @@ -1274,7 +1275,7 @@ export type InputFieldTemplate =
| MetadataInputFieldTemplate
| MetadataItemPolymorphicInputFieldTemplate
| MetadataCollectionInputFieldTemplate
| UnknownInputFieldTemplate;
| CustomInputFieldTemplate;

export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue
Expand Down
Loading

0 comments on commit 98a0ce0

Please sign in to comment.