From 93ec608591bb0966168c4b6bd2844e6f4894753e Mon Sep 17 00:00:00 2001 From: Grzegorz Rusin Date: Tue, 2 Apr 2024 17:29:30 +0200 Subject: [PATCH 1/4] handle partial --- pyproject.toml | 1 + src/databricks/labs/blueprint/parallel.py | 19 +++++++++++++++++- tests/unit/test_parallel.py | 24 ++++++++++++++++++++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4709962..ce1d8b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "mypy", "types-PyYAML", "types-requests", + "black" ] python="3.10" diff --git a/src/databricks/labs/blueprint/parallel.py b/src/databricks/labs/blueprint/parallel.py index e8a350e..21683b7 100644 --- a/src/databricks/labs/blueprint/parallel.py +++ b/src/databricks/labs/blueprint/parallel.py @@ -9,6 +9,7 @@ import threading from collections.abc import Callable, Collection, Sequence from concurrent.futures import ThreadPoolExecutor +from functools import partial from typing import Generic, TypeVar MIN_THREADS = 8 @@ -136,10 +137,26 @@ def _wrap_result(func, name): @functools.wraps(func) def inner(*args, **kwargs): + def _get_signature(f): + if isinstance(f, partial): + try: + args = [] + args.extend(repr(x) for x in f.args) + args.extend(f"{k}={v!r}" for (k, v) in f.keywords.items()) + args_str = ", ".join(args) + if args_str: + return f"{name}({args_str})" + return name + except Exception: # pylint: disable=broad-exception-caught + return str(f) + + return name + try: return func(*args, **kwargs), None except Exception as err: # pylint: disable=broad-exception-caught - logger.error(f"{name} task failed: {err!s}", exc_info=err) + signature = _get_signature(func) + logger.error(f"{signature} task failed: {err!s}", exc_info=err) return None, err return inner diff --git a/tests/unit/test_parallel.py b/tests/unit/test_parallel.py index fb2f44e..006515e 100644 --- a/tests/unit/test_parallel.py +++ b/tests/unit/test_parallel.py @@ -2,7 +2,7 @@ from databricks.sdk.core import DatabricksError -from databricks.labs.blueprint.parallel import Threads +from databricks.labs.blueprint.parallel import Threads, partial def _predictable_messages(caplog): @@ -117,3 +117,25 @@ def works(): assert [True, True, True, True] == results assert 0 == len(errors) assert ["Finished 'testing' tasks: 100% results available (4/4)"] == _predictable_messages(caplog) + + +def test_odd_partial_failed(caplog): + caplog.set_level(logging.INFO) + + def fails_on_odd(n=1): + if n % 2: + msg = "failed" + raise DatabricksError(msg) + + tasks = [partial(fails_on_odd, n=1), partial(fails_on_odd, 1), partial(fails_on_odd), partial(fails_on_odd, n=3)] + results, errors = Threads.gather("testing", tasks) + + assert [] == results + assert 4 == len(errors) + assert [ + "All 'testing' tasks failed!!!", + "testing task failed: failed", + "testing(1) task failed: failed", + "testing(n=1) task failed: failed", + "testing(n=3) task failed: failed", + ] == _predictable_messages(caplog) From bb88635b6a034b92cca8ff6da6032efc6e7038a4 Mon Sep 17 00:00:00 2001 From: Grzegorz Rusin Date: Tue, 2 Apr 2024 18:13:08 +0200 Subject: [PATCH 2/4] remove explicit partial --- src/databricks/labs/blueprint/parallel.py | 3 +-- tests/unit/test_parallel.py | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/databricks/labs/blueprint/parallel.py b/src/databricks/labs/blueprint/parallel.py index 21683b7..aeec9c6 100644 --- a/src/databricks/labs/blueprint/parallel.py +++ b/src/databricks/labs/blueprint/parallel.py @@ -9,7 +9,6 @@ import threading from collections.abc import Callable, Collection, Sequence from concurrent.futures import ThreadPoolExecutor -from functools import partial from typing import Generic, TypeVar MIN_THREADS = 8 @@ -138,7 +137,7 @@ def _wrap_result(func, name): @functools.wraps(func) def inner(*args, **kwargs): def _get_signature(f): - if isinstance(f, partial): + if isinstance(f, functools.partial): try: args = [] args.extend(repr(x) for x in f.args) diff --git a/tests/unit/test_parallel.py b/tests/unit/test_parallel.py index 006515e..a6572ca 100644 --- a/tests/unit/test_parallel.py +++ b/tests/unit/test_parallel.py @@ -1,8 +1,9 @@ import logging +from functools import partial from databricks.sdk.core import DatabricksError -from databricks.labs.blueprint.parallel import Threads, partial +from databricks.labs.blueprint.parallel import Threads def _predictable_messages(caplog): @@ -122,12 +123,20 @@ def works(): def test_odd_partial_failed(caplog): caplog.set_level(logging.INFO) - def fails_on_odd(n=1): + def fails_on_odd(n=1, dummy=None): + if isinstance(n, str): + raise RuntimeError("strings are not supported!") + if n % 2: msg = "failed" raise DatabricksError(msg) - tasks = [partial(fails_on_odd, n=1), partial(fails_on_odd, 1), partial(fails_on_odd), partial(fails_on_odd, n=3)] + tasks = [ + partial(fails_on_odd, n=1), + partial(fails_on_odd, 1, dummy="6"), + partial(fails_on_odd), + partial(fails_on_odd, n="aaa"), + ] results, errors = Threads.gather("testing", tasks) assert [] == results @@ -135,7 +144,7 @@ def fails_on_odd(n=1): assert [ "All 'testing' tasks failed!!!", "testing task failed: failed", - "testing(1) task failed: failed", + "testing(1, dummy='6') task failed: failed", + "testing(n='aaa') task failed: strings are not supported!", "testing(n=1) task failed: failed", - "testing(n=3) task failed: failed", ] == _predictable_messages(caplog) From 2b23ada334501c425d94dc3f171ebac43df7bdc1 Mon Sep 17 00:00:00 2001 From: Grzegorz Rusin Date: Wed, 3 Apr 2024 12:46:59 +0200 Subject: [PATCH 3/4] refactor --- pyproject.toml | 3 +- src/databricks/labs/blueprint/parallel.py | 35 ++++++++++++----------- tests/unit/test_parallel.py | 5 ++++ 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ce1d8b0..b1c681c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,8 +48,7 @@ dependencies = [ "isort>=2.5.0", "mypy", "types-PyYAML", - "types-requests", - "black" + "types-requests" ] python="3.10" diff --git a/src/databricks/labs/blueprint/parallel.py b/src/databricks/labs/blueprint/parallel.py index aeec9c6..b50fe7a 100644 --- a/src/databricks/labs/blueprint/parallel.py +++ b/src/databricks/labs/blueprint/parallel.py @@ -130,31 +130,34 @@ def _progress_report(self, _): msg = f"{self._name} {self._completed_cnt}/{total_cnt}, rps: {rps:.3f}/sec" logger.info(msg) + @staticmethod + def _get_result_function_signature(func, name): + if isinstance(func, functools.partial): + # try to build up signature, this should never fail + try: + args = [] + args.extend(repr(x) for x in func.args) + args.extend(f"{k}={v!r}" for (k, v) in func.keywords.items()) + args_str = ", ".join(args) + if args_str: + return f"{name}({args_str})" + return name + # but if it would ever fail, better return generic serialized name, than messing up traceback even more... + except Exception: # pylint: disable=broad-exception-caught + return str(func) + + return name + @staticmethod def _wrap_result(func, name): """This method emulates GoLang's error return style""" @functools.wraps(func) def inner(*args, **kwargs): - def _get_signature(f): - if isinstance(f, functools.partial): - try: - args = [] - args.extend(repr(x) for x in f.args) - args.extend(f"{k}={v!r}" for (k, v) in f.keywords.items()) - args_str = ", ".join(args) - if args_str: - return f"{name}({args_str})" - return name - except Exception: # pylint: disable=broad-exception-caught - return str(f) - - return name - try: return func(*args, **kwargs), None except Exception as err: # pylint: disable=broad-exception-caught - signature = _get_signature(func) + signature = Threads._get_result_function_signature(func, name) logger.error(f"{signature} task failed: {err!s}", exc_info=err) return None, err diff --git a/tests/unit/test_parallel.py b/tests/unit/test_parallel.py index a6572ca..a60b7da 100644 --- a/tests/unit/test_parallel.py +++ b/tests/unit/test_parallel.py @@ -137,6 +137,11 @@ def fails_on_odd(n=1, dummy=None): partial(fails_on_odd), partial(fails_on_odd, n="aaa"), ] + + signatures = [Threads._get_result_function_signature(func, "test") for func in tasks] + + assert signatures == ["test(n=1)", "test(1, dummy='6')", "test", "test(n='aaa')"] + results, errors = Threads.gather("testing", tasks) assert [] == results From c26079fad7ba86e4c8f165b9542e1541508f48cb Mon Sep 17 00:00:00 2001 From: Grzegorz Rusin Date: Wed, 3 Apr 2024 13:31:42 +0200 Subject: [PATCH 4/4] better code style --- src/databricks/labs/blueprint/parallel.py | 36 +++++++++++------------ tests/unit/test_parallel.py | 4 --- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/databricks/labs/blueprint/parallel.py b/src/databricks/labs/blueprint/parallel.py index b50fe7a..a1faa6c 100644 --- a/src/databricks/labs/blueprint/parallel.py +++ b/src/databricks/labs/blueprint/parallel.py @@ -132,24 +132,24 @@ def _progress_report(self, _): @staticmethod def _get_result_function_signature(func, name): - if isinstance(func, functools.partial): - # try to build up signature, this should never fail - try: - args = [] - args.extend(repr(x) for x in func.args) - args.extend(f"{k}={v!r}" for (k, v) in func.keywords.items()) - args_str = ", ".join(args) - if args_str: - return f"{name}({args_str})" - return name - # but if it would ever fail, better return generic serialized name, than messing up traceback even more... - except Exception: # pylint: disable=broad-exception-caught - return str(func) - - return name + if not isinstance(func, functools.partial): + return name + + # try to build up signature, this should never fail + try: + args = [] + args.extend(repr(x) for x in func.args) + args.extend(f"{k}={v!r}" for (k, v) in func.keywords.items()) + args_str = ", ".join(args) + if args_str: + return f"{name}({args_str})" + return name + # but if it would ever fail, better return generic serialized name, than messing up traceback even more... + except Exception: # pylint: disable=broad-exception-caught + return str(func) - @staticmethod - def _wrap_result(func, name): + @classmethod + def _wrap_result(cls, func, name): """This method emulates GoLang's error return style""" @functools.wraps(func) @@ -157,7 +157,7 @@ def inner(*args, **kwargs): try: return func(*args, **kwargs), None except Exception as err: # pylint: disable=broad-exception-caught - signature = Threads._get_result_function_signature(func, name) + signature = cls._get_result_function_signature(func, name) logger.error(f"{signature} task failed: {err!s}", exc_info=err) return None, err diff --git a/tests/unit/test_parallel.py b/tests/unit/test_parallel.py index a60b7da..8ca2c25 100644 --- a/tests/unit/test_parallel.py +++ b/tests/unit/test_parallel.py @@ -138,10 +138,6 @@ def fails_on_odd(n=1, dummy=None): partial(fails_on_odd, n="aaa"), ] - signatures = [Threads._get_result_function_signature(func, "test") for func in tasks] - - assert signatures == ["test(n=1)", "test(1, dummy='6')", "test", "test(n='aaa')"] - results, errors = Threads.gather("testing", tasks) assert [] == results