Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed iterator helper node type persistence #1222

Merged
merged 3 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions src/renderer/components/node/NodeOutputs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { GenericOutput } from '../outputs/GenericOutput';
import { LargeImageOutput } from '../outputs/LargeImageOutput';
import { NcnnModelOutput } from '../outputs/NcnnModelOutput';
import { OutputContainer } from '../outputs/OutputContainer';
import { OutputProps } from '../outputs/props';
import { OutputProps, UseOutputData } from '../outputs/props';
import { PyTorchOutput } from '../outputs/PyTorchOutput';

interface FullOutputProps extends Omit<Output, 'id' | 'type'>, OutputProps {
Expand Down Expand Up @@ -57,7 +57,7 @@ const pickOutput = (kind: OutputKind, props: FullOutputProps) => {
);
};

const NO_OUTPUT_DATA = [undefined, undefined] as const;
const NO_OUTPUT_DATA: UseOutputData<never> = { current: undefined, last: undefined, stale: false };

interface NodeOutputProps {
outputs: readonly Output[];
Expand All @@ -71,23 +71,21 @@ export const NodeOutputs = memo(({ outputs, id, schemaId, animated = false }: No
const outputDataEntry = useContextSelector(GlobalVolatileContext, (c) =>
c.outputDataMap.get(id)
);
const inputHash = useContextSelector(GlobalVolatileContext, (c) => c.inputHashes.get(id));

const useOutputData = useCallback(
// eslint-disable-next-line prefer-arrow-functions/prefer-arrow-functions, func-names
function <T>(
outputId: OutputId
):
| readonly [value: T, inputHash: string]
| readonly [value: undefined, inputHash: undefined] {
function <T>(outputId: OutputId): UseOutputData<T> {
if (outputDataEntry) {
const value = outputDataEntry.data?.[outputId] as T | undefined;
if (value !== undefined) {
return [value, outputDataEntry.inputHash];
const last = outputDataEntry.data?.[outputId] as T | undefined;
if (last !== undefined) {
const stale = inputHash !== outputDataEntry.inputHash;
return { current: stale ? undefined : last, last, stale };
}
}
return NO_OUTPUT_DATA;
},
[outputDataEntry]
[outputDataEntry, inputHash]
);

const functions = functionDefinitions.get(schemaId)!.outputDefaults;
Expand Down
14 changes: 6 additions & 8 deletions src/renderer/components/outputs/DefaultImageOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,22 @@ export const DefaultImageOutput = memo(
c.schemata.get(schemaId).outputs.findIndex((o) => o.id === outputId)
);

const inputHash = useContextSelector(GlobalVolatileContext, (c) => c.inputHashes.get(id));
const [value, valueInputHash] = useOutputData<ImageBroadcastData>(outputId);
const sameHash = valueInputHash === inputHash;
const { current } = useOutputData<ImageBroadcastData>(outputId);
useEffect(() => {
if (value && sameHash) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('Image', [
new NamedExpressionField('width', literal(value.width)),
new NamedExpressionField('height', literal(value.height)),
new NamedExpressionField('channels', literal(value.channels)),
new NamedExpressionField('width', literal(current.width)),
new NamedExpressionField('height', literal(current.height)),
new NamedExpressionField('channels', literal(current.channels)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}, [id, outputId, value, sameHash, setManualOutputType]);
}, [id, outputId, current, setManualOutputType]);

const { getNodes, getEdges } = useReactFlow<NodeData, EdgeData>();

Expand Down
11 changes: 5 additions & 6 deletions src/renderer/components/outputs/GenericOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,26 @@ export const GenericOutput = memo(

const schema = schemata.get(schemaId);

const [value] = useOutputData(outputId);

const { current } = useOutputData(outputId);
useEffect(() => {
if (isStartingNode(schema)) {
if (value !== undefined) {
if (current !== undefined) {
if (kind === 'text') {
setManualOutputType(id, outputId, literal(value as string));
setManualOutputType(id, outputId, literal(current as string));
} else if (kind === 'directory') {
setManualOutputType(
id,
outputId,
new NamedExpression('Directory', [
new NamedExpressionField('path', literal(value as string)),
new NamedExpressionField('path', literal(current as string)),
])
);
}
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, kind, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, kind, outputId, schema, setManualOutputType]);

return (
<Flex
Expand Down
20 changes: 9 additions & 11 deletions src/renderer/components/outputs/LargeImageOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,27 @@ export const LargeImageOutput = memo(
const { schemata } = useContext(BackendContext);
const schema = schemata.get(schemaId);

const inputHash = useContextSelector(GlobalVolatileContext, (c) => c.inputHashes.get(id));
const zoom = useContextSelector(GlobalVolatileContext, (c) => c.zoom);

const [value, valueInputHash] = useOutputData<LargeImageBroadcastData>(outputId);
const stale = value !== undefined && valueInputHash !== inputHash;
const { current, last, stale } = useOutputData<LargeImageBroadcastData>(outputId);

useEffect(() => {
if (isStartingNode(schema)) {
if (value) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('Image', [
new NamedExpressionField('width', literal(value.width)),
new NamedExpressionField('height', literal(value.height)),
new NamedExpressionField('channels', literal(value.channels)),
new NamedExpressionField('width', literal(current.width)),
new NamedExpressionField('height', literal(current.height)),
new NamedExpressionField('channels', literal(current.channels)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, outputId, schema, setManualOutputType]);

const imgBgColor = 'var(--node-image-preview-bg)';
const fontColor = 'var(--node-image-preview-color)';
Expand Down Expand Up @@ -107,22 +105,22 @@ export const LargeImageOutput = memo(
overflow="hidden"
w="200px"
>
{value ? (
{last ? (
<Center
maxH="200px"
maxW="200px"
>
<Image
alt="Image preview failed to load, probably unsupported file type."
backgroundImage={
value.channels === 4
last.channels === 4
? ''
: ''
}
draggable={false}
maxH="200px"
maxW="200px"
src={value.image}
src={last.image}
sx={{
imageRendering: zoom > 2 ? 'pixelated' : 'auto',
}}
Expand Down
26 changes: 13 additions & 13 deletions src/renderer/components/outputs/NcnnModelOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const getColorMode = (channels: number) => {

export const NcnnModelOutput = memo(
({ id, outputId, useOutputData, animated, schemaId }: OutputProps) => {
const [value] = useOutputData<NcnnModelData>(outputId);
const { current } = useOutputData<NcnnModelData>(outputId);

const { setManualOutputType } = useContext(GlobalContext);
const { schemata } = useContext(BackendContext);
Expand All @@ -41,23 +41,23 @@ export const NcnnModelOutput = memo(

useEffect(() => {
if (isStartingNode(schema)) {
if (value) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('NcnnNetwork', [
new NamedExpressionField('scale', literal(value.scale)),
new NamedExpressionField('inputChannels', literal(value.inNc)),
new NamedExpressionField('outputChannels', literal(value.outNc)),
new NamedExpressionField('nf', literal(value.nf)),
new NamedExpressionField('fp', literal(value.fp)),
new NamedExpressionField('scale', literal(current.scale)),
new NamedExpressionField('inputChannels', literal(current.inNc)),
new NamedExpressionField('outputChannels', literal(current.outNc)),
new NamedExpressionField('nf', literal(current.nf)),
new NamedExpressionField('fp', literal(current.fp)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, outputId, schema, setManualOutputType]);

const tagColor = 'var(--tag-bg)';
const fontColor = 'var(--tag-fg)';
Expand All @@ -70,7 +70,7 @@ export const NcnnModelOutput = memo(
verticalAlign="middle"
w="full"
>
{value && !animated ? (
{current && !animated ? (
<Center mt={1}>
<Wrap
justify="center"
Expand All @@ -82,31 +82,31 @@ export const NcnnModelOutput = memo(
bgColor={tagColor}
textColor={fontColor}
>
{value.scale}x
{current.scale}x
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{getColorMode(value.inNc)}→{getColorMode(value.outNc)}
{getColorMode(current.inNc)}→{getColorMode(current.outNc)}
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.nf}nf
{current.nf}nf
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.fp}
{current.fp}
</Tag>
</WrapItem>
</Wrap>
Expand Down
30 changes: 15 additions & 15 deletions src/renderer/components/outputs/PyTorchOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const getColorMode = (channels: number) => {

export const PyTorchOutput = memo(
({ id, outputId, useOutputData, animated, schemaId }: OutputProps) => {
const [value] = useOutputData<PyTorchModelData>(outputId);
const { current } = useOutputData<PyTorchModelData>(outputId);

const { setManualOutputType } = useContext(GlobalContext);
const { schemata } = useContext(BackendContext);
Expand All @@ -42,24 +42,24 @@ export const PyTorchOutput = memo(

useEffect(() => {
if (isStartingNode(schema)) {
if (value) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('PyTorchModel', [
new NamedExpressionField('scale', literal(value.scale)),
new NamedExpressionField('inputChannels', literal(value.inNc)),
new NamedExpressionField('outputChannels', literal(value.outNc)),
new NamedExpressionField('arch', literal(value.arch)),
new NamedExpressionField('size', literal(value.size.join('x'))),
new NamedExpressionField('subType', literal(value.subType)),
new NamedExpressionField('scale', literal(current.scale)),
new NamedExpressionField('inputChannels', literal(current.inNc)),
new NamedExpressionField('outputChannels', literal(current.outNc)),
new NamedExpressionField('arch', literal(current.arch)),
new NamedExpressionField('size', literal(current.size.join('x'))),
new NamedExpressionField('subType', literal(current.subType)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, outputId, schema, setManualOutputType]);

const tagColor = 'var(--tag-bg)';
const fontColor = 'var(--tag-fg)';
Expand All @@ -72,7 +72,7 @@ export const PyTorchOutput = memo(
verticalAlign="middle"
w="full"
>
{value && !animated ? (
{current && !animated ? (
<Center mt={1}>
<Wrap
justify="center"
Expand All @@ -84,34 +84,34 @@ export const PyTorchOutput = memo(
bgColor={tagColor}
textColor={fontColor}
>
{value.arch}
{current.arch}
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.subType}
{current.subType}
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.scale}x
{current.scale}x
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{getColorMode(value.inNc)}→{getColorMode(value.outNc)}
{getColorMode(current.inNc)}→{getColorMode(current.outNc)}
</Tag>
</WrapItem>
{value.size.map((size) => (
{current.size.map((size) => (
<WrapItem key={size}>
<Tag
bgColor={tagColor}
Expand Down
13 changes: 10 additions & 3 deletions src/renderer/components/outputs/props.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import { Type } from '@chainner/navi';
import { OutputId, OutputKind, SchemaId } from '../../../common/common-types';

export interface UseOutputData<T> {
/** The current output data. Current here means most recent + up to date (= same input hash). */
readonly current: T | undefined;
/** The most recent output data. */
readonly last: T | undefined;
/** Whether the most recent output data ({@link last}) is not the current output data ({@link current}). */
readonly stale: boolean;
}

export interface OutputProps {
readonly id: string;
readonly outputId: OutputId;
readonly label: string;
readonly schemaId: SchemaId;
readonly definitionType: Type;
readonly hasHandle: boolean;
readonly useOutputData: <T>(
outputId: OutputId
) => readonly [value: T, inputHash: string] | readonly [value: undefined, inputHash: undefined];
readonly useOutputData: <T>(outputId: OutputId) => UseOutputData<T>;
readonly animated: boolean;
readonly kind: OutputKind;
}
Loading