Skip to content

Commit

Permalink
Add support for combined iterator lineage (#2949)
Browse files Browse the repository at this point in the history
* rename to generators

* change node kind from newIterator to generator

* change more terminology

* we might be getting somewhere

* fix timing

* comment out validity blocker

* some more fixes

* Split by lineage

* clean up commented code

* more performant function

* Enforce that all connected iterators share the same expected length

* Add migration

* remove load image pairs

* finalize validity rules

* update snapshot

* fix type errors

* gen_supplier -> supplier

* supplier can still use Iterable

* move function, add doc comment

* use typing.Iterator instead of typing.Generator

* move function again, add unit tests

* Add identity functions, use frozen sets

* slight refactor
  • Loading branch information
joeyballentine authored Jun 16, 2024
1 parent 20cd09e commit 410e586
Show file tree
Hide file tree
Showing 29 changed files with 627 additions and 385 deletions.
33 changes: 17 additions & 16 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import os
import typing
from dataclasses import asdict, dataclass, field
from typing import (
Any,
Expand Down Expand Up @@ -149,10 +150,10 @@ def to_list(x: list[S] | S | None) -> list[S]:
iterator_inputs = to_list(iterator_inputs)
iterator_outputs = to_list(iterator_outputs)

if kind == "collector":
assert len(iterator_inputs) == 1 and len(iterator_outputs) == 0
elif kind == "newIterator":
if kind == "generator": # Generator
assert len(iterator_inputs) == 0 and len(iterator_outputs) == 1
elif kind == "collector":
assert len(iterator_inputs) == 1 and len(iterator_outputs) == 0
else:
assert len(iterator_inputs) == 0 and len(iterator_outputs) == 0

Expand Down Expand Up @@ -188,8 +189,8 @@ def inner_wrapper(wrapped_func: T) -> T:
inputs=p_inputs,
group_layout=group_layout,
outputs=p_outputs,
iterator_inputs=iterator_inputs,
iterator_outputs=iterator_outputs,
iterable_inputs=iterator_inputs,
iterable_outputs=iterator_outputs,
key_info=key_info,
suggestions=suggestions or [],
side_effects=side_effects,
Expand Down Expand Up @@ -511,25 +512,25 @@ def add_package(


@dataclass
class Iterator(Generic[I]):
iter_supplier: Callable[[], Iterable[I | Exception]]
class Generator(Generic[I]):
supplier: Callable[[], typing.Iterator[I | Exception]]
expected_length: int
fail_fast: bool = True

@staticmethod
def from_iter(
iter_supplier: Callable[[], Iterable[I | Exception]],
supplier: Callable[[], typing.Iterator[I | Exception]],
expected_length: int,
fail_fast: bool = True,
) -> Iterator[I]:
return Iterator(iter_supplier, expected_length, fail_fast=fail_fast)
) -> Generator[I]:
return Generator(supplier, expected_length, fail_fast=fail_fast)

@staticmethod
def from_list(
l: list[L], map_fn: Callable[[L, int], I], fail_fast: bool = True
) -> Iterator[I]:
) -> Generator[I]:
"""
Creates a new iterator from a list that is mapped using the given
Creates a new generator from a list that is mapped using the given
function. The iterable will be equivalent to `map(map_fn, l)`.
"""

Expand All @@ -540,14 +541,14 @@ def supplier():
except Exception as e:
yield e

return Iterator(supplier, len(l), fail_fast=fail_fast)
return Generator(supplier, len(l), fail_fast=fail_fast)

@staticmethod
def from_range(
count: int, map_fn: Callable[[int], I], fail_fast: bool = True
) -> Iterator[I]:
) -> Generator[I]:
"""
Creates a new iterator the given number of items where each item is
Creates a new generator the given number of items where each item is
lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`.
"""
assert count >= 0
Expand All @@ -559,7 +560,7 @@ def supplier():
except Exception as e:
yield e

return Iterator(supplier, count, fail_fast=fail_fast)
return Generator(supplier, count, fail_fast=fail_fast)


N = TypeVar("N")
Expand Down
16 changes: 8 additions & 8 deletions backend/src/api/node_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ class NodeData:
outputs: list[BaseOutput]
group_layout: list[InputId | NestedIdGroup]

iterator_inputs: list[IteratorInputInfo]
iterator_outputs: list[IteratorOutputInfo]
iterable_inputs: list[IteratorInputInfo]
iterable_outputs: list[IteratorOutputInfo]

key_info: KeyInfo | None
suggestions: list[SpecialSuggestion]
Expand All @@ -150,11 +150,11 @@ class NodeData:
run: RunFn

@property
def single_iterator_input(self) -> IteratorInputInfo:
assert len(self.iterator_inputs) == 1
return self.iterator_inputs[0]
def single_iterable_input(self) -> IteratorInputInfo:
assert len(self.iterable_inputs) == 1
return self.iterable_inputs[0]

@property
def single_iterator_output(self) -> IteratorOutputInfo:
assert len(self.iterator_outputs) == 1
return self.iterator_outputs[0]
def single_iterable_output(self) -> IteratorOutputInfo:
assert len(self.iterable_outputs) == 1
return self.iterable_outputs[0]
2 changes: 1 addition & 1 deletion backend/src/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@

RunFn = Callable[..., Any]

NodeKind = Literal["regularNode", "newIterator", "collector"]
NodeKind = Literal["regularNode", "generator", "collector"]
6 changes: 3 additions & 3 deletions backend/src/chain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from api import NodeId

from .chain import Chain, Edge, FunctionNode, NewIteratorNode
from .chain import Chain, Edge, FunctionNode, GeneratorNode


class CacheStrategy:
Expand Down Expand Up @@ -54,9 +54,9 @@ def any_are_iterated(out_edges: list[Edge]) -> bool:
else:
# the node is NOT implicitly iterated

if isinstance(node, NewIteratorNode):
if isinstance(node, GeneratorNode):
# we only care about non-iterator outputs
iterator_output = node.data.single_iterator_output
iterator_output = node.data.single_iterable_output
out_edges = [
out_edge
for out_edge in out_edges
Expand Down
18 changes: 9 additions & 9 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def has_side_effects(self) -> bool:
return self.data.side_effects


class NewIteratorNode:
class GeneratorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.data: NodeData = registry.get_node(schema_id)
assert self.data.kind == "newIterator"
assert self.data.kind == "generator"

def has_side_effects(self) -> bool:
return self.data.side_effects
Expand All @@ -50,7 +50,7 @@ def has_side_effects(self) -> bool:
return self.data.side_effects


Node = Union[FunctionNode, NewIteratorNode, CollectorNode]
Node = Union[FunctionNode, GeneratorNode, CollectorNode]


@dataclass(frozen=True)
Expand Down Expand Up @@ -176,28 +176,28 @@ def visit(node_id: NodeId):

return result

def get_parent_iterator_map(self) -> dict[FunctionNode, NewIteratorNode | None]:
def get_parent_iterator_map(self) -> dict[FunctionNode, GeneratorNode | None]:
"""
Returns a map of all function nodes to their parent iterator node (if any).
"""
iterator_cache: dict[FunctionNode, NewIteratorNode | None] = {}
iterator_cache: dict[FunctionNode, GeneratorNode | None] = {}

def get_iterator(r: FunctionNode) -> NewIteratorNode | None:
def get_iterator(r: FunctionNode) -> GeneratorNode | None:
if r in iterator_cache:
return iterator_cache[r]

iterator: NewIteratorNode | None = None
iterator: GeneratorNode | None = None

for in_edge in self.edges_to(r.id):
source = self.nodes[in_edge.source.id]
if isinstance(source, FunctionNode):
iterator = get_iterator(source)
if iterator is not None:
break
elif isinstance(source, NewIteratorNode):
elif isinstance(source, GeneratorNode):
if (
in_edge.source.output_id
in source.data.single_iterator_output.outputs
in source.data.single_iterable_output.outputs
):
iterator = source
break
Expand Down
6 changes: 3 additions & 3 deletions backend/src/chain/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EdgeSource,
EdgeTarget,
FunctionNode,
NewIteratorNode,
GeneratorNode,
)


Expand Down Expand Up @@ -53,8 +53,8 @@ def parse_json(json: list[JsonNode]) -> Chain:
index_edges: list[IndexEdge] = []

for json_node in json:
if json_node["nodeType"] == "newIterator":
node = NewIteratorNode(json_node["id"], json_node["schemaId"])
if json_node["nodeType"] == "generator":
node = GeneratorNode(json_node["id"], json_node["schemaId"])
elif json_node["nodeType"] == "collector":
node = CollectorNode(json_node["id"], json_node["schemaId"])
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sanic.log import logger

from api import Iterator, IteratorOutputInfo
from api import Generator, IteratorOutputInfo
from nodes.impl.ncnn.model import NcnnModelWrapper
from nodes.properties.inputs import BoolInput, DirectoryInput
from nodes.properties.outputs import (
Expand Down Expand Up @@ -46,12 +46,12 @@
),
],
iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]),
kind="newIterator",
kind="generator",
)
def load_models_node(
directory: Path,
fail_fast: bool,
) -> tuple[Iterator[tuple[NcnnModelWrapper, str, str, int]], Path]:
) -> tuple[Generator[tuple[NcnnModelWrapper, str, str, int]], Path]:
logger.debug(f"Iterating over models in directory: {directory}")

def load_model(filepath_pairs: tuple[Path, Path], index: int):
Expand Down Expand Up @@ -82,4 +82,4 @@ def load_model(filepath_pairs: tuple[Path, Path], index: int):

model_files = list(zip(param_files, bin_files))

return Iterator.from_list(model_files, load_model, fail_fast), directory
return Generator.from_list(model_files, load_model, fail_fast), directory
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sanic.log import logger

from api import Iterator, IteratorOutputInfo
from api import Generator, IteratorOutputInfo
from nodes.impl.onnx.model import OnnxModel
from nodes.properties.inputs import BoolInput, DirectoryInput
from nodes.properties.outputs import (
Expand Down Expand Up @@ -46,12 +46,12 @@
),
],
iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]),
kind="newIterator",
kind="generator",
)
def load_models_node(
directory: Path,
fail_fast: bool,
) -> tuple[Iterator[tuple[OnnxModel, str, str, int]], Path]:
) -> tuple[Generator[tuple[OnnxModel, str, str, int]], Path]:
logger.debug(f"Iterating over models in directory: {directory}")

def load_model(path: Path, index: int):
Expand All @@ -63,4 +63,4 @@ def load_model(path: Path, index: int):
supported_filetypes = [".onnx"]
model_files = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model, fail_fast), directory
return Generator.from_list(model_files, load_model, fail_fast), directory
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sanic.log import logger
from spandrel import ModelDescriptor

from api import Iterator, IteratorOutputInfo, NodeContext
from api import Generator, IteratorOutputInfo, NodeContext
from nodes.properties.inputs import DirectoryInput
from nodes.properties.inputs.generic_inputs import BoolInput
from nodes.properties.outputs import DirectoryOutput, NumberOutput, TextOutput
Expand Down Expand Up @@ -43,14 +43,14 @@
),
],
iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]),
kind="newIterator",
kind="generator",
node_context=True,
)
def load_models_node(
context: NodeContext,
directory: Path,
fail_fast: bool,
) -> tuple[Iterator[tuple[ModelDescriptor, str, str, int]], Path]:
) -> tuple[Generator[tuple[ModelDescriptor, str, str, int]], Path]:
logger.debug(f"Iterating over models in directory: {directory}")

def load_model(path: Path, index: int):
Expand All @@ -62,4 +62,4 @@ def load_model(path: Path, index: int):
supported_filetypes = [".pt", ".pth", ".ckpt", ".safetensors"]
model_files: list[Path] = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model, fail_fast), directory
return Generator.from_list(model_files, load_model, fail_fast), directory
Loading

0 comments on commit 410e586

Please sign in to comment.