-
Notifications
You must be signed in to change notification settings - Fork 267
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[batch] E2E works with driver and request proxy (#272)
* e2e driver and test * comment functions * check job status in test * format update * update copyright * add examples with instructions and interfaces * move batch tutorial --------- Co-authored-by: xin.chen <xin.chen@bytedance.com>
- Loading branch information
1 parent
8b8d120
commit 2f32a01
Showing
9 changed files
with
559 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Batch API Tutorial | ||
|
||
## Prepare dataset | ||
Before submitting a batch job, you need to prepare input data as a file. | ||
In this file, each line represents a request. | ||
The request's format is in json. | ||
The json should cover multiple attributes as specified here, https://platform.openai.com/docs/guides/batch/getting-started, such as endpoint and completion window. | ||
|
||
## Submit job input data | ||
Before submit input data, we need to construct a driver first. Assuming that the file name of the input data is "one_job_input.json", we can | ||
submit the data as the following. This call returns a job ID and later we rely on this job ID for remaining operations. | ||
|
||
``` | ||
_driver = BatchDriver() | ||
job_id = _driver.upload_batch_data("./one_job_input.json") | ||
``` | ||
|
||
## Create and submit batch job | ||
This submits batch job for inference. | ||
One parameter is the endpoint name, which should be consistent with the endpoint given in the input file. | ||
Another parameter is the time duration for job, after which this job will be considered as expired. | ||
|
||
``` | ||
_driver.create_job(job_id, "sample_endpoint", "20m") | ||
``` | ||
|
||
## Check job status | ||
After the job submission, we can check job status using the following operation. This API requires the job ID. | ||
The returned status might be one of the status: JobStatus.PENDING,JobStatus.IN_PROGRESS, JobStatus.COMPLETED. | ||
|
||
``` | ||
status = _driver.get_job_status(job_id) | ||
``` | ||
|
||
## Retrieve job's results | ||
Lastly, we can retrieve job's results as follows. | ||
The actual result depends on job's execution status. | ||
If the job is already completed, the returned results are all the results returned from all requests. | ||
If the job is not completed, it may contain partial results for the input requests. | ||
|
||
``` | ||
results = _driver.retrieve_job_result(job_id) | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright 2024 The Aibrix Team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
# The following are all constants. | ||
# This is the time interval for the sliding window to check. | ||
EXPIRE_INTERVAL = 1 | ||
# This is the job pool size in job scheduler. | ||
# It should be proportional to resource size in the backend. | ||
DEFAULT_JOB_POOL_SIZE = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright 2024 The Aibrix Team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import asyncio | ||
from aibrix.batch.constant import DEFAULT_JOB_POOL_SIZE | ||
|
||
import aibrix.batch.storage as _storage | ||
from aibrix.batch.scheduler import JobScheduler | ||
from aibrix.batch.job_manager import JobManager | ||
from aibrix.batch.request_proxy import RequestProxy | ||
|
||
|
||
class BatchDriver: | ||
def __init__(self): | ||
""" | ||
This is main entrance to bind all components to serve job requests. | ||
""" | ||
self._storage = _storage | ||
self._job_manager = JobManager() | ||
self._scheduler = JobScheduler(self._job_manager, DEFAULT_JOB_POOL_SIZE) | ||
self._proxy = RequestProxy(self._storage, self._job_manager) | ||
asyncio.create_task(self.jobs_running_loop()) | ||
|
||
def upload_batch_data(self, input_file_name): | ||
job_id = self._storage.submit_job_input(input_file_name) | ||
return job_id | ||
|
||
def create_job(self, job_id, endpoint, window_due_time): | ||
self._job_manager.create_job(job_id, endpoint, window_due_time) | ||
|
||
due_time = self._job_manager.get_job_window_due(job_id) | ||
self._scheduler.append_job(job_id, due_time) | ||
|
||
def get_job_status(self, job_id): | ||
return self._job_manager.get_job_status(job_id) | ||
|
||
def retrieve_job_result(self, job_id): | ||
num_requests = _storage.get_job_num_request(job_id) | ||
req_results = _storage.get_job_results(job_id, 0, num_requests) | ||
return req_results | ||
|
||
async def jobs_running_loop(self): | ||
""" | ||
This loop is going through all active jobs in scheduler. | ||
For now, the executing unit is one request. Later if necessary, | ||
we can support a batch size of request per execution. | ||
""" | ||
while True: | ||
one_job = self._scheduler.round_robin_get_job() | ||
if one_job: | ||
await self._proxy.execute_queries(one_job) | ||
await asyncio.sleep(0) | ||
|
||
def clear_job(self, job_id): | ||
self._storage.delete_job(job_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2024 The Aibrix Team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import time | ||
|
||
|
||
class RequestProxy: | ||
def __init__(self, storage, manager): | ||
""" """ | ||
self._storage = storage | ||
self._job_manager = manager | ||
self._inference_client = InferenceEngineClient() | ||
|
||
async def execute_queries(self, job_id): | ||
""" | ||
This is the entrance to inference engine. | ||
This fetches request input from storage and submit request | ||
to inference engine. Lastly the result is stored back to storage. | ||
""" | ||
request_id = self._job_manager.get_job_next_request(job_id) | ||
if request_id == -1: | ||
print(f"Job {job_id} has something wrong with metadata in job manager.") | ||
return | ||
|
||
endpoint = self._job_manager.get_job_endpoint(job_id) | ||
request_input = self.fetch_request_input(job_id, request_id) | ||
|
||
print(f"executing job {job_id} request {request_id}") | ||
request_output = self._inference_client.inference_request( | ||
endpoint, request_input | ||
) | ||
self.store_output(job_id, request_id, request_output) | ||
|
||
self.sync_job_status(job_id, request_id) | ||
|
||
def fetch_request_input(self, job_id, request_id): | ||
""" | ||
Read request input from storage. Now it only reads one request. | ||
Later we can add a list as a batch per call. | ||
""" | ||
num_request = 1 | ||
requests = self._storage.get_job_input_requests(job_id, request_id, num_request) | ||
return requests[0] | ||
|
||
def store_output(self, job_id, request_id, result): | ||
""" | ||
Write the request result back to storage. | ||
""" | ||
self._storage.put_job_results(job_id, request_id, [result]) | ||
|
||
def sync_job_status(self, job_id, reqeust_id): | ||
""" | ||
Update job's status back to job manager. | ||
""" | ||
self._job_manager.mark_job_progress(job_id, [reqeust_id]) | ||
|
||
|
||
class InferenceEngineClient: | ||
def __init__(self): | ||
""" | ||
Initiate client to inference engine, such as account | ||
and its authentication. | ||
""" | ||
pass | ||
|
||
def inference_request(self, endpoint, prompt_list): | ||
time.sleep(1) | ||
return prompt_list |
Oops, something went wrong.