From 61876e211f3973d0ce11dd604d9c1650c53e1e37 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Mon, 11 Nov 2024 23:00:55 -0800 Subject: [PATCH] AWQ: Patch for mismatched devices in RotaryEmbedding (#1480) --- olive/passes/pytorch/autoawq.py | 52 ++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/olive/passes/pytorch/autoawq.py b/olive/passes/pytorch/autoawq.py index b295bd980..12cf86fb6 100644 --- a/olive/passes/pytorch/autoawq.py +++ b/olive/passes/pytorch/autoawq.py @@ -7,11 +7,12 @@ from typing import Any, Dict, Union import torch +from packaging import version -from olive.common.utils import StrEnumBase +from olive.common.utils import StrEnumBase, get_attr from olive.data.config import DataConfig from olive.hardware.accelerator import AcceleratorSpec -from olive.model import HfModelHandler, PyTorchModelHandler +from olive.model import HfModelHandler from olive.passes import Pass from olive.passes.pass_config import PassConfigParam, get_user_script_data_config from olive.passes.pytorch.common import inherit_hf_from_hf @@ -107,9 +108,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } @torch.no_grad() - def _run_for_config( - self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str - ) -> Union[HfModelHandler, PyTorchModelHandler]: + def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: from awq import AutoAWQForCausalLM if not torch.cuda.is_available(): @@ -139,6 +138,7 @@ def _run_for_config( awq_model = AutoAWQForCausalLM.from_pretrained( model.model_name_or_path, **self._resolve_load_args(model.get_load_kwargs()) ) + awq_model = self._maybe_patch_awq_model(awq_model) tokenizer = model.get_hf_tokenizer() # quantize the model @@ -167,12 +167,36 @@ def _run_for_config( new_load_kwargs["extra_args"]["use_safetensors"] = True return inherit_hf_from_hf(model, output_model_path, adapter_path=adapter_path, load_kwargs=new_load_kwargs) - def _resolve_load_args(self, hf_loading_args): - loading_args = {} - # default value for `safetensors` is True in auto AWQ - loading_args["safetensors"] = hf_loading_args.get("use_safetensors", True) - if device_map := hf_loading_args.get("device_map"): - loading_args["device_map"] = device_map - if trust_remote_code := hf_loading_args.get("trust_remote_code"): - loading_args["trust_remote_code"] = trust_remote_code - return loading_args + def _resolve_load_args(self, hf_loading_args: Dict[str, Any]): + return { + # want to default to using safetensors like in AutoAWQ + "safetensors": hf_loading_args.get("use_safetensors", True), + # only trust remote code if the user has explicitly set it + "trust_remote_code": hf_loading_args.get("trust_remote_code"), + # Not much to be gained my using "auto" device map, so default to None + "device_map": hf_loading_args.get("device_map"), + } + + def _maybe_patch_awq_model(self, awq_model): + from awq import __version__ as autoawq_version + from transformers import __version__ as transformers_version + + # https://github.com/huggingface/transformers/issues/32420 + # there is an issue in transformers with the rotary embedding where some tensors are still on CPU + # causing device mismatch error + # max limit on awq version in case fix is released + # transformers releases too frequently so we can't keep track of all versions + if version.parse(transformers_version) >= version.parse("4.43") and version.parse( + autoawq_version + ) <= version.parse("0.2.6"): + original_move_embed = awq_model.move_embed + + def new_move_embed(model, device): + original_move_embed(model, "cuda") + # almost all model types have rotary embeddings at model.model.rotary_emb so won't keep a mapping + if rotary_embed_module := get_attr(model, "model.rotary_emb"): + rotary_embed_module.to(device) + + awq_model.move_embed = new_move_embed + + return awq_model