Skip to content

Commit

Permalink
ref: bypass queueing jobs with invalid payload (#121)
Browse files Browse the repository at this point in the history
* removed JobStatus.SKIPPED
* now we throw 422 for invalid payloads (same as pydantic validation)

Co-authored-by: Avram Tudor <tudor.avram@8x8.com>
  • Loading branch information
quitrk and Avram Tudor authored Nov 14, 2024
1 parent bfac47e commit 53eea13
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 50 deletions.
84 changes: 38 additions & 46 deletions skynet/modules/ttt/summaries/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from skynet.auth.openai import CredentialsType, get_credentials

from skynet.env import enable_batching, job_timeout, modules, redis_exp_seconds, summary_minimum_payload_length
from skynet.env import enable_batching, job_timeout, modules, redis_exp_seconds
from skynet.logs import get_logger
from skynet.modules.monitoring import (
OPENAI_API_RESTART_COUNTER,
Expand Down Expand Up @@ -149,10 +149,7 @@ async def run_job(job: Job) -> None:

async def update_done_job(job: Job, result: str, processor: Processors, has_failed: bool = False) -> None:
should_expire = not has_failed or processor != Processors.LOCAL
status = job.status

if status != JobStatus.SKIPPED:
status = JobStatus.ERROR if has_failed else JobStatus.SUCCESS
status = JobStatus.ERROR if has_failed else JobStatus.SUCCESS

updated_job = await update_job(
expires=redis_exp_seconds if should_expire else None,
Expand All @@ -168,67 +165,62 @@ async def update_done_job(job: Job, result: str, processor: Processors, has_fail

await db.lrem(RUNNING_JOBS_KEY, 0, job.id)

if updated_job.status != JobStatus.SKIPPED:
SUMMARY_DURATION_METRIC.labels(updated_job.metadata.app_id).observe(updated_job.computed_duration)
SUMMARY_FULL_DURATION_METRIC.observe(updated_job.computed_full_duration)
SUMMARY_INPUT_LENGTH_METRIC.observe(len(updated_job.payload.text))
SUMMARY_DURATION_METRIC.labels(updated_job.metadata.app_id).observe(updated_job.computed_duration)
SUMMARY_FULL_DURATION_METRIC.observe(updated_job.computed_full_duration)
SUMMARY_INPUT_LENGTH_METRIC.observe(len(updated_job.payload.text))

log.info(
f"Job {updated_job.id} duration: {updated_job.computed_duration}s full duration: {updated_job.computed_full_duration}s"
)
log.info(
f"Job {updated_job.id} duration: {updated_job.computed_duration}s full duration: {updated_job.computed_full_duration}s"
)


async def _run_job(job: Job) -> None:
has_failed = False
result = None
worker_id = await db.db.client_id()
start = time.time()
status = JobStatus.SKIPPED if len(job.payload.text) < summary_minimum_payload_length else JobStatus.RUNNING
customer_id = job.metadata.customer_id
processor = get_job_processor(customer_id) # may have changed since job was created

SUMMARY_TIME_IN_QUEUE_METRIC.observe(start - job.created)

log.info(f"Running job {job.id}. Queue time: {round(start - job.created, 3)} seconds")

job = await update_job(job_id=job.id, start=start, status=status, worker_id=worker_id, processor=processor)
job = await update_job(
job_id=job.id, start=start, status=JobStatus.RUNNING, worker_id=worker_id, processor=processor
)

# add to running jobs list if not already there (which may occur on multiple worker disconnects while running the same job)
if job.id not in await db.lrange(RUNNING_JOBS_KEY, 0, -1):
await db.rpush(RUNNING_JOBS_KEY, job.id)

if status == JobStatus.SKIPPED:
log.info(f"Summarisation for {job.id} did not run because payload is too short: \"{job.payload.text}\"")

result = job.payload.text
else:
try:
options = get_credentials(customer_id)
secret = options.get('secret')

if processor == Processors.OPENAI:
log.info(f"Forwarding inference to OpenAI for customer {customer_id}")

# needed for backwards compatibility
model = options.get('model') or options.get('metadata').get('model')
result = await process_open_ai(job.payload, job.type, secret, model)
elif processor == Processors.AZURE:
log.info(f"Forwarding inference to Azure openai for customer {customer_id}")

metadata = options.get('metadata')
result = await process_azure(
job.payload, job.type, secret, metadata.get('endpoint'), metadata.get('deploymentName')
)
else:
if customer_id:
log.info(f'Customer {customer_id} has no API key configured, falling back to local processing')

result = await process(job.payload, job.type)
except Exception as e:
log.warning(f"Job {job.id} failed: {e}")

has_failed = True
result = str(e)
try:
options = get_credentials(customer_id)
secret = options.get('secret')

if processor == Processors.OPENAI:
log.info(f"Forwarding inference to OpenAI for customer {customer_id}")

# needed for backwards compatibility
model = options.get('model') or options.get('metadata').get('model')
result = await process_open_ai(job.payload, job.type, secret, model)
elif processor == Processors.AZURE:
log.info(f"Forwarding inference to Azure openai for customer {customer_id}")

metadata = options.get('metadata')
result = await process_azure(
job.payload, job.type, secret, metadata.get('endpoint'), metadata.get('deploymentName')
)
else:
if customer_id:
log.info(f'Customer {customer_id} has no API key configured, falling back to local processing')

result = await process(job.payload, job.type)
except Exception as e:
log.warning(f"Job {job.id} failed: {e}")

has_failed = True
result = str(e)

await update_done_job(job, result, processor, has_failed)

Expand Down
1 change: 0 additions & 1 deletion skynet/modules/ttt/summaries/v1/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class JobStatus(Enum):
ERROR = 'error'
PENDING = 'pending'
RUNNING = 'running'
SKIPPED = 'skipped'
SUCCESS = 'success'


Expand Down
13 changes: 10 additions & 3 deletions skynet/modules/ttt/summaries/v1/router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from fastapi import HTTPException, Request
from fastapi import Depends, HTTPException, Request
from fastapi_versionizer.versionizer import api_version

from skynet.env import summary_minimum_payload_length

from skynet.utils import get_router

from ..jobs import create_job, get_job as get_job
Expand All @@ -26,8 +28,13 @@ def get_metadata(request: Request) -> DocumentMetadata:
return DocumentMetadata(app_id=get_app_id(request), customer_id=get_customer_id(request))


def validate_payload(payload: DocumentPayload) -> None:
if len(payload.text) < summary_minimum_payload_length:
raise HTTPException(status_code=422, detail="Payload is too short")


@api_version(1)
@router.post("/action-items")
@router.post("/action-items", dependencies=[Depends(validate_payload)])
async def get_action_items(payload: DocumentPayload, request: Request) -> JobId:
"""
Starts a job to extract action items from the given payload.
Expand All @@ -37,7 +44,7 @@ async def get_action_items(payload: DocumentPayload, request: Request) -> JobId:


@api_version(1)
@router.post("/summary")
@router.post("/summary", dependencies=[Depends(validate_payload)])
async def get_summary(payload: DocumentPayload, request: Request) -> JobId:
"""
Starts a job to summarize the given payload.
Expand Down

0 comments on commit 53eea13

Please sign in to comment.