diff --git a/engines/python/setup/djl_python/inputs.py b/engines/python/setup/djl_python/inputs.py index c25b0bc30..7ed3401cf 100644 --- a/engines/python/setup/djl_python/inputs.py +++ b/engines/python/setup/djl_python/inputs.py @@ -15,6 +15,7 @@ import struct import json import re +from typing import List from .np_util import from_nd_list from .pair_list import PairList @@ -88,7 +89,7 @@ def is_batch(self) -> bool: def get_batch_size(self) -> int: return int(self.properties.get("batch_size", "1")) - def get_batches(self) -> list: + def get_batches(self) -> List["Input"]: if not self.is_batch(): return [self] @@ -98,12 +99,14 @@ def get_batches(self) -> list: batch.append(Input()) for key, value in self.properties.items(): + # e.g batch_001_eula and eula is the key if key.startswith("batch_") and key != "batch_size": index = int(key[6:9]) key = key[10:] batch[index].properties[key] = value for i in range(self.content.size()): + # e.g batch_001_inputs and inputs is the key key = self.content.key_at(i) index = int(key[6:9]) key = key[10:] diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index c2544e035..09d0c3a94 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -13,7 +13,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Union +from typing import Union, List FINISH_REASON_MAPPER = ["length", "eos_token", "stop_sequence"] @@ -215,8 +215,8 @@ def __init__(self, **kwargs): :param kwargs passed while loading the model """ - self.pending_requests = [] - self.active_requests = [] + self.pending_requests: List[Request] = [] + self.active_requests: List[Request] = [] self.req_id_counter = 0 self.output_formatter = None self.waiting_steps = kwargs.get("waiting_steps", None) @@ -251,7 +251,8 @@ def inference(self, input_data, parameters): """ pass - def get_new_requests(self, input_data, parameters, batch_size): + def get_new_requests(self, input_data, parameters, + batch_size) -> List[Request]: total_req_len = len(self.active_requests) + len(self.pending_requests) if batch_size > total_req_len: for i in range(total_req_len, batch_size): diff --git a/engines/python/setup/djl_python/test_model.py b/engines/python/setup/djl_python/test_model.py index e3c9d2a64..d314b82ca 100644 --- a/engines/python/setup/djl_python/test_model.py +++ b/engines/python/setup/djl_python/test_model.py @@ -13,13 +13,17 @@ import logging import os +import struct import sys +from types import ModuleType +from typing import List, Dict, Union +from djl_python import PairList from .arg_parser import ArgParser from .inputs import Input from .outputs import Output from .np_util import to_nd_list -from .service_loader import load_model_service +from .service_loader import load_model_service, ModelService def create_request(input_files, parameters): @@ -74,6 +78,31 @@ def create_request(input_files, parameters): return request +def create_concurrent_batch_request(inputs: List[Dict], + properties: List[Dict] = None, + serving_properties={}) -> Input: + if properties is None: + properties = [] + # Flatten operation properties + flatten_properties = serving_properties + for idx, data in enumerate(properties): + for key, value in data.items(): + key = f"batch_{str(idx).zfill(3)}_{key}" + flatten_properties[key] = value + pair_list = PairList() + # Flatten operation data field + for idx, data in enumerate(inputs): + key = f"batch_{str(idx).zfill(3)}_data" + pair_list.add(key, Output._encode_json(data)) + + inputs_obj = Input() + inputs_obj.properties = flatten_properties + inputs_obj.function_name = flatten_properties.get("handler", "handle") + inputs_obj.content = pair_list + flatten_properties['batch_size'] = len(inputs) + return inputs_obj + + def create_text_request(text: str, key: str = None) -> Input: request = Input() request.properties["device_id"] = "-1" @@ -103,7 +132,7 @@ def create_npz_request(list, key: str = None) -> Input: return request -def _extract_output(outputs: Output) -> Input: +def extract_output_as_input(outputs: Output) -> Input: inputs = Input() inputs.properties = outputs.properties inputs.content = outputs.content @@ -111,19 +140,140 @@ def _extract_output(outputs: Output) -> Input: def extract_output_as_bytes(outputs: Output, key=None): - return _extract_output(outputs).get_as_bytes(key) + return extract_output_as_input(outputs).get_as_bytes(key) def extract_output_as_numpy(outputs: Output, key=None): - return _extract_output(outputs).get_as_numpy(key) + return extract_output_as_input(outputs).get_as_numpy(key) def extract_output_as_npz(outputs: Output, key=None): - return _extract_output(outputs).get_as_npz(key) + return extract_output_as_input(outputs).get_as_npz(key) def extract_output_as_string(outputs: Output, key=None): - return _extract_output(outputs).get_as_string(key) + return extract_output_as_input(outputs).get_as_string(key) + + +def retrieve_int(bytearr: bytearray, start_iter): + end_iter = start_iter + 4 + data = bytearr[start_iter:end_iter] + return struct.unpack(">i", data)[0], end_iter + + +def retrieve_short(bytearr: bytearray, start_iter): + end_iter = start_iter + 2 + data = bytearr[start_iter:end_iter] + return struct.unpack(">h", data)[0], end_iter + + +def retrieve_utf8(bytearr: bytearray, start_iter): + length, start_iter = retrieve_short(bytearr, start_iter) + if length < 0: + return None + end_iter = start_iter + length + data = bytearr[start_iter:end_iter] + return data.decode("utf8"), end_iter + + +def decode_encoded_output_binary(binary: bytearray): + start_iter = 0 + prop_size, start_iter = retrieve_short(binary, start_iter) + content = {} + for _ in range(prop_size): + key, start_iter = retrieve_utf8(binary, start_iter) + val, start_iter = retrieve_utf8(binary, start_iter) + content[key] = val + + return content + + +def load_properties(properties_dir): + if not properties_dir: + return {} + properties = {} + properties_file = os.path.join(properties_dir, 'serving.properties') + if os.path.exists(properties_file): + with open(properties_file, 'r') as f: + for line in f: + # ignoring line starting with # + if line.startswith("#") or not line.strip(): + continue + key, value = line.strip().split('=', 1) + key = key.strip() + if key.startswith("option."): + key = key[7:] + value = value.strip() + properties[key] = value + return properties + + +def update_properties_with_env_vars(kwargs): + env_vars = os.environ + for key, value in env_vars.items(): + if key.startswith("OPTION_"): + key = key[7:].lower() + if key == "entrypoint": + key = "entryPoint" + kwargs.setdefault(key, value) + return kwargs + + +class TestHandler: + + def __init__(self, + entry_point: Union[str, ModuleType], + model_dir: str = None): + self.serving_properties = update_properties_with_env_vars({}) + self.serving_properties.update(load_properties(model_dir)) + + if isinstance(entry_point, str): + os.chdir(model_dir) + model_dir = os.getcwd() + sys.path.append(model_dir) + self.service = load_model_service(model_dir, entry_point, "-1") + else: + self.service = ModelService(entry_point, model_dir) + + def inference(self, inputs: Input) -> Output: + function_name = inputs.get_function_name() + return self.service.invoke_handler(function_name, inputs) + + def inference_batch(self, + inputs: List[Dict], + properties: List[Dict] = None, + serving_properties=None) -> Output: + if serving_properties is None: + serving_properties = self.serving_properties + return self.inference( + create_concurrent_batch_request(inputs, properties, + serving_properties)) + + def inference_rolling_batch(self, + inputs: List[Dict], + properties: List[Dict] = None, + serving_properties=None): + cached_result = {} + for idx in range(len(inputs)): + cached_result[idx] = "" + live_indices = [_ for _ in range(len(inputs))] + while len(live_indices) > 0: + outputs = self.inference_batch(inputs, properties, + serving_properties) + read_only_outputs = extract_output_as_input(outputs) + encoded_content = read_only_outputs.get_content().values + finished_indices = [] + for idx, binary in enumerate(encoded_content): + data = decode_encoded_output_binary(binary) + cached_result[live_indices[idx]] += data['data'] + if data['last'].lower() == 'true': + print(f"Finished request {live_indices[idx]}") + finished_indices.append(idx) + for index in sorted(finished_indices, reverse=True): + del live_indices[index] + inputs = [{"inputs": ""} for _ in range(len(live_indices))] + + return cached_result def run(): @@ -133,17 +283,10 @@ def run(): args = ArgParser.test_model_args().parse_args() inputs = create_request(args.input, args.parameters) + handler = TestHandler(args.entry_point, args.model_dir) inputs.function_name = args.handler - os.chdir(args.model_dir) - model_dir = os.getcwd() - sys.path.append(model_dir) - - entry_point = args.entry_point - service = load_model_service(model_dir, entry_point, "-1") - - function_name = inputs.get_function_name() - outputs = service.invoke_handler(function_name, inputs) + outputs = handler.inference(inputs) print("output: " + str(outputs)) diff --git a/engines/python/setup/djl_python/tests/rolling_batch/__init__.py b/engines/python/setup/djl_python/tests/rolling_batch/__init__.py new file mode 100644 index 000000000..546f319b3 --- /dev/null +++ b/engines/python/setup/djl_python/tests/rolling_batch/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# +# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. diff --git a/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py b/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py new file mode 100644 index 000000000..88ea24502 --- /dev/null +++ b/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# +# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import random +from collections import OrderedDict +from transformers import AutoTokenizer +from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, Token + + +class FakeRollingBatch(RollingBatch): + + # TODO: Make properties is the only parameter, after refactoring all rolling batch handlers + def __init__(self, model_id_or_path, properties, **kwargs): + """ + Initializes the FakeRollingBatch. + """ + super().__init__(**kwargs) + self.sample_text = ( + "DJL-Serving is a powerful and user-friendly deep learning model serving solution " + "that enables developers to easily deploy and serve their trained deep learning models." + " With DJL-Serving, developers can quickly expose their models as web services or APIs," + " allowing them to integrate their deep learning models into various applications " + "and systems seamlessly. The framework supports various deep learning frameworks like " + "TensorFlow, PyTorch, MXNet, and more, making it versatile and adaptable to different model" + " architectures. DJL-Serving is designed to be highly scalable and efficient, ensuring that" + " models can handle high volumes of requests with low latency. Whether you are a researcher" + " or a developer, DJL-Serving simplifies the process of serving deep learning models," + " enabling you to focus on creating innovative applications with ease." + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, + padding_side="left", + trust_remote_code=True) + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokens = self.tokenizer.encode(self.sample_text) + self.total_length = 32000 + while len(self.tokens) < self.total_length: + self.tokens += self.tokens + self.tokens = self.tokens[:self.total_length] + self.cache = OrderedDict() + + def reset(self): + self.cache = OrderedDict() + super().reset() + + @stop_on_any_exception + def inference(self, input_data, parameters): + batch_size = len(input_data) + new_requests = self.get_new_requests(input_data, parameters, + batch_size) + + for new_request in new_requests: + max_len = new_request.parameters[ + "max_new_tokens"] if "max_new_tokens" in new_request.parameters else 256 + min_len = new_request.parameters[ + "min_new_tokens"] if "min_new_tokens" in new_request.parameters else 1 + max_len = max(min_len, max_len) + max_len = random.randint(min_len, max_len) + self.cache[new_request.id] = { + "max_len": max_len, + "cur_pos": -1, + "finished": False + } + + # fake inference + for value in self.cache.values(): + value["cur_pos"] += 1 + if value["cur_pos"] == value["max_len"]: + value["finished"] = True + + finished_id = [] + for (key, cache), request in zip(self.cache.items(), + self.active_requests): + # finish condition match + if cache["finished"]: + finished_id.append(key) + token_id = self.tokens[cache["cur_pos"]] + token_txt = " " + self.tokenizer.decode(token_id) + request.set_next_token(Token(token_id, token_txt), + self.output_formatter, cache["finished"]) + + return self.postprocess_results() + + def preprocess_requests(self, requests): + raise NotImplementedError("Not implemented for vLLM rolling batcher") diff --git a/engines/python/setup/djl_python/tests/test_input_output.py b/engines/python/setup/djl_python/tests/test_input_output.py index 400a7f568..c50b6b2b5 100644 --- a/engines/python/setup/djl_python/tests/test_input_output.py +++ b/engines/python/setup/djl_python/tests/test_input_output.py @@ -27,6 +27,27 @@ def test_numpy_input(self): result = inputs.get_as_npz() self.assertTrue(np.array_equal(result[0], nd[0])) + def test_concurrent_batch(self): + input_list = [{ + "inputs": "who win the oscar this year?", + "parameters": { + "max_new_tokens": 256 + } + }] + properties = [{"eula": "true", "Content-type": "application/json"}] + serving_properties = { + "engine": "MPI", + "option.rolling_batch": "lmi-dist", + "option.model_id": "llama-70b" + } + inputs = test_model.create_concurrent_batch_request( + input_list, properties, serving_properties) + batches = inputs.get_batches() + self.assertEqual(batches[0].properties, properties[0]) + self.assertEqual(batches[0].get_as_json(), input_list[0]) + for key, value in serving_properties.items(): + self.assertEqual(inputs.get_properties()[key], value) + def test_output(self): test_dict = {"Key": "Value"} nd = [np.ones((1, 3, 2))] diff --git a/engines/python/setup/djl_python/tests/test_test_model.py b/engines/python/setup/djl_python/tests/test_test_model.py new file mode 100644 index 000000000..473192a27 --- /dev/null +++ b/engines/python/setup/djl_python/tests/test_test_model.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import os +import unittest +from djl_python.test_model import TestHandler +from djl_python import huggingface +from .rolling_batch.fake_rolling_batch import FakeRollingBatch + + +def override_rolling_batch(rolling_batch_type: str, is_mpi: bool, + model_config): + return FakeRollingBatch + + +huggingface.get_rolling_batch_class_from_str = override_rolling_batch + + +class TestTestModel(unittest.TestCase): + + def test_all_code(self): + model_id = "NousResearch/Nous-Hermes-Llama2-13b" + handler = TestHandler(huggingface) + inputs = [{ + "inputs": "The winner of oscar this year is", + "parameters": { + "max_new_tokens": 256 + } + }, { + "inputs": "A little redhood is", + "parameters": { + "max_new_tokens": 50 + } + }] + serving_properties = { + "engine": "Python", + "rolling_batch": "auto", + "model_id": model_id + } + result = handler.inference_rolling_batch( + inputs, serving_properties=serving_properties) + self.assertEqual(len(result), len(inputs)) + + def test_with_env(self): + envs = { + "OPTION_MODEL_ID": "NousResearch/Nous-Hermes-Llama2-13b", + "SERVING_LOAD_MODELS": "test::MPI=/opt/ml/model", + "OPTION_ROLLING_BATCH": "auto" + } + for key, value in envs.items(): + os.environ[key] = value + handler = TestHandler(huggingface) + self.assertEqual(handler.serving_properties["model_id"], + envs["OPTION_MODEL_ID"]) + self.assertEqual(handler.serving_properties["rolling_batch"], + envs["OPTION_ROLLING_BATCH"]) + inputs = [{ + "inputs": "The winner of oscar this year is", + "parameters": { + "max_new_tokens": 50 + } + }, { + "inputs": "A little redhood is", + "parameters": { + "min_new_tokens": 51, + "max_new_tokens": 256 + } + }] + result = handler.inference_rolling_batch(inputs) + self.assertEqual(len(result), len(inputs)) + self.assertTrue(len(result[1]) > len(result[0])) + + for key in envs.keys(): + os.environ[key] = "" diff --git a/engines/python/setup/setup.py b/engines/python/setup/setup.py index 7277082ae..f42d50671 100644 --- a/engines/python/setup/setup.py +++ b/engines/python/setup/setup.py @@ -65,6 +65,7 @@ def run(self): 'accelerate', 'sentencepiece', 'protobuf', + "peft", 'yapf', 'pydantic==1.10.13', ] diff --git a/serving/docker/partition/trt_llm_partition.py b/serving/docker/partition/trt_llm_partition.py index 3b4bed783..a544274c2 100644 --- a/serving/docker/partition/trt_llm_partition.py +++ b/serving/docker/partition/trt_llm_partition.py @@ -27,7 +27,6 @@ def create_trt_llm_repo(properties, args): kwargs[key[7:]] = value else: kwargs[key] = value - kwargs = update_kwargs_with_env_vars(kwargs) kwargs['trt_llm_model_repo'] = args.trt_llm_model_repo kwargs["tensor_parallel_degree"] = args.tensor_parallel_degree model_id_or_path = args.model_path or kwargs['model_id'] @@ -38,9 +37,9 @@ def update_kwargs_with_env_vars(kwargs): env_vars = os.environ for key, value in env_vars.items(): if key.startswith("OPTION_"): - key = key[7:].lower() - if key == "entrypoint": - key = "entryPoint" + key = key.lower() + if key == "option_entrypoint": + key = "option.entryPoint" kwargs.setdefault(key, value) return kwargs @@ -72,7 +71,8 @@ def main(): help='local path to downloaded model') args = parser.parse_args() - properties = load_properties(args.properties_dir) + properties = update_kwargs_with_env_vars({}) + properties.update(load_properties(args.properties_dir)) create_trt_llm_repo(properties, args)