Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Cambricon MLUs support #29627

Merged
merged 9 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@
"is_timm_available",
"is_tokenizers_available",
"is_torch_available",
"is_torch_mlu_available",
"is_torch_neuroncore_available",
"is_torch_npu_available",
"is_torch_tpu_available",
Expand Down Expand Up @@ -5973,6 +5974,7 @@
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functools import partialmethod

from ..dependency_versions_check import dep_version_check
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_available, logging


if is_torch_available():
Expand All @@ -38,6 +38,9 @@ def is_deepspeed_available():
# AND checking it has an author field in the metadata that is HuggingFace.
if package_exists:
try:
if is_torch_mlu_available():
_ = importlib_metadata.metadata("deepspeed-mlu")
return True
_ = importlib_metadata.metadata("deepspeed")
return True
except importlib_metadata.PackageNotFoundError:
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_tf_available,
is_torch_available,
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
Expand Down Expand Up @@ -851,6 +852,8 @@ def __init__(
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
elif is_torch_mlu_available():
self.device = torch.device(f"mlu:{device}")
elif is_torch_cuda_available():
self.device = torch.device(f"cuda:{device}")
elif is_torch_npu_available():
Expand Down Expand Up @@ -988,6 +991,9 @@ def device_placement(self):
if self.device.type == "cuda":
with torch.cuda.device(self.device):
yield
elif self.device.type == "mlu":
with torch.mlu.device(self.device):
yield
else:
yield

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
Expand Down Expand Up @@ -2642,6 +2643,17 @@ def _load_rng_state(self, checkpoint):
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
if is_torch_mlu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"])
else:
try:
torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
Expand Down Expand Up @@ -2716,6 +2728,12 @@ def _save_rng_state(self, output_dir):
else:
rng_states["npu"] = torch.npu.random.get_rng_state()

if is_torch_mlu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
else:
rng_states["mlu"] = torch.mlu.random.get_rng_state()

# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
Expand Down
15 changes: 14 additions & 1 deletion src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_tf_available,
is_torch_available,
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_xla_available,
Expand Down Expand Up @@ -100,6 +101,8 @@ def set_seed(seed: int, deterministic: bool = False):
# ^^ safe to call this function even if cuda is not available
if deterministic:
torch.use_deterministic_algorithms(True)
if is_torch_mlu_available():
torch.mlu.manual_seed_all(seed)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
Expand Down Expand Up @@ -454,7 +457,7 @@ def __init__(self, skip_memory_metrics=False):

import psutil # noqa

if is_torch_cuda_available():
if is_torch_cuda_available() or is_torch_mlu_available():
import torch

self.torch = torch
Expand Down Expand Up @@ -527,6 +530,9 @@ def start(self):
if torch.cuda.is_available():
self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.reset_peak_memory_stats()
self.torch.mlu.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.reset_peak_memory_stats()
self.torch.xpu.empty_cache()
Expand All @@ -540,6 +546,8 @@ def start(self):
if self.torch is not None:
if torch.cuda.is_available():
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
elif is_torch_npu_available():
Expand Down Expand Up @@ -571,6 +579,8 @@ def stop(self, stage):
if self.torch is not None:
if torch.cuda.is_available():
self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.empty_cache()
elif is_torch_npu_available():
Expand All @@ -588,6 +598,9 @@ def stop(self, stage):
if torch.cuda.is_available():
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tf32_available,
Expand Down Expand Up @@ -993,7 +994,7 @@ class TrainingArguments:
default=None,
metadata={
"help": "The backend to be used for distributed training",
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl"],
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl"],
},
)
tpu_num_cores: Optional[int] = field(
Expand Down Expand Up @@ -1549,20 +1550,22 @@ def __post_init__(self):
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)."
" (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)."
)

if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
Expand All @@ -1572,7 +1575,7 @@ def __post_init__(self):
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU or CPU/TPU/NeuronCore devices."
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices."
)

if self.torchdynamo is not None:
Expand Down Expand Up @@ -1999,6 +2002,10 @@ def _setup_devices(self) -> "torch.device":
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_mlu_available():
device = torch.device("mlu:0")
torch.mlu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_npu_available,
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,29 @@ def is_torch_npu_available(check_device=False):
return hasattr(torch, "npu") and torch.npu.is_available()


@lru_cache()
def is_torch_mlu_available(check_device=False):
"Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
if not _torch_available or importlib.util.find_spec("torch_mlu") is None:
return False

import torch
import torch_mlu # noqa: F401

from ..dependency_versions_table import deps

deps["deepspeed"] = "deepspeed-mlu>=0.10.1"

if check_device:
try:
# Will raise a RuntimeError if no MLU is found
_ = torch.mlu.device_count()
return torch.mlu.is_available()
except RuntimeError:
return False
return hasattr(torch, "mlu") and torch.mlu.is_available()


def is_torchdynamo_available():
if not is_torch_available():
return False
Expand Down
Loading