Skip to content

Commit

Permalink
[Test] build test handler (#1537)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Feb 21, 2024
1 parent 66b1bca commit 20c2d7e
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 25 deletions.
5 changes: 4 additions & 1 deletion engines/python/setup/djl_python/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import struct
import json
import re
from typing import List

from .np_util import from_nd_list
from .pair_list import PairList
Expand Down Expand Up @@ -88,7 +89,7 @@ def is_batch(self) -> bool:
def get_batch_size(self) -> int:
return int(self.properties.get("batch_size", "1"))

def get_batches(self) -> list:
def get_batches(self) -> List["Input"]:
if not self.is_batch():
return [self]

Expand All @@ -98,12 +99,14 @@ def get_batches(self) -> list:
batch.append(Input())

for key, value in self.properties.items():
# e.g batch_001_eula and eula is the key
if key.startswith("batch_") and key != "batch_size":
index = int(key[6:9])
key = key[10:]
batch[index].properties[key] = value

for i in range(self.content.size()):
# e.g batch_001_inputs and inputs is the key
key = self.content.key_at(i)
index = int(key[6:9])
key = key[10:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import json
import logging
from abc import ABC, abstractmethod
from typing import Union
from typing import Union, List

FINISH_REASON_MAPPER = ["length", "eos_token", "stop_sequence"]

Expand Down Expand Up @@ -215,8 +215,8 @@ def __init__(self, **kwargs):
:param kwargs passed while loading the model
"""

self.pending_requests = []
self.active_requests = []
self.pending_requests: List[Request] = []
self.active_requests: List[Request] = []
self.req_id_counter = 0
self.output_formatter = None
self.waiting_steps = kwargs.get("waiting_steps", None)
Expand Down Expand Up @@ -251,7 +251,8 @@ def inference(self, input_data, parameters):
"""
pass

def get_new_requests(self, input_data, parameters, batch_size):
def get_new_requests(self, input_data, parameters,
batch_size) -> List[Request]:
total_req_len = len(self.active_requests) + len(self.pending_requests)
if batch_size > total_req_len:
for i in range(total_req_len, batch_size):
Expand Down
173 changes: 158 additions & 15 deletions engines/python/setup/djl_python/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@

import logging
import os
import struct
import sys
from types import ModuleType
from typing import List, Dict, Union

from djl_python import PairList
from .arg_parser import ArgParser
from .inputs import Input
from .outputs import Output
from .np_util import to_nd_list
from .service_loader import load_model_service
from .service_loader import load_model_service, ModelService


def create_request(input_files, parameters):
Expand Down Expand Up @@ -74,6 +78,31 @@ def create_request(input_files, parameters):
return request


def create_concurrent_batch_request(inputs: List[Dict],
properties: List[Dict] = None,
serving_properties={}) -> Input:
if properties is None:
properties = []
# Flatten operation properties
flatten_properties = serving_properties
for idx, data in enumerate(properties):
for key, value in data.items():
key = f"batch_{str(idx).zfill(3)}_{key}"
flatten_properties[key] = value
pair_list = PairList()
# Flatten operation data field
for idx, data in enumerate(inputs):
key = f"batch_{str(idx).zfill(3)}_data"
pair_list.add(key, Output._encode_json(data))

inputs_obj = Input()
inputs_obj.properties = flatten_properties
inputs_obj.function_name = flatten_properties.get("handler", "handle")
inputs_obj.content = pair_list
flatten_properties['batch_size'] = len(inputs)
return inputs_obj


def create_text_request(text: str, key: str = None) -> Input:
request = Input()
request.properties["device_id"] = "-1"
Expand Down Expand Up @@ -103,27 +132,148 @@ def create_npz_request(list, key: str = None) -> Input:
return request


def _extract_output(outputs: Output) -> Input:
def extract_output_as_input(outputs: Output) -> Input:
inputs = Input()
inputs.properties = outputs.properties
inputs.content = outputs.content
return inputs


def extract_output_as_bytes(outputs: Output, key=None):
return _extract_output(outputs).get_as_bytes(key)
return extract_output_as_input(outputs).get_as_bytes(key)


def extract_output_as_numpy(outputs: Output, key=None):
return _extract_output(outputs).get_as_numpy(key)
return extract_output_as_input(outputs).get_as_numpy(key)


def extract_output_as_npz(outputs: Output, key=None):
return _extract_output(outputs).get_as_npz(key)
return extract_output_as_input(outputs).get_as_npz(key)


def extract_output_as_string(outputs: Output, key=None):
return _extract_output(outputs).get_as_string(key)
return extract_output_as_input(outputs).get_as_string(key)


def retrieve_int(bytearr: bytearray, start_iter):
end_iter = start_iter + 4
data = bytearr[start_iter:end_iter]
return struct.unpack(">i", data)[0], end_iter


def retrieve_short(bytearr: bytearray, start_iter):
end_iter = start_iter + 2
data = bytearr[start_iter:end_iter]
return struct.unpack(">h", data)[0], end_iter


def retrieve_utf8(bytearr: bytearray, start_iter):
length, start_iter = retrieve_short(bytearr, start_iter)
if length < 0:
return None
end_iter = start_iter + length
data = bytearr[start_iter:end_iter]
return data.decode("utf8"), end_iter


def decode_encoded_output_binary(binary: bytearray):
start_iter = 0
prop_size, start_iter = retrieve_short(binary, start_iter)
content = {}
for _ in range(prop_size):
key, start_iter = retrieve_utf8(binary, start_iter)
val, start_iter = retrieve_utf8(binary, start_iter)
content[key] = val

return content


def load_properties(properties_dir):
if not properties_dir:
return {}
properties = {}
properties_file = os.path.join(properties_dir, 'serving.properties')
if os.path.exists(properties_file):
with open(properties_file, 'r') as f:
for line in f:
# ignoring line starting with #
if line.startswith("#") or not line.strip():
continue
key, value = line.strip().split('=', 1)
key = key.strip()
if key.startswith("option."):
key = key[7:]
value = value.strip()
properties[key] = value
return properties


def update_properties_with_env_vars(kwargs):
env_vars = os.environ
for key, value in env_vars.items():
if key.startswith("OPTION_"):
key = key[7:].lower()
if key == "entrypoint":
key = "entryPoint"
kwargs.setdefault(key, value)
return kwargs


class TestHandler:

def __init__(self,
entry_point: Union[str, ModuleType],
model_dir: str = None):
self.serving_properties = update_properties_with_env_vars({})
self.serving_properties.update(load_properties(model_dir))

if isinstance(entry_point, str):
os.chdir(model_dir)
model_dir = os.getcwd()
sys.path.append(model_dir)
self.service = load_model_service(model_dir, entry_point, "-1")
else:
self.service = ModelService(entry_point, model_dir)

def inference(self, inputs: Input) -> Output:
function_name = inputs.get_function_name()
return self.service.invoke_handler(function_name, inputs)

def inference_batch(self,
inputs: List[Dict],
properties: List[Dict] = None,
serving_properties=None) -> Output:
if serving_properties is None:
serving_properties = self.serving_properties
return self.inference(
create_concurrent_batch_request(inputs, properties,
serving_properties))

def inference_rolling_batch(self,
inputs: List[Dict],
properties: List[Dict] = None,
serving_properties=None):
cached_result = {}
for idx in range(len(inputs)):
cached_result[idx] = ""
live_indices = [_ for _ in range(len(inputs))]
while len(live_indices) > 0:
outputs = self.inference_batch(inputs, properties,
serving_properties)
read_only_outputs = extract_output_as_input(outputs)
encoded_content = read_only_outputs.get_content().values
finished_indices = []
for idx, binary in enumerate(encoded_content):
data = decode_encoded_output_binary(binary)
cached_result[live_indices[idx]] += data['data']
if data['last'].lower() == 'true':
print(f"Finished request {live_indices[idx]}")
finished_indices.append(idx)
for index in sorted(finished_indices, reverse=True):
del live_indices[index]
inputs = [{"inputs": ""} for _ in range(len(live_indices))]

return cached_result


def run():
Expand All @@ -133,17 +283,10 @@ def run():
args = ArgParser.test_model_args().parse_args()

inputs = create_request(args.input, args.parameters)
handler = TestHandler(args.entry_point, args.model_dir)
inputs.function_name = args.handler

os.chdir(args.model_dir)
model_dir = os.getcwd()
sys.path.append(model_dir)

entry_point = args.entry_point
service = load_model_service(model_dir, entry_point, "-1")

function_name = inputs.get_function_name()
outputs = service.invoke_handler(function_name, inputs)
outputs = handler.inference(inputs)
print("output: " + str(outputs))


Expand Down
12 changes: 12 additions & 0 deletions engines/python/setup/djl_python/tests/rolling_batch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import random
from collections import OrderedDict
from transformers import AutoTokenizer
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, Token


class FakeRollingBatch(RollingBatch):

# TODO: Make properties is the only parameter, after refactoring all rolling batch handlers
def __init__(self, model_id_or_path, properties, **kwargs):
"""
Initializes the FakeRollingBatch.
"""
super().__init__(**kwargs)
self.sample_text = (
"DJL-Serving is a powerful and user-friendly deep learning model serving solution "
"that enables developers to easily deploy and serve their trained deep learning models."
" With DJL-Serving, developers can quickly expose their models as web services or APIs,"
" allowing them to integrate their deep learning models into various applications "
"and systems seamlessly. The framework supports various deep learning frameworks like "
"TensorFlow, PyTorch, MXNet, and more, making it versatile and adaptable to different model"
" architectures. DJL-Serving is designed to be highly scalable and efficient, ensuring that"
" models can handle high volumes of requests with low latency. Whether you are a researcher"
" or a developer, DJL-Serving simplifies the process of serving deep learning models,"
" enabling you to focus on creating innovative applications with ease."
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
padding_side="left",
trust_remote_code=True)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokens = self.tokenizer.encode(self.sample_text)
self.total_length = 32000
while len(self.tokens) < self.total_length:
self.tokens += self.tokens
self.tokens = self.tokens[:self.total_length]
self.cache = OrderedDict()

def reset(self):
self.cache = OrderedDict()
super().reset()

@stop_on_any_exception
def inference(self, input_data, parameters):
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
batch_size)

for new_request in new_requests:
max_len = new_request.parameters[
"max_new_tokens"] if "max_new_tokens" in new_request.parameters else 256
min_len = new_request.parameters[
"min_new_tokens"] if "min_new_tokens" in new_request.parameters else 1
max_len = max(min_len, max_len)
max_len = random.randint(min_len, max_len)
self.cache[new_request.id] = {
"max_len": max_len,
"cur_pos": -1,
"finished": False
}

# fake inference
for value in self.cache.values():
value["cur_pos"] += 1
if value["cur_pos"] == value["max_len"]:
value["finished"] = True

finished_id = []
for (key, cache), request in zip(self.cache.items(),
self.active_requests):
# finish condition match
if cache["finished"]:
finished_id.append(key)
token_id = self.tokens[cache["cur_pos"]]
token_txt = " " + self.tokenizer.decode(token_id)
request.set_next_token(Token(token_id, token_txt),
self.output_formatter, cache["finished"])

return self.postprocess_results()

def preprocess_requests(self, requests):
raise NotImplementedError("Not implemented for vLLM rolling batcher")
Loading

0 comments on commit 20c2d7e

Please sign in to comment.