From afb02ba36580e47d7bfa0c661ffee1aa302d303f Mon Sep 17 00:00:00 2001 From: Christopher Keibel Date: Sun, 24 Mar 2024 11:39:44 +0100 Subject: [PATCH 1/6] add functions to get number of params which require grad, get optimizer group for parameters and get learning rates of param groups to trainer.py --- src/transformers/trainer.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 276c08788a13..1e69fadbac3c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1048,6 +1048,28 @@ def create_optimizer(self): return self.optimizer + def get_num_trainable_parameters(self): + """ + Get the number of trainable parameters. + """ + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) + + def get_learning_rates(self): + """ + Returns the learning rate of each parameter from self.optimizer. + """ + return [group["lr"] for group in self.optimizer.param_groups] + + def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None): + """ + Returns optimizer group for a parameter if given, else returns all optimizer groups for params. + """ + if param is not None: + for group in self.optimizer.param_groups: + if param in group["params"]: + return group + return [group["params"] for group in self.optimizer.param_groups] + @staticmethod def get_optimizer_cls_and_kwargs( args: TrainingArguments, model: Optional[PreTrainedModel] = None From b0c39dcc7e89d3c105eedcb4f62400c46a82a25c Mon Sep 17 00:00:00 2001 From: Christopher Keibel Date: Mon, 25 Mar 2024 23:43:38 +0100 Subject: [PATCH 2/6] add tests and raise ValueError when optimizer is None --- src/transformers/trainer.py | 8 ++++++++ tests/trainer/test_trainer.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1e69fadbac3c..85655024bf5c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1058,12 +1058,20 @@ def get_learning_rates(self): """ Returns the learning rate of each parameter from self.optimizer. """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") return [group["lr"] for group in self.optimizer.param_groups] def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None): """ Returns optimizer group for a parameter if given, else returns all optimizer groups for params. + + Args: + param (`str` or `torch.nn.parameter.Parameter`, *optional*): + The parameter for which optimizer group needs to be returned. """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") if param is not None: for group in self.optimizer.param_groups: if param in group["params"]: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ebc628146b96..28c1f991a7b3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3769,3 +3769,37 @@ def test_hyperparameter_search_backends(self): list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()), list(HPSearchBackend), ) + + +class OptimizerAndModelInspectionTest(unittest.TestCase): + def test_get_num_trainable_parameters(self): + in_features = 128 + out_features = 64 + # in_features * out_features + bias + expected_num_params = in_features * out_features + out_features + model = nn.Sequential(nn.Linear(in_features, out_features)) + trainer = Trainer(model=model) + self.assertEqual(trainer.get_num_trainable_parameters(), expected_num_params) + + def test_get_learning_rates(self): + model = nn.Sequential(nn.Linear(128, 64)) + trainer = Trainer(model=model) + with self.assertRaises(ValueError): + trainer.get_learning_rates() + trainer.create_optimizer() + self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) + + def test_get_optimizer_group(self): + model = nn.Sequential(nn.Linear(128, 64)) + trainer = Trainer(model=model) + # ValueError is raised if optimizer is None + with self.assertRaises(ValueError): + trainer.get_optimizer_group() + trainer.create_optimizer() + # Get groups + num_groups = len(trainer.get_optimizer_group()) + self.assertEqual(num_groups, 2) + # Get group of parameter + param = next(model.parameters()) + group = trainer.get_optimizer_group(param) + self.assertIn(param, group["params"]) From d7fac7e411024b3f354ade84b89cb044eaddffdc Mon Sep 17 00:00:00 2001 From: Christopher Keibel Date: Tue, 26 Mar 2024 08:23:07 +0100 Subject: [PATCH 3/6] add second layer to test and freeze its weigths --- tests/trainer/test_trainer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 28c1f991a7b3..da60f842c138 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3773,13 +3773,16 @@ def test_hyperparameter_search_backends(self): class OptimizerAndModelInspectionTest(unittest.TestCase): def test_get_num_trainable_parameters(self): - in_features = 128 - out_features = 64 + model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32)) # in_features * out_features + bias - expected_num_params = in_features * out_features + out_features - model = nn.Sequential(nn.Linear(in_features, out_features)) + layer_1 = 128 * 64 + 64 + layer_2 = 64 * 32 + 32 trainer = Trainer(model=model) - self.assertEqual(trainer.get_num_trainable_parameters(), expected_num_params) + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) + # Freeze the last layer + for param in model[-1].parameters(): + param.requires_grad = False + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) def test_get_learning_rates(self): model = nn.Sequential(nn.Linear(128, 64)) From 6b29066c5ab344b17983f09323c1454b854a142f Mon Sep 17 00:00:00 2001 From: Christopher Keibel Date: Tue, 26 Mar 2024 08:55:59 +0100 Subject: [PATCH 4/6] check if torch is available before running tests --- tests/trainer/test_trainer.py | 70 ++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index da60f842c138..757ad61eb5e6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3771,38 +3771,40 @@ def test_hyperparameter_search_backends(self): ) -class OptimizerAndModelInspectionTest(unittest.TestCase): - def test_get_num_trainable_parameters(self): - model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32)) - # in_features * out_features + bias - layer_1 = 128 * 64 + 64 - layer_2 = 64 * 32 + 32 - trainer = Trainer(model=model) - self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) - # Freeze the last layer - for param in model[-1].parameters(): - param.requires_grad = False - self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) - - def test_get_learning_rates(self): - model = nn.Sequential(nn.Linear(128, 64)) - trainer = Trainer(model=model) - with self.assertRaises(ValueError): - trainer.get_learning_rates() - trainer.create_optimizer() - self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) +if is_torch_available(): - def test_get_optimizer_group(self): - model = nn.Sequential(nn.Linear(128, 64)) - trainer = Trainer(model=model) - # ValueError is raised if optimizer is None - with self.assertRaises(ValueError): - trainer.get_optimizer_group() - trainer.create_optimizer() - # Get groups - num_groups = len(trainer.get_optimizer_group()) - self.assertEqual(num_groups, 2) - # Get group of parameter - param = next(model.parameters()) - group = trainer.get_optimizer_group(param) - self.assertIn(param, group["params"]) + class OptimizerAndModelInspectionTest(unittest.TestCase): + def test_get_num_trainable_parameters(self): + model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32)) + # in_features * out_features + bias + layer_1 = 128 * 64 + 64 + layer_2 = 64 * 32 + 32 + trainer = Trainer(model=model) + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) + # Freeze the last layer + for param in model[-1].parameters(): + param.requires_grad = False + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) + + def test_get_learning_rates(self): + model = nn.Sequential(nn.Linear(128, 64)) + trainer = Trainer(model=model) + with self.assertRaises(ValueError): + trainer.get_learning_rates() + trainer.create_optimizer() + self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) + + def test_get_optimizer_group(self): + model = nn.Sequential(nn.Linear(128, 64)) + trainer = Trainer(model=model) + # ValueError is raised if optimizer is None + with self.assertRaises(ValueError): + trainer.get_optimizer_group() + trainer.create_optimizer() + # Get groups + num_groups = len(trainer.get_optimizer_group()) + self.assertEqual(num_groups, 2) + # Get group of parameter + param = next(model.parameters()) + group = trainer.get_optimizer_group(param) + self.assertIn(param, group["params"]) From ab2cb7252b0ea45f51af9e7134279ede5284f4fd Mon Sep 17 00:00:00 2001 From: Christopher Keibel <55911084+CKeibel@users.noreply.github.com> Date: Tue, 26 Mar 2024 16:06:48 +0100 Subject: [PATCH 5/6] use decorator to check if torch is available Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/trainer/test_trainer.py | 43 +++++++++++++++++------------------ 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 757ad61eb5e6..a34a4d8133ec 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3771,28 +3771,27 @@ def test_hyperparameter_search_backends(self): ) -if is_torch_available(): - - class OptimizerAndModelInspectionTest(unittest.TestCase): - def test_get_num_trainable_parameters(self): - model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32)) - # in_features * out_features + bias - layer_1 = 128 * 64 + 64 - layer_2 = 64 * 32 + 32 - trainer = Trainer(model=model) - self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) - # Freeze the last layer - for param in model[-1].parameters(): - param.requires_grad = False - self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) - - def test_get_learning_rates(self): - model = nn.Sequential(nn.Linear(128, 64)) - trainer = Trainer(model=model) - with self.assertRaises(ValueError): - trainer.get_learning_rates() - trainer.create_optimizer() - self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) +@require_torch +class OptimizerAndModelInspectionTest(unittest.TestCase): + def test_get_num_trainable_parameters(self): + model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32)) + # in_features * out_features + bias + layer_1 = 128 * 64 + 64 + layer_2 = 64 * 32 + 32 + trainer = Trainer(model=model) + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) + # Freeze the last layer + for param in model[-1].parameters(): + param.requires_grad = False + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) + + def test_get_learning_rates(self): + model = nn.Sequential(nn.Linear(128, 64)) + trainer = Trainer(model=model) + with self.assertRaises(ValueError): + trainer.get_learning_rates() + trainer.create_optimizer() + self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) def test_get_optimizer_group(self): model = nn.Sequential(nn.Linear(128, 64)) From 383220e8bbf8845ddf5ad58745e575c0224dde4c Mon Sep 17 00:00:00 2001 From: Christopher Keibel <55911084+CKeibel@users.noreply.github.com> Date: Tue, 26 Mar 2024 16:34:01 +0100 Subject: [PATCH 6/6] fix test indentation Co-authored-by: Zach Mueller --- tests/trainer/test_trainer.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a34a4d8133ec..f2f5d0feedac 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3793,17 +3793,17 @@ def test_get_learning_rates(self): trainer.create_optimizer() self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) - def test_get_optimizer_group(self): - model = nn.Sequential(nn.Linear(128, 64)) - trainer = Trainer(model=model) - # ValueError is raised if optimizer is None - with self.assertRaises(ValueError): - trainer.get_optimizer_group() - trainer.create_optimizer() - # Get groups - num_groups = len(trainer.get_optimizer_group()) - self.assertEqual(num_groups, 2) - # Get group of parameter - param = next(model.parameters()) - group = trainer.get_optimizer_group(param) - self.assertIn(param, group["params"]) + def test_get_optimizer_group(self): + model = nn.Sequential(nn.Linear(128, 64)) + trainer = Trainer(model=model) + # ValueError is raised if optimizer is None + with self.assertRaises(ValueError): + trainer.get_optimizer_group() + trainer.create_optimizer() + # Get groups + num_groups = len(trainer.get_optimizer_group()) + self.assertEqual(num_groups, 2) + # Get group of parameter + param = next(model.parameters()) + group = trainer.get_optimizer_group(param) + self.assertIn(param, group["params"])