From a55bee140f57a4dd0f01856a21eaaf8fe2389023 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 12 Aug 2024 12:57:23 +0200 Subject: [PATCH 1/2] Automatically add `transformers` tag to the modelcard --- src/transformers/modelcard.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 60394f569cd8..fd227e585797 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -604,6 +604,9 @@ def from_trainer( elif "generated_from_trainer" not in tags: tags.append("generated_from_trainer") + if "transformers" not in tags: + tags.append("transformers") + _, eval_lines, eval_results = parse_log_history(trainer.state.log_history) hyperparameters = extract_hyperparameters_from_trainer(trainer) From c334c1a5af84974dd1898910629bac912ef97bf4 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 12 Aug 2024 16:02:37 +0200 Subject: [PATCH 2/2] Specify library_name and test --- src/transformers/modelcard.py | 4 +--- tests/utils/test_model_card.py | 7 ++++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index fd227e585797..acabf94d9546 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -454,6 +454,7 @@ def create_metadata(self): metric_mapping = infer_metric_tags_from_eval_results(self.eval_results) metadata = {} + metadata = _insert_value(metadata, "library_name", "transformers") metadata = _insert_values_as_list(metadata, "language", self.language) metadata = _insert_value(metadata, "license", self.license) if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0: @@ -604,9 +605,6 @@ def from_trainer( elif "generated_from_trainer" not in tags: tags.append("generated_from_trainer") - if "transformers" not in tags: - tags.append("transformers") - _, eval_lines, eval_results = parse_log_history(trainer.state.log_history) hyperparameters = extract_hyperparameters_from_trainer(trainer) diff --git a/tests/utils/test_model_card.py b/tests/utils/test_model_card.py index 7d0e8795e0aa..6235bb10ed7b 100644 --- a/tests/utils/test_model_card.py +++ b/tests/utils/test_model_card.py @@ -19,7 +19,7 @@ import tempfile import unittest -from transformers.modelcard import ModelCard +from transformers.modelcard import ModelCard, TrainingSummary class ModelCardTester(unittest.TestCase): @@ -82,3 +82,8 @@ def test_model_card_from_and_save_pretrained(self): model_card_second = ModelCard.from_pretrained(tmpdirname) self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) + + def test_model_summary_modelcard_base_metadata(self): + metadata = TrainingSummary("Model name").create_metadata() + self.assertTrue("library_name" in metadata) + self.assertTrue(metadata["library_name"] == "transformers")