From 931cc64c060b805896713d828946521cf7de5e64 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 14 Dec 2023 10:33:22 +0400 Subject: [PATCH] fix compatibility with transformers 4.36 --- optimum/exporters/openvino/__main__.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index cb011706c8..553be718bd 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -23,10 +23,21 @@ from optimum.exporters import TasksManager from optimum.exporters.onnx import __main__ as optimum_main from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast + + +try: + from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED +except ImportError: + # Duplicated from https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/constants.py + # until it is not part of package + SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [ + "bart", + "whisper", + ] from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors -from ...intel.utils.import_utils import is_nncf_available +from ...intel.utils.import_utils import is_nncf_available, is_transformers_version from .convert import export_models @@ -140,10 +151,12 @@ def main_export( do_gptq_patching = False try: config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) + model_type = config.model_type.replace("_", "-") config_dict = config.to_dict() quantization_config = config_dict.get("quantization_config", None) do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" except Exception: + model_type = None pass if do_gptq_patching: @@ -192,6 +205,10 @@ class StoreAttr(object): f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) + loading_kwargs = {} + if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: + loading_kwargs["attn_implementation"] = "eager" + model = TasksManager.get_model_from_task( task, model_name_or_path, @@ -204,6 +221,7 @@ class StoreAttr(object): trust_remote_code=trust_remote_code, framework=framework, device=device, + **loading_kwargs, ) custom_architecture = False