Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move fsdp handling to accelerate #23158

Merged
merged 13 commits into from
May 31, 2023
207 changes: 83 additions & 124 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
from accelerate import skip_first_batches

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs


if TYPE_CHECKING:
Expand Down Expand Up @@ -342,6 +343,12 @@ def __init__(
# create accelerator object
self.accelerator = Accelerator()

# post accelerator creation setup
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()
Expand Down Expand Up @@ -463,7 +470,7 @@ def __init__(
self.fsdp = ShardingStrategy.NO_SHARD

self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
"backward_prefetch", []
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
Expand Down Expand Up @@ -1478,119 +1485,65 @@ def _wrap_model(self, model, training=True, dataloader=None):
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
if not self.args.fsdp_config["xla"]:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)

auto_wrap_policy = None

if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
signature = inspect.signature(FSDP.__init__).parameters.keys()
kwargs = {}
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
if arg in signature:
kwargs[arg] = getattr(self, arg)
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
**kwargs,
)
else:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)

# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)

# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)

# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss

xm.optimizer_step = patched_optimizer_step
xm.optimizer_step = patched_optimizer_step
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
)
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
if is_torch_neuroncore_available():
return model
kwargs = {}
if self.args.ddp_find_unused_parameters is not None:
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
Expand All @@ -1603,15 +1556,8 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):

if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model
if any(p.requires_grad for p in model.parameters()):
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)

self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)

# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
Expand Down Expand Up @@ -1800,17 +1746,26 @@ def _inner_training_loop(
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# prepare using `accelerator` prepare
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if use_accelerator_prepare:
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)

if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
self.model = model

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model

# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
Expand Down Expand Up @@ -2898,11 +2853,15 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
or self.fsdp is not None
or getattr(self.accelerator.state, "fsdp_plugin", None) is not None
):
state_dict = self.model.state_dict()
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()

if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save:
Expand Down
28 changes: 27 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ class TrainingArguments:
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
gradient
computation.
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
- `"backward_post"` : This prefetches the next set of parameters after the current set of
parameter’s
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
Expand Down Expand Up @@ -1504,6 +1504,32 @@ def __post_init__(self):
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")

# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
)

for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
)
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down