Skip to content

Commit

Permalink
Trying to get Check running.
Browse files Browse the repository at this point in the history
  • Loading branch information
d-krupke committed Feb 25, 2024
1 parent 75cc777 commit 90039f5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
17 changes: 6 additions & 11 deletions src/slurminade/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

import click

from .conf import update_default_configuration
from .dispatcher import set_dispatcher
from .function import slurmify
from .function_map import set_entry_point
from slurminade import slurmify


@slurmify()
Expand All @@ -29,7 +26,9 @@ def check_slurm(partition, constraint):
Check if the code is running on a slurm node.
"""
# enforce slurm
from .dispatcher import SlurmDispatcher
from slurminade.conf import update_default_configuration
from slurminade.dispatcher import SlurmDispatcher, set_dispatcher
from slurminade.function_map import set_entry_point

set_dispatcher(SlurmDispatcher())
print("Setting entry point to ", __file__)
Expand All @@ -47,9 +46,7 @@ def check_slurm(partition, constraint):
_write_to_file.distribute_and_wait(tmp_file_path, "test")
if not Path(tmp_file_path).exists():
msg = "Slurminade failed: The file was not written to the temporary directory."
raise Exception(
msg
)
raise Exception(msg)
with open(tmp_file_path) as file:
content = file.readlines()
print(
Expand All @@ -67,9 +64,7 @@ def check_slurm(partition, constraint):
time.sleep(1)
if not Path(tmp_file_path).exists():
msg = "Slurminade failed: The file was not written to the temporary directory."
raise Exception(
msg
)
raise Exception(msg)
with open(tmp_file_path) as file:
content = file.readlines()
print(
Expand Down
8 changes: 6 additions & 2 deletions src/slurminade/function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""

import inspect
import logging
import os
import pathlib
import typing
from pathlib import Path
import logging

from .execute_cmds import call_slurminade_to_get_function_ids

Expand Down Expand Up @@ -102,7 +102,11 @@ def check_id(func_id: str) -> bool:
if func_id in FunctionMap._ids:
return True
FunctionMap._ids = call_slurminade_to_get_function_ids(get_entry_point())
logging.getLogger("slurminade").info("Entry point '%s' has functions %s", get_entry_point(), list(FunctionMap._ids))
logging.getLogger("slurminade").info(
"Entry point '%s' has functions %s",
get_entry_point(),
list(FunctionMap._ids),
)
return func_id in FunctionMap._ids

@staticmethod
Expand Down

0 comments on commit 90039f5

Please sign in to comment.