From bd8457abe74152a7103e364b6bb222a1eb3781e0 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Thu, 12 Sep 2024 09:53:15 +0200 Subject: [PATCH] Simplified model loading (#32) This PR simplifies the model loading taking advantage of the new functionality of `get_model()` from `fms.models`. The current implementation automatically infers `architecture` and `variant` from a given `model_path` pointing to directory with weights in **hf** (hugging face) format. ### Changes: - replacing as_fms_model() by get_model() for **hf** models. - removing if condition for **meta** weights Note: make sure to use the **hf** format of the weights for model **7B-F** (checkpoint trained by meta) from now on... --- vllm/model_executor/model_loader/sendnn.py | 31 +++------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/model_loader/sendnn.py b/vllm/model_executor/model_loader/sendnn.py index c71ab289b..c6c66ef5e 100644 --- a/vllm/model_executor/model_loader/sendnn.py +++ b/vllm/model_executor/model_loader/sendnn.py @@ -182,34 +182,9 @@ def sample( def load_weights(self, model_name_or_path: str, device_type, max_prompt_length, max_decode_length, **kwargs): - # check model source: hf or meta - files = os.listdir(model_name_or_path) - model_source = 'hf' - # default huggingface, but if .pth file in model directory, then it is meta weights - for f in files: - if f.endswith('.pth'): - model_source = 'meta' - break - - if model_source == 'hf': # hugging face - # load hugging face model - self.model = as_fms_model(model_name_or_path) - else: # meta - variant = "7b" - architecture = "llama" - distr_param = None - - # Load the weights from the cached or downloaded files. - self.model = get_model( - architecture=architecture, - variant=variant, - model_path=model_name_or_path, - source=model_source, - device_type=device_type, - distributed_strategy=distr_param, - group=dist.group.WORLD, - ) - + # function will infer architecture and variant for hf models based on model_name_or_path + self.model = get_model("hf_pretrained", model_name_or_path) + compile_mode = "default" dynamo_backend = DYN_BACKEND