Skip to content

Commit

Permalink
better task status handling (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Feb 7, 2022
1 parent b0c58b2 commit 60aab47
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 83 deletions.
241 changes: 172 additions & 69 deletions plantit/plantit/celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_and_submit_delayed(username, workflow, delayed_id: str = None):
# submit task chain
(prepare_task_environment.s(task.guid) | \
submit_task.s() | \
poll_task_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1)
poll_job_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1)


@app.task(track_started=True)
Expand All @@ -78,7 +78,7 @@ def create_and_submit_repeating(username, workflow, repeating_id: str = None):
# submit task chain
(prepare_task_environment.s(task.guid) | \
submit_task.s() | \
poll_task_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1)
poll_job_status.s()).apply_async(soft_time_limit=int(settings.TASKS_STEP_TIME_LIMIT_SECONDS), priority=1)


@app.task(track_started=True, bind=True)
Expand Down Expand Up @@ -149,7 +149,7 @@ def submit_task(self, guid: str):


@app.task(bind=True)
def poll_task_status(self, guid: str):
def poll_job_status(self, guid: str):
try:
task = Task.objects.get(guid=guid)
except:
Expand All @@ -159,29 +159,86 @@ def poll_task_status(self, guid: str):

refresh_delay = int(environ.get('TASKS_REFRESH_SECONDS'))
cleanup_delay = int(environ.get('TASKS_CLEANUP_MINUTES')) * 60
logger.info(f"Checking {task.agent.name} scheduler status for run {guid} (SLURM job {task.job_id})")
logger.info(f"Checking {task.agent.name} scheduler status for task {guid} (job {task.job_id})")

# if the job already failed, schedule cleanup
# if the scheduler job failed...
if task.job_status == 'FAILURE':
# mark the task as a failure
task.status = TaskStatus.FAILURE
final_message = f"Job {task.job_id} failed"
log_task_orchestrator_status(task, [final_message])
task.save()

# log the status update and push it to clients
message = f"Job {task.job_id} failed"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# schedule cleanup
cleanup_task.s(guid).apply_async(countdown=cleanup_delay, priority=2)
task.cleanup_time = timezone.now() + timedelta(seconds=cleanup_delay)
task.save()

# push AWS SNS notification
if task.user.profile.push_notification_status == 'enabled':
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {})
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {})

return guid

# otherwise poll the scheduler for job status and walltime and update the task
try:
# get the job status from the scheduler
check_job_logs_for_progress(task)
job_status = get_job_status(task)
job_walltime = get_job_walltime(task)

# get_job_status() returns None if the job isn't found in the agent's scheduler.
# there are 2 reasons this might happen:
# - the job was just submitted and hasn't been picked up for reporting by the scheduler yet
# - the job already completed and we waited too long between polls to check its status
#
# in both cases we return early
if job_status is None:
# we might have just submitted the job; scheduler may take a moment to reflect new submissions
if not (task.job_status == 'COMPLETED' or task.job_status == 'COMPLETING'):
# update the task and persist it
now = timezone.now()
task.updated = now
task.job_status = job_status
task.job_consumed_walltime = job_walltime
task.save()
retry_seconds = 10

# log the status update and push it to clients
message = f"Job {task.job_id} not found, retrying in {retry_seconds} seconds"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# wait a few seconds and poll again
poll_job_status.s(guid).apply_async(countdown=retry_seconds)
return
# otherwise the job completed and the scheduler's forgotten about it in the interval between polls
else:
# update the task and persist it
now = timezone.now()
task.updated = now
task.job_status = 'COMPLETED'
task.job_consumed_walltime = job_walltime
task.save()

# check that we have the results we expect
list_task_results.s(guid).apply_async()

# log the status update and push it to clients
message = f"Job {task.job_id} completed with unknown status" + (f" after {job_walltime}" if job_walltime is not None else '')
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# push AWS SNS notification
if task.user.profile.push_notification_status == 'enabled':
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {})

return guid

# update the task and persist it
task.job_status = job_status
task.job_consumed_walltime = job_walltime
now = timezone.now()
Expand All @@ -192,71 +249,58 @@ def poll_task_status(self, guid: str):
ssh = get_task_ssh_client(task)
get_task_remote_logs(task, ssh)

# mark task status according to job status
if job_status == 'COMPLETED':
task.completed = now
task.status = TaskStatus.SUCCESS
elif job_status == 'FAILED':
task.completed = now
# if job did not complete, go ahead and mark the task failed/cancelled/timed out/etc
job_complete = False
if job_status == 'FAILED':
task.status = TaskStatus.FAILURE
job_complete = True
elif job_status == 'CANCELLED':
task.completed = now
task.status = TaskStatus.CANCELED
job_complete = True
elif job_status == 'TIMEOUT':
task.completed = now
task.status = TaskStatus.TIMEOUT
job_complete = True
# but if it succeeded, we still need to check results before determining success/failure
elif job_status == 'COMPLETED':
job_complete = True

# update the task and persist it
now = timezone.now()
task.updated = now
task.save()

if task.is_complete:
# job is done, task is complete, now we can list results
if job_complete:
# check that we have the results we expect
list_task_results.s(guid).apply_async()
final_message = f"Job {task.job_id} {job_status}" + (f" after {job_walltime}" if job_walltime is not None else '')
log_task_orchestrator_status(task, [final_message])

# log the status update and push it to clients
message = f"Job {task.job_id} completed with status {job_status}" + (f" after {job_walltime}" if job_walltime is not None else '')
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# push AWS SNS notification
if task.user.profile.push_notification_status == 'enabled':
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {})
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {})

return guid
else:
# if task is past its due time, cancel it
# if past due time...
if now > task.due_time:
log_task_orchestrator_status(task, [f"Job {task.job_id} {job_status} (walltime {job_walltime}) is past its due time {str(task.due_time)}"])
# log the status update and push it to clients
message = f"Job {task.job_id} {job_status} (walltime {job_walltime}) is past its due time {str(task.due_time)}"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# cancel the task
cancel_task(task)
# otherwise schedule another round of polling
else:
log_task_orchestrator_status(task, [f"Job {task.job_id} {job_status} (walltime {job_walltime})"])
# log the status update and push it to clients
message = f"Job {task.job_id} {job_status} (walltime {job_walltime})"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)
poll_task_status.s(guid).apply_async(countdown=refresh_delay)
except StopIteration as e:
if not (task.job_status == 'COMPLETED' or task.job_status == 'COMPLETING'):
# we probably just created the task and
# it's not visible in the scheduler yet
# just wait a few seconds and try again
now = timezone.now()
task.updated = now
task.save()
retry_seconds = 10
log_task_orchestrator_status(task, [f"Job {task.job_id} not found, retrying in {retry_seconds} seconds"])
async_to_sync(push_task_channel_event)(task)
poll_task_status.s(guid).apply_async(countdown=retry_seconds)
return
else:
# job is done, task is complete, now we can list results
final_message = f"Job {task.job_id} succeeded"
log_task_orchestrator_status(task, [final_message])
async_to_sync(push_task_channel_event)(task)
cleanup_task.s(guid).apply_async(countdown=cleanup_delay, priority=2)
task.cleanup_time = timezone.now() + timedelta(seconds=cleanup_delay)
task.save()

# push AWS SNS notification
if task.user.profile.push_notification_status == 'enabled':
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {})

return guid
# otherwise schedule another round of polling
poll_job_status.s(guid).apply_async(countdown=refresh_delay)
except:
# mark the task failed
task.status = TaskStatus.FAILURE
Expand All @@ -266,16 +310,16 @@ def poll_task_status(self, guid: str):
task.save()

# there was an unexpected runtime exception somewhere, need to catch and log it
final_message = f"Job {task.job_id} encountered unexpected error: {traceback.format_exc()}"
log_task_orchestrator_status(task, [final_message])
message = f"Job {task.job_id} encountered unexpected error: {traceback.format_exc()}"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)
cleanup_task.s(guid).apply_async(countdown=cleanup_delay, priority=2)
task.cleanup_time = timezone.now() + timedelta(seconds=cleanup_delay)
task.save()

# push AWS SNS notification
if task.user.profile.push_notification_status == 'enabled':
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", final_message, {})
SnsClient.get().publish_message(task.user.profile.push_notification_topic_arn, f"PlantIT task {task.guid}", message, {})

# stop the task chain
self.request.callbacks = None
Expand All @@ -293,53 +337,112 @@ def list_task_results(self, guid: str):
redis = RedisClient.get()
ssh = get_task_ssh_client(task)

log_task_orchestrator_status(task, [f"Retrieving logs"])
# log status update and push it to clients
message = f"Retrieving logs"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# get logs from agent filesystem
get_task_remote_logs(task, ssh)

log_task_orchestrator_status(task, [f"Retrieving results"])
# log status update and push it to clients
message = f"Retrieving results"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)
expected = list_result_files(task)
found = [e for e in expected if e['exists']]

# get results from agent filesystem, then save them to cache and update the task
results = list_result_files(task)
found = [r for r in results if r['exists']]
missing = [r for r in results if not r['exists']]
redis.set(f"results/{task.guid}", json.dumps(found))
task.results_retrieved = True
task.save()

log_task_orchestrator_status(task, [f"Expected {len(expected)} result(s), found {len(found)}, verifying data was transferred to CyVerse"])
# make sure we got the results we expected
if len(missing) > 0:
# mark the task failed
task.status = TaskStatus.FAILURE
now = timezone.now()
task.updated = now
task.completed = now
task.save()

# log status update and push it to clients
message = f"Found {len(found)} results, missing {len(missing)}: {', '.join([m['name'] for m in missing])}"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)
else:
# log status update and push it to clients
message = f"Found {len(found)} results"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# log status update and push it to clients
message = f"Verifying data was transferred to CyVerse"
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

# make sure the results and logs were pushed to their destination in CyVerse
check_task_cyverse_transfer.s(guid).apply_async()

return guid


@app.task(bind=True)
def check_task_cyverse_transfer(self, guid: str, iteration: int = 0):
def check_task_cyverse_transfer(self, guid: str, attempts: int = 0):
try:
task = Task.objects.get(guid=guid)
except:
logger.warning(f"Could not find task with GUID {guid} (might have been deleted?)")
self.request.callbacks = None # stop the task chain
return

# check the expected filenames against the contents of the CyVerse collection
path = task.workflow['output']['to']
actual = [file.rpartition('/')[2] for file in terrain.list_dir(path, task.user.profile.cyverse_access_token)]
expected = [file['name'] for file in json.loads(RedisClient.get().get(f"results/{task.guid}"))]

if not set(expected).issubset(set(actual)):
logger.warning(f"Expected {len(expected)} results but found {len(actual)}")
if iteration < 5:
logger.warning(f"Checking again in 30 seconds (iteration {iteration})")
check_task_cyverse_transfer.s(guid, iteration + 1).apply_async(countdown=30)
logger.warning(f"Expected {len(expected)} uploads to CyVerse but found {len(actual)}")

# TODO make this configurable
max_attempts = 10
countdown = 30

if attempts < max_attempts:
message = f"Transfer to CyVerse directory {path} incomplete, checking again in {countdown} seconds (attempt {attempts})"
logger.warning(message)
check_task_cyverse_transfer.s(guid, attempts + 1).apply_async(countdown=countdown)
else:
message = f"Transfer to CyVerse directory {path} failed to complete after {attempts * countdown} seconds"
logger.info(message)

# mark the task failed
now = timezone.now()
task.updated = now
task.completed = now
task.status = TaskStatus.FAILURE
task.transferred = True
task.results_transferred = len(expected)
task.transfer_path = path
task.save()
else:
msg = f"Transfer to CyVerse directory {path} completed"
logger.info(msg)
message = f"Transfer to CyVerse directory {path} completed"
logger.info(message)

# mark the task succeeded
now = timezone.now()
task.updated = now
task.completed = now
task.status = TaskStatus.SUCCESS
task.transferred = True
task.results_transferred = len(expected)
task.transfer_path = path
task.save()

log_task_orchestrator_status(task, [msg])
async_to_sync(push_task_channel_event)(task)
# log status update and push it to clients
log_task_orchestrator_status(task, [message])
async_to_sync(push_task_channel_event)(task)

cleanup_delay = int(environ.get('TASKS_CLEANUP_MINUTES')) * 60
cleanup_task.s(guid).apply_async(priority=2, countdown=cleanup_delay)
Expand Down
Loading

0 comments on commit 60aab47

Please sign in to comment.