Skip to content

Commit

Permalink
add Cambricon MLUs support (#29627)
Browse files Browse the repository at this point in the history
* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker
  • Loading branch information
huismiling authored Mar 27, 2024
1 parent 0efcf32 commit 7576974
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,7 @@
"is_timm_available",
"is_tokenizers_available",
"is_torch_available",
"is_torch_mlu_available",
"is_torch_neuroncore_available",
"is_torch_npu_available",
"is_torch_tpu_available",
Expand Down Expand Up @@ -5987,6 +5988,7 @@
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functools import partialmethod

from ..dependency_versions_check import dep_version_check
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_available, logging


if is_torch_available():
Expand All @@ -38,6 +38,9 @@ def is_deepspeed_available():
# AND checking it has an author field in the metadata that is HuggingFace.
if package_exists:
try:
if is_torch_mlu_available():
_ = importlib_metadata.metadata("deepspeed-mlu")
return True
_ = importlib_metadata.metadata("deepspeed")
return True
except importlib_metadata.PackageNotFoundError:
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_tf_available,
is_torch_available,
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
Expand Down Expand Up @@ -851,6 +852,8 @@ def __init__(
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
elif is_torch_mlu_available():
self.device = torch.device(f"mlu:{device}")
elif is_torch_cuda_available():
self.device = torch.device(f"cuda:{device}")
elif is_torch_npu_available():
Expand Down Expand Up @@ -995,6 +998,9 @@ def device_placement(self):
if self.device.type == "cuda":
with torch.cuda.device(self.device):
yield
elif self.device.type == "mlu":
with torch.mlu.device(self.device):
yield
else:
yield

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
Expand Down Expand Up @@ -2671,6 +2672,17 @@ def _load_rng_state(self, checkpoint):
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
if is_torch_mlu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"])
else:
try:
torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
Expand Down Expand Up @@ -2745,6 +2757,12 @@ def _save_rng_state(self, output_dir):
else:
rng_states["npu"] = torch.npu.random.get_rng_state()

if is_torch_mlu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
else:
rng_states["mlu"] = torch.mlu.random.get_rng_state()

# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
Expand Down
15 changes: 14 additions & 1 deletion src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_tf_available,
is_torch_available,
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_xla_available,
Expand Down Expand Up @@ -100,6 +101,8 @@ def set_seed(seed: int, deterministic: bool = False):
# ^^ safe to call this function even if cuda is not available
if deterministic:
torch.use_deterministic_algorithms(True)
if is_torch_mlu_available():
torch.mlu.manual_seed_all(seed)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
Expand Down Expand Up @@ -455,7 +458,7 @@ def __init__(self, skip_memory_metrics=False):

import psutil # noqa

if is_torch_cuda_available():
if is_torch_cuda_available() or is_torch_mlu_available():
import torch

self.torch = torch
Expand Down Expand Up @@ -528,6 +531,9 @@ def start(self):
if torch.cuda.is_available():
self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.reset_peak_memory_stats()
self.torch.mlu.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.reset_peak_memory_stats()
self.torch.xpu.empty_cache()
Expand All @@ -541,6 +547,8 @@ def start(self):
if self.torch is not None:
if torch.cuda.is_available():
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
elif is_torch_npu_available():
Expand Down Expand Up @@ -572,6 +580,8 @@ def stop(self, stage):
if self.torch is not None:
if torch.cuda.is_available():
self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.empty_cache()
elif is_torch_npu_available():
Expand All @@ -589,6 +599,9 @@ def stop(self, stage):
if torch.cuda.is_available():
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tf32_available,
Expand Down Expand Up @@ -993,7 +994,7 @@ class TrainingArguments:
default=None,
metadata={
"help": "The backend to be used for distributed training",
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl"],
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl"],
},
)
tpu_num_cores: Optional[int] = field(
Expand Down Expand Up @@ -1549,20 +1550,22 @@ def __post_init__(self):
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)."
" (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)."
)

if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
Expand All @@ -1572,7 +1575,7 @@ def __post_init__(self):
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU or CPU/TPU/NeuronCore devices."
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices."
)

if self.torchdynamo is not None:
Expand Down Expand Up @@ -1999,6 +2002,10 @@ def _setup_devices(self) -> "torch.device":
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_mlu_available():
device = torch.device("mlu:0")
torch.mlu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_npu_available,
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,29 @@ def is_torch_npu_available(check_device=False):
return hasattr(torch, "npu") and torch.npu.is_available()


@lru_cache()
def is_torch_mlu_available(check_device=False):
"Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
if not _torch_available or importlib.util.find_spec("torch_mlu") is None:
return False

import torch
import torch_mlu # noqa: F401

from ..dependency_versions_table import deps

deps["deepspeed"] = "deepspeed-mlu>=0.10.1"

if check_device:
try:
# Will raise a RuntimeError if no MLU is found
_ = torch.mlu.device_count()
return torch.mlu.is_available()
except RuntimeError:
return False
return hasattr(torch, "mlu") and torch.mlu.is_available()


def is_torchdynamo_available():
if not is_torch_available():
return False
Expand Down

0 comments on commit 7576974

Please sign in to comment.