Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove code pertaining to old iterators #2267

Merged
merged 9 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,11 @@ def inner_wrapper(wrapped_func: T) -> T:

run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(
wrapped_func, node_type, p_inputs, p_outputs
),
lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs),
)
run_check(
NAME_CHECK_LEVEL,
lambda fix: check_naming_conventions(
wrapped_func, node_type, name, fix
),
lambda fix: check_naming_conventions(wrapped_func, name, fix),
)

if decorators is not None:
Expand Down
12 changes: 2 additions & 10 deletions backend/src/chain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,8 @@ def get_cache_strategies(chain: Chain) -> Dict[NodeId, CacheStrategy]:

for node in chain.nodes.values():
out_edges = chain.edges_from(node.id)
connected_to_child_node = any(
chain.nodes[e.target.id].parent for e in out_edges
)

strategy: CacheStrategy
if node.parent is None and connected_to_child_node:
# free nodes that are connected to child nodes need to live as the execution
strategy = StaticCaching
else:
strategy = CacheStrategy(len(out_edges))

strategy: CacheStrategy = CacheStrategy(len(out_edges))

result[node.id] = strategy

Expand Down
35 changes: 1 addition & 34 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class FunctionNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.parent: Union[NodeId, None] = None
self.is_helper: bool = False

def get_node(self) -> NodeData:
return registry.get_node(self.schema_id)
Expand All @@ -29,21 +27,6 @@ def has_side_effects(self) -> bool:
return self.get_node().side_effects


class IteratorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.parent: None = None
self.__node = None

def get_node(self) -> NodeData:
if self.__node is None:
node = registry.get_node(self.schema_id)
assert node.type == "iterator", "Invalid iterator node"
self.__node = node
return self.__node


class NewIteratorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
Expand Down Expand Up @@ -82,7 +65,7 @@ def has_side_effects(self) -> bool:
return self.get_node().side_effects


Node = Union[FunctionNode, IteratorNode, NewIteratorNode, CollectorNode]
Node = Union[FunctionNode, NewIteratorNode, CollectorNode]


class EdgeSource:
Expand Down Expand Up @@ -137,19 +120,3 @@ def remove_node(self, node_id: NodeId):
self.__edges_by_target[e.target.id].remove(e)
for e in self.__edges_by_target.pop(node_id, []):
self.__edges_by_source[e.source.id].remove(e)

if isinstance(node, IteratorNode):
# remove all child nodes
for n in list(self.nodes.values()):
if n.parent == node_id:
self.remove_node(n.id)


class SubChain:
def __init__(self, chain: Chain, iterator_id: NodeId):
self.nodes: Dict[NodeId, FunctionNode] = {}
self.iterator_id = iterator_id

for node in chain.nodes.values():
if node.parent is not None and node.parent == iterator_id:
self.nodes[node.id] = node
7 changes: 1 addition & 6 deletions backend/src/chain/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
EdgeSource,
EdgeTarget,
FunctionNode,
IteratorNode,
NewIteratorNode,
)
from .input import EdgeInput, Input, InputMap, ValueInput
Expand Down Expand Up @@ -54,16 +53,12 @@ def parse_json(json: List[JsonNode]) -> Tuple[Chain, InputMap]:
index_edges: List[IndexEdge] = []

for json_node in json:
if json_node["nodeType"] == "iterator":
node = IteratorNode(json_node["id"], json_node["schemaId"])
elif json_node["nodeType"] == "newIterator":
if json_node["nodeType"] == "newIterator":
node = NewIteratorNode(json_node["id"], json_node["schemaId"])
elif json_node["nodeType"] == "collector":
node = CollectorNode(json_node["id"], json_node["schemaId"])
else:
node = FunctionNode(json_node["id"], json_node["schemaId"])
node.parent = json_node["parent"]
node.is_helper = json_node["nodeType"] == "iteratorHelper"
chain.add_node(node)

inputs: List[Input] = []
Expand Down
35 changes: 2 additions & 33 deletions backend/src/chain/optimize.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,12 @@
from sanic.log import logger

from .chain import Chain, EdgeSource, IteratorNode, Node
from .chain import Chain, Node


def __has_side_effects(node: Node) -> bool:
if isinstance(node, IteratorNode) or node.is_helper:
# we assume that both iterators and their helper nodes always have side effects
return True
return node.has_side_effects()


def __outline_child_nodes(chain: Chain) -> bool:
"""
If a child node of an iterator is not downstream of any iterator helper node,
then this child node can be lifted out of the iterator (outlined) to be a free node.
"""
changed = False

for node in chain.nodes.values():
# we try to outline child nodes that are not iterator helper nodes
if node.parent is not None and not node.is_helper:

def has_no_parent(source: EdgeSource) -> bool:
n = chain.nodes.get(source.id)
assert n is not None
return n.parent is None

# we can only outline if all of its inputs are independent of the iterator
can_outline = all(has_no_parent(n.source) for n in chain.edges_to(node.id))
if can_outline:
node.parent = None
changed = True
logger.debug(
f"Chain optimization: Outlined {node.schema_id} node {node.id}"
)

return changed


def __removed_dead_nodes(chain: Chain) -> bool:
"""
If a node does not have side effects and has no downstream nodes, then it can be removed.
Expand All @@ -57,4 +26,4 @@ def __removed_dead_nodes(chain: Chain) -> bool:
def optimize(chain: Chain):
changed = True
while changed:
changed = __removed_dead_nodes(chain) or __outline_child_nodes(chain)
changed = __removed_dead_nodes(chain)
4 changes: 1 addition & 3 deletions backend/src/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

RunFn = Callable[..., Any]

NodeType = Literal[
"regularNode", "iterator", "iteratorHelper", "newIterator", "collector"
]
NodeType = Literal["regularNode", "newIterator", "collector"]

UpdateProgressFn = Callable[[str, float, Union[float, None]], Awaitable[None]]
17 changes: 1 addition & 16 deletions backend/src/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Dict, List, Literal, Optional, TypedDict, Union
from typing import Dict, Literal, Optional, TypedDict, Union

from base_types import InputId, NodeId, OutputId
from nodes.base_input import ErrorValue
Expand Down Expand Up @@ -36,15 +36,6 @@ class NodeStartData(TypedDict):
nodeId: NodeId


class IteratorProgressUpdateData(TypedDict):
percent: float
index: int
total: int
eta: float
iteratorId: NodeId
running: Optional[List[NodeId]]


class NodeProgressUpdateData(TypedDict):
percent: float
index: int
Expand Down Expand Up @@ -79,11 +70,6 @@ class NodeStartEvent(TypedDict):
data: NodeStartData


class IteratorProgressUpdateEvent(TypedDict):
event: Literal["iterator-progress-update"]
data: IteratorProgressUpdateData


class NodeProgressUpdateEvent(TypedDict):
event: Literal["node-progress-update"]
data: NodeProgressUpdateData
Expand All @@ -104,7 +90,6 @@ class BackendStateEvent(TypedDict):
ExecutionErrorEvent,
NodeFinishEvent,
NodeStartEvent,
IteratorProgressUpdateEvent,
NodeProgressUpdateEvent,
BackendStatusEvent,
BackendStateEvent,
Expand Down
22 changes: 1 addition & 21 deletions backend/src/node_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from enum import Enum
from typing import Any, Callable, Dict, List, NewType, Set, Union, cast, get_args

from custom_types import NodeType
from nodes.base_input import BaseInput
from nodes.base_output import BaseOutput

Expand Down Expand Up @@ -174,7 +173,6 @@ def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]):

def check_schema_types(
wrapped_func: Callable,
node_type: NodeType,
inputs: list[BaseInput],
outputs: list[BaseOutput],
):
Expand All @@ -195,19 +193,6 @@ def check_schema_types(
if not arg in ann:
raise CheckFailedError(f"Missing type annotation for '{arg}'")

if node_type == "iteratorHelper":
# iterator helpers have inputs that do not describe the arguments of the function, so we can't check them
return

if node_type == "iterator":
# the last argument of an iterator is the iterator context, so we have to account for that
context = [*ann.keys()][-1]
context_type = ann.pop(context)
if str(context_type) != "<class 'process.IteratorContext'>":
raise CheckFailedError(
f"Last argument of an iterator must be an IteratorContext, not '{context_type}'"
)

if arg_spec.varargs is not None:
if not arg_spec.varargs in ann:
raise CheckFailedError(f"Missing type annotation for '{arg_spec.varargs}'")
Expand Down Expand Up @@ -255,23 +240,18 @@ def check_schema_types(

def check_naming_conventions(
wrapped_func: Callable,
node_type: NodeType,
name: str,
fix: bool,
):
expected_name = (
name.lower()
.replace(" (iterator)", "")
.replace(" ", "_")
.replace("-", "_")
.replace("(", "")
.replace(")", "")
.replace("&", "and")
)

if node_type == "iteratorHelper":
expected_name = "iterator_helper_" + expected_name

func_name = wrapped_func.__name__
file_path = pathlib.Path(inspect.getfile(wrapped_func))
file_name = file_path.stem
Expand All @@ -289,7 +269,7 @@ def check_naming_conventions(
file_path.write_text(fixed_code, encoding="utf-8")

# check file name
if node_type != "iteratorHelper" and file_name != expected_name:
if file_name != expected_name:
if not fix:
raise CheckFailedError(
f"File name is '{file_name}.py', but it should be '{expected_name}.py'"
Expand Down
5 changes: 0 additions & 5 deletions backend/src/nodes/properties/inputs/generic_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,6 @@ def make_optional(self):
raise ValueError("ColorInput cannot be made optional")


def IteratorInput():
"""Input for showing that an iterator automatically handles the input"""
return BaseInput("IteratorAuto", "Auto (Iterator)", has_handle=False)


class VideoContainer(Enum):
MKV = "mkv"
MP4 = "mp4"
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_ncnn/ncnn/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
FileNameOutput("Name", of_input=0).with_id(1),
],
see_also=[
"chainner:ncnn:model_file_iterator",
"chainner:ncnn:load_models",
],
)
def load_model_node(
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_onnx/onnx/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
FileNameOutput("Name", of_input=0).with_id(1),
],
see_also=[
"chainner:onnx:model_file_iterator",
"chainner:onnx:load_models",
],
)
def load_model_node(path: str) -> Tuple[OnnxModel, str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def parse_ckpt_state_dict(checkpoint: dict):
FileNameOutput("Name", of_input=0).with_id(1),
],
see_also=[
"chainner:pytorch:model_file_iterator",
"chainner:pytorch:load_models",
],
)
def load_model_node(path: str) -> Tuple[PyTorchModel, str, str]:
Expand Down
Loading