Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "Defer errors" option to applicable iterators #2442

Merged
merged 9 commits into from
Jan 7, 2024
31 changes: 22 additions & 9 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,30 +605,40 @@ def add_package(

@dataclass
class Iterator(Generic[I]):
iter_supplier: Callable[[], Iterable[I]]
iter_supplier: Callable[[], Iterable[I | Exception]]
expected_length: int
defer_errors: bool = False

@staticmethod
def from_iter(
iter_supplier: Callable[[], Iterable[I]], expected_length: int
iter_supplier: Callable[[], Iterable[I | Exception]],
expected_length: int,
defer_errors: bool = False,
) -> Iterator[I]:
return Iterator(iter_supplier, expected_length)
return Iterator(iter_supplier, expected_length, defer_errors=defer_errors)

@staticmethod
def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Iterator[I]:
def from_list(
l: list[L], map_fn: Callable[[L, int], I], defer_errors: bool = False
) -> Iterator[I]:
"""
Creates a new iterator from a list that is mapped using the given
function. The iterable will be equivalent to `map(map_fn, l)`.
"""

def supplier():
for i, x in enumerate(l):
yield map_fn(x, i)
try:
yield map_fn(x, i)
except Exception as e:
yield e

return Iterator(supplier, len(l))
return Iterator(supplier, len(l), defer_errors=defer_errors)

@staticmethod
def from_range(count: int, map_fn: Callable[[int], I]) -> Iterator[I]:
def from_range(
count: int, map_fn: Callable[[int], I], defer_errors: bool = False
) -> Iterator[I]:
"""
Creates a new iterator the given number of items where each item is
lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`.
Expand All @@ -637,9 +647,12 @@ def from_range(count: int, map_fn: Callable[[int], I]) -> Iterator[I]:

def supplier():
for i in range(count):
yield map_fn(i)
try:
yield map_fn(i)
except Exception as e:
yield e

return Iterator(supplier, count)
return Iterator(supplier, count, defer_errors=defer_errors)


N = TypeVar("N")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from api import Iterator, IteratorOutputInfo
from nodes.impl.ncnn.model import NcnnModelWrapper
from nodes.properties.inputs import DirectoryInput
from nodes.properties.inputs import BoolInput, DirectoryInput
from nodes.properties.outputs import (
DirectoryOutput,
NcnnModelOutput,
Expand All @@ -30,6 +30,10 @@
icon="MdLoop",
inputs=[
DirectoryInput(),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad model doesn't interrupt your batch.",
hint=True,
),
],
outputs=[
NcnnModelOutput(),
Expand All @@ -45,6 +49,7 @@
)
def load_models_node(
directory: str,
defer_errors: bool,
) -> tuple[Iterator[tuple[NcnnModelWrapper, str, str, int]], str]:
logger.debug(f"Iterating over models in directory: {directory}")

Expand Down Expand Up @@ -76,4 +81,4 @@ def load_model(filepath_pairs: tuple[str, str], index: int):

model_files = list(zip(param_files, bin_files))

return Iterator.from_list(model_files, load_model), directory
return Iterator.from_list(model_files, load_model, defer_errors), directory
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from api import Iterator, IteratorOutputInfo
from nodes.impl.onnx.model import OnnxModel
from nodes.properties.inputs import DirectoryInput
from nodes.properties.inputs import BoolInput, DirectoryInput
from nodes.properties.outputs import (
DirectoryOutput,
NumberOutput,
Expand All @@ -30,6 +30,10 @@
icon="MdLoop",
inputs=[
DirectoryInput(),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad model doesn't interrupt your batch.",
hint=True,
),
],
outputs=[
OnnxModelOutput(),
Expand All @@ -45,6 +49,7 @@
)
def load_models_node(
directory: str,
defer_errors: bool,
) -> tuple[Iterator[tuple[OnnxModel, str, str, int]], str]:
logger.debug(f"Iterating over models in directory: {directory}")

Expand All @@ -57,4 +62,4 @@ def load_model(path: str, index: int):
supported_filetypes = [".onnx"]
model_files = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model), directory
return Iterator.from_list(model_files, load_model, defer_errors), directory
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from api import Iterator, IteratorOutputInfo
from nodes.properties.inputs import DirectoryInput
from nodes.properties.inputs.generic_inputs import BoolInput
from nodes.properties.outputs import DirectoryOutput, NumberOutput, TextOutput
from nodes.properties.outputs.pytorch_outputs import ModelOutput
from nodes.utils.utils import list_all_files_sorted
Expand All @@ -26,6 +27,10 @@
icon="MdLoop",
inputs=[
DirectoryInput(),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad model doesn't interrupt your batch.",
hint=True,
),
],
outputs=[
ModelOutput(),
Expand All @@ -41,6 +46,7 @@
)
def load_models_node(
directory: str,
defer_errors: bool,
) -> tuple[Iterator[tuple[ModelDescriptor, str, str, int]], str]:
logger.debug(f"Iterating over models in directory: {directory}")

Expand All @@ -53,4 +59,4 @@ def load_model(path: str, index: int):
supported_filetypes = [".pt", ".pth", ".ckpt", ".safetensors"]
model_files: list[str] = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model), directory
return Iterator.from_list(model_files, load_model, defer_errors), directory
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
"Limit the number of images to iterate over. This can be useful for testing the iterator without having to iterate over all images."
)
),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad image doesn't interrupt your batch.",
hint=True,
),
],
outputs=[
ImageOutput("Image A"),
Expand All @@ -56,6 +60,7 @@ def load_image_pairs_node(
directory_b: str,
use_limit: bool,
limit: int,
defer_errors: bool,
) -> tuple[Iterator[tuple[np.ndarray, np.ndarray, str, str, str, str, int]], str, str]:
def load_images(filepaths: tuple[str, str], index: int):
path_a, path_b = filepaths
Expand Down Expand Up @@ -84,4 +89,8 @@ def load_images(filepaths: tuple[str, str], index: int):

image_files = list(zip(image_files_a, image_files_b))

return Iterator.from_list(image_files, load_images), directory_a, directory_b
return (
Iterator.from_list(image_files, load_images, defer_errors),
directory_a,
directory_b,
)
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def list_glob(directory: str, globexpr: str, ext_filter: list[str]) -> list[str]
"Limit the number of images to iterate over. This can be useful for testing the iterator without having to iterate over all images."
)
),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad image doesn't interrupt your batch.",
hint=True,
),
],
outputs=[
ImageOutput(),
Expand All @@ -91,6 +95,7 @@ def load_images_node(
glob_str: str,
use_limit: bool,
limit: int,
defer_errors: bool,
) -> tuple[Iterator[tuple[np.ndarray, str, str, int]], str]:
def load_image(path: str, index: int):
img, img_dir, basename = load_image_node(path)
Expand All @@ -110,4 +115,4 @@ def load_image(path: str, index: int):
if use_limit:
just_image_files = just_image_files[:limit]

return Iterator.from_list(just_image_files, load_image), directory
return Iterator.from_list(just_image_files, load_image, defer_errors), directory
58 changes: 38 additions & 20 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,32 +570,46 @@ async def update_progress():
# iterate
await self.__send_node_progress(node, times, 0, expected_length)

deferred_errors: list[str] = []
for values in iterator_output.iterator.iter_supplier():
# write current values to cache
iter_output = fill_partial_output(values)
self.cache.set(node.id, iter_output, StaticCaching)
try:
if isinstance(values, Exception):
raise values

# broadcast
await self.__send_node_broadcast(node, iter_output.output)
# write current values to cache
iter_output = fill_partial_output(values)
self.cache.set(node.id, iter_output, StaticCaching)

# run each of the output nodes
for output_node in output_nodes:
await self.process_regular_node(output_node)
# broadcast
await self.__send_node_broadcast(node, iter_output.output)

# run each of the collector nodes
for collector, timer, collector_node in collectors:
await self.progress.suspend()
iterate_inputs = await self.__gather_collector_inputs(collector_node)
await self.progress.suspend()
with timer.run():
run_collector_iterate(collector_node, iterate_inputs, collector)
# run each of the output nodes
for output_node in output_nodes:
await self.process_regular_node(output_node)

# clear cache for next iteration
self.cache.delete_many(all_iterated_nodes)
# run each of the collector nodes
for collector, timer, collector_node in collectors:
await self.progress.suspend()
iterate_inputs = await self.__gather_collector_inputs(
collector_node
)
await self.progress.suspend()
with timer.run():
run_collector_iterate(collector_node, iterate_inputs, collector)

await self.progress.suspend()
await update_progress()
await self.progress.suspend()
# clear cache for next iteration
self.cache.delete_many(all_iterated_nodes)

await self.progress.suspend()
await update_progress()
await self.progress.suspend()
except Aborted:
raise
except Exception as e:
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
if iterator_output.iterator.defer_errors:
deferred_errors.append(str(e))
else:
raise e

# reset cached value
self.cache.delete_many(all_iterated_nodes)
Expand Down Expand Up @@ -628,6 +642,10 @@ async def update_progress():
self.cache_strategy[collector_node.id],
)

if len(deferred_errors) > 0:
error_string = "- " + "\n- ".join(deferred_errors)
raise Exception(f"Errors occurred during iteration:\n{error_string}")

async def __process_nodes(self):
await self.__send_chain_start()

Expand Down