Skip to content

Commit

Permalink
[batch] E2E works with driver and request proxy (#272)
Browse files Browse the repository at this point in the history
* e2e driver and test

* comment functions

* check job status in test

* format update

* update copyright

* add examples with instructions and interfaces

* move batch tutorial

---------

Co-authored-by: xin.chen <xin.chen@bytedance.com>
  • Loading branch information
xinchen384 and xin.chen authored Oct 10, 2024
1 parent 8b8d120 commit 2f32a01
Show file tree
Hide file tree
Showing 9 changed files with 559 additions and 17 deletions.
44 changes: 44 additions & 0 deletions docs/tutorial/batch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Batch API Tutorial

## Prepare dataset
Before submitting a batch job, you need to prepare input data as a file.
In this file, each line represents a request.
The request's format is in json.
The json should cover multiple attributes as specified here, https://platform.openai.com/docs/guides/batch/getting-started, such as endpoint and completion window.

## Submit job input data
Before submit input data, we need to construct a driver first. Assuming that the file name of the input data is "one_job_input.json", we can
submit the data as the following. This call returns a job ID and later we rely on this job ID for remaining operations.

```
_driver = BatchDriver()
job_id = _driver.upload_batch_data("./one_job_input.json")
```

## Create and submit batch job
This submits batch job for inference.
One parameter is the endpoint name, which should be consistent with the endpoint given in the input file.
Another parameter is the time duration for job, after which this job will be considered as expired.

```
_driver.create_job(job_id, "sample_endpoint", "20m")
```

## Check job status
After the job submission, we can check job status using the following operation. This API requires the job ID.
The returned status might be one of the status: JobStatus.PENDING,JobStatus.IN_PROGRESS, JobStatus.COMPLETED.

```
status = _driver.get_job_status(job_id)
```

## Retrieve job's results
Lastly, we can retrieve job's results as follows.
The actual result depends on job's execution status.
If the job is already completed, the returned results are all the results returned from all requests.
If the job is not completed, it may contain partial results for the input requests.

```
results = _driver.retrieve_job_result(job_id)
```

21 changes: 21 additions & 0 deletions python/aibrix/aibrix/batch/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# The following are all constants.
# This is the time interval for the sliding window to check.
EXPIRE_INTERVAL = 1
# This is the job pool size in job scheduler.
# It should be proportional to resource size in the backend.
DEFAULT_JOB_POOL_SIZE = 1
66 changes: 66 additions & 0 deletions python/aibrix/aibrix/batch/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from aibrix.batch.constant import DEFAULT_JOB_POOL_SIZE

import aibrix.batch.storage as _storage
from aibrix.batch.scheduler import JobScheduler
from aibrix.batch.job_manager import JobManager
from aibrix.batch.request_proxy import RequestProxy


class BatchDriver:
def __init__(self):
"""
This is main entrance to bind all components to serve job requests.
"""
self._storage = _storage
self._job_manager = JobManager()
self._scheduler = JobScheduler(self._job_manager, DEFAULT_JOB_POOL_SIZE)
self._proxy = RequestProxy(self._storage, self._job_manager)
asyncio.create_task(self.jobs_running_loop())

def upload_batch_data(self, input_file_name):
job_id = self._storage.submit_job_input(input_file_name)
return job_id

def create_job(self, job_id, endpoint, window_due_time):
self._job_manager.create_job(job_id, endpoint, window_due_time)

due_time = self._job_manager.get_job_window_due(job_id)
self._scheduler.append_job(job_id, due_time)

def get_job_status(self, job_id):
return self._job_manager.get_job_status(job_id)

def retrieve_job_result(self, job_id):
num_requests = _storage.get_job_num_request(job_id)
req_results = _storage.get_job_results(job_id, 0, num_requests)
return req_results

async def jobs_running_loop(self):
"""
This loop is going through all active jobs in scheduler.
For now, the executing unit is one request. Later if necessary,
we can support a batch size of request per execution.
"""
while True:
one_job = self._scheduler.round_robin_get_job()
if one_job:
await self._proxy.execute_queries(one_job)
await asyncio.sleep(0)

def clear_job(self, job_id):
self._storage.delete_job(job_id)
89 changes: 79 additions & 10 deletions python/aibrix/aibrix/batch/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import time
import asyncio
from enum import Enum
import aibrix.batch.storage as _storage

Expand Down Expand Up @@ -45,13 +46,56 @@ def __init__(self, job_id, model_endpoint, completion_window):
self._completion_window_str = completion_window
self._completion_window = 0

self._async_lock = asyncio.Lock()
self._current_request_id = 0
self._request_progress_bits = []
self._succeed_num_requests = 0
self._job_status = JobStatus.CREATED

# extral metadata
self._meta_data = {}

def set_request_executed(self, req_id):
# This marks the request successfully executed.
self._request_progress_bits[req_id] = True

def get_request_bit(self, req_id):
return self._request_progress_bits[req_id]

def set_job_status(self, status):
self._job_status = status

def get_job_status(self):
return self._job_status

def complete_one_request(self, req_id):
"""
This is called after an inference call. If all requests
are done, we need to update its status to be completed.
"""
if not self._request_progress_bits[req_id]:
self.set_request_executed(req_id)
self._succeed_num_requests += 1
if self._succeed_num_requests == self._num_requests:
self._job_status = JobStatus.COMPLETED

def next_request_id(self):
"""
Returns the next request for inference. Due to the propobility
that some requests are failed, this returns a request that
are not marked as executed.
"""
if self._succeed_num_requests == self._num_requests:
return -1

req_id = self._current_request_id
while self._request_progress_bits[req_id]:
req_id += 1
if req_id == self._num_requests:
req_id = 0

return req_id

def validate_job(self):
"""
This handles all validations before successfully creating a job.
Expand All @@ -67,7 +111,7 @@ def validate_job(self):
return False

# 2. check window time is a valid time string
completion_time_str = self._completion_window
completion_time_str = self._completion_window_str
try:
# For now, this only supports either minute or hour.
# A mixed of both h and m, like "4h 40m" needs to be extended later.
Expand Down Expand Up @@ -204,9 +248,37 @@ def start_execute_job(self, job_id):

meta_data = self._pending_jobs[job_id]
self._in_progress_jobs[job_id] = meta_data
meta_data.set_job_status(JobStatus.IN_PROGRESS)
del self._pending_jobs[job_id]
return True

def get_job_next_request(self, job_id):
request_id = -1
if job_id not in self._in_progress_jobs:
print(f"Job {job_id} has not been scheduled yet.")
return request_id
meta_data = self._in_progress_jobs[job_id]

return meta_data.next_request_id()

def get_job_window_due(self, job_id):
if job_id not in self._pending_jobs:
print(f"Job {job_id} is not in pending state, its due may change.")
return -1

meta_data = self._pending_jobs[job_id]
return meta_data._completion_window

def get_job_endpoint(self, job_id):
if job_id in self._pending_jobs:
meta_data = self._pending_jobs[job_id]
elif job_id in self._in_progress_jobs:
meta_data = self._in_progress_jobs[job_id]
else:
print(f"Job {job_id} is discarded.")
return -1
return meta_data._model_endpoint

def mark_job_progress(self, job_id, executed_requests):
"""
This is used to sync job's progress, called by execution proxy.
Expand All @@ -217,25 +289,22 @@ def mark_job_progress(self, job_id, executed_requests):
return False

meta_data = self._in_progress_jobs[job_id]
request_len = len(meta_data._request_progress_bits)
succeed_num = 0
request_len = meta_data._num_requests
invalid_flag = False

for req_id in executed_requests:
if req_id < 0 or req_id >= request_len:
print(f"makr job {job_id} progress, request index out of boundary!")
invalid_flag = True
continue
if not meta_data._request_progress_bits[req_id]:
meta_data._request_progress_bits[req_id] = True
succeed_num += 1
meta_data.complete_one_request(req_id)

meta_data._succeed_num_requests += succeed_num
if meta_data._succeed_num_requests == request_len:
status = meta_data.get_job_status()
if status == JobStatus.COMPLETED:
# Mark the job to be completed if all requests are finished.
del self._in_progress_jobs[job_id]
meta_data._job_status = JobStatus.COMPLETED
self._done_jobs[job_id] = meta_data
print(f"Job {job_id} is completed.")
else:
self._in_progress_jobs[job_id] = meta_data

Expand All @@ -252,7 +321,7 @@ def expire_job(self, job_id):

if job_id in self._pending_jobs:
meta_data = self._pending_jobs[job_id]
meta_data._job_status = JobStatus.EXPIRED
meta_data.set_job_status(JobStatus.EXPIRED)
self._done_jobs[job_id] = meta_data
del self._pending_jobs[job_id]
elif job_id in self._in_progress_jobs:
Expand Down
79 changes: 79 additions & 0 deletions python/aibrix/aibrix/batch/request_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time


class RequestProxy:
def __init__(self, storage, manager):
""" """
self._storage = storage
self._job_manager = manager
self._inference_client = InferenceEngineClient()

async def execute_queries(self, job_id):
"""
This is the entrance to inference engine.
This fetches request input from storage and submit request
to inference engine. Lastly the result is stored back to storage.
"""
request_id = self._job_manager.get_job_next_request(job_id)
if request_id == -1:
print(f"Job {job_id} has something wrong with metadata in job manager.")
return

endpoint = self._job_manager.get_job_endpoint(job_id)
request_input = self.fetch_request_input(job_id, request_id)

print(f"executing job {job_id} request {request_id}")
request_output = self._inference_client.inference_request(
endpoint, request_input
)
self.store_output(job_id, request_id, request_output)

self.sync_job_status(job_id, request_id)

def fetch_request_input(self, job_id, request_id):
"""
Read request input from storage. Now it only reads one request.
Later we can add a list as a batch per call.
"""
num_request = 1
requests = self._storage.get_job_input_requests(job_id, request_id, num_request)
return requests[0]

def store_output(self, job_id, request_id, result):
"""
Write the request result back to storage.
"""
self._storage.put_job_results(job_id, request_id, [result])

def sync_job_status(self, job_id, reqeust_id):
"""
Update job's status back to job manager.
"""
self._job_manager.mark_job_progress(job_id, [reqeust_id])


class InferenceEngineClient:
def __init__(self):
"""
Initiate client to inference engine, such as account
and its authentication.
"""
pass

def inference_request(self, endpoint, prompt_list):
time.sleep(1)
return prompt_list
Loading

0 comments on commit 2f32a01

Please sign in to comment.