Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make clearer about zero_init requirements #29879

Merged
merged 8 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import importlib.metadata as importlib_metadata
import importlib.util
import os
import weakref
from functools import partialmethod

Expand Down Expand Up @@ -282,9 +283,19 @@ def unset_hf_deepspeed_config():
_hf_deepspeed_config_weak_ref = None


def is_deepspeed_zero3_enabled():
def is_deepspeed_zero3_enabled(check_accelerate=False):
"""
If `check_accelerate`, will also check if `deepspeed_zero3` has been enabled through
the environment variables setup during `accelerate launch`.
"""
accelerate_zero_stage = int(os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", -1))
accelerate_zero_init = os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", "0")
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().is_zero3()
# This only gets triggered passively if the user launches code with a configured
# `accelerate launch` without making `TrainingArguments`
elif check_accelerate and accelerate_zero_stage != -1 and accelerate_zero_init != "0":
return True, False
else:
return False

Expand Down
20 changes: 18 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,15 @@ def _from_config(cls, config, **kwargs):
torch_dtype=torch_dtype,
)

if is_deepspeed_zero3_enabled():
deepspeed_enabled, accelerate_enabled = is_deepspeed_zero3_enabled(check_accelerate=True)

if deepspeed_enabled:
if not accelerate_enabled:
raise ValueError(
"Detected that you want to use `zero-3` Init, but the environment "
"has not been setup yet. Please create `TrainingArguments` before "
"initializing the model."
)
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
Expand Down Expand Up @@ -3386,7 +3394,15 @@ def from_pretrained(
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]

if is_deepspeed_zero3_enabled() and not is_quantized:
deepspeed_enabled, accelerate_enabled = is_deepspeed_zero3_enabled(check_accelerate=True)

if deepspeed_enabled and not is_quantized:
if not accelerate_enabled:
raise ValueError(
"Detected that you want to use `zero-3` Init, but the environment "
"has not been setup yet. Please create `TrainingArguments` before "
"initializing the model."
)
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ class TrainingArguments:
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
`ds_config.json`) or an already loaded json file as a `dict`"

<Tip warning={true}>
If enabling any Zero-init, make sure that your model is not initialized until
*after* initializing the `TrainingArguments`, else it will not be applied.
</Tip>

accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):
Config to be used with the internal `Accelerator` implementation. The value is either a location of
accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,
Expand Down
Loading