Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: yan ma <yan.ma@intel.com>
  • Loading branch information
yma11 committed Nov 8, 2024
1 parent 0a91e6d commit 9ef0bd0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 27 deletions.
5 changes: 2 additions & 3 deletions tests/quantization/test_ipex_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from vllm.platforms import current_platform

MODELS = [
"casperhansen/llama-3-8b-instruct-awq",
"TechxGenus/Meta-Llama-3-8B-GPTQ", # with g_idx
"TheBloke/Llama-2-7B-GPTQ", # w/o g_idx
"AMead10/Llama-3.2-1B-Instruct-AWQ",
"shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx
]
DTYPE = ["bfloat16"]

Expand Down
44 changes: 20 additions & 24 deletions vllm/model_executor/layers/quantization/ipex_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import torch

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
Expand All @@ -25,12 +27,14 @@ def __init__(
method: str,
weight_bits: int,
group_size: int,
modules_to_not_convert: Optional[List[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
self.method = method
self.weight_bits = weight_bits
self.group_size = group_size
self.modules_to_not_convert = modules_to_not_convert or []
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
Expand All @@ -44,12 +48,9 @@ def __init__(
f"but got {self.method}.")

def __repr__(self) -> str:
return (f"IPEXConfig(method={self.method}"
return (f"IPEXConfig(method={self.method},"
f"weight_bits={self.weight_bits}, "
f"group_size={self.group_size}")

def get_ipex_quant_method_id(self) -> int:
return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method]
f"group_size={self.group_size})")

@classmethod
def get_name(cls) -> str:
Expand Down Expand Up @@ -77,17 +78,18 @@ def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config,
["q_group_size", "group_size"])
return cls(method, weight_bits, group_size, False, False)
if method == "gptq":
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(method, weight_bits, group_size, modules_to_not_convert,
False, False)
# otherwise for gptq
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
try:
desc_act = cls.get_from_keys(config, ["desc_act"])
except Exception:
desc_act = False
return cls(method, weight_bits, group_size, desc_act,
desc_act = cls.get_from_keys_or(config, ["desc_act"],
default=False)
return cls(method, weight_bits, group_size, [], desc_act,
lm_head_quantized)

@classmethod
Expand All @@ -107,17 +109,13 @@ def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if self.method == "awq":
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return IPEXAWQLinearMethod(self)
if self.method == "gptq":
return IPEXGPTQLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
if self.method == "awq":
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
else:
return []


class IPEXGPTQLinearMethod(GPTQLinearMethod):
"""GPTQ linear method using IPEX for the CPU/XPU backend.
Expand Down Expand Up @@ -168,8 +166,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
g_idx=g_idx,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=self.quant_config.get_ipex_quant_method_id(
) # type: ignore
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
)

def apply(self,
Expand Down Expand Up @@ -235,8 +232,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
qconfig=qconfig,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=self.quant_config.get_ipex_quant_method_id(
) # type: ignore
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore
)

def apply(self,
Expand Down

0 comments on commit 9ef0bd0

Please sign in to comment.