diff --git a/plantit/plantit/celery_tasks.py b/plantit/plantit/celery_tasks.py index 430d5a94..72c1021d 100644 --- a/plantit/plantit/celery_tasks.py +++ b/plantit/plantit/celery_tasks.py @@ -59,7 +59,7 @@ def create_and_submit_delayed(username, workflow, delayed_id: str = None): # submit task chain (prepare_task_environment.s(task.guid) | \ submit_task.s() | \ - poll_task_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1) + poll_job_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1) @app.task(track_started=True) @@ -78,7 +78,7 @@ def create_and_submit_repeating(username, workflow, repeating_id: str = None): # submit task chain (prepare_task_environment.s(task.guid) | \ submit_task.s() | \ - poll_task_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1) + poll_job_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1) @app.task(track_started=True, bind=True) @@ -149,7 +149,7 @@ def submit_task(self, guid: str): @app.task(bind=True) -def poll_task_status(self, guid: str): +def poll_job_status(self, guid: str): try: task = Task.objects.get(guid=guid) except: @@ -159,29 +159,86 @@ def poll_task_status(self, guid: str): refresh_delay = int(environ.get('TASKS_REFRESH_SECONDS')) cleanup_delay = int(environ.get('TASKS_CLEANUP_MINUTES')) * 60 - logger.info(f"Checking {task.agent.name} scheduler status for run {guid} (SLURM job {task.job_id})") + logger.info(f"Checking {task.agent.name} scheduler status for task {guid} (job {task.job_id})") - # if the job already failed, schedule cleanup + # if the scheduler job failed... if task.job_status == 'FAILURE': + # mark the task as a failure task.status = TaskStatus.FAILURE - final_message = f"Job {task.job_id} failed" - log_task_orchestrator_status(task, [final_message]) + task.save() + + # log the status update and push it to clients + message = f"Job {task.job_id} failed" + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) + + # schedule cleanup cleanup_task.s(guid).apply_async(countdown=cleanup_delay, priority=2) task.cleanup_time = timezone.now() + timedelta(seconds=cleanup_delay) task.save() # push AWS SNS notification if task.user.profile.push_notification_status == 'enabled': - SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {}) + SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {}) return guid # otherwise poll the scheduler for job status and walltime and update the task try: + # get the job status from the scheduler check_job_logs_for_progress(task) job_status = get_job_status(task) job_walltime = get_job_walltime(task) + + # get_job_status() returns None if the job isn't found in the agent's scheduler. + # there are 2 reasons this might happen: + # - the job was just submitted and hasn't been picked up for reporting by the scheduler yet + # - the job already completed and we waited too long between polls to check its status + # + # in both cases we return early + if job_status is None: + # we might have just submitted the job; scheduler may take a moment to reflect new submissions + if not (task.job_status == 'COMPLETED' or task.job_status == 'COMPLETING'): + # update the task and persist it + now = timezone.now() + task.updated = now + task.job_status = job_status + task.job_consumed_walltime = job_walltime + task.save() + retry_seconds = 10 + + # log the status update and push it to clients + message = f"Job {task.job_id} not found, retrying in {retry_seconds} seconds" + log_task_orchestrator_status(task, [message]) + async_to_sync(push_task_channel_event)(task) + + # wait a few seconds and poll again + poll_job_status.s(guid).apply_async(countdown=retry_seconds) + return + # otherwise the job completed and the scheduler's forgotten about it in the interval between polls + else: + # update the task and persist it + now = timezone.now() + task.updated = now + task.job_status = 'COMPLETED' + task.job_consumed_walltime = job_walltime + task.save() + + # check that we have the results we expect + list_task_results.s(guid).apply_async() + + # log the status update and push it to clients + message = f"Job {task.job_id} completed with unknown status" + (f" after {job_walltime}" if job_walltime is not None else '') + log_task_orchestrator_status(task, [message]) + async_to_sync(push_task_channel_event)(task) + + # push AWS SNS notification + if task.user.profile.push_notification_status == 'enabled': + SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {}) + + return guid + + # update the task and persist it task.job_status = job_status task.job_consumed_walltime = job_walltime now = timezone.now() @@ -192,71 +249,58 @@ def poll_task_status(self, guid: str): ssh = get_task_ssh_client(task) get_task_remote_logs(task, ssh) - # mark task status according to job status - if job_status == 'COMPLETED': - task.completed = now - task.status = TaskStatus.SUCCESS - elif job_status == 'FAILED': - task.completed = now + # if job did not complete, go ahead and mark the task failed/cancelled/timed out/etc + job_complete = False + if job_status == 'FAILED': task.status = TaskStatus.FAILURE + job_complete = True elif job_status == 'CANCELLED': - task.completed = now task.status = TaskStatus.CANCELED + job_complete = True elif job_status == 'TIMEOUT': - task.completed = now task.status = TaskStatus.TIMEOUT + job_complete = True + # but if it succeeded, we still need to check results before determining success/failure + elif job_status == 'COMPLETED': + job_complete = True + + # update the task and persist it + now = timezone.now() + task.updated = now task.save() - if task.is_complete: - # job is done, task is complete, now we can list results + if job_complete: + # check that we have the results we expect list_task_results.s(guid).apply_async() - final_message = f"Job {task.job_id} {job_status}" + (f" after {job_walltime}" if job_walltime is not None else '') - log_task_orchestrator_status(task, [final_message]) + + # log the status update and push it to clients + message = f"Job {task.job_id} completed with status {job_status}" + (f" after {job_walltime}" if job_walltime is not None else '') + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) # push AWS SNS notification if task.user.profile.push_notification_status == 'enabled': - SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {}) + SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {}) return guid else: - # if task is past its due time, cancel it + # if past due time... if now > task.due_time: - log_task_orchestrator_status(task, [f"Job {task.job_id} {job_status} (walltime {job_walltime}) is past its due time {str(task.due_time)}"]) + # log the status update and push it to clients + message = f"Job {task.job_id} {job_status} (walltime {job_walltime}) is past its due time {str(task.due_time)}" + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) + + # cancel the task cancel_task(task) - # otherwise schedule another round of polling else: - log_task_orchestrator_status(task, [f"Job {task.job_id} {job_status} (walltime {job_walltime})"]) + # log the status update and push it to clients + message = f"Job {task.job_id} {job_status} (walltime {job_walltime})" + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) - poll_task_status.s(guid).apply_async(countdown=refresh_delay) - except StopIteration as e: - if not (task.job_status == 'COMPLETED' or task.job_status == 'COMPLETING'): - # we probably just created the task and - # it's not visible in the scheduler yet - # just wait a few seconds and try again - now = timezone.now() - task.updated = now - task.save() - retry_seconds = 10 - log_task_orchestrator_status(task, [f"Job {task.job_id} not found, retrying in {retry_seconds} seconds"]) - async_to_sync(push_task_channel_event)(task) - poll_task_status.s(guid).apply_async(countdown=retry_seconds) - return - else: - # job is done, task is complete, now we can list results - final_message = f"Job {task.job_id} succeeded" - log_task_orchestrator_status(task, [final_message]) - async_to_sync(push_task_channel_event)(task) - cleanup_task.s(guid).apply_async(countdown=cleanup_delay, priority=2) - task.cleanup_time = timezone.now() + timedelta(seconds=cleanup_delay) - task.save() - - # push AWS SNS notification - if task.user.profile.push_notification_status == 'enabled': - SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {}) - return guid + # otherwise schedule another round of polling + poll_job_status.s(guid).apply_async(countdown=refresh_delay) except: # mark the task failed task.status = TaskStatus.FAILURE @@ -266,8 +310,8 @@ def poll_task_status(self, guid: str): task.save() # there was an unexpected runtime exception somewhere, need to catch and log it - final_message = f"Job {task.job_id} encountered unexpected error: {traceback.format_exc()}" - log_task_orchestrator_status(task, [final_message]) + message = f"Job {task.job_id} encountered unexpected error: {traceback.format_exc()}" + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) cleanup_task.s(guid).apply_async(countdown=cleanup_delay, priority=2) task.cleanup_time = timezone.now() + timedelta(seconds=cleanup_delay) @@ -275,7 +319,7 @@ def poll_task_status(self, guid: str): # push AWS SNS notification if task.user.profile.push_notification_status == 'enabled': - SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {}) + SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {}) # stop the task chain self.request.callbacks = None @@ -293,27 +337,59 @@ def list_task_results(self, guid: str): redis = RedisClient.get() ssh = get_task_ssh_client(task) - log_task_orchestrator_status(task, [f"Retrieving logs"]) + # log status update and push it to clients + message = f"Retrieving logs" + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) + + # get logs from agent filesystem get_task_remote_logs(task, ssh) - log_task_orchestrator_status(task, [f"Retrieving results"]) + # log status update and push it to clients + message = f"Retrieving results" + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) - expected = list_result_files(task) - found = [e for e in expected if e['exists']] + + # get results from agent filesystem, then save them to cache and update the task + results = list_result_files(task) + found = [r for r in results if r['exists']] + missing = [r for r in results if not r['exists']] redis.set(f"results/{task.guid}", json.dumps(found)) task.results_retrieved = True task.save() - log_task_orchestrator_status(task, [f"Expected {len(expected)} result(s), found {len(found)}, verifying data was transferred to CyVerse"]) + # make sure we got the results we expected + if len(missing) > 0: + # mark the task failed + task.status = TaskStatus.FAILURE + now = timezone.now() + task.updated = now + task.completed = now + task.save() + + # log status update and push it to clients + message = f"Found {len(found)} results, missing {len(missing)}: {', '.join([m['name'] for m in missing])}" + log_task_orchestrator_status(task, [message]) + async_to_sync(push_task_channel_event)(task) + else: + # log status update and push it to clients + message = f"Found {len(found)} results" + log_task_orchestrator_status(task, [message]) + async_to_sync(push_task_channel_event)(task) + + # log status update and push it to clients + message = f"Verifying data was transferred to CyVerse" + log_task_orchestrator_status(task, [message]) async_to_sync(push_task_channel_event)(task) + + # make sure the results and logs were pushed to their destination in CyVerse check_task_cyverse_transfer.s(guid).apply_async() return guid @app.task(bind=True) -def check_task_cyverse_transfer(self, guid: str, iteration: int = 0): +def check_task_cyverse_transfer(self, guid: str, attempts: int = 0): try: task = Task.objects.get(guid=guid) except: @@ -321,25 +397,52 @@ def check_task_cyverse_transfer(self, guid: str, iteration: int = 0): self.request.callbacks = None # stop the task chain return + # check the expected filenames against the contents of the CyVerse collection path = task.workflow['output']['to'] actual = [file.rpartition('/')[2] for file in terrain.list_dir(path, task.user.profile.cyverse_access_token)] expected = [file['name'] for file in json.loads(RedisClient.get().get(f"results/{task.guid}"))] if not set(expected).issubset(set(actual)): - logger.warning(f"Expected {len(expected)} results but found {len(actual)}") - if iteration < 5: - logger.warning(f"Checking again in 30 seconds (iteration {iteration})") - check_task_cyverse_transfer.s(guid, iteration + 1).apply_async(countdown=30) + logger.warning(f"Expected {len(expected)} uploads to CyVerse but found {len(actual)}") + + # TODO make this configurable + max_attempts = 10 + countdown = 30 + + if attempts < max_attempts: + message = f"Transfer to CyVerse directory {path} incomplete, checking again in {countdown} seconds (attempt {attempts})" + logger.warning(message) + check_task_cyverse_transfer.s(guid, attempts + 1).apply_async(countdown=countdown) + else: + message = f"Transfer to CyVerse directory {path} failed to complete after {attempts * countdown} seconds" + logger.info(message) + + # mark the task failed + now = timezone.now() + task.updated = now + task.completed = now + task.status = TaskStatus.FAILURE + task.transferred = True + task.results_transferred = len(expected) + task.transfer_path = path + task.save() else: - msg = f"Transfer to CyVerse directory {path} completed" - logger.info(msg) + message = f"Transfer to CyVerse directory {path} completed" + logger.info(message) + + # mark the task succeeded + now = timezone.now() + task.updated = now + task.completed = now + task.status = TaskStatus.SUCCESS task.transferred = True task.results_transferred = len(expected) task.transfer_path = path task.save() - log_task_orchestrator_status(task, [msg]) - async_to_sync(push_task_channel_event)(task) + # log status update and push it to clients + log_task_orchestrator_status(task, [message]) + async_to_sync(push_task_channel_event)(task) cleanup_delay = int(environ.get('TASKS_CLEANUP_MINUTES')) * 60 cleanup_task.s(guid).apply_async(priority=2, countdown=cleanup_delay) diff --git a/plantit/plantit/task_lifecycle.py b/plantit/plantit/task_lifecycle.py index e1a0d187..593aecee 100644 --- a/plantit/plantit/task_lifecycle.py +++ b/plantit/plantit/task_lifecycle.py @@ -6,7 +6,7 @@ from os import environ from os.path import join, isdir from pathlib import Path -from typing import List +from typing import List, Tuple import binascii import json @@ -361,7 +361,7 @@ def get_job_walltime(task: Task) -> (str, str): return None -def get_job_status(task: Task) -> str: +def get_job_status(task: Task): ssh = get_task_ssh_client(task) with ssh: lines = execute_command( @@ -371,8 +371,12 @@ def get_job_status(task: Task) -> str: directory=join(task.agent.workdir, task.workdir), allow_stderr=True) - line = next(l for l in lines if task.job_id in l) - status = line.split()[5].replace('+', '') + try: + line = next(l for l in lines if task.job_id in l) + return line.split()[5].replace('+', '') + except StopIteration: + # if we don't receive any lines of output from `sacct -j `, the job wasn't found + pass # check the scheduler log file in case `sacct` is no longer displaying info # about this job so we don't miss a cancellation/timeout/failure/completion @@ -380,23 +384,28 @@ def get_job_status(task: Task) -> str: log_file_path = get_task_scheduler_log_file_path(task) stdin, stdout, stderr = ssh.client.exec_command(f"test -e {log_file_path} && echo exists") - if stdout.read().decode().strip() != 'exists': return status + # if log file doesn't exist, return None + if stdout.read().decode().strip() != 'exists': return None + + # otherwise check the log file to see if job status was written there with sftp.open(log_file_path, 'r') as log_file: logger.info(f"Checking scheduler log file {log_file_path} for job {task.job_id} status") + for line in log_file.readlines(): + # if we find success or failure, return immediately + if 'FAILED' in line or 'FAILURE' in line or 'NODE_FAIL' in line: + return 'FAILED' + if 'SUCCESS' in line or 'COMPLETED' in line: + return 'SUCCESS' + + # otherwise use the most recent status (last line of the log file) if 'CANCELLED' in line or 'CANCELED' in line: status = 'CANCELED' continue if 'TIMEOUT' in line: status = 'TIMEOUT' continue - if 'FAILED' in line or 'FAILURE' in line or 'NODE_FAIL' in line: - status = 'FAILED' - break - if 'SUCCESS' in line or 'COMPLETED' in line: - status = 'SUCCESS' - break return status diff --git a/plantit/plantit/task_resources.py b/plantit/plantit/task_resources.py index f7e15839..78becd5b 100644 --- a/plantit/plantit/task_resources.py +++ b/plantit/plantit/task_resources.py @@ -115,4 +115,4 @@ def get_scheduler_log_file_contents(task: Task) -> List[str]: else: scheduler_logs = [] - return scheduler_logs \ No newline at end of file + return scheduler_logs diff --git a/plantit/plantit/tasks/views.py b/plantit/plantit/tasks/views.py index 227e8003..96c946ce 100644 --- a/plantit/plantit/tasks/views.py +++ b/plantit/plantit/tasks/views.py @@ -21,7 +21,7 @@ from plantit.task_lifecycle import create_immediate_task, create_delayed_task, create_repeating_task, cancel_task from plantit.task_resources import get_task_ssh_client, push_task_channel_event, log_task_orchestrator_status from plantit.tasks.models import Task, DelayedTask, RepeatingTask, TaskStatus -from plantit.celery_tasks import prepare_task_environment, submit_task, poll_task_status +from plantit.celery_tasks import prepare_task_environment, submit_task, poll_job_status from plantit.utils.tasks import parse_task_time_limit, get_task_orchestrator_log_file_path, \ get_task_scheduler_log_file_path, \ get_task_agent_log_file_path @@ -67,7 +67,7 @@ def get_or_create(request): # submit task chain (prepare_task_environment.s(task.guid) | \ submit_task.s() | \ - poll_task_status.s()).apply_async( + poll_job_status.s()).apply_async( countdown=5, # TODO: make initial delay configurable soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1)