diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index ea8aa0a1..5b245e81 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -8,15 +8,12 @@ from datasets import Dataset from safetensors.torch import save_file from transformers import ( - AwqConfig, - BitsAndBytesConfig, - GPTQConfig, - TorchAoConfig, Trainer, TrainerCallback, TrainerState, TrainingArguments, ) +from transformers.quantizers import AutoQuantizationConfig from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available from ..base import Backend @@ -286,8 +283,6 @@ def create_no_weights_model(self) -> None: def process_quantization_config(self) -> None: if self.is_gptq_quantized: - self.logger.info("\t+ Processing GPTQ config") - try: import exllamav2_kernels # noqa: F401 except ImportError: @@ -299,12 +294,7 @@ def process_quantization_config(self) -> None: "`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`." ) - self.quantization_config = GPTQConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) elif self.is_awq_quantized: - self.logger.info("\t+ Processing AWQ config") - try: import exlv2_ext # noqa: F401 except ImportError: @@ -316,21 +306,10 @@ def process_quantization_config(self) -> None: "`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`." ) - self.quantization_config = AwqConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) - elif self.is_bnb_quantized: - self.logger.info("\t+ Processing BitsAndBytes config") - self.quantization_config = BitsAndBytesConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) - elif self.is_torchao_quantized: - self.logger.info("\t+ Processing TorchAO config") - self.quantization_config = TorchAoConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) - else: - raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized") + self.logger.info("\t+ Processing AutoQuantization config") + self.quantization_config = AutoQuantizationConfig.from_dict( + getattr(self.pretrained_config, "quantization_config", {}).update(self.config.quantization_config) + ) @property def is_quantized(self) -> bool: @@ -339,13 +318,6 @@ def is_quantized(self) -> bool: and self.pretrained_config.quantization_config.get("quant_method", None) is not None ) - @property - def is_bnb_quantized(self) -> bool: - return self.config.quantization_scheme == "bnb" or ( - hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "bnb" - ) - @property def is_gptq_quantized(self) -> bool: return self.config.quantization_scheme == "gptq" or ( @@ -360,13 +332,6 @@ def is_awq_quantized(self) -> bool: and self.pretrained_config.quantization_config.get("quant_method", None) == "awq" ) - @property - def is_torchao_quantized(self) -> bool: - return self.config.quantization_scheme == "torchao" or ( - hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao" - ) - @property def is_exllamav2(self) -> bool: return ( @@ -390,7 +355,10 @@ def automodel_kwargs(self) -> Dict[str, Any]: kwargs = {} if self.config.torch_dtype is not None: - kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) + if hasattr(torch, self.config.torch_dtype): + kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) + else: + kwargs["torch_dtype"] = self.config.torch_dtype if self.is_quantized: kwargs["quantization_config"] = self.quantization_config @@ -436,9 +404,9 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict @torch.inference_mode() def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, ( - "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" - ) + assert ( + kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1 + ), "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode()