Skip to content

Commit

Permalink
Refactor huggingface handler
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed May 9, 2024
1 parent 1859b7e commit 144b409
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 161 deletions.
163 changes: 36 additions & 127 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@
import re

import torch
from transformers import (
pipeline, Pipeline, Conversation, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig,
AutoModelForSequenceClassification, AutoModelForTokenClassification,
AutoModelForQuestionAnswering, StoppingCriteria, StoppingCriteriaList)
from transformers import (pipeline, Pipeline, Conversation,
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForQuestionAnswering, StoppingCriteria,
StoppingCriteriaList)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from peft import PeftConfig, PeftModel, PeftModelForCausalLM
from peft import PeftModel, PeftModelForCausalLM

from djl_python.encode_decode import encode
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.streaming_utils import StreamingUtils
from djl_python.rolling_batch.rolling_batch import get_content_type_from_output_formatter

from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled, \
RollingBatchEnum
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from djl_python.utils import read_model_config, get_tokenizer
from djl_python.rolling_batch.rolling_batch_service import RollingBatchService
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs

ARCHITECTURES_2_TASK = {
Expand Down Expand Up @@ -82,28 +85,6 @@ def enable_flash():
return False


def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
model_config):
if rolling_batch_type == "auto":
architecture = model_config.architectures[0]
if architecture in LMI_DIST_ADV_MODEL and is_mpi:
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
else:
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "scheduler":
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "lmi-dist":
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
elif rolling_batch_type == "vllm":
from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch
return VLLMRollingBatch
raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}")


class StopWord(StoppingCriteria):

def __init__(self, tokenizer, stop_seq):
Expand Down Expand Up @@ -132,7 +113,6 @@ def __init__(self):
self.initialized = False
self.model = None
self.tokenizer = None
self.rolling_batch = None
self.model_config = None
self.peft_config = None
self.stopping_criteria_list = None
Expand All @@ -143,19 +123,17 @@ def __init__(self):

def initialize(self, properties: dict):
self.hf_configs = HuggingFaceProperties(**properties)
self._read_model_config(self.hf_configs.model_id_or_path)

if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
_rolling_batch_cls = get_rolling_batch_class_from_str(
self.hf_configs.rolling_batch.value, self.hf_configs.is_mpi,
self.model_config)
self.hf_configs.kwargs["model_config"] = self.model_config
self.rolling_batch = _rolling_batch_cls(
self.hf_configs.model_id_or_path, properties,
**self.hf_configs.kwargs)
self._init_tokenizer(self.hf_configs.model_id_or_path)
elif is_streaming_enabled(self.hf_configs.enable_streaming):
self._init_tokenizer(self.hf_configs.model_id_or_path)
self.model_config, self.peft_config = read_model_config(
self.hf_configs.model_id_or_path,
trust_remote_code=self.hf_configs.trust_remote_code,
revision=self.hf_configs.revision)

if is_streaming_enabled(self.hf_configs.enable_streaming):
self.tokenizer = get_tokenizer(
self.hf_configs.model_id_or_path,
trust_remote_code=self.hf_configs.trust_remote_code,
revision=self.hf_configs.revision,
peft_config=self.peft_config)
self._init_model(self.hf_configs.model_id_or_path,
**self.hf_configs.kwargs)
else:
Expand All @@ -172,8 +150,7 @@ def initialize(self, properties: dict):
self.load_stopping_criteria_list(properties["stop_sequence"])

self.input_format_configs = InputFormatConfigs(
is_rolling_batch=is_rolling_batch_enabled(
self.hf_configs.rolling_batch),
is_rolling_batch=False,
is_adapters_supported=True,
output_formatter=self.hf_configs.output_formatter,
tokenizer=self.tokenizer)
Expand Down Expand Up @@ -243,56 +220,10 @@ def inference(self, inputs):
if len(input_data) == 0:
for i in range(len(batch)):
err = errors.get(i)
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
err = {"data": "", "last": True, "code": 424, "error": err}
outputs.add(Output.binary_encode(err),
key="data",
batch_index=i)
else:
outputs.add(err, key="data", batch_index=i)
outputs.add(err, key="data", batch_index=i)
return outputs

if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
if inputs.get_property("reset_rollingbatch"):
self.rolling_batch.reset()
if self.adapters is not None:
adapter_data = []
for i, a in enumerate(self.adapters):
if a is None or a == "":
adapter_data.append(None)
elif a in self.adapter_registry:
adapter_data.append(self.adapter_registry[a])
else:
adapter_data.append(None)
errors[i] = f"Unknown or invalid adapter {a}"
else:
adapter_data = None
result = self.rolling_batch.inference(input_data,
parameters,
adapters=adapter_data)
idx = 0
for i in range(len(batch)):
err = errors.get(i)
if err:
err = {"data": "", "last": True, "code": 424, "error": err}
outputs.add(Output.binary_encode(err),
key="data",
batch_index=i)
else:
outputs.add(Output.binary_encode(result[idx]),
key="data",
batch_index=i)
idx += 1

formatter = parameters[i].get("output_formatter")
content_type = get_content_type_from_output_formatter(
formatter)
if content_type is not None:
outputs.add_property(f"batch_{i}_Content-Type",
content_type)

return outputs
elif is_streaming_enabled(self.hf_configs.enable_streaming):
if is_streaming_enabled(self.hf_configs.enable_streaming):
if len(batch) > 1:
raise NotImplementedError(
"Dynamic batch not supported for generic streaming")
Expand Down Expand Up @@ -401,7 +332,11 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
**kwargs)
self.model = hf_pipeline.model
else:
self._init_tokenizer(model_id_or_path)
self.tokenizer = get_tokenizer(
model_id_or_path,
trust_remote_code=self.hf_configs.trust_remote_code,
revision=self.hf_configs.revision,
peft_config=self.peft_config)
self._init_model(model_id_or_path, **kwargs)
hf_pipeline = pipeline(task=task,
model=self.model,
Expand All @@ -426,15 +361,6 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs):

return hf_pipeline

def _init_tokenizer(self, model_id_or_path: str):
path_to_use = model_id_or_path if self.peft_config is None else self.peft_config.base_model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(
path_to_use,
padding_size="left",
trust_remote_code=self.hf_configs.trust_remote_code,
revision=self.hf_configs.revision,
)

def _init_model(self, model_id_or_path: str, **kwargs):
architectures = self.model_config.architectures
if architectures and architectures[0].endswith(
Expand Down Expand Up @@ -523,28 +449,6 @@ def infer_task_from_model_architecture(self):
)
return task

def _read_model_config(self, model_config_path: str):
try:
self.model_config = AutoConfig.from_pretrained(
model_config_path,
trust_remote_code=self.hf_configs.trust_remote_code,
revision=self.hf_configs.revision)
except OSError:
logging.warning(
f"config.json not found for {model_config_path}. Attempting to load with peft"
)
self.peft_config = PeftConfig.from_pretrained(model_config_path)
self.model_config = AutoConfig.from_pretrained(
self.peft_config.base_model_name_or_path,
trust_remote_code=self.hf_configs.trust_remote_code,
revision=self.hf_configs.revision,
)
except Exception as e:
logging.error(
f"{model_config_path} does not contain a config.json or adapter_config.json for lora models. "
f"This is required for loading huggingface models")
raise e


_service = HuggingFaceService()

Expand Down Expand Up @@ -593,9 +497,14 @@ def handle(inputs: Input):
"""
Default handler function
"""
global _service
if not _service.initialized:
properties = inputs.get_properties()
if "rolling_batch" in properties and properties.get(
"rolling_batch") != RollingBatchEnum.disable.value:
_service = RollingBatchService()
# stateful model
_service.initialize(inputs.get_properties())
_service.initialize(properties)

if inputs.is_empty():
# initialization request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from enum import Enum
from typing import Optional, Union, Callable, Any
from pydantic import BaseModel, field_validator, model_validator, ValidationInfo, ConfigDict
from pydantic import BaseModel, field_validator, model_validator, ValidationInfo, Field, Extra, ConfigDict


class RollingBatchEnum(str, Enum):
Expand Down Expand Up @@ -64,7 +64,8 @@ class Properties(BaseModel):

# model_config is for pydantic configurations for BaseModel.
model_config = ConfigDict(arbitrary_types_allowed=True,
protected_namespaces=())
protected_namespaces=(),
extra='allow')

@model_validator(mode='before')
def calculate_is_mpi(cls, properties):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ class LmiDistRollingBatch(RollingBatch):
It also gets any new tokens from the backend and sends them back to the handler.
"""

def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
def __init__(self, properties: dict):
"""
Initializes the LmiDistRollingBatch.
:param model_id_or_path (str): Currently unused since there is a copy inside properties
:param properties (dict): other properties of the model, such as decoder strategy
"""
self.lmi_dist_config = LmiDistRbProperties(**properties)
Expand Down Expand Up @@ -84,8 +83,6 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
kwargs["warmup_prefill_tokens"] = _WARMUP_PREFILL_TOKENS
self.engine = engine_from_args(engine_args, **kwargs)
self.request_cache = OrderedDict()
self.model_type = getattr(kwargs.get("model_config", None),
"model_type", None)
self.lora_ids = defaultdict(lambda: len(self.lora_ids) + 1)

def reset(self) -> None:
Expand Down
Loading

0 comments on commit 144b409

Please sign in to comment.