From abb124e477a79f3de316c3920dad315bd5bb1560 Mon Sep 17 00:00:00 2001 From: Somasundaram Date: Wed, 8 May 2024 16:03:20 -0700 Subject: [PATCH] change hf config --- .../rolling_batch/rolling_batch_service.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py index be2df4371f..17ac5055dc 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_service.py @@ -58,28 +58,28 @@ def __init__(self): self.initialized = False self.adapters = None self.adapter_registry = {} - self.rb_configs = None + self.hf_configs = None self.input_format_configs = None def initialize(self, properties: dict): - self.rb_configs = HuggingFaceProperties(**properties) + self.hf_configs = HuggingFaceProperties(**properties) self.model_config, self.peft_config = read_model_config( - self.rb_configs.model_id_or_path, - self.rb_configs.trust_remote_code, self.rb_configs.revision) + self.hf_configs.model_id_or_path, + self.hf_configs.trust_remote_code, self.hf_configs.revision) _rolling_batch_cls = get_rolling_batch_class_from_str( - self.rb_configs.rolling_batch.value, self.rb_configs.is_mpi, + self.hf_configs.rolling_batch.value, self.hf_configs.is_mpi, self.model_config) - self.rb_configs.kwargs["model_config"] = self.model_config + self.hf_configs.kwargs["model_config"] = self.model_config self.rolling_batch = _rolling_batch_cls(properties) - self.tokenizer = get_tokenizer(self.rb_configs.model_id_or_path, - self.rb_configs.trust_remote_code, - self.rb_configs.revision, + self.tokenizer = get_tokenizer(self.hf_configs.model_id_or_path, + self.hf_configs.trust_remote_code, + self.hf_configs.revision, peft_config=self.peft_config) self.input_format_configs = InputFormatConfigs( is_rolling_batch=True, is_adapters_supported=True, tokenizer=self.tokenizer, - output_formatter=self.rb_configs.output_formatter) + output_formatter=self.hf_configs.output_formatter) self.initialized = True def parse_input(