diff --git a/src/transformers/tools/base.py b/src/transformers/tools/base.py index 4042b28ac64c..f81096b3d855 100644 --- a/src/transformers/tools/base.py +++ b/src/transformers/tools/base.py @@ -46,6 +46,7 @@ import torch if is_accelerate_available(): + from accelerate import PartialState from accelerate.utils import send_to_device @@ -529,7 +530,7 @@ def setup(self): if self.device_map is not None: self.device = list(self.model.hf_device_map.values())[0] else: - self.device = get_default_device() + self.device = PartialState().default_device if self.device_map is None: self.model.to(self.device) @@ -597,23 +598,6 @@ def fn(*args, **kwargs): ).launch() -# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release. -def get_default_device(): - logger.warning( - "`get_default_device` is deprecated and will be replaced with `accelerate`'s `PartialState().default_device` " - "in version 4.38 of 🤗 Transformers. " - ) - if not is_torch_available(): - raise ImportError("Please install torch in order to use this tool.") - - if torch.backends.mps.is_available() and torch.backends.mps.is_built(): - return torch.device("mps") - elif torch.cuda.is_available(): - return torch.device("cuda") - else: - return torch.device("cpu") - - TASK_MAPPING = { "document-question-answering": "DocumentQuestionAnsweringTool", "image-captioning": "ImageCaptioningTool",