Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASP-2256 Migrate the SBATCH param mapper to jobbergate-api #193

Merged
merged 11 commits into from
Dec 7, 2022
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
"""
Parser for Slurm REST API parameters from SBATCH parameters at the job script file.
"""
from argparse import ArgumentParser
from dataclasses import dataclass, field
from itertools import chain
from typing import Any, Dict, Iterator, List, Sequence, Union, cast

from bidict import bidict
from loguru import logger

from jobbergate_api.apps.job_scripts.job_script_files import JobScriptFiles
from jobbergate_api.apps.job_submissions.schemas import JobProperties

_IDENTIFICATION_FLAG = "#SBATCH"
_INLINE_COMMENT_MARK = "#"


def _flagged_line(line: str) -> bool:
"""
Identify if a provided line starts with the identification flag.
"""
return line.startswith(_IDENTIFICATION_FLAG)


def _clean_line(line: str) -> str:
"""
Clean the provided line.

It includes removing the identification flag at the beginning of the line,
then remove the inline comment mark and anything after it, and finally strip white spaces at both sides.
"""
return line.lstrip(_IDENTIFICATION_FLAG).split(_INLINE_COMMENT_MARK)[0].strip()


def _split_line(line: str) -> List[str]:
"""
Split the provided line at the equal and with spaces characters.

This procedure is important because it is the way argparse expects to receive the parameters.
"""
return line.replace("=", " ").split()


def _clean_jobscript(jobscript: str) -> Iterator[str]:
"""
Transform a job script string.

It is done by filtering only the lines that start with
the identification flag and mapping a cleaning procedure to them in order
to remove the identification flag, remove inline comments, and strip extra
white spaces. Finally, split each pair of parameter/value and chain them
in a single iterator.
"""
jobscript_filtered = filter(_flagged_line, jobscript.splitlines())
jobscript_cleaned = map(_clean_line, jobscript_filtered)
jobscript_splitted = map(_split_line, jobscript_cleaned)
return chain.from_iterable(jobscript_splitted)


@dataclass(frozen=True)
class SbatchToSlurm:
"""
Store the information for each parameter, including its name at Slurm API and SBATCH.

Besides that, any extra argument this parameter needs when added to
the parser. This information is used to build the jobscript/SBATCH parser
and the two-way mapping between Slurm API and SBATCH names.
"""

slurmrestd_var_name: str
sbatch: str
sbatch_short: str = ""
argparser_param: dict = field(default_factory=dict)


sbatch_to_slurm = [
SbatchToSlurm("account", "--account", "-A"),
SbatchToSlurm("account_gather_freqency", "--acctg-freq"),
SbatchToSlurm("array", "--array", "-a"),
SbatchToSlurm("batch_features", "--batch"),
SbatchToSlurm("burst_buffer", "--bb"),
SbatchToSlurm("", "--bbf"),
SbatchToSlurm("begin_time", "--begin", "-b"),
SbatchToSlurm("current_working_directory", "--chdir", "-D"),
SbatchToSlurm("cluster_constraints", "--cluster-constraint"),
SbatchToSlurm("", "--clusters", "-M"),
SbatchToSlurm("comment", "--comment"),
SbatchToSlurm("constraints", "--constraint", "-C"),
SbatchToSlurm("", "--container"),
SbatchToSlurm("", "--contiguous", "", dict(action="store_const", const=True)),
SbatchToSlurm("core_specification", "--core-spec", "-S", dict(type=int)),
SbatchToSlurm("cores_per_socket", "--cores-per-socket", "", dict(type=int)),
SbatchToSlurm("cpu_binding", "--cpu-bind"),
SbatchToSlurm("cpu_frequency", "--cpu-freq"),
SbatchToSlurm("cpus_per_gpu", "--cpus-per-gpu"),
SbatchToSlurm("cpus_per_task", "--cpus-per-task", "-c", dict(type=int)),
SbatchToSlurm("deadline", "--deadline"),
SbatchToSlurm("delay_boot", "--delay-boot", "", dict(type=int)),
SbatchToSlurm("dependency", "--dependency", "-d"),
SbatchToSlurm("distribution", "--distribution", "-m"),
SbatchToSlurm("standard_error", "--error", "-e"),
SbatchToSlurm("", "--exclude", "-x"),
SbatchToSlurm(
"exclusive",
"--exclusive",
"",
dict(
type=str,
choices={"user", "mcs", "exclusive", "oversubscribe"},
nargs="?",
const="exclusive",
),
),
SbatchToSlurm("", "--export"),
SbatchToSlurm("", "--export-file"),
SbatchToSlurm("", "--extra-node-info", "-B"),
SbatchToSlurm("get_user_environment", "--get-user-env", "", dict(type=int)),
SbatchToSlurm("", "--gid"),
SbatchToSlurm("gpu_binding", "--gpu-bind"),
SbatchToSlurm("gpu_frequency", "--gpu-freq"),
SbatchToSlurm("gpus", "--gpus", "-G"),
SbatchToSlurm("gpus_per_node", "--gpus-per-node"),
SbatchToSlurm("gpus_per_socket", "--gpus-per-socket"),
SbatchToSlurm("gpus_per_task", "--gpus-per-task"),
SbatchToSlurm("gres", "--gres"),
SbatchToSlurm("gres_flags", "--gres-flags"),
SbatchToSlurm("", "--hint"),
SbatchToSlurm("hold", "--hold", "-H", dict(action="store_const", const=True)),
SbatchToSlurm("", "--ignore-pbs", "", dict(action="store_const", const=True)),
SbatchToSlurm("standard_input", "--input", "-i"),
SbatchToSlurm("name", "--job-name", "-J"),
# kill_on_invalid_dependency is an invalid key for Slurm API according to our tests
# SbatchToSlurm(
# "kill_on_invalid_dependency", "--kill-on-invalid-dep", "", dict(type=int)
# ),
SbatchToSlurm("licenses", "--licenses", "-L"),
SbatchToSlurm("mail_type", "--mail-type"),
SbatchToSlurm("mail_user", "--mail-user"),
SbatchToSlurm("mcs_label", "--mcs-label"),
SbatchToSlurm("memory_per_node", "--mem"),
SbatchToSlurm("memory_binding", "--mem-bind"),
SbatchToSlurm("memory_per_cpu", "--mem-per-cpu"),
SbatchToSlurm("memory_per_gpu", "--mem-per-gpu"),
SbatchToSlurm("minimum_cpus_per_node", "--mincpus", "", dict(type=int)),
SbatchToSlurm("", "--network"),
SbatchToSlurm("nice", "--nice"),
SbatchToSlurm("no_kill", "--no-kill", "-k", dict(action="store_const", const=True)),
SbatchToSlurm("", "--no-requeue", "", dict(action="store_false", dest="requeue")),
SbatchToSlurm("", "--nodefile", "-F"),
SbatchToSlurm("", "--nodelist", "-w"),
SbatchToSlurm("nodes", "--nodes", "-N"),
SbatchToSlurm("tasks", "--ntasks", "-n", dict(type=int)),
SbatchToSlurm("tasks_per_core", "--ntasks-per-core", "", dict(type=int)),
SbatchToSlurm("", "--ntasks-per-gpu"),
SbatchToSlurm("tasks_per_node", "--ntasks-per-node", "", dict(type=int)),
SbatchToSlurm("tasks_per_socket", "--ntasks-per-socket", "", dict(type=int)),
SbatchToSlurm("open_mode", "--open-mode"),
SbatchToSlurm("standard_output", "--output", "-o"),
SbatchToSlurm("", "--overcommit", "-O", dict(action="store_const", const=True)),
SbatchToSlurm(
"",
"--oversubscribe",
"-s",
dict(action="store_const", const="oversubscribe", dest="exclusive"),
),
SbatchToSlurm("", "--parsable", "", dict(action="store_const", const=True)),
SbatchToSlurm("partition", "--partition", "-p"),
SbatchToSlurm("", "--power"),
SbatchToSlurm("priority", "--priority"),
SbatchToSlurm("", "--profile"),
SbatchToSlurm("", "--propagate"),
SbatchToSlurm("qos", "--qos", "-q"),
SbatchToSlurm("", "--quiet", "-Q", dict(action="store_const", const=True)),
SbatchToSlurm("", "--reboot", "", dict(action="store_const", const=True)),
SbatchToSlurm("requeue", "--requeue", "", dict(action="store_const", const=True)),
SbatchToSlurm("reservation", "--reservation"),
SbatchToSlurm("signal", "--signal"),
SbatchToSlurm("sockets_per_node", "--sockets-per-node", "", dict(type=int)),
SbatchToSlurm("spread_job", "--spread-job", "", dict(action="store_const", const=True)),
SbatchToSlurm("", "--switches"),
SbatchToSlurm("", "--test-only", "", dict(action="store_const", const=True)),
SbatchToSlurm("thread_specification", "--thread-spec", "", dict(type=int)),
SbatchToSlurm("threads_per_core", "--threads-per-core", "", dict(type=int)),
SbatchToSlurm("time_limit", "--time", "-t"),
SbatchToSlurm("time_minimum", "--time-min"),
SbatchToSlurm("", "--tmp"),
SbatchToSlurm("", "--uid"),
SbatchToSlurm("", "--usage", "", dict(action="store_const", const=True)),
SbatchToSlurm("minimum_nodes", "--use-min-nodes", "", dict(action="store_const", const=True)),
SbatchToSlurm("", "--verbose", "-v", dict(action="store_const", const=True)),
SbatchToSlurm("", "--version", "-V", dict(action="store_const", const=True)),
SbatchToSlurm("", "--wait", "-W", dict(action="store_const", const=True)),
SbatchToSlurm("wait_all_nodes", "--wait-all-nodes", "", dict(type=int)),
SbatchToSlurm("wckey", "--wckey"),
SbatchToSlurm("", "--wrap"),
]


class ArgumentParserCustomExit(ArgumentParser):
"""
Custom implementation of the built-in class for argument parsing.

The sys.exit triggered by the original code is replaced by a ValueError,
besides some friendly logging messages.
"""

def exit(self, status=0, message=None):
"""
Raise ValueError when parsing invalid parameters or if the type of their values is not correct.
"""
log_message = f"Argparse exit status {status}: {message}"
if status:
logger.error(log_message)
else:
logger.info(log_message)
raise ValueError(message)


def build_parser() -> ArgumentParser:
"""
Build an ArgumentParser to handle all SBATCH parameters declared at sbatch_to_slurm.
"""
parser = ArgumentParserCustomExit()
for item in sbatch_to_slurm:
args = (i for i in (item.sbatch_short, item.sbatch) if i)
parser.add_argument(*args, **item.argparser_param)
# make --requeue and --no-requeue work together, with default to None
parser.set_defaults(requeue=None)
return parser


def build_mapping_sbatch_to_slurm() -> bidict:
"""
Create a mapper to translate in both ways between the names expected by Slurm REST API and SBATCH.
"""
mapping: bidict = bidict()

for item in sbatch_to_slurm:
if item.slurmrestd_var_name:
sbatch_name = item.sbatch.lstrip("-").replace("-", "_")
mapping[sbatch_name] = item.slurmrestd_var_name

return mapping


def jobscript_to_dict(jobscript: str) -> Dict[str, Union[str, bool]]:
"""
Extract the SBATCH params from a given job script.

It returns them in a dictionary for mapping the parameter names to their values.

Raise ValueError if any of the parameters are unknown to the parser.
"""
parsed_args, unknown_arg = parser.parse_known_args(
cast(Sequence[str], _clean_jobscript(jobscript)),
)

if unknown_arg:
raise ValueError("Unrecognized SBATCH arguments: {}".format(" ".join(unknown_arg)))

sbatch_params = {key: value for key, value in vars(parsed_args).items() if value is not None}

logger.debug(f"SBATCH params parsed from job script: {sbatch_params}")

return sbatch_params


def convert_sbatch_to_slurm_api(input: Dict[str, Any]) -> Dict[str, Any]:
"""
Take a dictionary containing key-value pairing of SBATCH parameter name space to Slurm API namespace.

Notice the values should not be affected.

Raise KeyError if any of the keys are unknown to the mapper.
"""
mapped = {}
unknown_keys = []

for sbatch_name, value in input.items():
try:
slurm_name = mapping_sbatch_to_slurm[sbatch_name]
except KeyError:
unknown_keys.append(sbatch_name)
else:
mapped[slurm_name] = value

if unknown_keys:
error_message = "Impossible to convert from SBATCH to Slurm REST API: {}"
raise KeyError(error_message.format(", ".join(unknown_keys)))

logger.debug(f"Slurm API params mapped from SBATCH: {mapped}")

return mapped


def get_job_parameters(jobscript: str) -> Dict[str, Any]:
"""
Parse all SBATCH parameters from a job script, map their names to Slurm API parameters.

They are returned as a key-value pairing dictionary.
"""
return convert_sbatch_to_slurm_api(jobscript_to_dict(jobscript))


def get_job_properties_from_job_script(job_script_id: int, **kwargs) -> JobProperties:
"""
Get the job properties for Slurm REST API from a job script file, given its id.

Extra keyword arguments can be used to overwrite any parameter from the
job script, like name or current_working_directory.
"""
job_script_files = JobScriptFiles.get_from_s3(job_script_id)
slurm_parameters = get_job_parameters(job_script_files.main_file)
merged_parameters = {**slurm_parameters, **kwargs}
return JobProperties.parse_obj(merged_parameters)


parser = build_parser()
mapping_sbatch_to_slurm = build_mapping_sbatch_to_slurm()
21 changes: 16 additions & 5 deletions jobbergate-api/jobbergate_api/apps/job_submissions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
searchable_fields,
sortable_fields,
)
from jobbergate_api.apps.job_submissions.properties_parser import get_job_properties_from_job_script
from jobbergate_api.apps.job_submissions.schemas import (
ActiveJobSubmission,
JobSubmissionCreateRequest,
Expand Down Expand Up @@ -66,17 +67,17 @@ async def job_submission_create(
detail=message,
)

create_dict = dict(
new_job_submission_data: Dict[str, Any] = dict(
**job_submission.dict(exclude_unset=True),
job_submission_owner_email=identity_claims.email,
status=JobSubmissionStatus.CREATED,
)
if job_submission.client_id is None:
create_dict.update(client_id=client_id)
new_job_submission_data.update(client_id=client_id)

exec_dir = create_dict.pop("execution_directory", None)
exec_dir = new_job_submission_data.pop("execution_directory", None)
if exec_dir is not None:
create_dict.update(execution_directory=str(exec_dir))
new_job_submission_data.update(execution_directory=str(exec_dir))

select_query = job_scripts_table.select().where(job_scripts_table.c.id == job_submission.job_script_id)
logger.trace(f"job_scripts select_query = {render_sql(select_query)}")
Expand All @@ -90,13 +91,23 @@ async def job_submission_create(
detail=message,
)

try:
job_properties = get_job_properties_from_job_script(
job_submission.job_script_id, **job_submission.execution_parameters.dict(exclude_unset=True)
)
new_job_submission_data["execution_parameters"] = job_properties.dict(exclude_unset=True)
except Exception as e:
message = f"Error extracting execution parameters from job script: {str(e)}"
logger.error(message)
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=message)

logger.debug("Inserting job-submission")

async with database.transaction():
try:
insert_query = job_submissions_table.insert().returning(job_submissions_table)
logger.trace(f"job_submissions insert_query = {render_sql(insert_query)}")
job_submission_data = await database.fetch_one(query=insert_query, values=create_dict)
job_submission_data = await database.fetch_one(query=insert_query, values=new_job_submission_data)

except INTEGRITY_CHECK_EXCEPTIONS as e:
logger.error(f"Reverting database transaction: {str(e)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class JobSubmissionCreateRequest(BaseModel):
job_script_id: int
execution_directory: Optional[Path]
client_id: Optional[str]
execution_parameters: Optional[JobProperties]
execution_parameters: JobProperties = Field(default_factory=dict)

class Config:
schema_extra = job_submission_meta_mapper
Expand Down
Loading