diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 802d0ca68974..ac07281b3d33 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -865,7 +865,7 @@ def is_ninja_available(): return True -def is_ipex_available(): +def is_ipex_available(min_version: str = ""): def get_major_and_minor_from_version(full_version): return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) @@ -880,6 +880,8 @@ def get_major_and_minor_from_version(full_version): f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." ) return False + if min_version: + return version.parse(_ipex_version) >= version.parse(min_version) return True diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d59a18c59db4..ba61d4b43677 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -24,6 +24,7 @@ import numpy as np import pytest +from packaging import version from parameterized import parameterized from transformers import AutoConfig, is_torch_available, pipeline @@ -44,6 +45,7 @@ slow, torch_device, ) +from transformers.utils import is_ipex_available from ..test_modeling_common import floats_tensor, ids_tensor from .test_framework_agnostic import GenerationIntegrationTestsMixin @@ -675,10 +677,11 @@ def test_beam_search_generate_dict_outputs_use_cache(self): @require_torch_multi_accelerator @pytest.mark.generate def test_model_parallel_beam_search(self): - for model_class in self.all_generative_model_classes: - if "xpu" in torch_device: - return unittest.skip(reason="device_map='auto' does not work with XPU devices") + if "xpu" in torch_device: + if not (is_ipex_available("2.5") or version.parse(torch.__version__) >= version.parse("2.6")): + self.skipTest(reason="device_map='auto' does not work with XPU devices") + for model_class in self.all_generative_model_classes: if model_class._no_split_modules is None: continue