Skip to content

Commit

Permalink
get default device through PartialState().default_device as it has
Browse files Browse the repository at this point in the history
been officially released
  • Loading branch information
ji-huazhong committed Dec 9, 2023
1 parent 80377eb commit a81e853
Showing 1 changed file with 2 additions and 18 deletions.
20 changes: 2 additions & 18 deletions src/transformers/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import torch

if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import send_to_device


Expand Down Expand Up @@ -529,7 +530,7 @@ def setup(self):
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = get_default_device()
self.device = PartialState().default_device

if self.device_map is None:
self.model.to(self.device)
Expand Down Expand Up @@ -597,23 +598,6 @@ def fn(*args, **kwargs):
).launch()


# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release.
def get_default_device():
logger.warning(
"`get_default_device` is deprecated and will be replaced with `accelerate`'s `PartialState().default_device` "
"in version 4.38 of 🤗 Transformers. "
)
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")


TASK_MAPPING = {
"document-question-answering": "DocumentQuestionAnsweringTool",
"image-captioning": "ImageCaptioningTool",
Expand Down

0 comments on commit a81e853

Please sign in to comment.