diff --git a/README.rst b/README.rst index 305b114..c458e55 100644 --- a/README.rst +++ b/README.rst @@ -59,6 +59,10 @@ A simple script could look like this: # compatible with any environment. # You can enforce slurm with `slurminade.set_dispatcher(slurminade.SlurmDispatcher())` + @slurminade.node_setup + def setup(): + print("I will run automatically on every slurm node at the beginning!") + # use this decorator to make a function distributable with slurm @slurminade.slurmify( @@ -353,6 +357,7 @@ The project is reasonably easy: Changes ------- +- 0.9.0: Lots of improvements. - 0.8.1: Bugfix and automatic detection of wrong usage when using ``Batch`` with ``wait_for``. - 0.8.0: Added extensive logging and improved typing. - 0.7.0: Warning if a Batch is flushed multiple times, as we noticed this to be a common indentation error. diff --git a/pyproject.toml b/pyproject.toml index f7ff4f8..5d91ae0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ where = ["src"] [project] name = "slurminade" -dynamic = ["version"] +version = "1.0.0" authors = [ { name = "TU Braunschweig, IBR, Algorithms Group (Dominik Krupke)", email = "krupke@ibr.cs.tu-bs.de" }, ] @@ -21,14 +21,14 @@ classifiers = [ ] keywords=["slurm"] dependencies = [ - "simple_slurm>=0.2.6" + "simple_slurm>=0.2.6", + "click", ] [project.urls] Homepage = "https://github.com/d-krupke/slurminade" Issues = "https://github.com/d-krupke/slurminade/issues" -[tool.setuptools_scm] [tool.pytest.ini_options] minversion = "6.0" diff --git a/src/slurminade/__init__.py b/src/slurminade/__init__.py index 528c64d..f247b1a 100644 --- a/src/slurminade/__init__.py +++ b/src/slurminade/__init__.py @@ -73,6 +73,7 @@ def clean_up(): SubprocessDispatcher, ) from .function_map import set_entry_point +from .node_setup import node_setup __all__ = [ "slurmify", @@ -90,10 +91,12 @@ def clean_up(): "TestDispatcher", "SubprocessDispatcher", "set_entry_point", + "node_setup", ] # set default logging import logging +import sys -logging.getLogger("slurminade").setLevel(logging.INFO) -logging.getLogger("slurminade").addHandler(logging.StreamHandler()) +# Set up the root logger to print to stdout by default +logging.basicConfig(level=logging.INFO, stream=sys.stdout) diff --git a/src/slurminade/batch.py b/src/slurminade/batch.py index a4e5049..778f2d2 100644 --- a/src/slurminade/batch.py +++ b/src/slurminade/batch.py @@ -134,7 +134,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if exc_type: - print("Aborted due to exception.") + logging.getLogger("slurminade").error("Aborted due to exception.") return self.flush() set_dispatcher(self.subdispatcher) diff --git a/src/slurminade/dispatcher.py b/src/slurminade/dispatcher.py index 159650e..e7ff5bb 100644 --- a/src/slurminade/dispatcher.py +++ b/src/slurminade/dispatcher.py @@ -367,19 +367,23 @@ def create_slurminade_command( :param max_arg_length: The maximum allowed length of a command line argument. :returns: A string representing the command to be executed in the terminal. """ - command = f"{sys.executable} -m slurminade.execute {shlex.quote(get_entry_point())}" + command = f"{sys.executable} -m slurminade.execute --root {shlex.quote(get_entry_point())}" # Serialize function calls as JSON - serialized_calls = json.dumps([f.to_json() for f in funcs]) + json_calls = json.dumps([f.to_json() for f in funcs]) + serialized_calls = shlex.quote(json_calls) - if len(shlex.quote(serialized_calls)) > max_arg_length: + if len(serialized_calls) > max_arg_length: # The argument is too long, create temporary file for the JSON fd, filename = mkstemp(prefix="slurminade_", suffix=".json", text=True, dir=".") + logging.getLogger("slurminade").info( + f"Long function calls. Serializing function calls to temporary file {filename}" + ) with os.fdopen(fd, "w") as f: - f.write(serialized_calls) - command += f" temp {shlex.quote(filename)}" + f.write(json_calls) + command += f" --fromfile {filename}" else: - command += f" arg {shlex.quote(serialized_calls)}" + command += f" --calls {serialized_calls}" return command @@ -427,6 +431,11 @@ def dispatch( :param options: The slurm options to be used. :return: The job id. """ + funcs = list(funcs) if not isinstance(funcs, FunctionCall) else [funcs] + for func in funcs: + if not FunctionMap.check_id(func.func_id): + msg = f"Function '{func.func_id}' cannot be called from the given entry point." + raise KeyError(msg) return get_dispatcher()(funcs, options) diff --git a/src/slurminade/execute.py b/src/slurminade/execute.py index e9b60db..58a26d4 100644 --- a/src/slurminade/execute.py +++ b/src/slurminade/execute.py @@ -2,63 +2,69 @@ This module provides the starting point for the slurm node. You do not have to call anything of this file yourself. """ + import json -import pathlib -import sys +import logging +from pathlib import Path + +import click -from .function import SlurmFunction +from .function import FunctionMap, SlurmFunction from .function_map import set_entry_point from .guard import prevent_distribution -def parse_args(): - batch_file_path = pathlib.Path( - sys.argv[1] - ) # the file with the code (function definition) - if not batch_file_path.exists(): - msg = "Batch file does not exist.\n" - msg += f" File: {batch_file_path}\n" - msg += "This should not happen. Please report this bug." - raise RuntimeError(msg) - # determine whether function calls are provided as an argument or in a temp file. - mode = sys.argv[2] - if mode == "arg": - function_calls = json.loads(sys.argv[3]) - elif mode == "temp": - tmp_file_path = pathlib.Path(sys.argv[3]) - if not tmp_file_path.exists(): - msg = "Using temporary file for passing function arguments, but file does not exist.\n" - msg += f" File: {tmp_file_path}\n" - msg += "This should not happen. Please report this bug." - raise RuntimeError(msg) - with open(tmp_file_path) as f: - function_calls = json.load(f) - tmp_file_path.unlink() # delete the temp file - else: - msg = "Unknown function call mode. Expected 'arg' or 'temp'.\n" - msg += f" Got: {mode}\n" - msg += "This should not happen. Please report this bug." - raise RuntimeError(msg) - assert isinstance(function_calls, list), "Expected a list of dicts" - return batch_file_path, function_calls - - -def main(): +@click.command() +@click.option( + "--root", + type=click.Path(exists=True), + help="The root file of the task.", + required=True, +) +@click.option("--calls", type=str, help="The function calls.", required=False) +@click.option( + "--fromfile", + type=click.Path(exists=True), + help="The file to read the function calls from.", + required=False, +) +@click.option( + "--listfuncs", + help="List all available functions.", + default=False, + is_flag=True, + required=False, +) +def main(root, calls, fromfile, listfuncs): prevent_distribution() # make sure, the code on the node does not distribute itself. - batch_file, function_calls = parse_args() - set_entry_point(batch_file) - with open(batch_file) as f: + set_entry_point(root) + with open(root) as f: code = "".join(f.readlines()) - # Workaround as otherwise __name__ is not defined - global __name__ - __name__ = None - glob = dict(globals()) - glob["__file__"] = batch_file + glob["__file__"] = root + glob["__name__"] = None exec(code, glob) + if listfuncs: + print(json.dumps(FunctionMap.get_all_ids())) # noqa T201 + return + if calls: + function_calls = json.loads(calls) + elif fromfile: + with open(fromfile) as f: + logging.getLogger("slurminade").info( + f"Reading function calls from {fromfile}." + ) + function_calls = json.load(f) + Path(fromfile).unlink() + else: + msg = "No function calls provided." + raise ValueError(msg) + if not isinstance(function_calls, list): + msg = "Expected a list of function calls." + raise ValueError(msg) # Execute the functions for fc in function_calls: SlurmFunction.call(fc["func_id"], *fc.get("args", []), **fc.get("kwargs", {})) diff --git a/src/slurminade/function_map.py b/src/slurminade/function_map.py index 3995bb6..b230b19 100644 --- a/src/slurminade/function_map.py +++ b/src/slurminade/function_map.py @@ -4,8 +4,11 @@ """ import inspect +import json import os import pathlib +import subprocess +import sys import typing @@ -21,6 +24,7 @@ class FunctionMap: # slurminade will set this value in the beginning to reconstruct it. entry_point: typing.Optional[str] = None _data = {} + _ids = set() @staticmethod def get_id(func: typing.Callable) -> str: @@ -92,6 +96,27 @@ def call( raise KeyError(msg) return FunctionMap._data[func_id](*args, **kwargs) + @staticmethod + def check_id(func_id: str) -> bool: + if func_id in FunctionMap._ids: + return True + cmd = [ + sys.executable, + "-m", + "slurminade.execute", + "--root", + get_entry_point(), + "--listfuncs", + ] + out = subprocess.check_output(cmd).decode() + ids = json.loads(out) + FunctionMap._ids = set(ids) + return func_id in FunctionMap._ids + + @staticmethod + def get_all_ids() -> typing.List[str]: + return list(FunctionMap._data.keys()) + def set_entry_point(entry_point: typing.Union[str, pathlib.Path]) -> None: """ diff --git a/src/slurminade/guard.py b/src/slurminade/guard.py index 55d6249..5077491 100644 --- a/src/slurminade/guard.py +++ b/src/slurminade/guard.py @@ -15,10 +15,12 @@ _exec_flag = False +def on_slurm_node(): + global _exec_flag + return _exec_flag def guard_recursive_distribution(): - global _exec_flag - if _exec_flag: + if on_slurm_node(): msg = """ You tried to distribute a task recursively. This is not allowed by default, because it probably indicates a bug in your code. To save you from accidentally diff --git a/src/slurminade/node_setup.py b/src/slurminade/node_setup.py new file mode 100644 index 0000000..1bed61e --- /dev/null +++ b/src/slurminade/node_setup.py @@ -0,0 +1,17 @@ +import inspect +import typing +from .guard import on_slurm_node + +def node_setup(func: typing.Callable): + """ + Decorator: Call this function on the node before running any function calls. + """ + if on_slurm_node(): + func() + else: + # check if the function has no arguments + sig = inspect.signature(func) + if sig.parameters: + msg = "The node setup function must not have any arguments." + raise ValueError(msg) + return func diff --git a/tests/test_create_command.py b/tests/test_create_command.py index 5a3028e..6004e98 100644 --- a/tests/test_create_command.py +++ b/tests/test_create_command.py @@ -1,9 +1,7 @@ -import shlex import unittest from pathlib import Path import slurminade -from slurminade.dispatcher import FunctionCall, create_slurminade_command test_file_path = Path("./f_test_file.txt") @@ -15,25 +13,6 @@ def f(s): class TestCreateCommand(unittest.TestCase): - def test_create_long_command(self): - slurminade.set_entry_point(__file__) - test_call = FunctionCall(f.func_id, ["." * 100], {}) - command = create_slurminade_command([test_call], 100) - args = shlex.split(command) - path = Path(args[-1]) - assert args[-2] == "temp" - # check creation of temporary file - assert Path(path).is_file() - if path.exists(): # delete the file - path.unlink() - - def test_create_short_command(self): - slurminade.set_entry_point(__file__) - test_call = FunctionCall(f.func_id, [""], {}) - command = create_slurminade_command([test_call], 100000) - args = shlex.split(command) - assert args[-2] == "arg" - def test_dispatch_with_temp_file(self): slurminade.set_entry_point(__file__) if test_file_path.exists(): diff --git a/tests/test_local_function.py b/tests/test_local_function.py new file mode 100644 index 0000000..53bede5 --- /dev/null +++ b/tests/test_local_function.py @@ -0,0 +1,13 @@ + +import slurminade +from pytest import raises + +def test_dispatch_limit_batch(): + @slurminade.slurmify() + def f(): + pass + + slurminade.set_entry_point(__file__) + + with raises(KeyError): + f.distribute() diff --git a/tests/test_node_setup.py b/tests/test_node_setup.py new file mode 100644 index 0000000..63f58c0 --- /dev/null +++ b/tests/test_node_setup.py @@ -0,0 +1,26 @@ +import slurminade +from pathlib import Path + +test_file_path = Path("./f_test_file.txt") + +@slurminade.node_setup +def f(): + with open(test_file_path, "w") as file: + file.write('node_setup') + +@slurminade.slurmify +def nil(): + pass + +def test_node_setup(): + slurminade.set_entry_point(__file__) + if test_file_path.exists(): + test_file_path.unlink() + dispatcher = slurminade.SubprocessDispatcher() + slurminade.set_dispatcher(dispatcher) + slurminade.set_dispatch_limit(100) + nil.distribute() + with open(test_file_path) as file: + assert file.readline() == 'node_setup' + if test_file_path.exists(): # delete the file + test_file_path.unlink() \ No newline at end of file