diff --git a/src/slurminade/__init__.py b/src/slurminade/__init__.py index 784beb9..53695d3 100644 --- a/src/slurminade/__init__.py +++ b/src/slurminade/__init__.py @@ -53,6 +53,12 @@ def clean_up(): - guard.py: Contains code to prevent you accidentally DDoSing your infrastructure. - options.py: Contains a simple data structure to save slurm options. """ +# set default logging +import logging +import sys + +# Set up the root logger to print to stdout by default +logging.basicConfig(level=logging.INFO, stream=sys.stdout) # flake8: noqa F401 from .function import slurmify, shell @@ -97,10 +103,3 @@ def clean_up(): "shell", "node_setup", ] - -# set default logging -import logging -import sys - -# Set up the root logger to print to stdout by default -logging.basicConfig(level=logging.INFO, stream=sys.stdout) diff --git a/src/slurminade/conf.py b/src/slurminade/conf.py index 479d78d..1a63f43 100644 --- a/src/slurminade/conf.py +++ b/src/slurminade/conf.py @@ -3,6 +3,7 @@ """ import json +import logging import os.path import typing from pathlib import Path @@ -20,7 +21,9 @@ def _load_conf(path: Path): else: return {} except Exception as e: - print(f"slurminade could not open default configuration {path}!\n{e!s}") + logging.getLogger("slurminade").error( + f"slurminade could not open default configuration {path}!\n{e!s}" + ) return {} diff --git a/src/slurminade/dispatcher.py b/src/slurminade/dispatcher.py index 0b43dc8..8aa0a27 100644 --- a/src/slurminade/dispatcher.py +++ b/src/slurminade/dispatcher.py @@ -21,13 +21,12 @@ from .function_call import FunctionCall from .function_map import FunctionMap, get_entry_point from .guard import dispatch_guard +from .job_reference import JobReference from .options import SlurmOptions # MAX_ARG_STRLEN on a Linux system with PAGE_SIZE 4096 is 131072 DEFAULT_MAX_ARG_LENGTH = 100000 -from .job_reference import JobReference - class Dispatcher(abc.ABC): """ diff --git a/src/slurminade/function.py b/src/slurminade/function.py index 941776d..bee029a 100644 --- a/src/slurminade/function.py +++ b/src/slurminade/function.py @@ -118,13 +118,12 @@ def __call__(self, *args, **kwargs): """ if self.call_policy == CallPolicy.LOCALLY: return self.run_locally(*args, **kwargs) - elif self.call_policy == CallPolicy.DISTRIBUTED: + if self.call_policy == CallPolicy.DISTRIBUTED: return self.distribute(*args, **kwargs) - elif self.call_policy == CallPolicy.DISTRIBUTED_BLOCKING: + if self.call_policy == CallPolicy.DISTRIBUTED_BLOCKING: return self.distribute_and_wait(*args, **kwargs) - else: - msg = "Unknown call policy." - raise RuntimeError(msg) + msg = "Unknown call policy." + raise RuntimeError(msg) def get_entry_point(self) -> Path: """ diff --git a/src/slurminade/function_map.py b/src/slurminade/function_map.py index a5ff94b..3868add 100644 --- a/src/slurminade/function_map.py +++ b/src/slurminade/function_map.py @@ -23,8 +23,8 @@ class FunctionMap: # The slurm node just executes the file content of a script, so the file name is lost. # slurminade will set this value in the beginning to reconstruct it. entry_point: typing.Optional[str] = None - _data = {} - _ids = set() + _data: typing.ClassVar[typing.Dict[str, typing.Callable]] = {} + _ids: typing.ClassVar[typing.Set[str]] = set() @staticmethod def get_id(func: typing.Callable) -> str: diff --git a/src/slurminade/options.py b/src/slurminade/options.py index b22fb3e..ee1ca70 100644 --- a/src/slurminade/options.py +++ b/src/slurminade/options.py @@ -11,8 +11,9 @@ class SlurmOptions(dict): def _items(self): for k, v in self.items(): if isinstance(v, dict): - v = SlurmOptions(**v) - yield k, v + yield k, SlurmOptions(**v) + else: + yield k, v def __hash__(self): return hash(tuple(sorted(hash((k, v)) for k, v in self._items())))