From 0653e2f5492f070fdd618cfa0b5e31cac59d3f28 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Wed, 18 Oct 2023 22:39:40 +0200 Subject: [PATCH 1/6] API mock --- backend/src/api.py | 46 ++++++++++++++++ backend/src/chain/chain.py | 22 +++----- .../ncnn/batch_processing/load_models.py | 9 ++-- .../onnx/batch_processing/load_models.py | 9 ++-- .../pytorch/iteration/load_models.py | 9 ++-- .../batch_processing/load_image_pairs.py | 19 ++----- .../image/batch_processing/load_images.py | 9 ++-- .../batch_processing/merge_spritesheet.py | 14 ++--- .../batch_processing/split_spritesheet.py | 3 +- .../image/video_frames/load_video.py | 15 ++++-- .../image/video_frames/save_video.py | 54 +++---------------- 11 files changed, 101 insertions(+), 108 deletions(-) diff --git a/backend/src/api.py b/backend/src/api.py index 98649c061..6065eab5f 100644 --- a/backend/src/api.py +++ b/backend/src/api.py @@ -19,6 +19,7 @@ from sanic.log import logger +import navi from base_types import InputId, OutputId from custom_types import NodeType, RunFn from node_check import ( @@ -80,6 +81,34 @@ class DefaultNode(TypedDict): schemaId: str +class IteratorInputInfo: + def __init__( + self, + inputs: int | InputId | List[int] | List[InputId] | List[int | InputId], + length_type: navi.ExpressionJson = "uint", + ) -> None: + self.inputs: List[InputId] = ( + [InputId(x) for x in inputs] + if isinstance(inputs, list) + else [InputId(inputs)] + ) + self.length_type: navi.ExpressionJson = length_type + + +class IteratorOutputInfo: + def __init__( + self, + outputs: int | OutputId | List[int] | List[OutputId] | List[int | OutputId], + length_type: navi.ExpressionJson = "uint", + ) -> None: + self.outputs: List[OutputId] = ( + [OutputId(x) for x in outputs] + if isinstance(outputs, list) + else [OutputId(outputs)] + ) + self.length_type: navi.ExpressionJson = length_type + + @dataclass(frozen=True) class NodeData: schema_id: str @@ -93,6 +122,9 @@ class NodeData: outputs: List[BaseOutput] group_layout: List[InputId | NestedIdGroup] + iterator_inputs: List[IteratorInputInfo] + iterator_outputs: List[IteratorOutputInfo] + side_effects: bool deprecated: bool features: List[FeatureId] @@ -129,6 +161,8 @@ def register( see_also: List[str] | str | None = None, features: List[FeatureId] | FeatureId | None = None, limited_to_8bpc: bool | str = False, + iterator_inputs: List[IteratorInputInfo] | IteratorInputInfo | None = None, + iterator_outputs: List[IteratorOutputInfo] | IteratorOutputInfo | None = None, ): if not isinstance(description, str): description = "\n\n".join(description) @@ -153,6 +187,16 @@ def to_list(x: List[S] | S | None) -> List[S]: see_also = to_list(see_also) features = to_list(features) + iterator_inputs = to_list(iterator_inputs) + iterator_outputs = to_list(iterator_outputs) + + if node_type == "collector": + assert len(iterator_inputs) == 1 and len(iterator_outputs) == 0 + elif node_type == "newIterator": + assert len(iterator_inputs) == 0 and len(iterator_outputs) == 1 + else: + assert len(iterator_inputs) == 0 and len(iterator_outputs) == 0 + def run_check(level: CheckLevel, run: Callable[[bool], None]): if level == CheckLevel.NONE: return @@ -193,6 +237,8 @@ def inner_wrapper(wrapped_func: T) -> T: inputs=p_inputs, group_layout=group_layout, outputs=p_outputs, + iterator_inputs=iterator_inputs, + iterator_outputs=iterator_outputs, side_effects=side_effects, deprecated=deprecated, features=features, diff --git a/backend/src/chain/chain.py b/backend/src/chain/chain.py index 89bbb1342..b0e56d619 100644 --- a/backend/src/chain/chain.py +++ b/backend/src/chain/chain.py @@ -31,16 +31,11 @@ class NewIteratorNode: def __init__(self, node_id: NodeId, schema_id: str): self.id: NodeId = node_id self.schema_id: str = schema_id - self.parent: None = None - self.__node = None - self.is_helper: bool = False def get_node(self) -> NodeData: - if self.__node is None: - node = registry.get_node(self.schema_id) - assert node.type == "newIterator", "Invalid iterator node" - self.__node = node - return self.__node + node = registry.get_node(self.schema_id) + assert node.type == "newIterator", "Invalid iterator node" + return node def has_side_effects(self) -> bool: return self.get_node().side_effects @@ -50,16 +45,11 @@ class CollectorNode: def __init__(self, node_id: NodeId, schema_id: str): self.id: NodeId = node_id self.schema_id: str = schema_id - self.parent: None = None - self.__node = None - self.is_helper: bool = False def get_node(self) -> NodeData: - if self.__node is None: - node = registry.get_node(self.schema_id) - assert node.type == "collector", "Invalid iterator node" - self.__node = node - return self.__node + node = registry.get_node(self.schema_id) + assert node.type == "collector", "Invalid collector node" + return node def has_side_effects(self) -> bool: return self.get_node().side_effects diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py index 03037c687..3cbf6bcdc 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py @@ -5,7 +5,7 @@ from sanic.log import logger -from api import Iterator +from api import Iterator, IteratorOutputInfo from nodes.impl.ncnn.model import NcnnModelWrapper from nodes.properties.inputs import DirectoryInput from nodes.properties.outputs import ( @@ -41,18 +41,19 @@ "A counter that starts at 0 and increments by 1 for each model." ), ], + iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]), node_type="newIterator", ) def load_models_node( directory: str, -) -> Iterator[Tuple[NcnnModelWrapper, str, str, str, int]]: +) -> Tuple[Iterator[Tuple[NcnnModelWrapper, str, str, int]], str]: logger.debug(f"Iterating over models in directory: {directory}") def load_model(filepath_pairs: Tuple[str, str], index: int): model, dirname, basename = load_model_node(filepath_pairs[0], filepath_pairs[1]) # Get relative path from root directory passed by Iterator directory input rel_path = os.path.relpath(dirname, directory) - return model, directory, rel_path, basename, index + return model, rel_path, basename, index param_files: List[str] = list_all_files_sorted(directory, [".param"]) bin_files: List[str] = list_all_files_sorted(directory, [".bin"]) @@ -76,4 +77,4 @@ def load_model(filepath_pairs: Tuple[str, str], index: int): model_files = list(zip(param_files, bin_files)) - return Iterator.from_list(model_files, load_model) + return Iterator.from_list(model_files, load_model), directory diff --git a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py index 68d9baa6c..09412da84 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py @@ -5,7 +5,7 @@ from sanic.log import logger -from api import Iterator +from api import Iterator, IteratorOutputInfo from nodes.impl.onnx.model import OnnxModel from nodes.properties.inputs import DirectoryInput from nodes.properties.outputs import ( @@ -41,20 +41,21 @@ "A counter that starts at 0 and increments by 1 for each model." ), ], + iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]), node_type="newIterator", ) def load_models_node( directory: str, -) -> Iterator[Tuple[OnnxModel, str, str, str, int]]: +) -> Tuple[Iterator[Tuple[OnnxModel, str, str, int]], str]: logger.debug(f"Iterating over models in directory: {directory}") def load_model(path: str, index: int): model, dirname, basename = load_model_node(path) # Get relative path from root directory passed by Iterator directory input rel_path = os.path.relpath(dirname, directory) - return model, directory, rel_path, basename, index + return model, rel_path, basename, index supported_filetypes = [".onnx"] model_files = list_all_files_sorted(directory, supported_filetypes) - return Iterator.from_list(model_files, load_model) + return Iterator.from_list(model_files, load_model), directory diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py index 0c0c0d480..a3f852af4 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py @@ -5,7 +5,7 @@ from sanic.log import logger -from api import Iterator +from api import Iterator, IteratorOutputInfo from nodes.impl.pytorch.types import PyTorchModel from nodes.properties.inputs import DirectoryInput from nodes.properties.outputs import DirectoryOutput, NumberOutput, TextOutput @@ -37,20 +37,21 @@ "A counter that starts at 0 and increments by 1 for each model." ), ], + iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]), node_type="newIterator", ) def load_models_node( directory: str, -) -> Iterator[Tuple[PyTorchModel, str, str, str, int]]: +) -> Tuple[Iterator[Tuple[PyTorchModel, str, str, int]], str]: logger.debug(f"Iterating over models in directory: {directory}") def load_model(path: str, index: int): model, dirname, basename = load_model_node(path) # Get relative path from root directory passed by Iterator directory input rel_path = os.path.relpath(dirname, directory) - return model, directory, rel_path, basename, index + return model, rel_path, basename, index supported_filetypes = [".pt", ".pth", ".ckpt"] model_files: List[str] = list_all_files_sorted(directory, supported_filetypes) - return Iterator.from_list(model_files, load_model) + return Iterator.from_list(model_files, load_model), directory diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_image_pairs.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_image_pairs.py index 8d58949ac..2ce6934d7 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_image_pairs.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_image_pairs.py @@ -5,7 +5,7 @@ import numpy as np -from api import Iterator +from api import Iterator, IteratorOutputInfo from nodes.groups import Condition, if_group from nodes.impl.image_formats import get_available_image_formats from nodes.properties.inputs import BoolInput, DirectoryInput, NumberInput @@ -49,6 +49,7 @@ "A counter that starts at 0 and increments by 1 for each image." ), ], + iterator_outputs=IteratorOutputInfo(outputs=[0, 1, 4, 5, 6, 7, 8]), node_type="newIterator", ) def load_image_pairs_node( @@ -56,7 +57,7 @@ def load_image_pairs_node( directory_b: str, use_limit: bool, limit: int, -) -> Iterator[Tuple[np.ndarray, np.ndarray, str, str, str, str, str, str, int]]: +) -> Tuple[Iterator[Tuple[np.ndarray, np.ndarray, str, str, str, str, int]], str, str]: def load_images(filepaths: Tuple[str, str], index: int): path_a, path_b = filepaths img_a, img_dir_a, basename_a = load_image_node(path_a) @@ -65,17 +66,7 @@ def load_images(filepaths: Tuple[str, str], index: int): # Get relative path from root directory passed by Iterator directory input rel_path_a = os.path.relpath(img_dir_a, directory_a) rel_path_b = os.path.relpath(img_dir_b, directory_b) - return ( - img_a, - img_b, - directory_a, - directory_b, - rel_path_a, - rel_path_b, - basename_a, - basename_b, - index, - ) + return img_a, img_b, rel_path_a, rel_path_b, basename_a, basename_b, index supported_filetypes = get_available_image_formats() @@ -94,4 +85,4 @@ def load_images(filepaths: Tuple[str, str], index: int): image_files = list(zip(image_files_a, image_files_b)) - return Iterator.from_list(image_files, load_images) + return Iterator.from_list(image_files, load_images), directory_a, directory_b diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py index 632dd363d..44516c50a 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py @@ -7,7 +7,7 @@ import numpy as np from wcmatch import glob -from api import Iterator +from api import Iterator, IteratorOutputInfo from nodes.groups import Condition, if_group from nodes.impl.image_formats import get_available_image_formats from nodes.properties.inputs import BoolInput, DirectoryInput, NumberInput, TextInput @@ -79,6 +79,7 @@ def list_glob(directory: str, globexpr: str, ext_filter: List[str]) -> List[str] TextOutput("Name"), NumberOutput("Index"), ], + iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]), node_type="newIterator", ) def load_images_node( @@ -88,12 +89,12 @@ def load_images_node( glob_str: str, use_limit: bool, limit: int, -) -> Iterator[Tuple[np.ndarray, str, str, str, int]]: +) -> Tuple[Iterator[Tuple[np.ndarray, str, str, int]], str]: def load_image(path: str, index: int): img, img_dir, basename = load_image_node(path) # Get relative path from root directory passed by Iterator directory input rel_path = os.path.relpath(img_dir, directory) - return img, directory, rel_path, basename, index + return img, rel_path, basename, index supported_filetypes = get_available_image_formats() @@ -107,4 +108,4 @@ def load_image(path: str, index: int): if use_limit: just_image_files = just_image_files[:limit] - return Iterator.from_list(just_image_files, load_image) + return Iterator.from_list(just_image_files, load_image), directory diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/merge_spritesheet.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/merge_spritesheet.py index ce1d1ab14..78ab3a250 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/merge_spritesheet.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/merge_spritesheet.py @@ -1,10 +1,8 @@ from __future__ import annotations -from typing import Tuple - import numpy as np -from api import Collector +from api import Collector, IteratorInputInfo from nodes.properties.inputs import ImageInput, NumberInput from nodes.properties.outputs import ImageOutput @@ -39,6 +37,7 @@ "The number of columns to split the image into. The width of the image must be a multiple of this number." ), ], + iterator_inputs=IteratorInputInfo(inputs=0), outputs=[ ImageOutput( image_type=""" @@ -52,16 +51,13 @@ node_type="collector", ) def merge_spritesheet_node( - _tile: np.ndarray, + _: None, rows: int, columns: int, -) -> Collector[Tuple[np.ndarray, int, int], np.ndarray]: +) -> Collector[np.ndarray, np.ndarray]: results = [] - # TODO: This system is pretty messy. We need to separate out the creation - # of the collector from the actual collection. As-is we have unused inputs - def on_iterate(inputs: Tuple[np.ndarray, int, int]): - tile = inputs[0] + def on_iterate(tile: np.ndarray): results.append(tile) def on_complete(): diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py index 6567be877..c3b507e59 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py @@ -4,7 +4,7 @@ import numpy as np -from api import Iterator +from api import Iterator, IteratorOutputInfo from nodes.properties.inputs import ImageInput, NumberInput from nodes.properties.outputs import ImageOutput, NumberOutput from nodes.utils.utils import get_h_w_c @@ -52,6 +52,7 @@ "A counter that starts at 0 and increments by 1 for each image." ), ], + iterator_outputs=IteratorOutputInfo(outputs=[0, 1], length_type="Input1 * Input2"), node_type="newIterator", ) def split_spritesheet_node( diff --git a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py index 523198174..fd86f8923 100644 --- a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py +++ b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py @@ -7,7 +7,7 @@ import ffmpeg import numpy as np -from api import Iterator +from api import Iterator, IteratorOutputInfo from nodes.groups import Condition, if_group from nodes.properties.inputs import BoolInput, NumberInput, VideoFileInput from nodes.properties.outputs import ( @@ -54,13 +54,14 @@ NumberOutput("FPS"), AudioStreamOutput(), ], + iterator_outputs=IteratorOutputInfo(outputs=[0, 1]), node_type="newIterator", ) def load_video_node( path: str, use_limit: bool, limit: int, -) -> Iterator[Tuple[np.ndarray, int, str, str, float, Any]]: +) -> Tuple[Iterator[Tuple[np.ndarray, int]], str, str, float, Any]: video_dir, video_name, _ = split_file_path(path) ffmpeg_reader = ( @@ -121,7 +122,13 @@ def iterator(): break in_frame = np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3]) in_frame = cv2.cvtColor(in_frame, cv2.COLOR_RGB2BGR) - yield in_frame, index, video_dir, video_name, fps, audio_stream + yield in_frame, index index += 1 - return Iterator.from_iter(iter_supplier=iterator, expected_length=frame_count) + return ( + Iterator.from_iter(iter_supplier=iterator, expected_length=frame_count), + video_dir, + video_name, + fps, + audio_stream, + ) diff --git a/backend/src/packages/chaiNNer_standard/image/video_frames/save_video.py b/backend/src/packages/chaiNNer_standard/image/video_frames/save_video.py index b71bc27ea..cef17386f 100644 --- a/backend/src/packages/chaiNNer_standard/image/video_frames/save_video.py +++ b/backend/src/packages/chaiNNer_standard/image/video_frames/save_video.py @@ -4,14 +4,14 @@ from dataclasses import dataclass from enum import Enum from subprocess import Popen -from typing import Any, Optional, Tuple +from typing import Any, Optional import cv2 import ffmpeg import numpy as np from sanic.log import logger -from api import Collector +from api import Collector, IteratorInputInfo from nodes.groups import Condition, if_enum_group, if_group from nodes.impl.image_utils import to_uint8 from nodes.properties.inputs import ( @@ -170,12 +170,13 @@ class Writer: .with_id(11) ), ], + iterator_inputs=IteratorInputInfo(inputs=0), outputs=[], node_type="collector", side_effects=True, ) def save_video_node( - _frames: np.ndarray, + _: None, save_dir: str, video_name: str, video_encoder: VideoEncoder, @@ -191,27 +192,7 @@ def save_video_node( audio: Any, audio_settings: AudioSettings, audio_reduced_settings: AudioReducedSettings, -) -> Collector[ - Tuple[ - np.ndarray, - str, - str, - VideoEncoder, # encoder - VideoContainer, # h264_container - VideoContainer, # h265_container - VideoContainer, # ffv1_container - VideoContainer, # vp9_container - str, # video_preset - int, - bool, - str, - float, - Any, - AudioSettings, - AudioReducedSettings, - ], - None, -]: +) -> Collector[np.ndarray, None,]: encoder = VideoEncoder(video_encoder) container = None @@ -279,30 +260,7 @@ def save_video_node( writer = Writer() - # TODO: This system is pretty messy. We need to separate out the creation - # of the collector from the actual collection. As-is we have unused inputs - def on_iterate( - inputs: Tuple[ - np.ndarray, - str, - str, - VideoEncoder, # encoder - VideoContainer, # h264_container - VideoContainer, # h265_container - VideoContainer, # ffv1_container - VideoContainer, # vp9_container - str, # video_preset - int, - bool, - str, - float, - Any, - AudioSettings, - AudioReducedSettings, - ], - ): - img = inputs[0] - + def on_iterate(img: np.ndarray): # Create the writer and run process if writer.out is None: h, w, _ = get_h_w_c(img) From 5bdef9599dcda76de1d06b045e7db39325718788 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Fri, 20 Oct 2023 17:26:03 +0200 Subject: [PATCH 2/6] WIP --- backend/src/api.py | 24 + backend/src/chain/cache.py | 39 +- backend/src/chain/chain.py | 86 +- backend/src/chain/json.py | 4 +- backend/src/chain/optimize.py | 8 +- backend/src/events.py | 5 + .../batch_processing/split_spritesheet.py | 4 +- backend/src/process.py | 740 +++++++++++------- backend/src/server.py | 75 +- 9 files changed, 612 insertions(+), 373 deletions(-) diff --git a/backend/src/api.py b/backend/src/api.py index 6065eab5f..53acc1fd3 100644 --- a/backend/src/api.py +++ b/backend/src/api.py @@ -131,6 +131,16 @@ class NodeData: run: RunFn + @property + def single_iterator_input(self) -> IteratorInputInfo: + assert len(self.iterator_inputs) == 1 + return self.iterator_inputs[0] + + @property + def single_iterator_output(self) -> IteratorOutputInfo: + assert len(self.iterator_outputs) == 1 + return self.iterator_outputs[0] + T = TypeVar("T", bound=RunFn) S = TypeVar("S") @@ -602,6 +612,20 @@ def supplier(): return Iterator(supplier, len(l)) + @staticmethod + def from_range(count: int, map_fn: Callable[[int], I]) -> "Iterator[I]": + """ + Creates a new iterator the given number of items where each item is + lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`. + """ + assert count >= 0 + + def supplier(): + for i in range(count): + yield map_fn(i) + + return Iterator(supplier, count) + N = TypeVar("N") R = TypeVar("R") diff --git a/backend/src/chain/cache.py b/backend/src/chain/cache.py index a46781351..0264ca7e1 100644 --- a/backend/src/chain/cache.py +++ b/backend/src/chain/cache.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import gc -from typing import Dict, Generic, Iterable, Optional, Set, TypeVar +from typing import Dict, Generic, Iterable, List, Optional, Set, TypeVar from sanic.log import logger -from .chain import Chain, NodeId +from .chain import Chain, Edge, FunctionNode, NewIteratorNode, NodeId class CacheStrategy: @@ -29,12 +31,41 @@ def no_caching(self) -> bool: def get_cache_strategies(chain: Chain) -> Dict[NodeId, CacheStrategy]: """Create a map with the cache strategies for all nodes in the given chain.""" + iterator_map = chain.get_parent_iterator_map() + + def any_are_iterated(out_edges: List[Edge]) -> bool: + for out_edge in out_edges: + target = chain.nodes[out_edge.target.id] + if isinstance(target, FunctionNode) and iterator_map[target] is not None: + return True + return False + result: Dict[NodeId, CacheStrategy] = {} for node in chain.nodes.values(): - out_edges = chain.edges_from(node.id) + strategy: CacheStrategy - strategy: CacheStrategy = CacheStrategy(len(out_edges)) + out_edges = chain.edges_from(node.id) + if isinstance(node, FunctionNode) and iterator_map[node] is not None: + # the function node is iterated + strategy = CacheStrategy(len(out_edges)) + else: + # the node is NOT implicitly iterated + + if isinstance(node, NewIteratorNode): + # we only care about non-iterator outputs + iterator_output = node.data.single_iterator_output + out_edges = [ + out_edge + for out_edge in out_edges + if out_edge.source.output_id not in iterator_output.outputs + ] + + if any_are_iterated(out_edges): + # some output is used by an iterated node + strategy = StaticCaching + else: + strategy = CacheStrategy(len(out_edges)) result[node.id] = strategy diff --git a/backend/src/chain/chain.py b/backend/src/chain/chain.py index b0e56d619..a0833f6df 100644 --- a/backend/src/chain/chain.py +++ b/backend/src/chain/chain.py @@ -1,4 +1,6 @@ -from typing import Callable, Dict, List, TypeVar, Union +from __future__ import annotations + +from typing import Callable, Dict, List, Set, TypeVar, Union from api import NodeData, registry from base_types import InputId, NodeId, OutputId @@ -19,40 +21,33 @@ class FunctionNode: def __init__(self, node_id: NodeId, schema_id: str): self.id: NodeId = node_id self.schema_id: str = schema_id - - def get_node(self) -> NodeData: - return registry.get_node(self.schema_id) + self.data: NodeData = registry.get_node(schema_id) + assert self.data.type == "regularNode" def has_side_effects(self) -> bool: - return self.get_node().side_effects + return self.data.side_effects class NewIteratorNode: def __init__(self, node_id: NodeId, schema_id: str): self.id: NodeId = node_id self.schema_id: str = schema_id - - def get_node(self) -> NodeData: - node = registry.get_node(self.schema_id) - assert node.type == "newIterator", "Invalid iterator node" - return node + self.data: NodeData = registry.get_node(schema_id) + assert self.data.type == "newIterator" def has_side_effects(self) -> bool: - return self.get_node().side_effects + return self.data.side_effects class CollectorNode: def __init__(self, node_id: NodeId, schema_id: str): self.id: NodeId = node_id self.schema_id: str = schema_id - - def get_node(self) -> NodeData: - node = registry.get_node(self.schema_id) - assert node.type == "collector", "Invalid collector node" - return node + self.data: NodeData = registry.get_node(schema_id) + assert self.data.type == "collector" def has_side_effects(self) -> bool: - return self.get_node().side_effects + return self.data.side_effects Node = Union[FunctionNode, NewIteratorNode, CollectorNode] @@ -110,3 +105,60 @@ def remove_node(self, node_id: NodeId): self.__edges_by_target[e.target.id].remove(e) for e in self.__edges_by_target.pop(node_id, []): self.__edges_by_source[e.source.id].remove(e) + + def topological_order(self) -> List[NodeId]: + """ + Returns all nodes in topological order. + """ + result: List[NodeId] = [] + visited: Set[NodeId] = set() + + def visit(node_id: NodeId): + if node_id in visited: + return + visited.add(node_id) + + for e in self.edges_from(node_id): + visit(e.target.id) + + result.append(node_id) + + for node_id in self.nodes: + visit(node_id) + + return result + + def get_parent_iterator_map(self) -> Dict[FunctionNode, NewIteratorNode | None]: + """ + Returns a map of all function nodes to their parent iterator node (if any). + """ + iterator_cache: Dict[FunctionNode, NewIteratorNode | None] = {} + + def get_iterator(r: FunctionNode) -> NewIteratorNode | None: + if r in iterator_cache: + return iterator_cache[r] + + iterator: NewIteratorNode | None = None + + for in_edge in self.edges_to(r.id): + source = self.nodes[in_edge.source.id] + if isinstance(source, FunctionNode): + iterator = get_iterator(source) + if iterator is not None: + break + elif isinstance(source, NewIteratorNode): + if ( + in_edge.source.output_id + in source.data.single_iterator_output.outputs + ): + iterator = source + break + + iterator_cache[r] = iterator + return iterator + + for node in self.nodes.values(): + if isinstance(node, FunctionNode): + get_iterator(node) + + return iterator_cache diff --git a/backend/src/chain/json.py b/backend/src/chain/json.py index 3f76fe250..73192b5ba 100644 --- a/backend/src/chain/json.py +++ b/backend/src/chain/json.py @@ -71,8 +71,8 @@ def parse_json(json: List[JsonNode]) -> Tuple[Chain, InputMap]: input_map.set(node.id, inputs) for index_edge in index_edges: - source_node = chain.nodes[index_edge.from_id].get_node() - target_node = chain.nodes[index_edge.to_id].get_node() + source_node = chain.nodes[index_edge.from_id].data + target_node = chain.nodes[index_edge.to_id].data chain.add_edge( Edge( diff --git a/backend/src/chain/optimize.py b/backend/src/chain/optimize.py index 0a7a9f39e..99bb037b9 100644 --- a/backend/src/chain/optimize.py +++ b/backend/src/chain/optimize.py @@ -1,10 +1,6 @@ from sanic.log import logger -from .chain import Chain, Node - - -def __has_side_effects(node: Node) -> bool: - return node.has_side_effects() +from .chain import Chain def __removed_dead_nodes(chain: Chain) -> bool: @@ -14,7 +10,7 @@ def __removed_dead_nodes(chain: Chain) -> bool: changed = False for node in list(chain.nodes.values()): - is_dead = len(chain.edges_from(node.id)) == 0 and not __has_side_effects(node) + is_dead = len(chain.edges_from(node.id)) == 0 and not node.has_side_effects() if is_dead: chain.remove_node(node.id) changed = True diff --git a/backend/src/events.py b/backend/src/events.py index f3e05fe1d..7b1d710a8 100644 --- a/backend/src/events.py +++ b/backend/src/events.py @@ -4,6 +4,8 @@ from base_types import InputId, NodeId, OutputId from nodes.base_input import ErrorValue +# Data of events + class FinishData(TypedDict): message: str @@ -50,6 +52,9 @@ class BackendStatusData(TypedDict): statusProgress: Optional[float] +# Events + + class FinishEvent(TypedDict): event: Literal["finish"] data: FinishData diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py index c3b507e59..5c41b9e27 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/split_spritesheet.py @@ -71,7 +71,7 @@ def split_spritesheet_node( individual_h = h // rows individual_w = w // columns - def get_sprite(_, index: int): + def get_sprite(index: int): row = index // columns col = index % columns @@ -83,4 +83,4 @@ def get_sprite(_, index: int): return sprite, index # We just need the index, so we can pass in a list of None's - return Iterator.from_list([None] * (rows * columns), get_sprite) + return Iterator.from_range(rows * columns, get_sprite) diff --git a/backend/src/process.py b/backend/src/process.py index 9cbc05bc4..3627e3f22 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -6,15 +6,17 @@ import time import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Iterable, List, Optional, Set +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Union from sanic.log import logger from api import Collector, Iterator, NodeData -from base_types import NodeId, OutputId -from chain.cache import CacheStrategy, OutputCache, get_cache_strategies -from chain.chain import Chain, CollectorNode, NewIteratorNode, Node -from chain.input import EdgeInput, InputMap +from base_types import InputId, NodeId, OutputId +from chain.cache import CacheStrategy, OutputCache, StaticCaching, get_cache_strategies +from chain.chain import Chain, CollectorNode, FunctionNode, NewIteratorNode, Node +from chain.input import EdgeInput, Input, InputMap from events import Event, EventQueue, InputsDict from nodes.base_output import BaseOutput from progress_controller import Aborted, ProgressController @@ -58,21 +60,28 @@ def collect_input_information( def enforce_inputs( - inputs: Iterable[object], node: NodeData, node_id: NodeId + inputs: List[object], + node: NodeData, + node_id: NodeId, + ignored_inputs: List[InputId], ) -> List[object]: - inputs = list(inputs) - try: enforced_inputs: List[object] = [] + for index, value in enumerate(inputs): - enforced_inputs.append(node.inputs[index].enforce_(value)) + i = node.inputs[index] + if i.id in ignored_inputs: + enforced_inputs.append(None) + else: + enforced_inputs.append(i.enforce_(value)) + return enforced_inputs except Exception as e: input_dict = collect_input_information(node, inputs, enforced=False) raise NodeExecutionError(node_id, node, str(e), input_dict) from e -def enforce_output(raw_output: object, node: NodeData) -> Output: +def enforce_output(raw_output: object, node: NodeData) -> RegularOutput: l = len(node.outputs) output: Output @@ -92,32 +101,56 @@ def enforce_output(raw_output: object, node: NodeData) -> Output: for i, o in enumerate(node.outputs): output[i] = o.enforce(output[i]) - return output + return RegularOutput(output) + + +def enforce_iterator_output(raw_output: object, node: NodeData) -> IteratorOutput: + l = len(node.outputs) + iterator_output = node.single_iterator_output + + partial: list[object] = [None] * l + + if l == len(iterator_output.outputs): + assert isinstance(raw_output, Iterator), "Expected the output to be an iterator" + return IteratorOutput(iterator=raw_output, partial_output=partial) + + assert l > len(iterator_output.outputs) + assert isinstance(raw_output, (tuple, list)) + + iterator, *rest = raw_output + assert isinstance( + iterator, Iterator + ), "Expected the first tuple element to be an iterator" + assert len(rest) == l - len(iterator_output.outputs) + + # output-specific validations + for i, o in enumerate(node.outputs): + if o.id not in iterator_output.outputs: + partial[i] = o.enforce(rest.pop(0)) + + return IteratorOutput(iterator=iterator, partial_output=partial) def run_node( node: NodeData, inputs: List[object], node_id: NodeId -) -> Output | Iterator | Collector: - assert ( - node.type == "regularNode" - or node.type == "newIterator" - or node.type == "collector" - ) +) -> NodeOutput | CollectorOutput: + if node.type == "collector": + ignored_inputs = node.single_iterator_input.inputs + else: + ignored_inputs = [] + + enforced_inputs = enforce_inputs(inputs, node, node_id, ignored_inputs) - enforced_inputs = [] - if node.type != "collector": - enforced_inputs = enforce_inputs(inputs, node, node_id) try: - if node.type != "collector": - raw_output = node.run(*enforced_inputs) - else: - raw_output = node.run(*list(inputs)) - if node.type == "newIterator": - assert isinstance(raw_output, Iterator) - return raw_output + raw_output = node.run(*enforced_inputs) + if node.type == "collector": assert isinstance(raw_output, Collector) - return raw_output + return CollectorOutput(raw_output) + if node.type == "newIterator": + return enforce_iterator_output(raw_output, node) + + assert node.type == "regularNode" return enforce_output(raw_output, node) except Aborted: raise @@ -125,20 +158,80 @@ def run_node( raise except Exception as e: # collect information to provide good error messages - if node.type != "collector": - input_dict = collect_input_information(node, enforced_inputs) - else: - input_dict = collect_input_information(node, list(inputs)) + input_dict = collect_input_information(node, enforced_inputs) raise NodeExecutionError(node_id, node, str(e), input_dict) from e +def run_collector_iterate( + node: CollectorNode, inputs: List[object], collector: Collector +) -> None: + iterator_input = node.data.single_iterator_input + + def get_partial_inputs(values: List[object]) -> List[object]: + partial_inputs: List[object] = [] + index = 0 + for i in node.data.inputs: + if i.id in iterator_input.inputs: + partial_inputs.append(values[index]) + index += 1 + else: + partial_inputs.append(None) + return partial_inputs + + enforced_inputs: List[object] = [] + try: + for i in node.data.inputs: + if i.id in iterator_input.inputs: + enforced_inputs.append(i.enforce_(inputs[len(enforced_inputs)])) + except Exception as e: + input_dict = collect_input_information( + node.data, get_partial_inputs(inputs), enforced=False + ) + raise NodeExecutionError(node.id, node.data, str(e), input_dict) from e + + input_value = ( + enforced_inputs[0] if len(enforced_inputs) == 1 else tuple(enforced_inputs) + ) + + try: + raw_output = collector.on_iterate(input_value) + assert raw_output is None + except Exception as e: + input_dict = collect_input_information( + node.data, get_partial_inputs(enforced_inputs) + ) + raise NodeExecutionError(node.id, node.data, str(e), input_dict) from e + + +class _Timer: + def __init__(self) -> None: + self.duration: float = 0 + + def run(self): + return _timer_run(self) + + def add_since(self, start: float): + self.duration += time.time() - start + + +@contextmanager +def _timer_run(timer: _Timer): + start = time.time() + try: + yield None + finally: + timer.add_since(start) + + def compute_broadcast(output: Output, node_outputs: Iterable[BaseOutput]): data: Dict[OutputId, object] = dict() types: Dict[OutputId, object] = dict() for index, node_output in enumerate(node_outputs): try: - data[node_output.id] = node_output.get_broadcast_data(output[index]) - types[node_output.id] = node_output.get_broadcast_type(output[index]) + value = output[index] + if value is not None: + data[node_output.id] = node_output.get_broadcast_data(value) + types[node_output.id] = node_output.get_broadcast_type(value) except Exception as e: logger.error(f"Error broadcasting output: {e}") return data, types @@ -158,6 +251,25 @@ def __init__( self.inputs: InputsDict = inputs +@dataclass(frozen=True) +class RegularOutput: + output: Output + + +@dataclass(frozen=True) +class IteratorOutput: + iterator: Iterator + partial_output: Output + + +@dataclass(frozen=True) +class CollectorOutput: + collector: Collector + + +NodeOutput = Union[RegularOutput, IteratorOutput] + + class Executor: """ Class for executing chaiNNer's processing logic @@ -171,14 +283,13 @@ def __init__( loop: asyncio.AbstractEventLoop, queue: EventQueue, pool: ThreadPoolExecutor, - parent_cache: Optional[OutputCache[Output]] = None, + parent_cache: Optional[OutputCache[NodeOutput]] = None, ): self.execution_id: str = uuid.uuid4().hex self.chain = chain self.inputs = inputs self.send_broadcast_data: bool = send_broadcast_data - self.cache: OutputCache[Output] = OutputCache(parent=parent_cache) - self.collector_cache: Dict[NodeId, Collector] = {} + self.cache: OutputCache[NodeOutput] = OutputCache(parent=parent_cache) self.__broadcast_tasks: List[asyncio.Task[None]] = [] self.progress = ProgressController() @@ -191,7 +302,17 @@ def __init__( self.cache_strategy: Dict[NodeId, CacheStrategy] = get_cache_strategies(chain) - async def process(self, node_id: NodeId) -> Output | Iterator: + @property + def completed_percentage(self) -> float: + completed = self.cache.keys().union(self.completed_node_ids) + return len(completed) / len(self.chain.nodes) + + async def process(self, node_id: NodeId) -> NodeOutput | CollectorOutput: + # Return cached output value from an already-run node if that cached output exists + cached = self.cache.get(node_id) + if cached is not None: + return cached + node = self.chain.nodes[node_id] try: return await self.__process(node) @@ -200,107 +321,149 @@ async def process(self, node_id: NodeId) -> Output | Iterator: except NodeExecutionError: raise except Exception as e: - raise NodeExecutionError(node.id, node.get_node(), str(e), {}) from e + raise NodeExecutionError(node.id, node.data, str(e), {}) from e - async def __process(self, node: Node) -> Output | Iterator: - """Process a single node""" + async def process_regular_node(self, node_id: NodeId) -> RegularOutput: + assert self.chain.nodes[node_id].data.type == "regularNode" + result = await self.process(node_id) + assert isinstance(result, RegularOutput) + return result - # Return cached output value from an already-run node if that cached output exists - cached = self.cache.get(node.id) - if cached is not None: - if not node.id in self.completed_node_ids: - self.completed_node_ids.add(node.id) - await self.queue.put(self.__create_node_finish(node.id)) - return cached + async def __get_node_output(self, node_id: NodeId, output_index: int) -> object: + """ + Returns the output value of the given node. - logger.debug(f"node: {node}") - logger.debug(f"Running node {node.id}") + Note: `output_index` is NOT an output ID. + """ - await self.queue.put(self.__create_node_start(node.id)) + # Recursively get the value of the input + output = await self.process(node_id) - await self.progress.suspend() + if isinstance(output, CollectorOutput): + # this generally shouldn't be possible + raise ValueError("A collector was not run before another node needed it.") - inputs = [] - for node_input in self.inputs.get(node.id): + if isinstance(output, IteratorOutput): + value = output.partial_output[output_index] + assert value is not None, "An iterator output was not assigned correctly" + return value + + assert isinstance(output, RegularOutput) + return output.output[output_index] + + async def __resolve_node_input(self, node_input: Input) -> object: + if isinstance(node_input, EdgeInput): # If input is a dict indicating another node, use that node's output value - if isinstance(node_input, EdgeInput): - # Recursively get the value of the input - processed_input = await self.process(node_input.id) - assert not isinstance(processed_input, Iterator) - assert not isinstance(processed_input, Collector) - inputs.append(processed_input[node_input.index]) + # Recursively get the value of the input + return await self.__get_node_output(node_input.id, node_input.index) + else: # Otherwise, just use the given input (number, string, etc) + return node_input.value + + async def __gather_inputs(self, node: Node) -> List[object]: + """ + Returns the list of input values for the given node. + """ + + # we want to ignore some inputs if we are running a collector node + ignore: set[int] = set() + if isinstance(node, CollectorNode): + iterator_input = node.data.single_iterator_input + for input_index, i in enumerate(node.data.inputs): + if i.id in iterator_input.inputs: + ignore.add(input_index) + + assigned_inputs = self.inputs.get(node.id) + assert len(assigned_inputs) == len(node.data.inputs) + + inputs = [] + for input_index, node_input in enumerate(assigned_inputs): + if input_index in ignore: + inputs.append(None) else: - inputs.append(node_input.value) + inputs.append(await self.__resolve_node_input(node_input)) - await self.progress.suspend() + return inputs - # Create node based on given category/name information - node_instance = node.get_node() + async def __gather_collector_inputs(self, node: Node) -> List[object]: + """ + Returns the input values to be consumed by `Collector.on_iterate`. + """ - if node_instance.type == "newIterator": - output, execution_time = await self.loop.run_in_executor( - self.pool, - timed_supplier( - functools.partial(run_node, node_instance, inputs, node.id) - ), - ) - assert isinstance(output, Iterator) + assert isinstance(node, CollectorNode) + assert node.data.type == "collector" - await self.progress.suspend() - elif node_instance.type == "collector": - collector_node = self.collector_cache[node.id] - collector_node.on_iterate(inputs) - output = None + iterator_input = node.data.single_iterator_input - await self.progress.suspend() - else: - output, execution_time = await self.loop.run_in_executor( - self.pool, - timed_supplier( - functools.partial(run_node, node_instance, inputs, node.id) - ), - ) + assigned_inputs = self.inputs.get(node.id) + assert len(assigned_inputs) == len(node.data.inputs) - await self.progress.suspend() - await self.__broadcast_data(node_instance, node.id, execution_time, output) # type: ignore + inputs = [] + for input_index, node_input in enumerate(assigned_inputs): + i = node.data.inputs[input_index] + if i.id in iterator_input.inputs: + inputs.append(await self.__resolve_node_input(node_input)) + + return inputs + + async def __process(self, node: Node) -> NodeOutput | CollectorOutput: + """ + Process a single node. + + In the case of iterators and collectors, it will only run the node itself, + not the actual iteration or collection. + """ + + logger.debug(f"node: {node}") + logger.debug(f"Running node {node.id}") + + await self.queue.put(self.__create_node_start(node.id)) + await self.progress.suspend() + + inputs = await self.__gather_inputs(node) + + await self.progress.suspend() + + output, execution_time = await self.loop.run_in_executor( + self.pool, + timed_supplier(functools.partial(run_node, node.data, inputs, node.id)), + ) + await self.progress.suspend() + + if isinstance(output, RegularOutput): + await self.__broadcast_data(node, execution_time, output.output) + elif isinstance(output, IteratorOutput): + await self.__broadcast_data(node, execution_time, output.partial_output) # Cache the output of the node - # If we are executing a free node from within an iterator, - # we want to store the result in the cache of the parent executor - if node.get_node().type != "collector": - write_cache = self.cache - write_cache.set(node.id, output, self.cache_strategy[node.id]) # type: ignore + if not isinstance(output, CollectorOutput): + self.cache.set(node.id, output, self.cache_strategy[node.id]) - return output # type: ignore + await self.progress.suspend() + + return output async def __broadcast_data( self, - node_instance: NodeData, - node_id: NodeId, + node: Node, execution_time: float, output: Output, ): - finished = self.cache.keys() - finished.add(node_id) - finished = list(finished) - - self.completed_node_ids.add(node_id) + self.completed_node_ids.add(node.id) async def send_broadcast(): data, types = await self.loop.run_in_executor( - self.pool, lambda: compute_broadcast(output, node_instance.outputs) + self.pool, lambda: compute_broadcast(output, node.data.outputs) ) await self.queue.put( { "event": "node-finish", "data": { - "nodeId": node_id, + "nodeId": node.id, "executionTime": execution_time, "data": data, "types": types, - "progressPercent": len(self.completed_node_ids) - / len(self.chain.nodes), + "progressPercent": self.completed_percentage, }, } ) @@ -308,8 +471,8 @@ async def send_broadcast(): # Only broadcast the output if the node has outputs and the output is not cached if ( self.send_broadcast_data - and len(node_instance.outputs) > 0 - and not self.cache.has(node_id) + and len(node.data.outputs) > 0 + and not self.cache.has(node.id) ): # broadcasts are done is parallel, so don't wait self.__broadcast_tasks.append(self.loop.create_task(send_broadcast())) @@ -318,34 +481,15 @@ async def send_broadcast(): { "event": "node-finish", "data": { - "nodeId": node_id, + "nodeId": node.id, "executionTime": execution_time, "data": None, "types": None, - "progressPercent": len(self.completed_node_ids) - / len(self.chain.nodes), + "progressPercent": self.completed_percentage, }, } ) - def __create_node_finish(self, node_id: NodeId) -> Event: - finished = self.cache.keys() - finished.add(node_id) - finished = list(finished) - - self.completed_node_ids.add(node_id) - - return { - "event": "node-finish", - "data": { - "nodeId": node_id, - "executionTime": None, - "data": None, - "types": None, - "progressPercent": len(self.completed_node_ids) / len(self.chain.nodes), - }, - } - def __create_node_start(self, node_id: NodeId) -> Event: return { "event": "node-start", @@ -354,199 +498,199 @@ def __create_node_start(self, node_id: NodeId) -> Event: }, } - def __get_output_nodes(self) -> List[NodeId]: - output_nodes: List[NodeId] = [] - for node in self.chain.nodes.values(): - side_effects = node.has_side_effects() - if side_effects: - output_nodes.append(node.id) - return output_nodes - - def __get_iterator_nodes(self) -> List[NodeId]: - iterator_nodes: List[NodeId] = [] - for node in self.chain.nodes.values(): - if isinstance(node, NewIteratorNode): - iterator_nodes.append(node.id) - return iterator_nodes - - def __get_collector_nodes(self) -> List[NodeId]: - collector_nodes: List[NodeId] = [] - for node in self.chain.nodes.values(): - if isinstance(node, CollectorNode): - collector_nodes.append(node.id) - return collector_nodes - - def __get_downstream_nodes(self, node: NodeId) -> Set[NodeId]: - downstream_nodes: List[NodeId] = [] - for edge in self.chain.edges_from(node): - downstream_nodes.append(edge.target.id) - for downstream_node in downstream_nodes: - downstream_nodes.extend(self.__get_downstream_nodes(downstream_node)) - return set(downstream_nodes) - - def __get_upstream_nodes(self, node: NodeId) -> Set[NodeId]: - upstream_nodes: List[NodeId] = [] - for edge in self.chain.edges_to(node): - upstream_nodes.append(edge.source.id) - for upstream_node in upstream_nodes: - upstream_nodes.extend(self.__get_upstream_nodes(upstream_node)) - return set(upstream_nodes) + def __get_iterated_nodes( + self, node: NewIteratorNode + ) -> tuple[set[CollectorNode], set[FunctionNode], set[Node]]: + """ + Returns all collector and output nodes iterated by the given iterator node + """ + collectors: set[CollectorNode] = set() + output_nodes: set[FunctionNode] = set() + + seen: set[Node] = {node} + + def visit(n: Node): + if n in seen: + return + seen.add(n) + + if isinstance(n, CollectorNode): + collectors.add(n) + elif isinstance(n, NewIteratorNode): + raise ValueError("Nested iterators are not supported") + else: + assert isinstance(n, FunctionNode) + + if n.has_side_effects(): + output_nodes.add(n) + + # follow edges + for edge in self.chain.edges_from(n.id): + target_node = self.chain.nodes[edge.target.id] + visit(target_node) + + iterator_output = node.data.single_iterator_output + for edge in self.chain.edges_from(node.id): + # only follow iterator outputs + if edge.source.output_id in iterator_output.outputs: + target_node = self.chain.nodes[edge.target.id] + visit(target_node) + + return collectors, output_nodes, seen + + def __iterator_fill_partial_output( + self, node: NodeData, partial_output: Output, values: object + ) -> Output: + assert node.type == "newIterator" + iterator_output = node.single_iterator_output + + values_list: list[object] = [] + if len(iterator_output.outputs) == 1: + values_list.append(values) + else: + assert isinstance(values, (tuple, list)) + values_list.extend(values) - async def __process_nodes(self): + assert len(values_list) == len(iterator_output.outputs) + + output: Output = partial_output.copy() + for index, o in enumerate(node.outputs): + if o.id in iterator_output.outputs: + output[index] = o.enforce(values_list.pop(0)) + + return output + + async def __process_iterator_node(self, node: NewIteratorNode): await self.progress.suspend() - iterator_node_set = set() - chain_output_nodes = self.__get_output_nodes() + # run the iterator node itself before anything else + iterator_output = await self.process(node.id) + assert isinstance(iterator_output, IteratorOutput) + + collector_nodes, output_nodes, all_iterated_nodes = self.__get_iterated_nodes( + node + ) + all_iterated_nodes = {n.id for n in all_iterated_nodes} + + if len(collector_nodes) == 0 and len(output_nodes) == 0: + # unusual, but this can happen + # since we don't need to actually iterate the iterator, we can stop here + return - collector_nodes = self.__get_collector_nodes() - self.collector_cache: Dict[NodeId, Collector] = {} - collector_downstreams = set() + def fill_partial_output(values: object) -> RegularOutput: + return RegularOutput( + self.__iterator_fill_partial_output( + node.data, iterator_output.partial_output, values + ) + ) - # Run each of the collector nodes first. This gives us all the collector objects that we will use when iterating + # run each of the collector nodes + collectors: list[tuple[Collector, _Timer, CollectorNode]] = [] for collector_node in collector_nodes: - inputs = [] - for collector_input in self.inputs.get(collector_node): - # If input is a dict indicating another node, use that node's output value - if isinstance(collector_input, EdgeInput): - # We can't use connections for collectors in case the connection is to an iterator - inputs.append(None) - else: - inputs.append(collector_input.value) - node_instance = self.chain.nodes[collector_node].get_node() - collector_output, execution_time = await self.loop.run_in_executor( - self.pool, - timed_supplier( - functools.partial(run_node, node_instance, inputs, collector_node) - ), + await self.progress.suspend() + timer = _Timer() + with timer.run(): + collector_output = await self.process(collector_node.id) + assert isinstance(collector_output, CollectorOutput) + collectors.append((collector_output.collector, timer, collector_node)) + + # timing iterations + times: List[float] = [] + expected_length = iterator_output.iterator.expected_length + start_time = time.time() + last_time = [start_time] + + async def update_progress(): + times.append(time.time() - last_time[0]) + iterations = len(times) + last_time[0] = time.time() + await self.__update_progress( + node.id, times, iterations, max(expected_length, iterations) ) - assert isinstance(collector_output, Collector) - self.collector_cache[collector_node] = collector_output - # Anything downstream from the collector we don't want to run yet, so we keep track of them here - downstream_from_collector = self.__get_downstream_nodes(collector_node) - collector_downstreams.update(downstream_from_collector) - - before_iteration_time = time.time() - - # Now run each of the iterators - for iterator_node in self.__get_iterator_nodes(): - # Get all downstream nodes of the iterator - # This excludes any nodes that are downstream of a collector, as well as collectors themselves - downstream_nodes = [ - x - for x in self.__get_downstream_nodes(iterator_node) - if x not in collector_downstreams and x not in collector_nodes - ] - output_nodes = [x for x in chain_output_nodes if x in downstream_nodes] - upstream_nodes = [self.__get_upstream_nodes(x) for x in output_nodes] - flat_upstream_nodes = set() - for x in upstream_nodes: - flat_upstream_nodes.update(x) - combined_subchain = flat_upstream_nodes.union(downstream_nodes) - iterator_node_set = iterator_node_set.union(combined_subchain) - - node_instance = self.chain.nodes[iterator_node].get_node() - assert node_instance.type == "newIterator" - - self.cache.set(iterator_node, None, CacheStrategy(0)) # type: ignore - - iter_result = await self.process(iterator_node) - - assert isinstance(iter_result, Iterator) - - num_outgoers = len(self.chain.edges_from(iterator_node)) - - start_time = time.time() - last_time = start_time - times: List[float] = [] - enforced_values = None - for index, values in enumerate(iter_result.iter_supplier()): - await self.queue.put(self.__create_node_start(iterator_node)) - - self.cache.delete_many(downstream_nodes) - enforced_values = enforce_output( - values, self.chain.nodes[iterator_node].get_node() - ) - after_time = time.time() - execution_time = after_time - last_time - times.append(execution_time) - await self.__broadcast_data( - node_instance, iterator_node, execution_time, enforced_values - ) - await self.__update_progress( - iterator_node, times, index, iter_result.expected_length - ) - last_time = after_time + # iterate + await self.__update_progress(node.id, times, 0, expected_length) - # Set the cache to the value of the generator, so that downstream nodes will pull from that - self.cache.set( - iterator_node, enforced_values, CacheStrategy(num_outgoers) - ) - # Run each of the collector nodes - for collector_node in collector_nodes: - await self.progress.suspend() - await self.process(collector_node) - # Run each of the output nodes - for output_node in output_nodes: - await self.progress.suspend() - await self.process(output_node) - - logger.debug(self.cache.keys()) - end_time = time.time() - execution_time = end_time - start_time - if enforced_values is not None: - await self.__broadcast_data( - node_instance, iterator_node, execution_time, enforced_values + for values in iterator_output.iterator.iter_supplier(): + # write current values to cache + iter_output = fill_partial_output(values) + self.cache.set(node.id, iter_output, StaticCaching) + + # broadcast + # TODO: Execution time. I just don't think any values makes sense here + await self.__broadcast_data(node, 0, iter_output.output) + + # run each of the output nodes + for output_node in output_nodes: + await self.process(output_node.id) + + # run each of the collector nodes + for collector, timer, collector_node in collectors: + await self.progress.suspend() + iterate_inputs = await self.__gather_collector_inputs(collector_node) + await self.progress.suspend() + with timer.run(): + run_collector_iterate(collector_node, iterate_inputs, collector) + + # clear cache for next iteration + self.cache.delete_many(all_iterated_nodes) + + await self.progress.suspend() + await update_progress() + await self.progress.suspend() + + # reset cached value + self.cache.delete_many(all_iterated_nodes) + self.cache.set(node.id, iterator_output, self.cache_strategy[node.id]) + + # re-broadcast final value + iterations = len(times) + await self.__finish_progress(node.id, iterations) + await self.__broadcast_data( + node, time.time() - start_time, iterator_output.partial_output + ) + + # finalize collectors + for collector, timer, collector_node in collectors: + await self.progress.suspend() + with timer.run(): + collector_output = enforce_output( + collector.on_complete(), collector_node.data ) - await self.__finish_progress(iterator_node, iter_result.expected_length) - # Complete each of the collector nodes, and cache their values - for collector_node in collector_nodes: - collector_result = self.collector_cache[collector_node].on_complete() - enforced_values = enforce_output( - collector_result, self.chain.nodes[collector_node].get_node() + # TODO: execution time + await self.__broadcast_data( + collector_node, timer.duration, collector_output.output ) + self.cache.set( - collector_node, - enforced_values, - CacheStrategy(len(self.chain.edges_from(collector_node))), - ) - collector_time = time.time() - before_iteration_time - await self.__broadcast_data( - self.chain.nodes[collector_node].get_node(), - collector_node, - collector_time, - enforced_values, + collector_node.id, + collector_output, + self.cache_strategy[collector_node.id], ) - # Now run everything downstream of the collectors - collector_downstream_outputs = [ - x for x in chain_output_nodes if x in collector_downstreams - ] - for output_node in collector_downstream_outputs: - await self.progress.suspend() - await self.process(output_node) + async def __process_nodes(self): + # we first need to run iterator nodes in topological order + for node_id in self.chain.topological_order(): + node = self.chain.nodes[node_id] + if isinstance(node, NewIteratorNode): + await self.__process_iterator_node(node) - iterator_node_set.update(collector_nodes) - iterator_node_set.update(collector_downstreams) + # now the output nodes outside of iterators # Now run everything that is not in an iterator lineage - without_iterator_lineage = [ - x for x in self.chain.nodes.values() if x not in iterator_node_set + non_iterator_output_nodes = [ + node + for node, iter_node in self.chain.get_parent_iterator_map().items() + if iter_node is None and node.has_side_effects() ] + for output_node in non_iterator_output_nodes: + await self.progress.suspend() + await self.process(output_node.id) + # clear cache after the chain is done self.cache.clear() - if len(without_iterator_lineage) > 0: - non_iterator_output_nodes = [ - x for x in chain_output_nodes if x not in iterator_node_set - ] - for output_node in non_iterator_output_nodes: - await self.progress.suspend() - await self.process(output_node) - - logger.debug(self.cache.keys()) - # await all broadcasts tasks = self.__broadcast_tasks self.__broadcast_tasks = [] diff --git a/backend/src/server.py b/backend/src/server.py index 35b57177c..ee9f86554 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -1,5 +1,4 @@ import asyncio -import functools import gc import importlib import logging @@ -21,6 +20,8 @@ import api from base_types import NodeId from chain.cache import OutputCache +from chain.chain import Chain, FunctionNode +from chain.input import InputMap from chain.json import JsonNode, parse_json from chain.optimize import optimize from custom_types import UpdateProgressFn @@ -33,7 +34,7 @@ JsonExecutionOptions, set_execution_options, ) -from process import Executor, NodeExecutionError, Output, compute_broadcast, run_node +from process import Executor, NodeExecutionError, NodeOutput from progress_controller import Aborted from response import ( alreadyRunningResponse, @@ -43,14 +44,13 @@ ) from server_config import ServerConfig from system import is_arm_mac -from util import timed_supplier class AppContext: def __init__(self): self.config: ServerConfig = None # type: ignore self.executor: Optional[Executor] = None - self.cache: Dict[NodeId, Output] = dict() + self.cache: Dict[NodeId, NodeOutput] = dict() # This will be initialized by after_server_start. # This is necessary because we don't know Sanic's event loop yet. self.queue: EventQueue = None # type: ignore @@ -169,12 +169,12 @@ async def run(request: Request): exec_opts = ExecutionOptions.parse(full_data["options"]) set_execution_options(exec_opts) executor = Executor( - chain, - inputs, - full_data["sendBroadcastData"], - app.loop, - ctx.queue, - ctx.pool, + chain=chain, + inputs=inputs, + send_broadcast_data=full_data["sendBroadcastData"], + loop=app.loop, + queue=ctx.queue, + pool=ctx.pool, parent_cache=OutputCache(static_data=ctx.cache.copy()), ) try: @@ -230,44 +230,31 @@ async def run_individual(request: Request): logger.debug(full_data) exec_opts = ExecutionOptions.parse(full_data["options"]) set_execution_options(exec_opts) - # Create node based on given category/name information - node_instance = api.registry.get_node(full_data["schemaId"]) + + chain = Chain() + chain.add_node(FunctionNode(node_id, full_data["schemaId"])) + + input_map = InputMap() + input_map.set_values(node_id, full_data["inputs"]) + + executor = Executor( + chain=chain, + inputs=input_map, + send_broadcast_data=True, + loop=app.loop, + queue=ctx.queue, + pool=ctx.pool, + ) with runIndividualCounter: - # Run the node and pass in inputs as args - output, execution_time = await app.loop.run_in_executor( - None, - timed_supplier( - functools.partial( - run_node, node_instance, full_data["inputs"], node_id - ) - ), - ) - # Cache the output of the node - if not isinstance(output, api.Iterator) and not isinstance( - output, api.Collector - ): + try: + output = await executor.process_regular_node(node_id) ctx.cache[node_id] = output + except Aborted: + pass + finally: + gc.collect() - # Broadcast the output from the individual run - node_outputs = node_instance.outputs - if len(node_outputs) > 0: - assert not isinstance(output, api.Iterator) - assert not isinstance(output, api.Collector) - data, types = compute_broadcast(output, node_outputs) - await ctx.queue.put( - { - "event": "node-finish", - "data": { - "nodeId": node_id, - "executionTime": execution_time, - "data": data, - "types": types, - "progressPercent": None, - }, - } - ) - gc.collect() return json({"success": True, "data": None}) except Exception as exception: logger.error(exception, exc_info=True) From 26aee9b57e865234e4b394be55eafe4e7041ebc3 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Fri, 20 Oct 2023 17:30:40 +0200 Subject: [PATCH 3/6] Minor change in timer --- backend/src/process.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/backend/src/process.py b/backend/src/process.py index 3627e3f22..1e8fd2931 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -207,22 +207,18 @@ class _Timer: def __init__(self) -> None: self.duration: float = 0 + @contextmanager def run(self): - return _timer_run(self) + start = time.time() + try: + yield None + finally: + self.add_since(start) def add_since(self, start: float): self.duration += time.time() - start -@contextmanager -def _timer_run(timer: _Timer): - start = time.time() - try: - yield None - finally: - timer.add_since(start) - - def compute_broadcast(output: Output, node_outputs: Iterable[BaseOutput]): data: Dict[OutputId, object] = dict() types: Dict[OutputId, object] = dict() From 74c62bd5d6dcf675196886d1de30410d5b339bcc Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Mon, 23 Oct 2023 18:26:42 +0200 Subject: [PATCH 4/6] Better types for iterator nodes --- .../chaiNNer_ncnn/ncnn/batch_processing/load_models.py | 2 +- .../chaiNNer_onnx/onnx/batch_processing/load_models.py | 2 +- .../chaiNNer_pytorch/pytorch/iteration/load_models.py | 2 +- .../chaiNNer_standard/image/batch_processing/load_images.py | 2 +- .../chaiNNer_standard/image/video_frames/load_video.py | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py index 3cbf6bcdc..ea679a85e 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py @@ -34,7 +34,7 @@ ], outputs=[ NcnnModelOutput(), - DirectoryOutput("Directory"), + DirectoryOutput("Directory", output_type="Input0"), TextOutput("Subdirectory Path"), TextOutput("Name"), NumberOutput("Index", output_type="uint").with_docs( diff --git a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py index 09412da84..aebf8d151 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py @@ -34,7 +34,7 @@ ], outputs=[ OnnxModelOutput(), - DirectoryOutput("Directory"), + DirectoryOutput("Directory", output_type="Input0"), TextOutput("Subdirectory Path"), TextOutput("Name"), NumberOutput("Index", output_type="uint").with_docs( diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py index a3f852af4..ec51da8d2 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py @@ -30,7 +30,7 @@ ], outputs=[ ModelOutput(), - DirectoryOutput("Directory"), + DirectoryOutput("Directory", output_type="Input0"), TextOutput("Subdirectory Path"), TextOutput("Name"), NumberOutput("Index", output_type="uint").with_docs( diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py index 44516c50a..7a688f9e0 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py @@ -74,7 +74,7 @@ def list_glob(directory: str, globexpr: str, ext_filter: List[str]) -> List[str] ], outputs=[ ImageOutput(), - DirectoryOutput("Directory"), + DirectoryOutput("Directory", output_type="Input0"), TextOutput("Subdirectory Path"), TextOutput("Name"), NumberOutput("Index"), diff --git a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py index fd86f8923..afbfd98f3 100644 --- a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py +++ b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py @@ -13,9 +13,9 @@ from nodes.properties.outputs import ( AudioStreamOutput, DirectoryOutput, + FileNameOutput, ImageOutput, NumberOutput, - TextOutput, ) from nodes.utils.utils import split_file_path @@ -49,8 +49,8 @@ NumberOutput("Frame Index", output_type="uint").with_docs( "A counter that starts at 0 and increments by 1 for each frame." ), - DirectoryOutput("Video Directory"), - TextOutput("Video Name"), + DirectoryOutput("Video Directory", of_input=0), + FileNameOutput("Name", of_input=0), NumberOutput("FPS"), AudioStreamOutput(), ], From a62b03cf68451c35795b7693cd08bec73f6fbe6a Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Mon, 23 Oct 2023 18:30:12 +0200 Subject: [PATCH 5/6] Added range iterator --- .../chaiNNer_standard/utility/value/range.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 backend/src/packages/chaiNNer_standard/utility/value/range.py diff --git a/backend/src/packages/chaiNNer_standard/utility/value/range.py b/backend/src/packages/chaiNNer_standard/utility/value/range.py new file mode 100644 index 000000000..f327f2d3e --- /dev/null +++ b/backend/src/packages/chaiNNer_standard/utility/value/range.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from api import Iterator, IteratorOutputInfo +from nodes.properties.inputs import BoolInput, NumberInput +from nodes.properties.outputs import NumberOutput + +from .. import value_group + + +@value_group.register( + schema_id="chainner:utility:range", + name="Range", + description="Iterates through all integers in the given range.", + icon="MdCalculate", + inputs=[ + NumberInput("Start", default=0, minimum=None, maximum=None), + BoolInput("Start Inclusive", default=True), + NumberInput("Stop", default=10, minimum=None, maximum=None), + BoolInput("Stop Inclusive", default=False), + ], + outputs=[ + NumberOutput( + "Number", + output_type=""" + let start = if Input1 { Input0 } else { Input0 + 1 }; + let stop = if Input3 { Input2 } else { Input2 - 1 }; + + max(int, start) & min(int, stop) + """, + ).with_never_reason("The range is empty."), + ], + iterator_outputs=IteratorOutputInfo(outputs=0), + node_type="newIterator", +) +def range_node( + start: int, + start_inclusive: bool, + end: int, + end_inclusive: bool, +) -> Iterator[int]: + if not start_inclusive: + start += 1 + if end_inclusive: + end += 1 + count = end - start + return Iterator.from_range(count, lambda i: start + i) From 3f8dd01021faffd567794f9c8a705114bc173ed3 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Tue, 24 Oct 2023 17:13:05 +0200 Subject: [PATCH 6/6] Disable type check for collectors and iterators --- backend/src/api.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/backend/src/api.py b/backend/src/api.py index 53acc1fd3..de5994f10 100644 --- a/backend/src/api.py +++ b/backend/src/api.py @@ -224,10 +224,11 @@ def inner_wrapper(wrapped_func: T) -> T: p_inputs, group_layout = _process_inputs(inputs) p_outputs = _process_outputs(outputs) - run_check( - TYPE_CHECK_LEVEL, - lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs), - ) + if node_type == "regularNode": + run_check( + TYPE_CHECK_LEVEL, + lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs), + ) run_check( NAME_CHECK_LEVEL, lambda fix: check_naming_conventions(wrapped_func, name, fix),