From 144b409be583429bfff0893ffbaf18f3c50052ee Mon Sep 17 00:00:00 2001 From: Somasundaram Date: Fri, 19 Apr 2024 16:25:08 -0700 Subject: [PATCH] Refactor huggingface handler --- .../python/setup/djl_python/huggingface.py | 163 ++++-------------- .../properties_manager/properties.py | 5 +- .../rolling_batch/lmi_dist_rolling_batch.py | 5 +- .../rolling_batch/rolling_batch_service.py | 145 ++++++++++++++++ .../rolling_batch/scheduler_rolling_batch.py | 20 +-- .../rolling_batch/vllm_rolling_batch.py | 4 +- .../tests/rolling_batch/fake_rolling_batch.py | 16 +- .../setup/djl_python/tests/test_test_model.py | 15 +- engines/python/setup/djl_python/utils.py | 42 +++++ 9 files changed, 254 insertions(+), 161 deletions(-) create mode 100644 engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index b19141a5b3..aa7d8eb3e6 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -15,22 +15,25 @@ import re import torch -from transformers import ( - pipeline, Pipeline, Conversation, AutoModelForCausalLM, - AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, - AutoModelForSequenceClassification, AutoModelForTokenClassification, - AutoModelForQuestionAnswering, StoppingCriteria, StoppingCriteriaList) +from transformers import (pipeline, Pipeline, Conversation, + AutoModelForCausalLM, AutoModelForSeq2SeqLM, + AutoTokenizer, AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelForQuestionAnswering, StoppingCriteria, + StoppingCriteriaList) from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from peft import PeftConfig, PeftModel, PeftModelForCausalLM +from peft import PeftModel, PeftModelForCausalLM from djl_python.encode_decode import encode from djl_python.inputs import Input from djl_python.outputs import Output from djl_python.streaming_utils import StreamingUtils -from djl_python.rolling_batch.rolling_batch import get_content_type_from_output_formatter -from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled +from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled, \ + RollingBatchEnum from djl_python.properties_manager.hf_properties import HuggingFaceProperties +from djl_python.utils import read_model_config, get_tokenizer +from djl_python.rolling_batch.rolling_batch_service import RollingBatchService from djl_python.utils import parse_input_with_formatter, InputFormatConfigs ARCHITECTURES_2_TASK = { @@ -82,28 +85,6 @@ def enable_flash(): return False -def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool, - model_config): - if rolling_batch_type == "auto": - architecture = model_config.architectures[0] - if architecture in LMI_DIST_ADV_MODEL and is_mpi: - from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch - return LmiDistRollingBatch - else: - from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch - return SchedulerRollingBatch - elif rolling_batch_type == "scheduler": - from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch - return SchedulerRollingBatch - elif rolling_batch_type == "lmi-dist": - from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch - return LmiDistRollingBatch - elif rolling_batch_type == "vllm": - from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch - return VLLMRollingBatch - raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}") - - class StopWord(StoppingCriteria): def __init__(self, tokenizer, stop_seq): @@ -132,7 +113,6 @@ def __init__(self): self.initialized = False self.model = None self.tokenizer = None - self.rolling_batch = None self.model_config = None self.peft_config = None self.stopping_criteria_list = None @@ -143,19 +123,17 @@ def __init__(self): def initialize(self, properties: dict): self.hf_configs = HuggingFaceProperties(**properties) - self._read_model_config(self.hf_configs.model_id_or_path) - - if is_rolling_batch_enabled(self.hf_configs.rolling_batch): - _rolling_batch_cls = get_rolling_batch_class_from_str( - self.hf_configs.rolling_batch.value, self.hf_configs.is_mpi, - self.model_config) - self.hf_configs.kwargs["model_config"] = self.model_config - self.rolling_batch = _rolling_batch_cls( - self.hf_configs.model_id_or_path, properties, - **self.hf_configs.kwargs) - self._init_tokenizer(self.hf_configs.model_id_or_path) - elif is_streaming_enabled(self.hf_configs.enable_streaming): - self._init_tokenizer(self.hf_configs.model_id_or_path) + self.model_config, self.peft_config = read_model_config( + self.hf_configs.model_id_or_path, + trust_remote_code=self.hf_configs.trust_remote_code, + revision=self.hf_configs.revision) + + if is_streaming_enabled(self.hf_configs.enable_streaming): + self.tokenizer = get_tokenizer( + self.hf_configs.model_id_or_path, + trust_remote_code=self.hf_configs.trust_remote_code, + revision=self.hf_configs.revision, + peft_config=self.peft_config) self._init_model(self.hf_configs.model_id_or_path, **self.hf_configs.kwargs) else: @@ -172,8 +150,7 @@ def initialize(self, properties: dict): self.load_stopping_criteria_list(properties["stop_sequence"]) self.input_format_configs = InputFormatConfigs( - is_rolling_batch=is_rolling_batch_enabled( - self.hf_configs.rolling_batch), + is_rolling_batch=False, is_adapters_supported=True, output_formatter=self.hf_configs.output_formatter, tokenizer=self.tokenizer) @@ -243,56 +220,10 @@ def inference(self, inputs): if len(input_data) == 0: for i in range(len(batch)): err = errors.get(i) - if is_rolling_batch_enabled(self.hf_configs.rolling_batch): - err = {"data": "", "last": True, "code": 424, "error": err} - outputs.add(Output.binary_encode(err), - key="data", - batch_index=i) - else: - outputs.add(err, key="data", batch_index=i) + outputs.add(err, key="data", batch_index=i) return outputs - if is_rolling_batch_enabled(self.hf_configs.rolling_batch): - if inputs.get_property("reset_rollingbatch"): - self.rolling_batch.reset() - if self.adapters is not None: - adapter_data = [] - for i, a in enumerate(self.adapters): - if a is None or a == "": - adapter_data.append(None) - elif a in self.adapter_registry: - adapter_data.append(self.adapter_registry[a]) - else: - adapter_data.append(None) - errors[i] = f"Unknown or invalid adapter {a}" - else: - adapter_data = None - result = self.rolling_batch.inference(input_data, - parameters, - adapters=adapter_data) - idx = 0 - for i in range(len(batch)): - err = errors.get(i) - if err: - err = {"data": "", "last": True, "code": 424, "error": err} - outputs.add(Output.binary_encode(err), - key="data", - batch_index=i) - else: - outputs.add(Output.binary_encode(result[idx]), - key="data", - batch_index=i) - idx += 1 - - formatter = parameters[i].get("output_formatter") - content_type = get_content_type_from_output_formatter( - formatter) - if content_type is not None: - outputs.add_property(f"batch_{i}_Content-Type", - content_type) - - return outputs - elif is_streaming_enabled(self.hf_configs.enable_streaming): + if is_streaming_enabled(self.hf_configs.enable_streaming): if len(batch) > 1: raise NotImplementedError( "Dynamic batch not supported for generic streaming") @@ -401,7 +332,11 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs): **kwargs) self.model = hf_pipeline.model else: - self._init_tokenizer(model_id_or_path) + self.tokenizer = get_tokenizer( + model_id_or_path, + trust_remote_code=self.hf_configs.trust_remote_code, + revision=self.hf_configs.revision, + peft_config=self.peft_config) self._init_model(model_id_or_path, **kwargs) hf_pipeline = pipeline(task=task, model=self.model, @@ -426,15 +361,6 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs): return hf_pipeline - def _init_tokenizer(self, model_id_or_path: str): - path_to_use = model_id_or_path if self.peft_config is None else self.peft_config.base_model_name_or_path - self.tokenizer = AutoTokenizer.from_pretrained( - path_to_use, - padding_size="left", - trust_remote_code=self.hf_configs.trust_remote_code, - revision=self.hf_configs.revision, - ) - def _init_model(self, model_id_or_path: str, **kwargs): architectures = self.model_config.architectures if architectures and architectures[0].endswith( @@ -523,28 +449,6 @@ def infer_task_from_model_architecture(self): ) return task - def _read_model_config(self, model_config_path: str): - try: - self.model_config = AutoConfig.from_pretrained( - model_config_path, - trust_remote_code=self.hf_configs.trust_remote_code, - revision=self.hf_configs.revision) - except OSError: - logging.warning( - f"config.json not found for {model_config_path}. Attempting to load with peft" - ) - self.peft_config = PeftConfig.from_pretrained(model_config_path) - self.model_config = AutoConfig.from_pretrained( - self.peft_config.base_model_name_or_path, - trust_remote_code=self.hf_configs.trust_remote_code, - revision=self.hf_configs.revision, - ) - except Exception as e: - logging.error( - f"{model_config_path} does not contain a config.json or adapter_config.json for lora models. " - f"This is required for loading huggingface models") - raise e - _service = HuggingFaceService() @@ -593,9 +497,14 @@ def handle(inputs: Input): """ Default handler function """ + global _service if not _service.initialized: + properties = inputs.get_properties() + if "rolling_batch" in properties and properties.get( + "rolling_batch") != RollingBatchEnum.disable.value: + _service = RollingBatchService() # stateful model - _service.initialize(inputs.get_properties()) + _service.initialize(properties) if inputs.is_empty(): # initialization request diff --git a/engines/python/setup/djl_python/properties_manager/properties.py b/engines/python/setup/djl_python/properties_manager/properties.py index cfe0a8045b..88d41117a3 100644 --- a/engines/python/setup/djl_python/properties_manager/properties.py +++ b/engines/python/setup/djl_python/properties_manager/properties.py @@ -14,7 +14,7 @@ import os from enum import Enum from typing import Optional, Union, Callable, Any -from pydantic import BaseModel, field_validator, model_validator, ValidationInfo, ConfigDict +from pydantic import BaseModel, field_validator, model_validator, ValidationInfo, Field, Extra, ConfigDict class RollingBatchEnum(str, Enum): @@ -64,7 +64,8 @@ class Properties(BaseModel): # model_config is for pydantic configurations for BaseModel. model_config = ConfigDict(arbitrary_types_allowed=True, - protected_namespaces=()) + protected_namespaces=(), + extra='allow') @model_validator(mode='before') def calculate_is_mpi(cls, properties): diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 92a89c21e4..1175182dda 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -38,11 +38,10 @@ class LmiDistRollingBatch(RollingBatch): It also gets any new tokens from the backend and sends them back to the handler. """ - def __init__(self, model_id_or_path: str, properties: dict, **kwargs): + def __init__(self, properties: dict): """ Initializes the LmiDistRollingBatch. - :param model_id_or_path (str): Currently unused since there is a copy inside properties :param properties (dict): other properties of the model, such as decoder strategy """ self.lmi_dist_config = LmiDistRbProperties(**properties) @@ -84,8 +83,6 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs): kwargs["warmup_prefill_tokens"] = _WARMUP_PREFILL_TOKENS self.engine = engine_from_args(engine_args, **kwargs) self.request_cache = OrderedDict() - self.model_type = getattr(kwargs.get("model_config", None), - "model_type", None) self.lora_ids = defaultdict(lambda: len(self.lora_ids) + 1) def reset(self) -> None: diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py new file mode 100644 index 0000000000..be2df4371f --- /dev/null +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# +# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +from djl_python import Output, Input +from djl_python.properties_manager.hf_properties import HuggingFaceProperties +from djl_python.utils import read_model_config, get_tokenizer, parse_input_with_formatter, InputFormatConfigs +from djl_python.rolling_batch.rolling_batch import get_content_type_from_output_formatter + +LMI_DIST_ADV_MODEL = { + "RWForCausalLM", + "GPTNeoXForCausalLM", + "T5ForConditionalGeneration", + "LlamaForCausalLM", + "FalconForCausalLM", + "MPTForCausalLM", + "GPTBigCodeForCausalLM", +} + + +def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool, + model_config): + if rolling_batch_type == "auto": + architecture = model_config.architectures[0] + if architecture in LMI_DIST_ADV_MODEL and is_mpi: + from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch + return LmiDistRollingBatch + else: + from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch + return SchedulerRollingBatch + elif rolling_batch_type == "scheduler": + from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch + return SchedulerRollingBatch + elif rolling_batch_type == "lmi-dist": + from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch + return LmiDistRollingBatch + elif rolling_batch_type == "vllm": + from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch + return VLLMRollingBatch + raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}") + + +class RollingBatchService: + + def __init__(self): + self.tokenizer = None + self.rolling_batch = None + self.model_config = None + self.peft_config = None + self.initialized = False + self.adapters = None + self.adapter_registry = {} + self.rb_configs = None + self.input_format_configs = None + + def initialize(self, properties: dict): + self.rb_configs = HuggingFaceProperties(**properties) + self.model_config, self.peft_config = read_model_config( + self.rb_configs.model_id_or_path, + self.rb_configs.trust_remote_code, self.rb_configs.revision) + _rolling_batch_cls = get_rolling_batch_class_from_str( + self.rb_configs.rolling_batch.value, self.rb_configs.is_mpi, + self.model_config) + self.rb_configs.kwargs["model_config"] = self.model_config + self.rolling_batch = _rolling_batch_cls(properties) + self.tokenizer = get_tokenizer(self.rb_configs.model_id_or_path, + self.rb_configs.trust_remote_code, + self.rb_configs.revision, + peft_config=self.peft_config) + self.input_format_configs = InputFormatConfigs( + is_rolling_batch=True, + is_adapters_supported=True, + tokenizer=self.tokenizer, + output_formatter=self.rb_configs.output_formatter) + self.initialized = True + + def parse_input( + self, inputs: Input + ) -> tuple[list[str], list[int], list[dict], dict, list]: + parsed_input = parse_input_with_formatter( + inputs, input_format_configs=self.input_format_configs) + + self.adapters = parsed_input.adapters if parsed_input.found_adapters else None + + return parsed_input.input_data, parsed_input.input_size, parsed_input.parameters, parsed_input.errors, parsed_input.batch + + def inference(self, inputs): + outputs = Output() + + input_data, input_size, parameters, errors, batch = self.parse_input( + inputs) + if len(input_data) == 0: + for i in range(len(batch)): + err = errors.get(i) + err = {"data": "", "last": True, "code": 424, "error": err} + outputs.add(Output.binary_encode(err), + key="data", + batch_index=i) + return outputs + + if inputs.get_property("reset_rollingbatch"): + self.rolling_batch.reset() + if self.adapters is not None: + adapter_data = [] + for i, a in enumerate(self.adapters): + if a is None or a == "": + adapter_data.append(None) + elif a in self.adapter_registry: + adapter_data.append(self.adapter_registry[a]) + else: + adapter_data.append(None) + errors[i] = f"Unknown or invalid adapter {a}" + else: + adapter_data = None + result = self.rolling_batch.inference(input_data, + parameters, + adapters=adapter_data) + idx = 0 + for i in range(len(batch)): + err = errors.get(i) + if err: + err = {"data": "", "last": True, "code": 424, "error": err} + outputs.add(Output.binary_encode(err), + key="data", + batch_index=i) + else: + outputs.add(Output.binary_encode(result[idx]), + key="data", + batch_index=i) + idx += 1 + + formatter = parameters[i].get("output_formatter") + content_type = get_content_type_from_output_formatter(formatter) + if content_type is not None: + outputs.add_property(f"batch_{i}_Content-Type", content_type) + + return outputs diff --git a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py index b898a05d49..97a4357c61 100644 --- a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py @@ -11,9 +11,9 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -from seq_scheduler.lm_block import HuggingfaceBlock, BloomBlock, FalconBlock -from seq_scheduler.search_config import SearchConfig -from seq_scheduler.seq_batch_scheduler import SeqBatchScheduler +from djl_python.seq_scheduler.lm_block import HuggingfaceBlock, BloomBlock, FalconBlock +from djl_python.seq_scheduler.search_config import SearchConfig +from djl_python.seq_scheduler.seq_batch_scheduler import SeqBatchScheduler from collections import namedtuple, defaultdict from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig @@ -51,14 +51,11 @@ class SchedulerRollingBatch(RollingBatch): and other experimental features. """ - def __init__(self, model_id_or_path: str, properties: dict, - **kwargs) -> None: + def __init__(self, properties: dict) -> None: """ Initializes the rolling batch scheduler. - :param model_id_or_path (str): model id or path :param properties (dict): other properties of the model, such as decoder strategy - :param kwargs passed while loading the model """ self.scheduler_configs = SchedulerRbProperties(**properties) @@ -203,12 +200,9 @@ def _init_scheduler(self) -> None: eos_token_id=self.tokenizer.eos_token, pad_token_id=self.tokenizer.pad_token) self.search_algorithm = self.scheduler_configs.decoding_strategy - self.scheduler = SeqBatchScheduler( - self.lm_block, - self.search_algorithm, - self.search_config, - max_sparsity=self.scheduler_configs.max_sparsity, - max_splits=self.scheduler_configs.max_splits) + self.scheduler = SeqBatchScheduler(self.lm_block, + self.search_algorithm, + self.search_config) def _prefill_and_decode(self, new_requests) -> None: """ diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index f6dfbd5b42..4362a20b2a 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -33,12 +33,10 @@ class VLLMRollingBatch(RollingBatch): """ # TODO: Make properties is the only parameter, after refactoring all rolling batch handlers - def __init__(self, model_id_or_path: str, properties: dict, - **kwargs) -> None: + def __init__(self, properties: dict) -> None: """ Initializes the VLLMRollingBatch. - :param model_id_or_path: Currently unused since there is a copy inside properties :param properties: other properties of the model, such as decoder strategy """ self.vllm_configs = VllmRbProperties(**properties) diff --git a/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py b/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py index e761ffa1a5..9717d27449 100644 --- a/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py @@ -19,11 +19,12 @@ class FakeRollingBatch(RollingBatch): # TODO: Make properties is the only parameter, after refactoring all rolling batch handlers - def __init__(self, model_id_or_path, properties, **kwargs): + def __init__(self, properties): """ Initializes the FakeRollingBatch. """ - super().__init__(**kwargs) + super().__init__(waiting_steps=properties.get("waiting_steps"), + output_formatter=properties.get("output_formatter")) self.sample_text = ( "DJL-Serving is a powerful and user-friendly deep learning model serving solution " "that enables developers to easily deploy and serve their trained deep learning models." @@ -36,9 +37,10 @@ def __init__(self, model_id_or_path, properties, **kwargs): " or a developer, DJL-Serving simplifies the process of serving deep learning models," " enabling you to focus on creating innovative applications with ease." ) - self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, - padding_side="left", - trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained( + properties.get("model_id") or properties.get("model_dir"), + padding_side="left", + trust_remote_code=True) if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokens = self.tokenizer.encode(self.sample_text) @@ -99,8 +101,8 @@ def preprocess_requests(self, requests): class FakeRollingBatchWithException(FakeRollingBatch): - def __init__(self, model_id_or_path, properties, **kwargs): - super().__init__(model_id_or_path, properties, **kwargs) + def __init__(self, properties): + super().__init__(properties) self.dead_counter = 0 self.dead_trigger = random.randint(1, 50) diff --git a/engines/python/setup/djl_python/tests/test_test_model.py b/engines/python/setup/djl_python/tests/test_test_model.py index 434d3c5ecb..dc91b77704 100644 --- a/engines/python/setup/djl_python/tests/test_test_model.py +++ b/engines/python/setup/djl_python/tests/test_test_model.py @@ -16,6 +16,7 @@ from djl_python.test_model import TestHandler from djl_python import huggingface from .rolling_batch.fake_rolling_batch import FakeRollingBatch, FakeRollingBatchWithException +from djl_python.rolling_batch import rolling_batch_service def override_rolling_batch(rolling_batch_type: str, is_mpi: bool, @@ -32,7 +33,7 @@ class TestTestModel(unittest.TestCase): def test_all_code(self): model_id = "NousResearch/Nous-Hermes-Llama2-13b" - huggingface.get_rolling_batch_class_from_str = override_rolling_batch + rolling_batch_service.get_rolling_batch_class_from_str = override_rolling_batch handler = TestHandler(huggingface) inputs = [{ "inputs": "The winner of oscar this year is", @@ -65,7 +66,7 @@ def test_with_env(self): } for key, value in envs.items(): os.environ[key] = value - huggingface.get_rolling_batch_class_from_str = override_rolling_batch + rolling_batch_service.get_rolling_batch_class_from_str = override_rolling_batch handler = TestHandler(huggingface) self.assertEqual(handler.serving_properties["model_id"], envs["OPTION_MODEL_ID"]) @@ -95,7 +96,7 @@ def test_with_env(self): def test_all_code_chat(self): model_id = "TheBloke/Llama-2-7B-Chat-fp16" - huggingface.get_rolling_batch_class_from_str = override_rolling_batch + rolling_batch_service.get_rolling_batch_class_from_str = override_rolling_batch handler = TestHandler(huggingface) inputs = [{ "inputs": @@ -127,7 +128,7 @@ def test_with_env_chat(self): } for key, value in envs.items(): os.environ[key] = value - huggingface.get_rolling_batch_class_from_str = override_rolling_batch + rolling_batch_service.get_rolling_batch_class_from_str = override_rolling_batch handler = TestHandler(huggingface) self.assertEqual(handler.serving_properties["model_id"], envs["OPTION_MODEL_ID"]) @@ -155,7 +156,7 @@ def test_with_env_chat(self): os.environ[key] = "" def test_exception_handling(self): - huggingface.get_rolling_batch_class_from_str = override_rolling_batch_with_exception + rolling_batch_service.get_rolling_batch_class_from_str = override_rolling_batch_with_exception model_id = "NousResearch/Nous-Hermes-Llama2-13b" handler = TestHandler(huggingface) inputs = [{ @@ -202,3 +203,7 @@ def test_exception_handling(self): for _, value in result.items(): final_dict = json.loads(value.splitlines()[-1]) self.assertEqual(final_dict["details"]["finish_reason"], 'error') + + +if __name__ == '__main__': + unittest.main() diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index 005d6ee640..1e33b063de 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -1,6 +1,10 @@ import logging from typing import Union, Callable, Any, List +from peft import PeftConfig +from transformers import AutoConfig, AutoTokenizer +from typing import Union, Callable, Any, List + from djl_python.inputs import Input from djl_python.encode_decode import decode from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request @@ -147,3 +151,41 @@ def _fetch_adapters_from_input(input_map: dict, inputs: Input): adapters_per_item = [adapters_per_item] return adapters_per_item + + +def get_tokenizer(model_id_or_path: str, trust_remote_code: bool, + revision: str, peft_config): + path_to_use = model_id_or_path if peft_config is None else peft_config.base_model_name_or_path + return AutoTokenizer.from_pretrained( + path_to_use, + padding_size="left", + trust_remote_code=trust_remote_code, + revision=revision, + ) + + +def read_model_config(model_config_path: str, trust_remote_code: bool, + revision: str): + model_config = None + peft_config = None + try: + model_config = AutoConfig.from_pretrained( + model_config_path, + trust_remote_code=trust_remote_code, + revision=revision) + except OSError: + logging.warning( + f"config.json not found for {model_config_path}. Attempting to load with peft" + ) + peft_config = PeftConfig.from_pretrained(model_config_path) + model_config = AutoConfig.from_pretrained( + peft_config.base_model_name_or_path, + trust_remote_code=trust_remote_code, + revision=revision, + ) + except Exception as e: + logging.error( + f"{model_config_path} does not contain a config.json or adapter_config.json for lora models. " + f"This is required for loading huggingface models") + raise e + return model_config, peft_config