Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update unwrap from accelerate #29933

Merged
merged 13 commits into from
Apr 19, 2024
27 changes: 22 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
from accelerate.hooks import add_hook_to_module
from accelerate.utils import (
check_tied_parameters_on_same_device,
extract_model_from_parallel,
find_tied_parameters,
get_balanced_memory,
get_max_memory,
Expand Down Expand Up @@ -4805,18 +4806,34 @@ def forward(
return output


def unwrap_model(model: nn.Module) -> nn.Module:
def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).

Args:
model (`torch.nn.Module`): The model to unwrap.
recursive (`bool`, *optional*, defaults to `False`):
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
recursively, not just the top-level distributed containers.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
# Use accelerate implementation if available (should always be the case when using torch)
# This is for pytorch, as we also have to handle things like dynamo
if is_accelerate_available():
kwargs = {}
if recursive:
if not is_accelerate_available("0.29.0"):
raise RuntimeError(
"Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
)
else:
kwargs["recursive"] = recursive
return extract_model_from_parallel(model, **kwargs)
else:
return model
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model


def expand_device_map(device_map, param_names, start_prefix):
Expand Down
20 changes: 10 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
Expand Down Expand Up @@ -684,7 +684,7 @@ def _activate_neftune(self, model):
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
https://arxiv.org/abs/2310.05914
"""
unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)

if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
Expand All @@ -705,7 +705,7 @@ def _deactivate_neftune(self, model):
if not hasattr(self, "neftune_hook_handle"):
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")

unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)

if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
Expand Down Expand Up @@ -1617,7 +1617,7 @@ def _wrap_model(self, model, training=True, dataloader=None):
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
if self.accelerator.unwrap_model(model) is not model:
return model

# Mixed precision training with apex (torch < 1.6)
Expand Down Expand Up @@ -3165,7 +3165,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
self._past = outputs[self.args.past_index]

if labels is not None:
unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
Expand Down Expand Up @@ -3272,8 +3272,8 @@ def _save_tpu(self, output_dir: Optional[str] = None):
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint")
if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained(
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
self.accelerator.unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
Expand Down Expand Up @@ -3311,8 +3311,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
if state_dict is None:
state_dict = self.model.state_dict()

if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
Expand Down Expand Up @@ -3969,7 +3969,7 @@ def create_model_card(
f.write(model_card)

if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)

def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
Expand Down
7 changes: 4 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@
Trainer,
TrainerState,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_pt_utils import AcceleratorConfig

if is_safetensors_available():
Expand Down Expand Up @@ -2468,8 +2467,10 @@ def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)

def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
)

# with plain model
assert_flos_extraction(trainer, trainer.model)
Expand Down
Loading