From 261e1a335b59050ea6910140cb599d6571402fe0 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Tue, 3 Dec 2024 11:40:27 +0000 Subject: [PATCH 1/2] :sparkles: added `run_tokenizer` in `text_generation_local.py` - :art: tox formatting - :construction: added a test to assert the length of the run_tokenizer output - :construction: made a more comprehensive test for the run tokenizer method Signed-off-by: m-misiura --- .../text_generation/text_generation_local.py | 7 ++++- .../test_text_generation_local.py | 30 +++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 290551ba..885028f4 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -55,6 +55,7 @@ TRAINING_LOSS_LOG_FILENAME = "training_logs.jsonl" + # pylint: disable=too-many-lines,too-many-instance-attributes @module( id="f9181353-4ccf-4572-bd1e-f12bcda26792", @@ -590,7 +591,11 @@ def run_tokenizer( TokenizationResults The token count """ - raise NotImplementedError("Tokenization not implemented for local") + error.type_check("", str, text=text) + tokenized_output = self.model.tokenizer(text) + return TokenizationResults( + token_count=len(tokenized_output["input_ids"]), + ) ################################## Private Functions ###################################### diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index f76400f1..8338a163 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -1,5 +1,6 @@ """Tests for text-generation module """ + # Standard import os import platform @@ -10,7 +11,7 @@ import torch # First Party -from caikit.interfaces.nlp.data_model import GeneratedTextResult +from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults import caikit # Local @@ -211,7 +212,26 @@ def test_zero_epoch_case(disable_wip): assert isinstance(model.model, HFAutoSeq2SeqLM) -def test_run_tokenizer_not_implemented(): - with pytest.raises(NotImplementedError): - model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) - model.run_tokenizer("This text doesn't matter") +# ############################## Run Tokenizer ################################ + + +def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device): + """Test tokenizer on edge cases like empty strings and long input.""" + model = TextGeneration.bootstrap(CAUSAL_LM_MODEL) + + # Edge case: Empty string + empty_result = model.run_tokenizer("") + assert isinstance(empty_result, TokenizationResults) + assert empty_result.token_count == 0 + + # Normal case: short sentence + short_text = "This is a test sentence." + short_result = model.run_tokenizer(short_text) + assert isinstance(short_result, TokenizationResults) + assert short_result.token_count > 0 + + # Edge case: Long input + long_text = "This is a test sentence. " * 1000 + long_result = model.run_tokenizer(long_text) + assert isinstance(long_result, TokenizationResults) + assert long_result.token_count > 0 From f629248b689b0c041ea4546e78826f0139c7602d Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 4 Dec 2024 20:16:17 +0000 Subject: [PATCH 2/2] :white_check_mark: based on the PR comments, changed test case to check for an expected number instead of checking if length is non-zero; added `return_attention_mask=True` in the `run_tokenizer` method Signed-off-by: m-misiura --- caikit_nlp/modules/text_generation/text_generation_local.py | 2 +- tests/modules/text_generation/test_text_generation_local.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 885028f4..ba19a585 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -592,7 +592,7 @@ def run_tokenizer( The token count """ error.type_check("", str, text=text) - tokenized_output = self.model.tokenizer(text) + tokenized_output = self.model.tokenizer(text, return_attention_mask=True) return TokenizationResults( token_count=len(tokenized_output["input_ids"]), ) diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index 8338a163..5e91bea0 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -228,10 +228,10 @@ def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device): short_text = "This is a test sentence." short_result = model.run_tokenizer(short_text) assert isinstance(short_result, TokenizationResults) - assert short_result.token_count > 0 + assert short_result.token_count == len(model.model.tokenizer.encode(short_text)) # Edge case: Long input long_text = "This is a test sentence. " * 1000 long_result = model.run_tokenizer(long_text) assert isinstance(long_result, TokenizationResults) - assert long_result.token_count > 0 + assert long_result.token_count == len(model.model.tokenizer.encode(long_text))