From 3d7c2f9dea45338b7ebcd459b452e2fad7abfa1f Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> Date: Mon, 5 Aug 2024 09:22:48 +0200 Subject: [PATCH] #32184 save total_vocab_size (#32240) * save total_vocab_size = vocab_size + user added tokens to speed up operation * updating length when added_tokens_decoder is set * add test len(tokenizer) --- src/transformers/tokenization_utils.py | 15 ++++++++++++--- tests/tokenization/test_tokenization_utils.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 1853d2de4560..f04eaae4525d 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -480,6 +480,7 @@ def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> Dict self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token self._added_tokens_encoder[str(token)] = index + self._update_total_vocab_size() def get_added_vocab(self) -> Dict[str, int]: """ @@ -494,10 +495,17 @@ def get_added_vocab(self) -> Dict[str, int]: def __len__(self): """ - Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if - there is a hole in the vocab, we will add tokenizers at a wrong index. + Size of the full vocabulary with the added tokens. """ - return len(set(self.get_vocab().keys())) + return self.total_vocab_size + + def _update_total_vocab_size(self): + """ + Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because + otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and + is only updated when adding tokens. + """ + self.total_vocab_size = len(self.get_vocab()) def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: """ @@ -574,6 +582,7 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to logger.info(f"Adding {token} to the vocabulary") self._update_trie() + self._update_total_vocab_size() return added_tokens def _update_trie(self, unique_no_split_tokens: Optional[str] = []): diff --git a/tests/tokenization/test_tokenization_utils.py b/tests/tokenization/test_tokenization_utils.py index 7ff6b29629ea..f97ef6a63022 100644 --- a/tests/tokenization/test_tokenization_utils.py +++ b/tests/tokenization/test_tokenization_utils.py @@ -284,3 +284,15 @@ def test_instantiation_from_tokenizers_json_file(self): with tempfile.TemporaryDirectory() as tmpdirname: bert_tokenizer.save(os.path.join(tmpdirname, "tokenizer.json")) PreTrainedTokenizerFast(tokenizer_file=os.path.join(tmpdirname, "tokenizer.json")) + + def test_len_tokenizer(self): + for tokenizer_class in [BertTokenizer, BertTokenizerFast]: + with self.subTest(f"{tokenizer_class}"): + tokenizer = tokenizer_class.from_pretrained("bert-base-uncased") + added_tokens_size = len(tokenizer.added_tokens_decoder) + self.assertEqual(len(tokenizer), tokenizer.vocab_size) + + tokenizer.add_tokens([""]) + self.assertEqual(len(tokenizer), tokenizer.vocab_size + 1) + self.assertEqual(len(tokenizer.added_tokens_decoder), added_tokens_size + 1) + self.assertEqual(len(tokenizer.added_tokens_encoder), added_tokens_size + 1)