diff --git a/lab_4_fill_words_by_ngrams/main.py b/lab_4_fill_words_by_ngrams/main.py index b739ae182..1fe19ee9c 100644 --- a/lab_4_fill_words_by_ngrams/main.py +++ b/lab_4_fill_words_by_ngrams/main.py @@ -4,6 +4,8 @@ Top-p sampling generation and filling gaps with ngrams """ # pylint:disable=too-few-public-methods, too-many-arguments +from random import choice + from lab_3_generate_by_ngrams.main import (BeamSearchTextGenerator, GreedyTextGenerator, NGramLanguageModel, TextProcessor) @@ -28,6 +30,19 @@ def _tokenize(self, text: str) -> tuple[str, ...]: # type: ignore Raises: ValueError: In case of inappropriate type input argument or if input argument is empty. """ + if not isinstance(text, str) or not text: + raise ValueError('Type input is inappropriate or input argument is empty.') + + tokens = [] + punctuation_signs = '?!.' + for word in text.lower().split(): + cleaned_word = [letter for letter in word if letter.isalpha()] + if not cleaned_word: + continue + tokens.append(''.join(cleaned_word)) + if word[-1] in punctuation_signs: + tokens.append(self._end_of_word_token) + return tuple(tokens) def _put(self, element: str) -> None: """ @@ -39,6 +54,11 @@ def _put(self, element: str) -> None: Raises: ValueError: In case of inappropriate type input argument or if input argument is empty. """ + if not isinstance(element, str) or not element: + raise ValueError('Type input is inappropriate or input argument is empty.') + + if element not in self._storage: + self._storage[element] = len(self._storage) def _postprocess_decoded_text(self, decoded_corpus: tuple[str, ...]) -> str: # type: ignore """ @@ -56,6 +76,16 @@ def _postprocess_decoded_text(self, decoded_corpus: tuple[str, ...]) -> str: # Raises: ValueError: In case of inappropriate type input argument or if input argument is empty. """ + if not isinstance(decoded_corpus, tuple) or not decoded_corpus: + raise ValueError('Type input is inappropriate or input argument is empty.') + + words = " ".join(decoded_corpus) + sentences = words.split(self._end_of_word_token) + resulted_text = ". ".join([sentence.strip().capitalize() for sentence in sentences]) + + if resulted_text[-1] == ' ': + return resulted_text[:-1] + return f"{resulted_text}." class TopPGenerator: @@ -80,6 +110,9 @@ def __init__( word_processor (WordProcessor): WordProcessor instance to handle text processing p_value (float): Collective probability mass threshold """ + self._model = language_model + self._word_processor = word_processor + self._p_value = p_value def run(self, seq_len: int, prompt: str) -> str: # type: ignore """ @@ -98,6 +131,40 @@ def run(self, seq_len: int, prompt: str) -> str: # type: ignore or if sequence has inappropriate length, or if methods used return None. """ + if not (isinstance(seq_len, int) and isinstance(prompt, str) and + seq_len > 0 and prompt): + raise ValueError('Type input is inappropriate or input argument is empty.') + + encoded_prompt = self._word_processor.encode(prompt) + if encoded_prompt is None: + raise ValueError('None is returned') + + encoded_list = list(encoded_prompt) + for i in range(seq_len): + candidates = self._model.generate_next_token(encoded_prompt) + if candidates is None: + raise ValueError('None is returned.') + if not candidates: + break + sorted_candidates = sorted(list(candidates.items()), + key=lambda pair: (pair[1], pair[0]), reverse=True) + sum_freq = 0 + num_candidates = 0 + for _, freq in sorted_candidates: + if sum_freq >= self._p_value: + break + sum_freq += freq + num_candidates += 1 + + random_token = choice(sorted_candidates[:num_candidates])[0] + encoded_list.append(random_token) + encoded_prompt = tuple(encoded_list) + + decoded = self._word_processor.decode(encoded_prompt) + if decoded is None: + raise ValueError('None is returned') + + return decoded class GeneratorTypes: diff --git a/lab_4_fill_words_by_ngrams/start.py b/lab_4_fill_words_by_ngrams/start.py index c41386377..a9d6d93ad 100644 --- a/lab_4_fill_words_by_ngrams/start.py +++ b/lab_4_fill_words_by_ngrams/start.py @@ -2,6 +2,7 @@ Filling word by ngrams starter """ # pylint:disable=too-many-locals,unused-import +from lab_4_fill_words_by_ngrams.main import NGramLanguageModel, TopPGenerator, WordProcessor def main() -> None: @@ -10,7 +11,15 @@ def main() -> None: """ with open("./assets/Harry_Potter.txt", "r", encoding="utf-8") as text_file: text = text_file.read() - result = None + word_processor = WordProcessor("") + encoded_text = word_processor.encode(text) + lang_model = NGramLanguageModel(encoded_text, 2) + lang_model.build() + + top_p_generator = TopPGenerator(lang_model, word_processor, 0.5) + result = top_p_generator.run(51, "Vernon") + print(result) + assert result diff --git a/lab_4_fill_words_by_ngrams/target_score.txt b/lab_4_fill_words_by_ngrams/target_score.txt index 573541ac9..1e8b31496 100644 --- a/lab_4_fill_words_by_ngrams/target_score.txt +++ b/lab_4_fill_words_by_ngrams/target_score.txt @@ -1 +1 @@ -0 +6