Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update unwrap from accelerate #29933

Merged
merged 13 commits into from
Apr 19, 2024
14 changes: 13 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
from accelerate.hooks import add_hook_to_module
from accelerate.utils import (
check_tied_parameters_on_same_device,
extract_model_from_parallel,
find_tied_parameters,
get_balanced_memory,
get_max_memory,
Expand Down Expand Up @@ -4805,13 +4806,24 @@ def forward(
return output


def unwrap_model(model: nn.Module) -> nn.Module:
def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).

Args:
model (`torch.nn.Module`): The model to unwrap.
"""
if is_accelerate_available():
kwargs = {}
if version.parse(importlib.metadata.version("accelerate")) >= version.parse("0.29.0"):
kwargs["recursive"] = recursive
# Need to update to accelerate>0.29.0 if one uses recursive=True
elif recursive:
logger.error(
"Using recursive=True in unwrap_model requires a version of accelerate >= 0.29.0. Please upgrade your version of accelerate."
)
return extract_model_from_parallel(model, **kwargs)

# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
Expand Down
20 changes: 10 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
Expand Down Expand Up @@ -684,7 +684,7 @@ def _activate_neftune(self, model):
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
https://arxiv.org/abs/2310.05914
"""
unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)

if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
Expand All @@ -705,7 +705,7 @@ def _deactivate_neftune(self, model):
if not hasattr(self, "neftune_hook_handle"):
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")

unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)

if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
Expand Down Expand Up @@ -1617,7 +1617,7 @@ def _wrap_model(self, model, training=True, dataloader=None):
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
if self.accelerator.unwrap_model(model) is not model:
return model

# Mixed precision training with apex (torch < 1.6)
Expand Down Expand Up @@ -3165,7 +3165,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
self._past = outputs[self.args.past_index]

if labels is not None:
unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
Expand Down Expand Up @@ -3272,8 +3272,8 @@ def _save_tpu(self, output_dir: Optional[str] = None):
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint")
if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained(
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
self.accelerator.unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
Expand Down Expand Up @@ -3311,8 +3311,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
if state_dict is None:
state_dict = self.model.state_dict()

if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
Expand Down Expand Up @@ -3969,7 +3969,7 @@ def create_model_card(
f.write(model_card)

if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)

def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
Expand Down
7 changes: 4 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@
Trainer,
TrainerState,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_pt_utils import AcceleratorConfig

if is_safetensors_available():
Expand Down Expand Up @@ -2468,8 +2467,10 @@ def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)

def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
)

# with plain model
assert_flos_extraction(trainer, trainer.model)
Expand Down
Loading