Skip to content

Commit

Permalink
AWQ: Patch for mismatched devices in RotaryEmbedding (#1480)
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk authored Nov 12, 2024
1 parent b740597 commit 61876e2
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions olive/passes/pytorch/autoawq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from typing import Any, Dict, Union

import torch
from packaging import version

from olive.common.utils import StrEnumBase
from olive.common.utils import StrEnumBase, get_attr
from olive.data.config import DataConfig
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import HfModelHandler, PyTorchModelHandler
from olive.model import HfModelHandler
from olive.passes import Pass
from olive.passes.pass_config import PassConfigParam, get_user_script_data_config
from olive.passes.pytorch.common import inherit_hf_from_hf
Expand Down Expand Up @@ -107,9 +108,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
}

@torch.no_grad()
def _run_for_config(
self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str
) -> Union[HfModelHandler, PyTorchModelHandler]:
def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler:
from awq import AutoAWQForCausalLM

if not torch.cuda.is_available():
Expand Down Expand Up @@ -139,6 +138,7 @@ def _run_for_config(
awq_model = AutoAWQForCausalLM.from_pretrained(
model.model_name_or_path, **self._resolve_load_args(model.get_load_kwargs())
)
awq_model = self._maybe_patch_awq_model(awq_model)
tokenizer = model.get_hf_tokenizer()

# quantize the model
Expand Down Expand Up @@ -167,12 +167,36 @@ def _run_for_config(
new_load_kwargs["extra_args"]["use_safetensors"] = True
return inherit_hf_from_hf(model, output_model_path, adapter_path=adapter_path, load_kwargs=new_load_kwargs)

def _resolve_load_args(self, hf_loading_args):
loading_args = {}
# default value for `safetensors` is True in auto AWQ
loading_args["safetensors"] = hf_loading_args.get("use_safetensors", True)
if device_map := hf_loading_args.get("device_map"):
loading_args["device_map"] = device_map
if trust_remote_code := hf_loading_args.get("trust_remote_code"):
loading_args["trust_remote_code"] = trust_remote_code
return loading_args
def _resolve_load_args(self, hf_loading_args: Dict[str, Any]):
return {
# want to default to using safetensors like in AutoAWQ
"safetensors": hf_loading_args.get("use_safetensors", True),
# only trust remote code if the user has explicitly set it
"trust_remote_code": hf_loading_args.get("trust_remote_code"),
# Not much to be gained my using "auto" device map, so default to None
"device_map": hf_loading_args.get("device_map"),
}

def _maybe_patch_awq_model(self, awq_model):
from awq import __version__ as autoawq_version
from transformers import __version__ as transformers_version

# https://github.com/huggingface/transformers/issues/32420
# there is an issue in transformers with the rotary embedding where some tensors are still on CPU
# causing device mismatch error
# max limit on awq version in case fix is released
# transformers releases too frequently so we can't keep track of all versions
if version.parse(transformers_version) >= version.parse("4.43") and version.parse(
autoawq_version
) <= version.parse("0.2.6"):
original_move_embed = awq_model.move_embed

def new_move_embed(model, device):
original_move_embed(model, "cuda")
# almost all model types have rotary embeddings at model.model.rotary_emb so won't keep a mapping
if rotary_embed_module := get_attr(model, "model.rotary_emb"):
rotary_embed_module.to(device)

awq_model.move_embed = new_move_embed

return awq_model

0 comments on commit 61876e2

Please sign in to comment.