Skip to content

Commit

Permalink
[from_pretrained] Make from_pretrained fast again (#27709)
Browse files Browse the repository at this point in the history
* Skip nn.Module.reset_parameters

* Actually skip

* Check quality

* Maybe change all inits

* Fix init issues: only modify public functions

* Add a small test for now

* Style

* test updates

* style

* nice tes

* style

* make it even faster

* one more second

* remove fx icompatible

* Update tests/test_modeling_common.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update tests/test_modeling_common.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* skip

* fix quality

* protect the import

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
  • Loading branch information
ArthurZucker and LysandreJik authored Dec 11, 2023
1 parent 9f18cc6 commit 0676d99
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
35 changes: 34 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def is_local_dist_rank_0():
if is_peft_available():
from .utils import find_adapter_config_file

TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
"trunc_normal_": nn.init.trunc_normal_,
"constant_": nn.init.constant_,
"xavier_uniform_": nn.init.xavier_uniform_,
"xavier_normal_": nn.init.xavier_normal_,
"kaiming_uniform_": nn.init.kaiming_uniform_,
"kaiming_normal_": nn.init.kaiming_normal_,
"uniform": nn.init.uniform,
"normal": nn.init.normal,
"xavier_uniform": nn.init.xavier_uniform,
"xavier_normal": nn.init.xavier_normal,
"kaiming_uniform": nn.init.kaiming_uniform,
"kaiming_normal": nn.init.kaiming_normal,
}


@contextmanager
def no_init_weights(_enable=True):
Expand All @@ -164,12 +181,24 @@ def no_init_weights(_enable=True):
"""
global _init_weights
old_init_weights = _init_weights

if _enable:
_init_weights = False

def _skip_init(*args, **kwargs):
pass

# # Save the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
try:
yield
finally:
_init_weights = old_init_weights
if _enable:
# # Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)


def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
Expand Down Expand Up @@ -1506,7 +1535,10 @@ def get_output_embeddings(self) -> nn.Module:

def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class.
Initialize the weights. This method should be overridden by derived class and is
the only initialization method that will be called when loading a checkpoint
using `from_pretrained`. Any attempt to initialize outside of this function
will be useless as the torch.nn.init function are all replaced with skip.
"""
pass

Expand Down Expand Up @@ -3414,6 +3446,7 @@ def from_pretrained(
)

with ContextManagers(init_contexts):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)

# make sure we use the model's config since the __init__ call might have copied it
Expand Down
55 changes: 54 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
AutoModelForCausalLM,
AutoModelForSequenceClassification,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
logging,
set_seed,
)
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import (
Expand Down Expand Up @@ -85,7 +87,7 @@
is_torch_fx_available,
is_torch_sdpa_available,
)
from transformers.utils.generic import ModelOutput
from transformers.utils.generic import ContextManagers, ModelOutput


if is_accelerate_available():
Expand All @@ -99,6 +101,7 @@
from torch import nn

from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights
from transformers.pytorch_utils import id_tensor_storage


Expand Down Expand Up @@ -428,6 +431,56 @@ class CopyClass(model_class):
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_fast_init_context_manager(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel):
config_class = PretrainedConfig

def __init__(self, config=None):
super().__init__(config if config is not None else PretrainedConfig())
self.linear = nn.Linear(10, 10, bias=True)
self.embedding = nn.Embedding(10, 10)
self.std = 1

def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
if module.bias is not None:
module.bias.data.normal_(mean=0.0, std=self.std)

# 2. Make sure a linear layer's reset params is properly skipped:
with ContextManagers([no_init_weights(True)]):
no_init_instance = MyClass()

set_seed(0)
expected_bias = torch.tensor(
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
)
init_instance = MyClass()
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)

set_seed(0)
torch.testing.assert_allclose(
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
)

# 3. Make sure weights that are not present use init_weight_ and get expected values
with tempfile.TemporaryDirectory() as tmpdirname:
state_dict = init_instance.state_dict()
del state_dict["linear.weight"]

init_instance.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
set_seed(0)
model_fast_init = MyClass.from_pretrained(tmpdirname)

set_seed(0)
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)

for key in model_fast_init.state_dict().keys():
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")

def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
Expand Down

0 comments on commit 0676d99

Please sign in to comment.