Skip to content

Commit

Permalink
[lmi] remove redundant auto logic from python handler (#2152)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Jul 10, 2024
1 parent 36d0212 commit ecdc519
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 27 deletions.
26 changes: 3 additions & 23 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@
"BloomModel": "text-generation",
}

LMI_DIST_ADV_MODEL = {
"RWForCausalLM",
"GPTNeoXForCausalLM",
"T5ForConditionalGeneration",
"LlamaForCausalLM",
"FalconForCausalLM",
"MPTForCausalLM",
"GPTBigCodeForCausalLM",
}

# https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu
FLASH_2_SUPPORTED_MODELS = {
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
Expand All @@ -85,17 +75,8 @@ def enable_flash():
return False


def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
model_config):
if rolling_batch_type == "auto":
architecture = model_config.architectures[0]
if architecture in LMI_DIST_ADV_MODEL and is_mpi:
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
else:
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "scheduler":
def get_rolling_batch_class_from_str(rolling_batch_type: str):
if rolling_batch_type == "scheduler":
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "lmi-dist":
Expand Down Expand Up @@ -149,8 +130,7 @@ def initialize(self, properties: dict):

if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
_rolling_batch_cls = get_rolling_batch_class_from_str(
self.hf_configs.rolling_batch.value, self.hf_configs.mpi_mode,
self.model_config)
self.hf_configs.rolling_batch.value)
self.hf_configs.kwargs["model_config"] = self.model_config
self.rolling_batch = _rolling_batch_cls(
self.hf_configs.model_id_or_path, properties,
Expand Down
6 changes: 2 additions & 4 deletions engines/python/setup/djl_python/tests/test_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
from djl_python import huggingface


def override_rolling_batch(rolling_batch_type: str, is_mpi: bool,
model_config):
def override_rolling_batch(rolling_batch_type: str):
from djl_python.tests.rolling_batch.fake_rolling_batch import FakeRollingBatch
return FakeRollingBatch


def override_rolling_batch_with_exception(rolling_batch_type: str,
is_mpi: bool, model_config):
def override_rolling_batch_with_exception(rolling_batch_type: str):
from djl_python.tests.rolling_batch.fake_rolling_batch import FakeRollingBatchWithException
return FakeRollingBatchWithException

Expand Down

0 comments on commit ecdc519

Please sign in to comment.