Skip to content

Commit

Permalink
CHORES: Lots of cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
d-krupke committed Aug 3, 2024
1 parent 50c506e commit 1bfe812
Show file tree
Hide file tree
Showing 14 changed files with 148 additions and 149 deletions.
9 changes: 6 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ repos:
args: [--prose-wrap=always]
exclude: "^tests"

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.269
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.1
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
types_or: [python, pyi, jupyter]
#args: [--fix,--ignore, T201, --ignore, E402]
- id: ruff-format
types_or: [python, pyi, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
Expand Down
28 changes: 14 additions & 14 deletions src/slurminade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,25 @@ def clean_up():
"""

# flake8: noqa F401
from .function import slurmify, shell
from .conf import update_default_configuration, set_default_configuration
from .guard import (
set_dispatch_limit,
allow_recursive_distribution,
disable_warning_on_repeated_flushes,
)
from .bundling import JobBundling, Batch
from .bundling import Batch, JobBundling
from .conf import set_default_configuration, update_default_configuration
from .dispatcher import (
srun,
sbatch,
join,
SlurmDispatcher,
set_dispatcher,
get_dispatcher,
TestDispatcher,
SubprocessDispatcher,
TestDispatcher,
get_dispatcher,
join,
sbatch,
set_dispatcher,
srun,
)
from .function import shell, slurmify
from .function_map import set_entry_point
from .guard import (
allow_recursive_distribution,
disable_warning_on_repeated_flushes,
set_dispatch_limit,
)
from .node_setup import node_setup

__all__ = [
Expand Down
1 change: 1 addition & 0 deletions src/slurminade/bundling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Contains code for bundling function calls together.
"""

import logging
import typing
from collections import defaultdict
Expand Down
8 changes: 4 additions & 4 deletions src/slurminade/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
This allows to change the behaviour of the distribution, e.g., we use it for batch:
Batch simply wraps the dispatcher by a buffered version.
"""

import abc
import logging
import os
Expand All @@ -21,13 +22,12 @@
from .function_call import FunctionCall
from .function_map import FunctionMap, get_entry_point
from .guard import dispatch_guard
from .job_reference import JobReference
from .options import SlurmOptions

# MAX_ARG_STRLEN on a Linux system with PAGE_SIZE 4096 is 131072
DEFAULT_MAX_ARG_LENGTH = 100000

from .job_reference import JobReference


class Dispatcher(abc.ABC):
"""
Expand Down Expand Up @@ -443,7 +443,7 @@ def get_dispatcher() -> Dispatcher:
to allow compatibility.
:return: The dispatcher.
"""
global __dispatcher
global __dispatcher # noqa: PLW0603
if __dispatcher is None:
try:
__dispatcher = SlurmDispatcher()
Expand All @@ -460,7 +460,7 @@ def set_dispatcher(dispatcher: Dispatcher) -> None:
:param dispatcher: The dispatcher to be used.
:return: None
"""
global __dispatcher
global __dispatcher # noqa: PLW0603
__dispatcher = dispatcher
assert dispatcher == get_dispatcher()

Expand Down
9 changes: 4 additions & 5 deletions src/slurminade/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,12 @@ def __call__(self, *args, **kwargs):
"""
if self.call_policy == CallPolicy.LOCALLY:
return self.run_locally(*args, **kwargs)
elif self.call_policy == CallPolicy.DISTRIBUTED:
if self.call_policy == CallPolicy.DISTRIBUTED:
return self.distribute(*args, **kwargs)
elif self.call_policy == CallPolicy.DISTRIBUTED_BLOCKING:
if self.call_policy == CallPolicy.DISTRIBUTED_BLOCKING:
return self.distribute_and_wait(*args, **kwargs)
else:
msg = "Unknown call policy."
raise RuntimeError(msg)
msg = "Unknown call policy."
raise RuntimeError(msg)

def get_entry_point(self) -> Path:
"""
Expand Down
9 changes: 5 additions & 4 deletions src/slurminade/function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import logging
import pathlib
import typing
from typing import Optional
from pathlib import Path
from typing import Optional

from .execute_cmds import call_slurminade_to_get_function_ids

Expand All @@ -24,8 +24,8 @@ class FunctionMap:
# The slurm node just executes the file content of a script, so the file name is lost.
# slurminade will set this value in the beginning to reconstruct it.
entry_point: typing.Optional[str] = None
_data = {}
_ids: Optional[set] = set()
_data: typing.ClassVar[dict] = {}
_ids: typing.ClassVar[Optional[set]] = set()

@staticmethod
def get_id(func: typing.Callable) -> str:
Expand Down Expand Up @@ -113,7 +113,8 @@ def check_id(func_id: str, entry_point: Path) -> bool:
FunctionMap._ids = call_slurminade_to_get_function_ids(entry_point)
except Exception as e:
logging.getLogger("slurminade").warning(
"Cannot verify function ids before submitting to slurm: %s. This is not critical, things will just be more difficult to debug in case you make an error.", e
"Cannot verify function ids before submitting to slurm: %s. This is not critical, things will just be more difficult to debug in case you make an error.",
e,
)
FunctionMap._ids = None
return True
Expand Down
1 change: 1 addition & 0 deletions src/slurminade/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
``allow_recursive_distribution``, ``set_dispatch_limit(None)``, and
``disable_warning_for_multiple_flushes``.
"""

import logging
import typing

Expand Down
31 changes: 14 additions & 17 deletions tests/test_create_command.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest
from pathlib import Path

import slurminade
Expand All @@ -8,22 +7,20 @@

@slurminade.slurmify()
def f(s):
with open(test_file_path, "w") as file:
with test_file_path.open("w") as file:
file.write(s)


class TestCreateCommand(unittest.TestCase):
def test_dispatch_with_temp_file(self):
slurminade.set_entry_point(__file__)
if test_file_path.exists():
test_file_path.unlink()
dispatcher = slurminade.SubprocessDispatcher()
dispatcher.max_arg_length = 1
slurminade.set_dispatcher(dispatcher)
s = "test"
f.distribute(s)
assert test_file_path.is_file()
with open(test_file_path) as file:
assert file.readline() == s
if test_file_path.exists(): # delete the file
test_file_path.unlink()
def test_create_command():
slurminade.set_entry_point(__file__)
if test_file_path.exists():
test_file_path.unlink()
dispatcher = slurminade.SubprocessDispatcher()
dispatcher.max_arg_length = 1
slurminade.set_dispatcher(dispatcher)
s = "test"
f.distribute(s)
assert test_file_path.is_file()
with test_file_path.open() as file:
assert file.readline() == s
test_file_path.unlink(missing_ok=True)
32 changes: 15 additions & 17 deletions tests/test_create_command_with_noise.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
import unittest
from pathlib import Path

import slurminade

test_file_path = Path("./f_test_file.txt")
print("NOISE123")


@slurminade.slurmify()
def f(s):
with open(test_file_path, "w") as file:
with test_file_path.open("w") as file:
file.write(s)


class TestCreateCommand(unittest.TestCase):
def test_dispatch_with_temp_file(self):
slurminade.set_entry_point(__file__)
if test_file_path.exists():
test_file_path.unlink()
dispatcher = slurminade.SubprocessDispatcher()
dispatcher.max_arg_length = 1
slurminade.set_dispatcher(dispatcher)
s = "test"
f.distribute(s)
assert test_file_path.is_file()
with open(test_file_path) as file:
assert file.readline() == s
if test_file_path.exists(): # delete the file
test_file_path.unlink()
def test_create_command_with_noise():
slurminade.set_entry_point(__file__)
if test_file_path.exists():
test_file_path.unlink()
dispatcher = slurminade.SubprocessDispatcher()
dispatcher.max_arg_length = 1
slurminade.set_dispatcher(dispatcher)
s = "test"
f.distribute(s)
assert test_file_path.is_file()
with test_file_path.open() as file:
assert file.readline() == s
test_file_path.unlink(missing_ok=True)
46 changes: 25 additions & 21 deletions tests/test_dispatch_guard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import unittest
import pytest

import slurminade
from slurminade.guard import TooManyDispatchesError, _DispatchGuard, set_dispatch_limit
Expand All @@ -9,27 +9,31 @@ def f():
pass


class TestDispatchGuard(unittest.TestCase):
def test_simple(self):
slurminade.set_entry_point(__file__)
dg = _DispatchGuard(3)
def test_dispatch_guard_simple():
slurminade.set_entry_point(__file__)
dg = _DispatchGuard(3)
dg()
dg()
dg()
with pytest.raises(TooManyDispatchesError):
dg()
dg()
dg()
self.assertRaises(TooManyDispatchesError, dg)

def test_dispatch_limit(self):
slurminade.set_entry_point(__file__)
set_dispatch_limit(3)
f.distribute()

def test_dispatch_guard_dispatch_limit():
slurminade.set_entry_point(__file__)
set_dispatch_limit(3)
f.distribute()
f.distribute()
f.distribute()
with pytest.raises(TooManyDispatchesError):
f.distribute()


def test_dispatch_guard_dispatch_limit_batch():
slurminade.set_entry_point(__file__)
set_dispatch_limit(2)
with slurminade.JobBundling(max_size=2):
for _ in range(4):
f.distribute()
with pytest.raises(TooManyDispatchesError):
f.distribute()
self.assertRaises(TooManyDispatchesError, f.distribute)

def test_dispatch_limit_batch(self):
slurminade.set_entry_point(__file__)
set_dispatch_limit(2)
with slurminade.JobBundling(max_size=2):
for _ in range(4):
f.distribute()
self.assertRaises(TooManyDispatchesError, f.distribute)
46 changes: 22 additions & 24 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os.path
import unittest
from pathlib import Path

import slurminade
from slurminade.function import SlurmFunction
Expand All @@ -10,37 +9,36 @@

@slurminade.slurmify()
def f():
with open(f_file, "w") as file:
with Path(f_file).open("w") as file:
file.write("test")


def delete_f():
if os.path.exists(f_file):
os.remove(f_file)
Path(f_file).unlink(missing_ok=True)


@slurminade.slurmify()
def g(x, y):
with open(g_file, "w") as file:
with Path(g_file).open("w") as file:
file.write(f"{x}:{y}")


def delete_g():
if os.path.exists(g_file):
os.remove(g_file)


class TestLocal(unittest.TestCase):
def test_1(self):
delete_f()
SlurmFunction.call(f.func_id)
assert os.path.exists(f_file)
delete_f()

def test_2(self):
delete_g()
SlurmFunction.call(g.func_id, x="a", y=2)
assert os.path.exists(g_file)
with open(g_file) as file:
assert file.readline() == "a:2"
delete_g()
Path(g_file).unlink(missing_ok=True)


def test_local_1():
delete_f()
SlurmFunction.call(f.func_id)
assert Path(f_file).exists()
delete_f()


def test_local_2():
delete_g()
SlurmFunction.call(g.func_id, x="a", y=2)

assert Path(g_file).exists()
with Path(g_file).open() as file:
assert file.readline() == "a:2"
delete_g()
4 changes: 2 additions & 2 deletions tests/test_local_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pytest import raises
import pytest

import slurminade

Expand All @@ -10,5 +10,5 @@ def f():

slurminade.set_entry_point(__file__)

with raises(KeyError):
with pytest.raises(KeyError):
f.distribute()
Loading

0 comments on commit 1bfe812

Please sign in to comment.