Skip to content

Commit

Permalink
[lmi] remove redundant auto logic from python handler
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Jul 9, 2024
1 parent b6bc5ca commit 6546d43
Showing 1 changed file with 3 additions and 23 deletions.
26 changes: 3 additions & 23 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@
"BloomModel": "text-generation",
}

LMI_DIST_ADV_MODEL = {
"RWForCausalLM",
"GPTNeoXForCausalLM",
"T5ForConditionalGeneration",
"LlamaForCausalLM",
"FalconForCausalLM",
"MPTForCausalLM",
"GPTBigCodeForCausalLM",
}

# https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu
FLASH_2_SUPPORTED_MODELS = {
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
Expand All @@ -85,17 +75,8 @@ def enable_flash():
return False


def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
model_config):
if rolling_batch_type == "auto":
architecture = model_config.architectures[0]
if architecture in LMI_DIST_ADV_MODEL and is_mpi:
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
else:
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "scheduler":
def get_rolling_batch_class_from_str(rolling_batch_type: str):
if rolling_batch_type == "scheduler":
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "lmi-dist":
Expand Down Expand Up @@ -149,8 +130,7 @@ def initialize(self, properties: dict):

if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
_rolling_batch_cls = get_rolling_batch_class_from_str(
self.hf_configs.rolling_batch.value, self.hf_configs.mpi_mode,
self.model_config)
self.hf_configs.rolling_batch.value)
self.hf_configs.kwargs["model_config"] = self.model_config
self.rolling_batch = _rolling_batch_cls(
self.hf_configs.model_id_or_path, properties,
Expand Down

0 comments on commit 6546d43

Please sign in to comment.