Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed iterator and non-iterator inputs/outputs #2280

Merged
merged 6 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from sanic.log import logger

import navi
from base_types import InputId, OutputId
from custom_types import NodeType, RunFn
from node_check import (
Expand Down Expand Up @@ -80,6 +81,34 @@ class DefaultNode(TypedDict):
schemaId: str


class IteratorInputInfo:
def __init__(
self,
inputs: int | InputId | List[int] | List[InputId] | List[int | InputId],
length_type: navi.ExpressionJson = "uint",
) -> None:
self.inputs: List[InputId] = (
[InputId(x) for x in inputs]
if isinstance(inputs, list)
else [InputId(inputs)]
)
self.length_type: navi.ExpressionJson = length_type


class IteratorOutputInfo:
def __init__(
self,
outputs: int | OutputId | List[int] | List[OutputId] | List[int | OutputId],
length_type: navi.ExpressionJson = "uint",
) -> None:
self.outputs: List[OutputId] = (
[OutputId(x) for x in outputs]
if isinstance(outputs, list)
else [OutputId(outputs)]
)
self.length_type: navi.ExpressionJson = length_type


@dataclass(frozen=True)
class NodeData:
schema_id: str
Expand All @@ -93,12 +122,25 @@ class NodeData:
outputs: List[BaseOutput]
group_layout: List[InputId | NestedIdGroup]

iterator_inputs: List[IteratorInputInfo]
iterator_outputs: List[IteratorOutputInfo]

side_effects: bool
deprecated: bool
features: List[FeatureId]

run: RunFn

@property
def single_iterator_input(self) -> IteratorInputInfo:
assert len(self.iterator_inputs) == 1
return self.iterator_inputs[0]

@property
def single_iterator_output(self) -> IteratorOutputInfo:
assert len(self.iterator_outputs) == 1
return self.iterator_outputs[0]


T = TypeVar("T", bound=RunFn)
S = TypeVar("S")
Expand Down Expand Up @@ -129,6 +171,8 @@ def register(
see_also: List[str] | str | None = None,
features: List[FeatureId] | FeatureId | None = None,
limited_to_8bpc: bool | str = False,
iterator_inputs: List[IteratorInputInfo] | IteratorInputInfo | None = None,
iterator_outputs: List[IteratorOutputInfo] | IteratorOutputInfo | None = None,
):
if not isinstance(description, str):
description = "\n\n".join(description)
Expand All @@ -153,6 +197,16 @@ def to_list(x: List[S] | S | None) -> List[S]:
see_also = to_list(see_also)
features = to_list(features)

iterator_inputs = to_list(iterator_inputs)
iterator_outputs = to_list(iterator_outputs)

if node_type == "collector":
assert len(iterator_inputs) == 1 and len(iterator_outputs) == 0
elif node_type == "newIterator":
assert len(iterator_inputs) == 0 and len(iterator_outputs) == 1
else:
assert len(iterator_inputs) == 0 and len(iterator_outputs) == 0

def run_check(level: CheckLevel, run: Callable[[bool], None]):
if level == CheckLevel.NONE:
return
Expand All @@ -170,10 +224,11 @@ def inner_wrapper(wrapped_func: T) -> T:
p_inputs, group_layout = _process_inputs(inputs)
p_outputs = _process_outputs(outputs)

run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs),
)
if node_type == "regularNode":
run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs),
)
run_check(
NAME_CHECK_LEVEL,
lambda fix: check_naming_conventions(wrapped_func, name, fix),
Expand All @@ -193,6 +248,8 @@ def inner_wrapper(wrapped_func: T) -> T:
inputs=p_inputs,
group_layout=group_layout,
outputs=p_outputs,
iterator_inputs=iterator_inputs,
iterator_outputs=iterator_outputs,
side_effects=side_effects,
deprecated=deprecated,
features=features,
Expand Down Expand Up @@ -556,6 +613,20 @@ def supplier():

return Iterator(supplier, len(l))

@staticmethod
def from_range(count: int, map_fn: Callable[[int], I]) -> "Iterator[I]":
"""
Creates a new iterator the given number of items where each item is
lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`.
"""
assert count >= 0

def supplier():
for i in range(count):
yield map_fn(i)

return Iterator(supplier, count)


N = TypeVar("N")
R = TypeVar("R")
Expand Down
39 changes: 35 additions & 4 deletions backend/src/chain/cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import gc
from typing import Dict, Generic, Iterable, Optional, Set, TypeVar
from typing import Dict, Generic, Iterable, List, Optional, Set, TypeVar

from sanic.log import logger

from .chain import Chain, NodeId
from .chain import Chain, Edge, FunctionNode, NewIteratorNode, NodeId


class CacheStrategy:
Expand All @@ -29,12 +31,41 @@ def no_caching(self) -> bool:
def get_cache_strategies(chain: Chain) -> Dict[NodeId, CacheStrategy]:
"""Create a map with the cache strategies for all nodes in the given chain."""

iterator_map = chain.get_parent_iterator_map()

def any_are_iterated(out_edges: List[Edge]) -> bool:
for out_edge in out_edges:
target = chain.nodes[out_edge.target.id]
if isinstance(target, FunctionNode) and iterator_map[target] is not None:
return True
return False

result: Dict[NodeId, CacheStrategy] = {}

for node in chain.nodes.values():
out_edges = chain.edges_from(node.id)
strategy: CacheStrategy

strategy: CacheStrategy = CacheStrategy(len(out_edges))
out_edges = chain.edges_from(node.id)
if isinstance(node, FunctionNode) and iterator_map[node] is not None:
# the function node is iterated
strategy = CacheStrategy(len(out_edges))
else:
# the node is NOT implicitly iterated

if isinstance(node, NewIteratorNode):
# we only care about non-iterator outputs
iterator_output = node.data.single_iterator_output
out_edges = [
out_edge
for out_edge in out_edges
if out_edge.source.output_id not in iterator_output.outputs
]

if any_are_iterated(out_edges):
# some output is used by an iterated node
strategy = StaticCaching
else:
strategy = CacheStrategy(len(out_edges))

result[node.id] = strategy

Expand Down
96 changes: 69 additions & 27 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Callable, Dict, List, TypeVar, Union
from __future__ import annotations

from typing import Callable, Dict, List, Set, TypeVar, Union

from api import NodeData, registry
from base_types import InputId, NodeId, OutputId
Expand All @@ -19,50 +21,33 @@ class FunctionNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id

def get_node(self) -> NodeData:
return registry.get_node(self.schema_id)
self.data: NodeData = registry.get_node(schema_id)
assert self.data.type == "regularNode"

def has_side_effects(self) -> bool:
return self.get_node().side_effects
return self.data.side_effects


class NewIteratorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.parent: None = None
self.__node = None
self.is_helper: bool = False

def get_node(self) -> NodeData:
if self.__node is None:
node = registry.get_node(self.schema_id)
assert node.type == "newIterator", "Invalid iterator node"
self.__node = node
return self.__node
self.data: NodeData = registry.get_node(schema_id)
assert self.data.type == "newIterator"

def has_side_effects(self) -> bool:
return self.get_node().side_effects
return self.data.side_effects


class CollectorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.parent: None = None
self.__node = None
self.is_helper: bool = False

def get_node(self) -> NodeData:
if self.__node is None:
node = registry.get_node(self.schema_id)
assert node.type == "collector", "Invalid iterator node"
self.__node = node
return self.__node
self.data: NodeData = registry.get_node(schema_id)
assert self.data.type == "collector"

def has_side_effects(self) -> bool:
return self.get_node().side_effects
return self.data.side_effects


Node = Union[FunctionNode, NewIteratorNode, CollectorNode]
Expand Down Expand Up @@ -120,3 +105,60 @@ def remove_node(self, node_id: NodeId):
self.__edges_by_target[e.target.id].remove(e)
for e in self.__edges_by_target.pop(node_id, []):
self.__edges_by_source[e.source.id].remove(e)

def topological_order(self) -> List[NodeId]:
"""
Returns all nodes in topological order.
"""
result: List[NodeId] = []
visited: Set[NodeId] = set()

def visit(node_id: NodeId):
if node_id in visited:
return
visited.add(node_id)

for e in self.edges_from(node_id):
visit(e.target.id)

result.append(node_id)

for node_id in self.nodes:
visit(node_id)

return result

def get_parent_iterator_map(self) -> Dict[FunctionNode, NewIteratorNode | None]:
"""
Returns a map of all function nodes to their parent iterator node (if any).
"""
iterator_cache: Dict[FunctionNode, NewIteratorNode | None] = {}

def get_iterator(r: FunctionNode) -> NewIteratorNode | None:
if r in iterator_cache:
return iterator_cache[r]

iterator: NewIteratorNode | None = None

for in_edge in self.edges_to(r.id):
source = self.nodes[in_edge.source.id]
if isinstance(source, FunctionNode):
iterator = get_iterator(source)
if iterator is not None:
break
elif isinstance(source, NewIteratorNode):
if (
in_edge.source.output_id
in source.data.single_iterator_output.outputs
):
iterator = source
break

iterator_cache[r] = iterator
return iterator

for node in self.nodes.values():
if isinstance(node, FunctionNode):
get_iterator(node)

return iterator_cache
4 changes: 2 additions & 2 deletions backend/src/chain/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def parse_json(json: List[JsonNode]) -> Tuple[Chain, InputMap]:
input_map.set(node.id, inputs)

for index_edge in index_edges:
source_node = chain.nodes[index_edge.from_id].get_node()
target_node = chain.nodes[index_edge.to_id].get_node()
source_node = chain.nodes[index_edge.from_id].data
target_node = chain.nodes[index_edge.to_id].data

chain.add_edge(
Edge(
Expand Down
8 changes: 2 additions & 6 deletions backend/src/chain/optimize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from sanic.log import logger

from .chain import Chain, Node


def __has_side_effects(node: Node) -> bool:
return node.has_side_effects()
from .chain import Chain


def __removed_dead_nodes(chain: Chain) -> bool:
Expand All @@ -14,7 +10,7 @@ def __removed_dead_nodes(chain: Chain) -> bool:
changed = False

for node in list(chain.nodes.values()):
is_dead = len(chain.edges_from(node.id)) == 0 and not __has_side_effects(node)
is_dead = len(chain.edges_from(node.id)) == 0 and not node.has_side_effects()
if is_dead:
chain.remove_node(node.id)
changed = True
Expand Down
5 changes: 5 additions & 0 deletions backend/src/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from base_types import InputId, NodeId, OutputId
from nodes.base_input import ErrorValue

# Data of events


class FinishData(TypedDict):
message: str
Expand Down Expand Up @@ -50,6 +52,9 @@ class BackendStatusData(TypedDict):
statusProgress: Optional[float]


# Events


class FinishEvent(TypedDict):
event: Literal["finish"]
data: FinishData
Expand Down
Loading