Skip to content

Commit

Permalink
Add a workflow stats viewer
Browse files Browse the repository at this point in the history
  • Loading branch information
TorecLuik committed Aug 12, 2024
1 parent 1f07e44 commit 2fcf853
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 20 deletions.
6 changes: 4 additions & 2 deletions biomero/slurm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import io
import os
from biomero.eventsourcing import WorkflowTracker
from biomero.views import JobAccounting, JobProgress
from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics
from eventsourcing.system import System, SingleThreadedRunner

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -402,7 +402,8 @@ def __init__(self,
self.track_workflows = track_workflows
system = System(pipes=[
[WorkflowTracker, JobAccounting],
[WorkflowTracker, JobProgress]
[WorkflowTracker, JobProgress],
[WorkflowTracker, WorkflowAnalytics]
])
if self.track_workflows: # use configured persistence from env
runner = SingleThreadedRunner(system)
Expand All @@ -413,6 +414,7 @@ def __init__(self,
self.workflowTracker = runner.get(WorkflowTracker)
self.jobAccounting = runner.get(JobAccounting)
self.jobProgress = runner.get(JobProgress)
self.workflowAnalytics = runner.get(WorkflowAnalytics)

def init_workflows(self, force_update: bool = False):
"""
Expand Down
341 changes: 323 additions & 18 deletions biomero/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,55 @@
from uuid import NAMESPACE_URL, UUID, uuid5
from typing import Any, Dict, List
import logging
from sqlalchemy import create_engine, text, Column, Integer, String, URL
from sqlalchemy import create_engine, text, Column, Integer, String, URL, DateTime, Float
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import func
from sqlalchemy.dialects.postgresql import UUID as PGUUID
from biomero.eventsourcing import WorkflowRun, Task


logger = logging.getLogger(__name__)

# --------------------- VIEWS ---------------------------- #
# --------------------- VIEWS DB tables/classes ---------------------------- #

# Base class for declarative class definitions
Base = declarative_base()


class JobView(Base):
__tablename__ = 'biomero_job_view'

slurm_job_id = Column(Integer, primary_key=True)
user = Column(Integer, nullable=False)
group = Column(Integer, nullable=False)


class JobProgressView(Base):
__tablename__ = 'biomero_job_progress_view'

slurm_job_id = Column(Integer, primary_key=True)
status = Column(String, nullable=False)
progress = Column(String, nullable=True)


class TaskExecution(Base):
__tablename__ = 'biomero_task_execution'

task_id = Column(PGUUID(as_uuid=True), primary_key=True)
task_name = Column(String, nullable=False)
task_version = Column(String)
user_id = Column(Integer, nullable=True)
group_id = Column(Integer, nullable=True)
status = Column(String, nullable=False)
start_time = Column(DateTime, nullable=False)
end_time = Column(DateTime, nullable=True)
error_type = Column(String, nullable=True)


# ------------------- View Listener Applications ------------------ #


class BaseApplication:
def __init__(self):
# Read database configuration from environment variables
Expand Down Expand Up @@ -50,22 +85,6 @@ def __init__(self):
Base.metadata.create_all(self.engine)


class JobView(Base):
__tablename__ = 'biomero_job_view'

slurm_job_id = Column(Integer, primary_key=True)
user = Column(Integer, nullable=False)
group = Column(Integer, nullable=False)


class JobProgressView(Base):
__tablename__ = 'biomero_job_progress_view'

slurm_job_id = Column(Integer, primary_key=True)
status = Column(String, nullable=False)
progress = Column(String, nullable=True)


class JobAccounting(ProcessApplication, BaseApplication):
def __init__(self, *args, **kwargs):
ProcessApplication.__init__(self, *args, **kwargs)
Expand Down Expand Up @@ -263,3 +282,289 @@ def update_view_table(self, job_id):
except IntegrityError:
session.rollback()
logger.error(f"Failed to insert/update job progress in view table: job_id={job_id}")


class WorkflowAnalytics(BaseApplication, ProcessApplication):
def __init__(self, *args, **kwargs):
ProcessApplication.__init__(self, *args, **kwargs)
BaseApplication.__init__(self)

# State tracking for workflows and tasks
self.workflows = {} # {wf_id: {"user": user, "group": group}}
self.tasks = {} # {task_id: {"wf_id": wf_id, "task_name": task_name, "task_version": task_version, "start_time": timestamp, "status": status, "end_time": timestamp, "error_type": error_type}}

@singledispatchmethod
def policy(self, domain_event, process_event):
"""Default policy"""
pass

@policy.register(WorkflowRun.WorkflowInitiated)
def _(self, domain_event, process_event):
"""Handle WorkflowInitiated event"""
user = domain_event.user
group = domain_event.group
wf_id = domain_event.originator_id

# Track workflow
self.workflows[wf_id] = {"user": user, "group": group}
logger.debug(f"Workflow initiated: wf_id={wf_id}, user={user}, group={group}")

@policy.register(WorkflowRun.TaskAdded)
def _(self, domain_event, process_event):
"""Handle TaskAdded event"""
task_id = domain_event.task_id
wf_id = domain_event.originator_id

# Add workflow ID to the existing task information
if task_id in self.tasks:
self.tasks[task_id]["wf_id"] = wf_id
else:
# In case TaskAdded arrives before TaskCreated (unlikely but possible)
self.tasks[task_id] = {"wf_id": wf_id}

logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}")

@policy.register(Task.TaskCreated)
def _(self, domain_event, process_event):
"""Handle TaskCreated event"""
task_id = domain_event.originator_id
task_name = domain_event.task_name
task_version = domain_event.task_version
timestamp_created = domain_event.timestamp

# Track task creation details
if task_id in self.tasks:
self.tasks[task_id].update({
"task_name": task_name,
"task_version": task_version,
"start_time": timestamp_created
})
else:
# Initialize task tracking if TaskAdded hasn't been processed yet
self.tasks[task_id] = {
"task_name": task_name,
"task_version": task_version,
"start_time": timestamp_created
}

logger.debug(f"Task created: task_id={task_id}, task_name={task_name}, timestamp={timestamp_created}")
self.update_view_table(task_id)

@policy.register(Task.StatusUpdated)
def _(self, domain_event, process_event):
"""Handle StatusUpdated event"""
task_id = domain_event.originator_id
status = domain_event.status

# Update task with status
if task_id in self.tasks:
self.tasks[task_id]["status"] = status
logger.debug(f"Task status updated: task_id={task_id}, status={status}")
self.update_view_table(task_id)

@policy.register(Task.TaskCompleted)
def _(self, domain_event, process_event):
"""Handle TaskCompleted event"""
task_id = domain_event.originator_id
timestamp_completed = domain_event.timestamp

# Update task with end time
if task_id in self.tasks:
self.tasks[task_id]["end_time"] = timestamp_completed
logger.debug(f"Task completed: task_id={task_id}, end_time={timestamp_completed}")
self.update_view_table(task_id)

@policy.register(Task.TaskFailed)
def _(self, domain_event, process_event):
"""Handle TaskFailed event"""
task_id = domain_event.originator_id
timestamp_failed = domain_event.timestamp
error_message = domain_event.error_message

# Update task with end time and error message
if task_id in self.tasks:
self.tasks[task_id]["end_time"] = timestamp_failed
self.tasks[task_id]["error_type"] = error_message
logger.debug(f"Task failed: task_id={task_id}, end_time={timestamp_failed}, error={error_message}")
self.update_view_table(task_id)

def update_view_table(self, task_id):
"""Update the view table with new task execution information."""
task_info = self.tasks.get(task_id)
if not task_info:
return # Skip if task information is incomplete

wf_id = task_info.get("wf_id")
user_id = None
group_id = None

# Retrieve user and group from workflow
if wf_id and wf_id in self.workflows:
user_id = self.workflows[wf_id]["user"]
group_id = self.workflows[wf_id]["group"]

with self.SessionLocal() as session:
try:
existing_task = session.query(TaskExecution).filter_by(task_id=task_id).first()
if existing_task:
# Update existing task execution record
existing_task.task_name = task_info.get("task_name", existing_task.task_name)
existing_task.task_version = task_info.get("task_version", existing_task.task_version)
existing_task.user_id = user_id
existing_task.group_id = group_id
existing_task.status = task_info.get("status", existing_task.status)
existing_task.start_time = task_info.get("start_time", existing_task.start_time)
existing_task.end_time = task_info.get("end_time", existing_task.end_time)
existing_task.error_type = task_info.get("error_type", existing_task.error_type)
else:
# Create a new task execution record
new_task_execution = TaskExecution(
task_id=task_id,
task_name=task_info.get("task_name"),
task_version=task_info.get("task_version"),
user_id=user_id,
group_id=group_id,
status=task_info.get("status"),
start_time=task_info.get("start_time"),
end_time=task_info.get("end_time"),
error_type=task_info.get("error_type")
)
session.add(new_task_execution)

session.commit()
logger.debug(f"Updated/Inserted task execution into view table: task_id={task_id}, task_name={task_info.get('task_name')}")
except IntegrityError:
session.rollback()
logger.error(f"Failed to insert/update task execution into view table: task_id={task_id}")

def get_task_counts(self, user=None, group=None):
"""Retrieve task execution counts grouped by task name and version.
Parameters:
- user (int, optional): The user ID to filter by.
- group (int, optional): The group ID to filter by.
Returns:
- Dictionary of task names and versions to counts.
"""
with self.SessionLocal() as session:
query = session.query(
TaskExecution.task_name,
TaskExecution.task_version,
func.count(TaskExecution.task_name)
).group_by(TaskExecution.task_name, TaskExecution.task_version)

if user is not None:
query = query.filter_by(user_id=user)

if group is not None:
query = query.filter_by(group_id=group)

task_counts = query.all()
result = {
(task_name, task_version): count
for task_name, task_version, count in task_counts
}
logger.debug(f"Retrieved task counts: {result}")
return result

def get_average_task_duration(self, user=None, group=None):
"""Retrieve the average task duration grouped by task name and version.
Parameters:
- user (int, optional): The user ID to filter by.
- group (int, optional): The group ID to filter by.
Returns:
- Dictionary of task names and versions to average duration (in seconds).
"""
with self.SessionLocal() as session:
query = session.query(
TaskExecution.task_name,
TaskExecution.task_version,
func.avg(
func.extract('epoch', TaskExecution.end_time) - func.extract('epoch', TaskExecution.start_time)
).label('avg_duration')
).filter(TaskExecution.end_time.isnot(None)) # Only include completed tasks
query = query.group_by(TaskExecution.task_name, TaskExecution.task_version)

if user is not None:
query = query.filter_by(user_id=user)

if group is not None:
query = query.filter_by(group_id=group)

task_durations = query.all()
result = {
(task_name, task_version): avg_duration
for task_name, task_version, avg_duration in task_durations
}
logger.debug(f"Retrieved average task durations: {result}")
return result

def get_task_failures(self, user=None, group=None):
"""Retrieve tasks that failed, grouped by task name and version.
Parameters:
- user (int, optional): The user ID to filter by.
- group (int, optional): The group ID to filter by.
Returns:
- Dictionary of task names and versions to lists of failure reasons.
"""
with self.SessionLocal() as session:
query = session.query(
TaskExecution.task_name,
TaskExecution.task_version,
TaskExecution.error_type
).filter(TaskExecution.error_type.isnot(None)) # Only include failed tasks
query = query.group_by(TaskExecution.task_name, TaskExecution.task_version, TaskExecution.error_type)

if user is not None:
query = query.filter_by(user_id=user)

if group is not None:
query = query.filter_by(group_id=group)

task_failures = query.all()
result = {}
for task_name, task_version, error_type in task_failures:
key = (task_name, task_version)
if key not in result:
result[key] = []
result[key].append(error_type)

logger.debug(f"Retrieved task failures: {result}")
return result

def get_task_usage_over_time(self, task_name, user=None, group=None):
"""Retrieve task usage over time for a specific task.
Parameters:
- task_name (str): The name of the task to filter by.
- user (int, optional): The user ID to filter by.
- group (int, optional): The group ID to filter by.
Returns:
- Dictionary mapping date to the count of task executions on that date.
"""
with self.SessionLocal() as session:
query = session.query(
func.date(TaskExecution.start_time),
func.count(TaskExecution.task_name)
).filter(TaskExecution.task_name == task_name)
query = query.group_by(func.date(TaskExecution.start_time))

if user is not None:
query = query.filter_by(user_id=user)

if group is not None:
query = query.filter_by(group_id=group)

usage_over_time = query.all()
result = {
date: count
for date, count in usage_over_time
}
logger.debug(f"Retrieved task usage over time for {task_name}: {result}")
return result

0 comments on commit 2fcf853

Please sign in to comment.