Skip to content

Commit

Permalink
Allow any valid node with no incoming connections and with side effec…
Browse files Browse the repository at this point in the history
…ts to run automatically (#2944)

* Allow any valid node with no incoming connections to be a "starting node"

* lint

* lint

* Use side effects to determine whether we should run the node or not

* Rename hook

* Update .vscode/settings.json
  • Loading branch information
joeyballentine authored Jun 17, 2024
1 parent 410e586 commit 5035430
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 26 deletions.
1 change: 1 addition & 0 deletions backend/src/packages/chaiNNer_ncnn/ncnn/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
see_also=[
"chainner:ncnn:load_models",
],
side_effects=True,
)
def load_model_node(
param_path: Path, bin_path: Path
Expand Down
1 change: 1 addition & 0 deletions backend/src/packages/chaiNNer_onnx/onnx/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
see_also=[
"chainner:onnx:load_models",
],
side_effects=True,
)
def load_model_node(path: Path) -> tuple[OnnxModel, Path, str]:
assert os.path.exists(path), f"Model file at location {path} does not exist"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def parse_ckpt_state_dict(checkpoint: dict):
see_also=[
"chainner:pytorch:load_models",
],
side_effects=True,
)
def load_model_node(
context: NodeContext, path: Path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _for_ext(ext: str | Iterable[str], decoder: _Decoder) -> _Decoder:
DirectoryOutput("Directory", of_input=0),
FileNameOutput("Name", of_input=0),
],
side_effects=True,
)
def load_image_node(path: Path) -> tuple[np.ndarray, Path, str]:
logger.debug(f"Reading image from path: {path}")
Expand Down
4 changes: 0 additions & 4 deletions src/common/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ export const topologicalSort = <T>(
return result.reverse();
};

export const isStartingNode = (schema: NodeSchema) => {
return !schema.inputs.some((i) => i.hasHandle) && schema.outputs.length > 0;
};

export const isEndingNode = (schema: NodeSchema) => {
return !schema.outputs.some((i) => i.hasHandle) && schema.inputs.length > 0;
};
Expand Down
22 changes: 6 additions & 16 deletions src/renderer/components/node/Node.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ import { useReactFlow } from 'reactflow';
import { useContext, useContextSelector } from 'use-context-selector';
import { Input, NodeData } from '../../../common/common-types';
import { DisabledStatus } from '../../../common/nodes/disabled';
import {
EMPTY_ARRAY,
getInputValue,
isStartingNode,
parseSourceHandle,
} from '../../../common/util';
import { EMPTY_ARRAY, getInputValue, parseSourceHandle } from '../../../common/util';
import { Validity } from '../../../common/Validity';
import { AlertBoxContext } from '../../contexts/AlertBoxContext';
import { BackendContext } from '../../contexts/BackendContext';
Expand All @@ -30,6 +25,7 @@ import {
} from '../../contexts/ExecutionContext';
import { GlobalContext, GlobalVolatileContext } from '../../contexts/GlobalNodeState';
import { getCategoryAccentColor, getTypeAccentColors } from '../../helpers/accentColors';

import { getSingleFileWithExtension } from '../../helpers/dataTransfer';
import { NodeState, useNodeStateFromData } from '../../helpers/nodeState';
import { NO_DISABLED, UseDisabled, useDisabled } from '../../hooks/useDisabled';
Expand Down Expand Up @@ -256,15 +252,9 @@ const NodeInner = memo(({ data, selected }: NodeProps) => {
}
};

const startingNode = isStartingNode(schema);
const isNewIterator = schema.kind === 'generator';
const hasStaticValueInput = schema.inputs.some((i) => i.kind === 'static');
const reload = useRunNode(
data,
validity.isValid && startingNode && !isNewIterator && !hasStaticValueInput
);
const { reload, isLive } = useRunNode(data, validity.isValid);
const filesToWatch = useMemo(() => {
if (!startingNode) return EMPTY_ARRAY;
if (!isLive) return EMPTY_ARRAY;

const files: string[] = [];
for (const input of schema.inputs) {
Expand All @@ -278,15 +268,15 @@ const NodeInner = memo(({ data, selected }: NodeProps) => {

if (files.length === 0) return EMPTY_ARRAY;
return files;
}, [startingNode, data.inputData, schema]);
}, [isLive, data.inputData, schema]);
useWatchFiles(filesToWatch, reload);

const disabled = useDisabled(data);
const passthrough = usePassthrough(data);
const menu = useNodeMenu(data, {
disabled,
passthrough,
reload: startingNode ? reload : undefined,
reload: isLive ? reload : undefined,
});

const toggleCollapse = useCallback(() => {
Expand Down
8 changes: 5 additions & 3 deletions src/renderer/components/node/NodeOutputs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import { OutputId, OutputKind, Size } from '../../../common/common-types';
import { log } from '../../../common/log';
import { getChainnerScope } from '../../../common/types/chainner-scope';
import { ExpressionJson, fromJson } from '../../../common/types/json';
import { isStartingNode } from '../../../common/util';
import { BackendContext } from '../../contexts/BackendContext';
import { GlobalContext, GlobalVolatileContext } from '../../contexts/GlobalNodeState';
import { NodeState } from '../../helpers/nodeState';
import { useAutomaticFeatures } from '../../hooks/useAutomaticFeatures';
import { useIsCollapsedNode } from '../../hooks/useIsCollapsedNode';
import { GenericOutput } from '../outputs/GenericOutput';
import { LargeImageOutput } from '../outputs/LargeImageOutput';
Expand Down Expand Up @@ -81,14 +81,16 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => {

const currentTypes = stale ? undefined : outputDataEntry?.types;

const { isAutomatic } = useAutomaticFeatures(id, schemaId);

useEffect(() => {
if (isStartingNode(schema)) {
if (isAutomatic) {
for (const output of schema.outputs) {
const type = evalExpression(currentTypes?.[output.id]);
setManualOutputType(id, output.id, type);
}
}
}, [id, currentTypes, schema, setManualOutputType]);
}, [id, currentTypes, schema, setManualOutputType, isAutomatic]);

const isCollapsed = useIsCollapsedNode();
if (isCollapsed) {
Expand Down
32 changes: 32 additions & 0 deletions src/renderer/hooks/useAutomaticFeatures.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { getIncomers, useReactFlow } from 'reactflow';
import { useContext } from 'use-context-selector';
import { EdgeData, NodeData, SchemaId } from '../../common/common-types';
import { BackendContext } from '../contexts/BackendContext';

/**
* Determines whether a node should use automatic ahead-of-time features, such as individually running the node or determining certain type features automatically.
*/
export const useAutomaticFeatures = (id: string, schemaId: SchemaId) => {
const { schemata } = useContext(BackendContext);
const schema = schemata.get(schemaId);

const { getEdges, getNodes, getNode } = useReactFlow<NodeData, EdgeData>();
const thisNode = getNode(id);

// A node should not use automatic features if it has incoming connections
const hasIncomingConnections =
thisNode && getIncomers(thisNode, getNodes(), getEdges()).length > 0;

// If the node is a generator, it should not use automatic features
const isGenerator = schema.kind === 'generator';
// Same if it has any static input values
const hasStaticValueInput = schema.inputs.some((i) => i.kind === 'static');
// We should only use automatic features if the node has side effects
const { hasSideEffects } = schema;

return {
isAutomatic:
hasSideEffects && !hasIncomingConnections && !isGenerator && !hasStaticValueInput,
hasIncomingConnections,
};
};
17 changes: 14 additions & 3 deletions src/renderer/hooks/useRunNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { AlertBoxContext } from '../contexts/AlertBoxContext';
import { BackendContext } from '../contexts/BackendContext';
import { GlobalContext } from '../contexts/GlobalNodeState';
import { useAsyncEffect } from './useAsyncEffect';
import { useAutomaticFeatures } from './useAutomaticFeatures';
import { useSettings } from './useSettings';

/**
Expand All @@ -16,8 +17,8 @@ import { useSettings } from './useSettings';
*/
export const useRunNode = (
{ inputData, id, schemaId }: NodeData,
shouldRun: boolean
): (() => void) => {
isValid: boolean
): { reload: () => void; isLive: boolean } => {
const { sendToast } = useContext(AlertBoxContext);
const { addIndividuallyRunning, removeIndividuallyRunning } = useContext(GlobalContext);
const { schemata, backend } = useContext(BackendContext);
Expand All @@ -39,6 +40,10 @@ export const useRunNode = (
[reloadCounter, inputs]
);
const lastInputHash = useRef<string>();

const { isAutomatic, hasIncomingConnections } = useAutomaticFeatures(id, schemaId);
const shouldRun = isValid && isAutomatic;

useAsyncEffect(
() => async (token) => {
if (inputHash === lastInputHash.current) {
Expand Down Expand Up @@ -85,5 +90,11 @@ export const useRunNode = (
};
}, [backend, id]);

return reload;
useEffect(() => {
if (hasIncomingConnections && didEverRun.current) {
backend.clearNodeCacheIndividual(id).catch(log.error);
}
}, [backend, hasIncomingConnections, id]);

return { reload, isLive: shouldRun };
};

0 comments on commit 5035430

Please sign in to comment.