Skip to content

Commit

Permalink
Fixed iterator helper node type persistence (#1222)
Browse files Browse the repository at this point in the history
* WIP

* new useOutputData
  • Loading branch information
RunDevelopment authored Nov 12, 2022
1 parent 6e31218 commit 8640760
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 70 deletions.
20 changes: 9 additions & 11 deletions src/renderer/components/node/NodeOutputs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { GenericOutput } from '../outputs/GenericOutput';
import { LargeImageOutput } from '../outputs/LargeImageOutput';
import { NcnnModelOutput } from '../outputs/NcnnModelOutput';
import { OutputContainer } from '../outputs/OutputContainer';
import { OutputProps } from '../outputs/props';
import { OutputProps, UseOutputData } from '../outputs/props';
import { PyTorchOutput } from '../outputs/PyTorchOutput';

interface FullOutputProps extends Omit<Output, 'id' | 'type'>, OutputProps {
Expand Down Expand Up @@ -57,7 +57,7 @@ const pickOutput = (kind: OutputKind, props: FullOutputProps) => {
);
};

const NO_OUTPUT_DATA = [undefined, undefined] as const;
const NO_OUTPUT_DATA: UseOutputData<never> = { current: undefined, last: undefined, stale: false };

interface NodeOutputProps {
outputs: readonly Output[];
Expand All @@ -71,23 +71,21 @@ export const NodeOutputs = memo(({ outputs, id, schemaId, animated = false }: No
const outputDataEntry = useContextSelector(GlobalVolatileContext, (c) =>
c.outputDataMap.get(id)
);
const inputHash = useContextSelector(GlobalVolatileContext, (c) => c.inputHashes.get(id));

const useOutputData = useCallback(
// eslint-disable-next-line prefer-arrow-functions/prefer-arrow-functions, func-names
function <T>(
outputId: OutputId
):
| readonly [value: T, inputHash: string]
| readonly [value: undefined, inputHash: undefined] {
function <T>(outputId: OutputId): UseOutputData<T> {
if (outputDataEntry) {
const value = outputDataEntry.data?.[outputId] as T | undefined;
if (value !== undefined) {
return [value, outputDataEntry.inputHash];
const last = outputDataEntry.data?.[outputId] as T | undefined;
if (last !== undefined) {
const stale = inputHash !== outputDataEntry.inputHash;
return { current: stale ? undefined : last, last, stale };
}
}
return NO_OUTPUT_DATA;
},
[outputDataEntry]
[outputDataEntry, inputHash]
);

const functions = functionDefinitions.get(schemaId)!.outputDefaults;
Expand Down
14 changes: 6 additions & 8 deletions src/renderer/components/outputs/DefaultImageOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,22 @@ export const DefaultImageOutput = memo(
c.schemata.get(schemaId).outputs.findIndex((o) => o.id === outputId)
);

const inputHash = useContextSelector(GlobalVolatileContext, (c) => c.inputHashes.get(id));
const [value, valueInputHash] = useOutputData<ImageBroadcastData>(outputId);
const sameHash = valueInputHash === inputHash;
const { current } = useOutputData<ImageBroadcastData>(outputId);
useEffect(() => {
if (value && sameHash) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('Image', [
new NamedExpressionField('width', literal(value.width)),
new NamedExpressionField('height', literal(value.height)),
new NamedExpressionField('channels', literal(value.channels)),
new NamedExpressionField('width', literal(current.width)),
new NamedExpressionField('height', literal(current.height)),
new NamedExpressionField('channels', literal(current.channels)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}, [id, outputId, value, sameHash, setManualOutputType]);
}, [id, outputId, current, setManualOutputType]);

const { getNodes, getEdges } = useReactFlow<NodeData, EdgeData>();

Expand Down
11 changes: 5 additions & 6 deletions src/renderer/components/outputs/GenericOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,26 @@ export const GenericOutput = memo(

const schema = schemata.get(schemaId);

const [value] = useOutputData(outputId);

const { current } = useOutputData(outputId);
useEffect(() => {
if (isStartingNode(schema)) {
if (value !== undefined) {
if (current !== undefined) {
if (kind === 'text') {
setManualOutputType(id, outputId, literal(value as string));
setManualOutputType(id, outputId, literal(current as string));
} else if (kind === 'directory') {
setManualOutputType(
id,
outputId,
new NamedExpression('Directory', [
new NamedExpressionField('path', literal(value as string)),
new NamedExpressionField('path', literal(current as string)),
])
);
}
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, kind, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, kind, outputId, schema, setManualOutputType]);

return (
<Flex
Expand Down
20 changes: 9 additions & 11 deletions src/renderer/components/outputs/LargeImageOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,27 @@ export const LargeImageOutput = memo(
const { schemata } = useContext(BackendContext);
const schema = schemata.get(schemaId);

const inputHash = useContextSelector(GlobalVolatileContext, (c) => c.inputHashes.get(id));
const zoom = useContextSelector(GlobalVolatileContext, (c) => c.zoom);

const [value, valueInputHash] = useOutputData<LargeImageBroadcastData>(outputId);
const stale = value !== undefined && valueInputHash !== inputHash;
const { current, last, stale } = useOutputData<LargeImageBroadcastData>(outputId);

useEffect(() => {
if (isStartingNode(schema)) {
if (value) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('Image', [
new NamedExpressionField('width', literal(value.width)),
new NamedExpressionField('height', literal(value.height)),
new NamedExpressionField('channels', literal(value.channels)),
new NamedExpressionField('width', literal(current.width)),
new NamedExpressionField('height', literal(current.height)),
new NamedExpressionField('channels', literal(current.channels)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, outputId, schema, setManualOutputType]);

const imgBgColor = 'var(--node-image-preview-bg)';
const fontColor = 'var(--node-image-preview-color)';
Expand Down Expand Up @@ -107,22 +105,22 @@ export const LargeImageOutput = memo(
overflow="hidden"
w="200px"
>
{value ? (
{last ? (
<Center
maxH="200px"
maxW="200px"
>
<Image
alt="Image preview failed to load, probably unsupported file type."
backgroundImage={
value.channels === 4
last.channels === 4
? 'data:image/webp;base64,UklGRigAAABXRUJQVlA4IBwAAAAwAQCdASoQABAACMCWJaQAA3AA/u11j//aQAAA'
: ''
}
draggable={false}
maxH="200px"
maxW="200px"
src={value.image}
src={last.image}
sx={{
imageRendering: zoom > 2 ? 'pixelated' : 'auto',
}}
Expand Down
26 changes: 13 additions & 13 deletions src/renderer/components/outputs/NcnnModelOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const getColorMode = (channels: number) => {

export const NcnnModelOutput = memo(
({ id, outputId, useOutputData, animated, schemaId }: OutputProps) => {
const [value] = useOutputData<NcnnModelData>(outputId);
const { current } = useOutputData<NcnnModelData>(outputId);

const { setManualOutputType } = useContext(GlobalContext);
const { schemata } = useContext(BackendContext);
Expand All @@ -41,23 +41,23 @@ export const NcnnModelOutput = memo(

useEffect(() => {
if (isStartingNode(schema)) {
if (value) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('NcnnNetwork', [
new NamedExpressionField('scale', literal(value.scale)),
new NamedExpressionField('inputChannels', literal(value.inNc)),
new NamedExpressionField('outputChannels', literal(value.outNc)),
new NamedExpressionField('nf', literal(value.nf)),
new NamedExpressionField('fp', literal(value.fp)),
new NamedExpressionField('scale', literal(current.scale)),
new NamedExpressionField('inputChannels', literal(current.inNc)),
new NamedExpressionField('outputChannels', literal(current.outNc)),
new NamedExpressionField('nf', literal(current.nf)),
new NamedExpressionField('fp', literal(current.fp)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, outputId, schema, setManualOutputType]);

const tagColor = 'var(--tag-bg)';
const fontColor = 'var(--tag-fg)';
Expand All @@ -70,7 +70,7 @@ export const NcnnModelOutput = memo(
verticalAlign="middle"
w="full"
>
{value && !animated ? (
{current && !animated ? (
<Center mt={1}>
<Wrap
justify="center"
Expand All @@ -82,31 +82,31 @@ export const NcnnModelOutput = memo(
bgColor={tagColor}
textColor={fontColor}
>
{value.scale}x
{current.scale}x
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{getColorMode(value.inNc)}{getColorMode(value.outNc)}
{getColorMode(current.inNc)}{getColorMode(current.outNc)}
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.nf}nf
{current.nf}nf
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.fp}
{current.fp}
</Tag>
</WrapItem>
</Wrap>
Expand Down
30 changes: 15 additions & 15 deletions src/renderer/components/outputs/PyTorchOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const getColorMode = (channels: number) => {

export const PyTorchOutput = memo(
({ id, outputId, useOutputData, animated, schemaId }: OutputProps) => {
const [value] = useOutputData<PyTorchModelData>(outputId);
const { current } = useOutputData<PyTorchModelData>(outputId);

const { setManualOutputType } = useContext(GlobalContext);
const { schemata } = useContext(BackendContext);
Expand All @@ -42,24 +42,24 @@ export const PyTorchOutput = memo(

useEffect(() => {
if (isStartingNode(schema)) {
if (value) {
if (current) {
setManualOutputType(
id,
outputId,
new NamedExpression('PyTorchModel', [
new NamedExpressionField('scale', literal(value.scale)),
new NamedExpressionField('inputChannels', literal(value.inNc)),
new NamedExpressionField('outputChannels', literal(value.outNc)),
new NamedExpressionField('arch', literal(value.arch)),
new NamedExpressionField('size', literal(value.size.join('x'))),
new NamedExpressionField('subType', literal(value.subType)),
new NamedExpressionField('scale', literal(current.scale)),
new NamedExpressionField('inputChannels', literal(current.inNc)),
new NamedExpressionField('outputChannels', literal(current.outNc)),
new NamedExpressionField('arch', literal(current.arch)),
new NamedExpressionField('size', literal(current.size.join('x'))),
new NamedExpressionField('subType', literal(current.subType)),
])
);
} else {
setManualOutputType(id, outputId, undefined);
}
}
}, [id, schemaId, value, outputId, schema, setManualOutputType]);
}, [id, schemaId, current, outputId, schema, setManualOutputType]);

const tagColor = 'var(--tag-bg)';
const fontColor = 'var(--tag-fg)';
Expand All @@ -72,7 +72,7 @@ export const PyTorchOutput = memo(
verticalAlign="middle"
w="full"
>
{value && !animated ? (
{current && !animated ? (
<Center mt={1}>
<Wrap
justify="center"
Expand All @@ -84,34 +84,34 @@ export const PyTorchOutput = memo(
bgColor={tagColor}
textColor={fontColor}
>
{value.arch}
{current.arch}
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.subType}
{current.subType}
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{value.scale}x
{current.scale}x
</Tag>
</WrapItem>
<WrapItem>
<Tag
bgColor={tagColor}
textColor={fontColor}
>
{getColorMode(value.inNc)}{getColorMode(value.outNc)}
{getColorMode(current.inNc)}{getColorMode(current.outNc)}
</Tag>
</WrapItem>
{value.size.map((size) => (
{current.size.map((size) => (
<WrapItem key={size}>
<Tag
bgColor={tagColor}
Expand Down
13 changes: 10 additions & 3 deletions src/renderer/components/outputs/props.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import { Type } from '@chainner/navi';
import { OutputId, OutputKind, SchemaId } from '../../../common/common-types';

export interface UseOutputData<T> {
/** The current output data. Current here means most recent + up to date (= same input hash). */
readonly current: T | undefined;
/** The most recent output data. */
readonly last: T | undefined;
/** Whether the most recent output data ({@link last}) is not the current output data ({@link current}). */
readonly stale: boolean;
}

export interface OutputProps {
readonly id: string;
readonly outputId: OutputId;
readonly label: string;
readonly schemaId: SchemaId;
readonly definitionType: Type;
readonly hasHandle: boolean;
readonly useOutputData: <T>(
outputId: OutputId
) => readonly [value: T, inputHash: string] | readonly [value: undefined, inputHash: undefined];
readonly useOutputData: <T>(outputId: OutputId) => UseOutputData<T>;
readonly animated: boolean;
readonly kind: OutputKind;
}
Loading

0 comments on commit 8640760

Please sign in to comment.