Skip to content

Commit

Permalink
Migrate import checks not need accelerate, and be more clear on min v…
Browse files Browse the repository at this point in the history
…ersions (#32292)

* Migrate import checks to secondary accelerate calls

* better errs too

* Revert, just keep the import checks + remove accelerate-specific things

* Rm extra'

* Empty commit for ci

* Small nits

* Final
  • Loading branch information
muellerzr authored Aug 6, 2024
1 parent 80b90e7 commit 194cf1f
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 22 deletions.
6 changes: 4 additions & 2 deletions src/transformers/integrations/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"AQLM (Additive Quantization of Language Model) integration file"

from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available
from ..utils import ACCELERATE_MIN_VERSION, is_accelerate_available, is_aqlm_available, is_torch_available


if is_torch_available():
Expand Down Expand Up @@ -50,7 +50,9 @@ def replace_with_aqlm_linear(
raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")

if not is_accelerate_available():
raise ValueError("AQLM requires Accelerate to be installed: `pip install accelerate`")
raise ValueError(
f"AQLM requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)

if linear_weights_not_to_quantize is None:
linear_weights_not_to_quantize = []
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
ACCELERATE_MIN_VERSION,
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
Expand Down Expand Up @@ -3299,7 +3300,7 @@ def from_pretrained(
)
elif not is_accelerate_available():
raise ImportError(
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)

# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
from ..utils import (
ACCELERATE_MIN_VERSION,
is_accelerate_available,
is_bitsandbytes_available,
is_torch_available,
logging,
)


if is_torch_available():
Expand Down Expand Up @@ -62,7 +68,9 @@ def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available():
raise ImportError("Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install accelerate`")
raise ImportError(
f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
if not is_bitsandbytes_available():
raise ImportError(
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
from ..utils import (
ACCELERATE_MIN_VERSION,
is_accelerate_available,
is_bitsandbytes_available,
is_torch_available,
logging,
)
from .quantizers_utils import get_module_from_name


Expand Down Expand Up @@ -62,7 +68,9 @@ def validate_environment(self, *args, **kwargs):
raise RuntimeError("No GPU found. A GPU is needed for quantization.")

if not is_accelerate_available():
raise ImportError("Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate`")
raise ImportError(
f"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
if not is_bitsandbytes_available():
raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
Expand Down
15 changes: 6 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,11 @@
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
strtobool,
)
Expand Down Expand Up @@ -223,11 +225,6 @@
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
is_mlu_available,
is_mps_available,
is_npu_available,
is_torch_version,
is_xpu_available,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -3332,13 +3329,13 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_xpu_available():
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_npu_available():
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_version(">=", "2.0") and is_mps_available():
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,7 +1968,9 @@ def __post_init__(self):
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.
if not is_accelerate_available():
raise ValueError("--deepspeed requires Accelerate to be installed: `pip install accelerate`.")
raise ValueError(
f"--deepspeed requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`."
)
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

# will be used later by the Trainer
Expand Down Expand Up @@ -2102,7 +2104,7 @@ def _setup_devices(self) -> "torch.device":
if not is_accelerate_available():
raise ImportError(
f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
# We delay the init of `PartialState` to the end for clarity
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def is_torch_available():
return _torch_available


def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)


def is_torch_deterministic():
"""
Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
Expand Down Expand Up @@ -885,10 +889,6 @@ def is_protobuf_available():
return importlib.util.find_spec("google.protobuf") is not None


def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)


def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version)

Expand Down

0 comments on commit 194cf1f

Please sign in to comment.