From 92f152e3095b8fedee516ab5fb0b822489e5e701 Mon Sep 17 00:00:00 2001 From: Koki Ryu Date: Mon, 28 Oct 2024 06:59:25 +0000 Subject: [PATCH 1/5] Bump version --- .github/workflows/pip_install_matrix.yml | 6 ++---- .readthedocs.yaml | 2 +- docs/installation.md | 2 +- docs/tutorial_langcheckchat.md | 2 +- pyproject.toml | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pip_install_matrix.yml b/.github/workflows/pip_install_matrix.yml index 5468eb60..debf8ca4 100644 --- a/.github/workflows/pip_install_matrix.yml +++ b/.github/workflows/pip_install_matrix.yml @@ -17,17 +17,15 @@ jobs: fail-fast: false # Continue running jobs even if another fails matrix: # We specify Python versions as strings so 3.10 doesn't become 3.1 - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] os: [ubuntu-latest, windows-latest, macos-14] # "en", "de", and "" are equivalent # "all" is tested by pytest.yml language: ["en", "ja", "zh"] exclude: - # GitHub Actions doesn't support Python 3.8 and 3.9 on M1 macOS yet: + # GitHub Actions doesn't support Python 3.9 on M1 macOS yet: # https://github.com/actions/setup-python/issues/696 - - python-version: "3.8" - os: macos-14 - python-version: "3.9" os: macos-14 # TODO: Figure out how to install MeCab on Windows to install diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 9aa05816..5790595d 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,7 +7,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.9" sphinx: configuration: docs/conf.py diff --git a/docs/installation.md b/docs/installation.md index 22e880a8..b72d5400 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -14,7 +14,7 @@ pip install --upgrade pip pip install langcheck[all] ``` -LangCheck works with Python 3.8 or higher. +LangCheck works with Python 3.9 or higher. :::{note} Model files are lazily downloaded the first time you run a metric function. For example, the first time you run the ``langcheck.metrics.sentiment()`` function, LangCheck will automatically download the Twitter-roBERTa-base model. diff --git a/docs/tutorial_langcheckchat.md b/docs/tutorial_langcheckchat.md index ccab624a..fb257cc8 100644 --- a/docs/tutorial_langcheckchat.md +++ b/docs/tutorial_langcheckchat.md @@ -69,7 +69,7 @@ Here’s the response from the LLM: > > pip install langcheck > -> Please note that LangCheck requires Python 3.8 or higher to work properly. +> Please note that LangCheck requires Python 3.9 or higher to work properly. We can also see the sources that were retrieved from the index. By default, the top 2 most relevant source nodes are returned, which is what we see in `response.source_nodes`. diff --git a/pyproject.toml b/pyproject.toml index cf7a48bc..0e8b0e58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ 'tabulate >= 0.9.0', # For model manager print table 'omegaconf >= 2.3.0' # For model manager print table ] -requires-python = ">=3.8" +requires-python = ">=3.9" [project.optional-dependencies] de = [] # No extra dependencies needed for German From c43ededae2b3f7f06f3d0272058da87ff62c17bb Mon Sep 17 00:00:00 2001 From: Koki Ryu Date: Mon, 28 Oct 2024 07:38:25 +0000 Subject: [PATCH 2/5] Fix typing --- src/langcheck/metrics/de/_tokenizers.py | 4 +- .../de/reference_based_text_quality.py | 30 ++-- .../metrics/de/reference_free_text_quality.py | 40 +++-- .../metrics/de/source_based_text_quality.py | 14 +- .../metrics/en/pairwise_text_quality.py | 24 +-- .../en/reference_based_text_quality.py | 38 +++-- .../metrics/en/reference_free_text_quality.py | 44 +++-- .../metrics/en/source_based_text_quality.py | 18 +-- .../metrics/eval_clients/_anthropic.py | 3 +- src/langcheck/metrics/eval_clients/_base.py | 11 +- src/langcheck/metrics/eval_clients/_gemini.py | 3 +- src/langcheck/metrics/eval_clients/_llama.py | 2 +- src/langcheck/metrics/eval_clients/_openai.py | 5 +- .../metrics/eval_clients/_prometheus.py | 2 +- .../metrics/ja/pairwise_text_quality.py | 16 +- .../ja/reference_based_text_quality.py | 48 +++--- .../metrics/ja/reference_free_text_quality.py | 38 +++-- .../metrics/ja/source_based_text_quality.py | 20 +-- src/langcheck/metrics/metric_inputs.py | 5 +- src/langcheck/metrics/metric_value.py | 12 +- .../metrics/model_manager/_model_loader.py | 46 +++--- .../model_manager/_model_management.py | 15 +- src/langcheck/metrics/prompts/_utils.py | 2 +- .../metrics/reference_based_text_quality.py | 8 +- src/langcheck/metrics/scorer/_base.py | 8 +- .../metrics/scorer/detoxify_models.py | 22 ++- src/langcheck/metrics/scorer/hf_models.py | 16 +- src/langcheck/metrics/text_structure.py | 46 +++--- .../zh/reference_based_text_quality.py | 40 +++-- .../metrics/zh/reference_free_text_quality.py | 30 ++-- .../metrics/zh/source_based_text_quality.py | 10 +- src/langcheck/plot/_scatter.py | 7 +- src/langcheck/plot/_utils.py | 32 ++-- src/langcheck/utils/progress_bar.py | 15 +- tests/augment/en/test_change_case.py | 128 +++++++++++---- tests/augment/en/test_gender.py | 72 ++++++--- tests/augment/en/test_keyboard_typo.py | 13 +- tests/augment/en/test_ocr_typo.py | 13 +- tests/augment/en/test_remove_punctuation.py | 31 ++-- tests/augment/en/test_to_full_width.py | 5 +- tests/augment/ja/test_conv_kana.py | 5 +- tests/metrics/de/test_tokenizers.py | 58 +++++-- tests/metrics/de/test_translation.py | 58 ++++--- .../ja/test_reference_based_text_quality.py | 12 +- tests/metrics/ja/test_tokenizers.py | 35 ++-- tests/metrics/test_metric_value.py | 4 +- .../zh/test_reference_based_text_quality.py | 153 ++++++++++++------ tests/metrics/zh/test_tokenizers.py | 36 +++-- tests/utils.py | 27 ++-- 49 files changed, 769 insertions(+), 555 deletions(-) diff --git a/src/langcheck/metrics/de/_tokenizers.py b/src/langcheck/metrics/de/_tokenizers.py index cee6d96a..1dbbb675 100644 --- a/src/langcheck/metrics/de/_tokenizers.py +++ b/src/langcheck/metrics/de/_tokenizers.py @@ -1,5 +1,3 @@ -from typing import List - from nltk.stem.cistem import Cistem from nltk.tokenize import word_tokenize from rouge_score.tokenizers import Tokenizer as BaseTokenizer @@ -16,7 +14,7 @@ def __init__(self, stemmer=False): if stemmer: self.stemmer = Cistem() - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: if self.stemmer: # use only the stem part of the word text, _ = self.stemmer.segment(text) diff --git a/src/langcheck/metrics/de/reference_based_text_quality.py b/src/langcheck/metrics/de/reference_based_text_quality.py index a27225db..51f0c104 100644 --- a/src/langcheck/metrics/de/reference_based_text_quality.py +++ b/src/langcheck/metrics/de/reference_based_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from rouge_score import rouge_scorer from langcheck.metrics.de._tokenizers import DeTokenizer @@ -19,9 +17,9 @@ def semantic_similarity( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", ) -> MetricValue[float]: """Calculates the semantic similarities between the generated outputs and @@ -85,9 +83,9 @@ def semantic_similarity( def rouge1( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-1 scores between the generated outputs and the reference outputs. It evaluates the overlap of unigrams @@ -127,9 +125,9 @@ def rouge1( def rouge2( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-2 scores between the generated outputs and the reference outputs. It evaluates the overlap of bigrams @@ -169,9 +167,9 @@ def rouge2( def rougeL( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-L scores between the generated outputs and the reference outputs. It evaluates the longest common @@ -221,8 +219,8 @@ def rougeL( def _rouge( - generated_outputs: List[str], reference_outputs: List[str], rouge_type: str -) -> List[float]: + generated_outputs: list[str], reference_outputs: list[str], rouge_type: str +) -> list[float]: """Helper function for computing the rouge1, rouge2, and rougeL metrics. This uses Google Research's implementation of ROUGE: https://github.com/google-research/google-research/tree/master/rouge diff --git a/src/langcheck/metrics/de/reference_free_text_quality.py b/src/langcheck/metrics/de/reference_free_text_quality.py index 070e7db8..30a60afb 100644 --- a/src/langcheck/metrics/de/reference_free_text_quality.py +++ b/src/langcheck/metrics/de/reference_free_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from langcheck.metrics.de._translation import Translate from langcheck.metrics.de.reference_based_text_quality import ( semantic_similarity, @@ -30,11 +28,11 @@ def sentiment( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the sentiment scores of generated outputs. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive sentiment. (NOTE: when using an EvalClient, the sentiment scores @@ -112,8 +110,8 @@ def sentiment( def _sentiment_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the sentiment scores of generated outputs using the twitter-xlm-roberta-base-sentiment-finetunned model. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive @@ -142,10 +140,10 @@ def _sentiment_local( def fluency( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the fluency scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low fluency and 1 is high fluency. (NOTE: when using an EvalClient, the fluency scores are either 0.0 @@ -220,11 +218,11 @@ def fluency( def toxicity( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the toxicity scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. (NOTE: when using an EvalClient, the toxicity scores are in steps of @@ -301,8 +299,8 @@ def toxicity( def _toxicity_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the toxicity scores of generated outputs using the Detoxify model. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. @@ -324,8 +322,8 @@ def _toxicity_local( def flesch_kincaid_grade( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the readability of generated outputs using the Flesch-Kincaid. It is the same as in English (but higher): @@ -338,8 +336,8 @@ def flesch_kincaid_grade( def flesch_reading_ease( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the readability of generated outputs using the Flesch Reading Ease Score. This metric takes on float values between (-∞, 121.22], but @@ -387,8 +385,8 @@ def flesch_reading_ease( def ai_disclaimer_similarity( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ai_disclaimer_phrase: str = ( "Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein." ), diff --git a/src/langcheck/metrics/de/source_based_text_quality.py b/src/langcheck/metrics/de/source_based_text_quality.py index 91ef5407..56be205b 100644 --- a/src/langcheck/metrics/de/source_based_text_quality.py +++ b/src/langcheck/metrics/de/source_based_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from langcheck.metrics.de._translation import Translate from langcheck.metrics.en.source_based_text_quality import ( factual_consistency as en_factual_consistency, @@ -20,11 +18,11 @@ def factual_consistency( - generated_outputs: List[str] | str, - sources: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + sources: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the factual consistency between the generated outputs and the sources. This metric takes on float values between [0, 1], where 0 means that the output is not at all consistent with the source text, and 1 @@ -123,8 +121,8 @@ def factual_consistency( def context_relevance( - sources: List[str] | str, prompts: List[str] | str, eval_model: EvalClient -) -> MetricValue[Optional[float]]: + sources: list[str] | str, prompts: list[str] | str, eval_model: EvalClient +) -> MetricValue[float | None]: """Calculates the relevance of the sources to the prompts. This metric takes on float values between [0, 1], where 0 means that the source text is not at all relevant to the prompt, and 1 means that the source text is fully diff --git a/src/langcheck/metrics/en/pairwise_text_quality.py b/src/langcheck/metrics/en/pairwise_text_quality.py index 23263a38..07482507 100644 --- a/src/langcheck/metrics/en/pairwise_text_quality.py +++ b/src/langcheck/metrics/en/pairwise_text_quality.py @@ -2,7 +2,7 @@ import math import random -from typing import List, Optional, cast +from typing import cast from langcheck.metrics._pairwise_text_quality_utils import ( compute_pairwise_comparison_metric_values_with_consistency, @@ -16,13 +16,13 @@ def simulated_annotators( - prompt_params: List[dict[str, str | None]], + prompt_params: list[dict[str, str | None]], eval_model: EvalClient, preference_data_path: str = "en/confidence_estimating/preference_data_examples.jsonl", k: int = 5, n: int = 5, seed: int | None = None, -) -> List[float | None]: +) -> list[float | None]: """Compute a confidence score for the pairwise comparison metric based on the method Simulated Annotators proposed in the paper "Trust or Escalate: LLM Judges with Provable Guarantees for Human Agreement" @@ -73,7 +73,7 @@ def simulated_annotators( prompts.append(prompt_template.render(prompt_param)) # Get the response and top five logprobs of the first token - responses: List[Optional[TextResponseWithLogProbs]] = ( + responses: list[TextResponseWithLogProbs | None] = ( eval_model.get_text_responses_with_log_likelihood( prompts, top_logprobs=5 ) @@ -83,7 +83,7 @@ def simulated_annotators( if response: response = cast(TextResponseWithLogProbs, response) top_five_first_token_logprobs = cast( - List[TokenLogProb], response["response_logprobs"][0] + list[TokenLogProb], response["response_logprobs"][0] ) # Extract logprobs for tokens 'A' and 'B' logprobs_dict = { @@ -110,12 +110,12 @@ def simulated_annotators( def pairwise_comparison( - generated_outputs_a: List[str] | str, - generated_outputs_b: List[str] | str, - prompts: List[str] | str, - sources_a: Optional[List[str] | str] = None, - sources_b: Optional[List[str] | str] = None, - reference_outputs: Optional[List[str] | str] = None, + generated_outputs_a: list[str] | str, + generated_outputs_b: list[str] | str, + prompts: list[str] | str, + sources_a: list[str] | str | None = None, + sources_b: list[str] | str | None = None, + reference_outputs: list[str] | str | None = None, enforce_consistency: bool = True, calculated_confidence: bool = False, preference_data_path: str = "en/confidence_estimating/preference_data_examples.jsonl", @@ -123,7 +123,7 @@ def pairwise_comparison( n: int = 5, seed: int | None = None, eval_model: EvalClient | None = None, -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the pairwise comparison metric. This metric takes on float values of either 0.0 (Response A is better), 0.5 (Tie), or 1.0 (Response B is better). The score may also be `None` if it could not be computed. diff --git a/src/langcheck/metrics/en/reference_based_text_quality.py b/src/langcheck/metrics/en/reference_based_text_quality.py index 578f62cc..cc0e9680 100644 --- a/src/langcheck/metrics/en/reference_based_text_quality.py +++ b/src/langcheck/metrics/en/reference_based_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from rouge_score import rouge_scorer from langcheck.metrics.eval_clients import EvalClient @@ -19,11 +17,11 @@ def answer_correctness( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: List[str] | str, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str, eval_model: EvalClient, -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the correctness of the generated outputs. This metric takes on float values of either 0.0 (Incorrect), 0.5 (Partially Correct), or 1.0 (Correct). The score may also be `None` if it could not be computed. @@ -61,9 +59,9 @@ def answer_correctness( def semantic_similarity( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", ) -> MetricValue[float]: """Calculates the semantic similarities between the generated outputs and @@ -126,9 +124,9 @@ def semantic_similarity( def rouge1( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-1 scores between the generated outputs and the reference outputs. It evaluates the overlap of unigrams @@ -168,9 +166,9 @@ def rouge1( def rouge2( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-2 scores between the generated outputs and the reference outputs. It evaluates the overlap of bigrams @@ -210,9 +208,9 @@ def rouge2( def rougeL( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-L scores between the generated outputs and the reference outputs. It evaluates the longest common @@ -262,8 +260,8 @@ def rougeL( def _rouge( - generated_outputs: List[str], reference_outputs: List[str], rouge_type: str -) -> List[float]: + generated_outputs: list[str], reference_outputs: list[str], rouge_type: str +) -> list[float]: """Helper function for computing the rouge1, rouge2, and rougeL metrics. This uses Google Research's implementation of ROUGE: https://github.com/google-research/google-research/tree/master/rouge diff --git a/src/langcheck/metrics/en/reference_free_text_quality.py b/src/langcheck/metrics/en/reference_free_text_quality.py index 43636091..efbcef86 100644 --- a/src/langcheck/metrics/en/reference_free_text_quality.py +++ b/src/langcheck/metrics/en/reference_free_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from langcheck.metrics.en.reference_based_text_quality import ( semantic_similarity, ) @@ -22,11 +20,11 @@ def sentiment( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the sentiment scores of generated outputs. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive sentiment. (NOTE: when using an EvalClient, the sentiment scores @@ -101,8 +99,8 @@ def sentiment( def _sentiment_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the sentiment scores of generated outputs using the Twitter-roBERTa-base model. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive sentiment. @@ -131,11 +129,11 @@ def _sentiment_local( def fluency( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the fluency scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low fluency and 1 is high fluency. (NOTE: when using an EvalClient, the fluency scores are either 0.0 @@ -210,8 +208,8 @@ def fluency( def _fluency_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the fluency scores of generated outputs using the Parrot fluency model. This metric takes on float values between [0, 1], where 0 is low fluency and 1 is high fluency. @@ -238,12 +236,12 @@ def _fluency_local( def toxicity( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", eval_prompt_version: str = "v2", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the toxicity scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. (NOTE: when using an EvalClient, the toxicity scores are either 0.0 @@ -335,8 +333,8 @@ def toxicity( def _toxicity_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the toxicity scores of generated outputs using the Detoxify model. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. @@ -356,8 +354,8 @@ def _toxicity_local( def flesch_reading_ease( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the readability of generated outputs using the Flesch Reading Ease Score. This metric takes on float values between (-∞, 121.22], but @@ -403,8 +401,8 @@ def flesch_reading_ease( def flesch_kincaid_grade( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the readability of generated outputs using the Flesch-Kincaid Grade Level metric. This metric takes on float values between [-3.40, ∞), @@ -451,8 +449,8 @@ def flesch_kincaid_grade( def ai_disclaimer_similarity( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ai_disclaimer_phrase: str = ( "I don't have personal opinions, emotions, or consciousness." ), diff --git a/src/langcheck/metrics/en/source_based_text_quality.py b/src/langcheck/metrics/en/source_based_text_quality.py index 7f87f798..efe171b3 100644 --- a/src/langcheck/metrics/en/source_based_text_quality.py +++ b/src/langcheck/metrics/en/source_based_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - import nltk import torch import torch.nn as nn @@ -26,11 +24,11 @@ def factual_consistency( - generated_outputs: List[str] | str, - sources: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + sources: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the factual consistency between the generated outputs and the sources. This metric takes on float values between [0, 1], where 0 means that the output is not at all consistent with the source text, and 1 @@ -104,8 +102,8 @@ def factual_consistency( def _factual_consistency_local( - generated_outputs: List[str], sources: List[str] -) -> List[float]: + generated_outputs: list[str], sources: list[str] +) -> list[float]: """Calculates the factual consistency between each generated sentence and its corresponding source text. The factual consistency score for one generated output is computed as the average of the per-sentence @@ -226,8 +224,8 @@ def _factual_consistency_local( def context_relevance( - sources: List[str] | str, prompts: List[str] | str, eval_model: EvalClient -) -> MetricValue[Optional[float]]: + sources: list[str] | str, prompts: list[str] | str, eval_model: EvalClient +) -> MetricValue[float | None]: """Calculates the relevance of the sources to the prompts. This metric takes on float values between [0, 1], where 0 means that the source text is not at all relevant to the prompt, and 1 means that the source text is fully diff --git a/src/langcheck/metrics/eval_clients/_anthropic.py b/src/langcheck/metrics/eval_clients/_anthropic.py index dd28fd9a..e8989740 100644 --- a/src/langcheck/metrics/eval_clients/_anthropic.py +++ b/src/langcheck/metrics/eval_clients/_anthropic.py @@ -1,7 +1,8 @@ from __future__ import annotations import asyncio -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any from anthropic import Anthropic, AsyncAnthropic diff --git a/src/langcheck/metrics/eval_clients/_base.py b/src/langcheck/metrics/eval_clients/_base.py index e2075dae..9e1ac635 100644 --- a/src/langcheck/metrics/eval_clients/_base.py +++ b/src/langcheck/metrics/eval_clients/_base.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Dict, Iterable, List, Optional, Union +from collections.abc import Iterable +from typing import Union from jinja2 import Template @@ -10,9 +11,9 @@ from ..prompts._utils import get_template from ..scorer._base import BaseSimilarityScorer -TokenLogProb = Dict[str, Union[str, float]] -TopKLogProbs = List[List[TokenLogProb]] -TextResponseWithLogProbs = Dict[str, Union[str, List[TopKLogProbs]]] +TokenLogProb = dict[str, Union[str, float]] +TopKLogProbs = list[list[TokenLogProb]] +TextResponseWithLogProbs = dict[str, Union[str, list[TopKLogProbs]]] class EvalClient: @@ -71,7 +72,7 @@ def get_text_responses_with_log_likelihood( top_logprobs: int | None = None, *, tqdm_description: str | None = None, - ) -> List[Optional[TextResponseWithLogProbs]]: + ) -> list[TextResponseWithLogProbs | None]: """The function that gets responses with log likelihood to the given prompt texts. Each concrete subclass needs to define the concrete implementation of this function to enable text scoring. diff --git a/src/langcheck/metrics/eval_clients/_gemini.py b/src/langcheck/metrics/eval_clients/_gemini.py index e066c086..f5196b07 100644 --- a/src/langcheck/metrics/eval_clients/_gemini.py +++ b/src/langcheck/metrics/eval_clients/_gemini.py @@ -1,7 +1,8 @@ from __future__ import annotations import os -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any import google.ai.generativelanguage as glm import google.generativeai as genai diff --git a/src/langcheck/metrics/eval_clients/_llama.py b/src/langcheck/metrics/eval_clients/_llama.py index c26afa2e..a4e2b1d9 100644 --- a/src/langcheck/metrics/eval_clients/_llama.py +++ b/src/langcheck/metrics/eval_clients/_llama.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Iterable +from collections.abc import Iterable from transformers import AutoTokenizer from vllm import LLM, SamplingParams diff --git a/src/langcheck/metrics/eval_clients/_openai.py b/src/langcheck/metrics/eval_clients/_openai.py index 2a044ddd..50c75c03 100644 --- a/src/langcheck/metrics/eval_clients/_openai.py +++ b/src/langcheck/metrics/eval_clients/_openai.py @@ -3,7 +3,8 @@ import asyncio import json import os -from typing import Any, Iterable, List, Optional +from collections.abc import Iterable +from typing import Any import torch from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI @@ -143,7 +144,7 @@ def get_text_responses_with_log_likelihood( top_logprobs: int | None = None, *, tqdm_description: str | None = None, - ) -> List[Optional[TextResponseWithLogProbs]]: + ) -> list[TextResponseWithLogProbs | None]: """The function that gets responses with log likelihood to the given prompt texts. Each concrete subclass needs to define the concrete implementation of this function to enable text scoring. diff --git a/src/langcheck/metrics/eval_clients/_prometheus.py b/src/langcheck/metrics/eval_clients/_prometheus.py index 62691975..c51ee9bb 100644 --- a/src/langcheck/metrics/eval_clients/_prometheus.py +++ b/src/langcheck/metrics/eval_clients/_prometheus.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Iterable +from collections.abc import Iterable from jinja2 import Template from transformers import AutoTokenizer diff --git a/src/langcheck/metrics/ja/pairwise_text_quality.py b/src/langcheck/metrics/ja/pairwise_text_quality.py index 22d6d1e2..cf6ab7e8 100644 --- a/src/langcheck/metrics/ja/pairwise_text_quality.py +++ b/src/langcheck/metrics/ja/pairwise_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from langcheck.metrics._pairwise_text_quality_utils import ( compute_pairwise_comparison_metric_values_with_consistency, ) @@ -11,15 +9,15 @@ def pairwise_comparison( - generated_outputs_a: List[str] | str, - generated_outputs_b: List[str] | str, - prompts: List[str] | str, - sources_a: Optional[List[str] | str] = None, - sources_b: Optional[List[str] | str] = None, - reference_outputs: Optional[List[str] | str] = None, + generated_outputs_a: list[str] | str, + generated_outputs_b: list[str] | str, + prompts: list[str] | str, + sources_a: list[str] | str | None = None, + sources_b: list[str] | str | None = None, + reference_outputs: list[str] | str | None = None, enforce_consistency: bool = True, eval_model: EvalClient | None = None, -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the pairwise comparison metric. This metric takes on float values of either 0.0 (Response A is better), 0.5 (Tie), or 1.0 (Response B is better). The score may also be `None` if it could not be computed. diff --git a/src/langcheck/metrics/ja/reference_based_text_quality.py b/src/langcheck/metrics/ja/reference_based_text_quality.py index 4b8f7419..a8027875 100644 --- a/src/langcheck/metrics/ja/reference_based_text_quality.py +++ b/src/langcheck/metrics/ja/reference_based_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from rouge_score import rouge_scorer from rouge_score.tokenizers import Tokenizer @@ -21,11 +19,11 @@ def answer_correctness( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: List[str] | str, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str, eval_model: EvalClient, -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the correctness of the generated outputs. This metric takes on float values of either 0.0 (Incorrect), 0.5 (Partially Correct), or 1.0 (Correct). The score may also be `None` if it could not be computed. @@ -63,9 +61,9 @@ def answer_correctness( def semantic_similarity( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", ) -> MetricValue[float]: """Calculates the semantic similarities between the generated outputs and @@ -130,11 +128,11 @@ def semantic_similarity( def rouge1( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, *, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Tokenizer | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-1 scores between the generated (single tokens) between the generated outputs and the reference outputs. @@ -176,11 +174,11 @@ def rouge1( def rouge2( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, *, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Tokenizer | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-2 scores between the generated outputs and the reference outputs. It evaluates the overlap of bigrams @@ -222,11 +220,11 @@ def rouge2( def rougeL( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, *, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Tokenizer | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-L scores between the generated outputs and the reference outputs. It evaluates the longest common @@ -278,12 +276,12 @@ def rougeL( def _rouge( - generated_outputs: List[str], - reference_outputs: List[str], + generated_outputs: list[str], + reference_outputs: list[str], rouge_type: str, *, - tokenizer: Optional[Tokenizer] = None, -) -> List[float]: + tokenizer: Tokenizer | None = None, +) -> list[float]: """Helper function for computing the rouge1, rouge2, and rougeL metrics. This uses Google Research's implementation of ROUGE: https://github.com/google-research/google-research/tree/master/rouge diff --git a/src/langcheck/metrics/ja/reference_free_text_quality.py b/src/langcheck/metrics/ja/reference_free_text_quality.py index c8b0efb3..ec6d6565 100644 --- a/src/langcheck/metrics/ja/reference_free_text_quality.py +++ b/src/langcheck/metrics/ja/reference_free_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - import regex as re from langcheck.metrics.eval_clients import EvalClient @@ -19,11 +17,11 @@ def sentiment( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the sentiment scores of generated outputs. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive sentiment. (NOTE: when using an EvalClient, the sentiment scores @@ -101,8 +99,8 @@ def sentiment( def _sentiment_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the sentiment scores of generated outputs using the Twitter-roBERTa-base-sentiment-multilingual model. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive @@ -132,12 +130,12 @@ def _sentiment_local( def toxicity( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", eval_prompt_version: str = "v2", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the toxicity scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. (NOTE: when using an EvalClient, the toxicity scores are either 0.0 @@ -235,8 +233,8 @@ def toxicity( def _toxicity_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the toxicity scores of generated outputs using a fine-tuned model from `line-corporation/line-distilbert-base-japanese`. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high @@ -265,11 +263,11 @@ def _toxicity_local( def fluency( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", local_overflow_strategy: str = "truncate", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the fluency scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low fluency and 1 is high fluency. (NOTE: when using an EvalClient, the fluency scores are either 0.0 @@ -350,8 +348,8 @@ def fluency( def _fluency_local( - generated_outputs: List[str], overflow_strategy: str -) -> List[Optional[float]]: + generated_outputs: list[str], overflow_strategy: str +) -> list[float | None]: """Calculates the fluency scores of generated outputs using a fine-tuned model from `line-corporation/line-distilbert-base-japanese`. This metric takes on float values between [0, 1], where 0 is low fluency and 1 is high @@ -380,8 +378,8 @@ def _fluency_local( def tateishi_ono_yamada_reading_ease( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the readability of generated Japanese outputs using the reading ease score introduced in "日本文の読みやすさの評価式 (A Computer @@ -423,7 +421,7 @@ def tateishi_ono_yamada_reading_ease( delimiters_re = r"[、|。|!|?|!|?|「|」|,|,|.|.|…|『|』]" # Aux function to compute the average length of strings in the list - def _mean_str_length(ls: List[str]) -> float: + def _mean_str_length(ls: list[str]) -> float: if len(ls) == 0: return 0 lens = [len(el) for el in ls] diff --git a/src/langcheck/metrics/ja/source_based_text_quality.py b/src/langcheck/metrics/ja/source_based_text_quality.py index b9e9ee93..708ac66a 100644 --- a/src/langcheck/metrics/ja/source_based_text_quality.py +++ b/src/langcheck/metrics/ja/source_based_text_quality.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, cast +from typing import cast from transformers.pipelines import pipeline from transformers.pipelines.base import Pipeline @@ -23,11 +23,11 @@ def factual_consistency( - generated_outputs: List[str] | str, - sources: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + sources: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the factual consistency between the generated outputs and the sources. This metric takes on float values between [0, 1], where 0 means that the output is not at all consistent with the source text, and 1 @@ -106,8 +106,8 @@ def factual_consistency( def _factual_consistency_local( - generated_outputs: List[str], sources: List[str] -) -> List[float]: + generated_outputs: list[str], sources: list[str] +) -> list[float]: """Calculates the factual consistency between each generated sentence and its corresponding source text. The factual consistency score for one generated output is computed as the average of the per-sentence @@ -177,13 +177,13 @@ def _factual_consistency_local( generated_outputs=en_generated_outputs, sources=en_source ).metric_values - # Local factual consistency scores are of type List[float] + # Local factual consistency scores are of type list[float] return factual_consistency_scores # type: ignore def context_relevance( - sources: List[str] | str, prompts: List[str] | str, eval_model: EvalClient -) -> MetricValue[Optional[float]]: + sources: list[str] | str, prompts: list[str] | str, eval_model: EvalClient +) -> MetricValue[float | None]: """Calculates the relevance of the sources to the prompts. This metric takes on float values between [0, 1], where 0 means that the source text is not at all relevant to the prompt, and 1 means that the source text is fully diff --git a/src/langcheck/metrics/metric_inputs.py b/src/langcheck/metrics/metric_inputs.py index c7e0d8ac..98b0d335 100644 --- a/src/langcheck/metrics/metric_inputs.py +++ b/src/langcheck/metrics/metric_inputs.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import List, Union +from typing import Union import pandas as pd from jinja2 import Environment, meta -IndividualInputType = Union[str, List[str], None] +# You need "Union" to declare a type in Python < 3.10 +IndividualInputType = Union[str, list[str], None] def _map_pairwise_input_to_list( diff --git a/src/langcheck/metrics/metric_value.py b/src/langcheck/metrics/metric_value.py index 4bf45bdd..1e41dfd4 100644 --- a/src/langcheck/metrics/metric_value.py +++ b/src/langcheck/metrics/metric_value.py @@ -4,7 +4,7 @@ import warnings from dataclasses import dataclass, fields from statistics import mean -from typing import Generic, List, Optional, TypeVar +from typing import Generic, TypeVar import pandas as pd @@ -12,7 +12,7 @@ # Metrics take on float or integer values # Some metrics may return `None` values when the score fails to be computed -NumericType = TypeVar("NumericType", float, int, Optional[float], Optional[int]) +NumericType = TypeVar("NumericType", float, int, float | None, int | None) @dataclass @@ -20,14 +20,14 @@ class MetricValue(Generic[NumericType]): """A rich object that is the output of any langcheck.metrics function.""" metric_name: str - metric_values: List[NumericType] + metric_values: list[NumericType] # Input of the metrics such as prompts, generated outputs... etc metric_inputs: MetricInputs # An explanation can be None if the metric could not be computed - explanations: Optional[List[Optional[str]]] - language: Optional[str] + explanations: list[str | None] | None + language: str | None def to_df(self) -> pd.DataFrame: """Returns a DataFrame of metric values for each data point.""" @@ -235,7 +235,7 @@ def pass_rate(self) -> float: return self._pass_rate @property - def threshold_results(self) -> List[bool]: + def threshold_results(self) -> list[bool]: """Returns a list of booleans indicating whether each data point passes the threshold. """ diff --git a/src/langcheck/metrics/model_manager/_model_loader.py b/src/langcheck/metrics/model_manager/_model_loader.py index fcf75dd4..c8b8e1b0 100644 --- a/src/langcheck/metrics/model_manager/_model_loader.py +++ b/src/langcheck/metrics/model_manager/_model_loader.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from __future__ import annotations from sentence_transformers import SentenceTransformer from transformers.models.auto.modeling_auto import ( @@ -11,10 +11,11 @@ def load_sentence_transformers( - model_name: str, - model_revision: Optional[str] = None, - tokenizer_name: Optional[str] = None, - tokenizer_revision: Optional[str] = None) -> SentenceTransformer: + model_name: str, + model_revision: str | None = None, + tokenizer_name: str | None = None, + tokenizer_revision: str | None = None, +) -> SentenceTransformer: """ Loads a SentenceTransformer model. @@ -44,10 +45,10 @@ def load_sentence_transformers( def load_auto_model_for_text_classification( model_name: str, - model_revision: Optional[str] = None, - tokenizer_name: Optional[str] = None, - tokenizer_revision: Optional[str] = None -) -> Tuple[AutoTokenizer, AutoModelForSequenceClassification]: + model_revision: str | None = None, + tokenizer_name: str | None = None, + tokenizer_revision: str | None = None, +) -> tuple[AutoTokenizer, AutoModelForSequenceClassification]: """ Loads a sequence classification model and its tokenizer. @@ -67,20 +68,21 @@ def load_auto_model_for_text_classification( # There are "Some weights are not used warning" for some models, but we # ignore it because that is intended. with _handle_logging_level(): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, - trust_remote_code=True, - revision=tokenizer_revision) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=True, revision=tokenizer_revision + ) model = AutoModelForSequenceClassification.from_pretrained( - model_name, revision=model_revision) + model_name, revision=model_revision + ) return tokenizer, model # type: ignore def load_auto_model_for_seq2seq( model_name: str, - model_revision: Optional[str] = None, - tokenizer_name: Optional[str] = None, - tokenizer_revision: Optional[str] = None -) -> Tuple[AutoTokenizer, AutoModelForSeq2SeqLM]: + model_revision: str | None = None, + tokenizer_name: str | None = None, + tokenizer_revision: str | None = None, +) -> tuple[AutoTokenizer, AutoModelForSeq2SeqLM]: """ Loads a sequence-to-sequence model and its tokenizer. @@ -97,11 +99,13 @@ def load_auto_model_for_seq2seq( """ if tokenizer_name is None: tokenizer_name = model_name - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, - revision=tokenizer_revision) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, revision=tokenizer_revision + ) # There are "Some weights are not used warning" for some models, but we # ignore it because that is intended. with _handle_logging_level(): - model = AutoModelForSeq2SeqLM.from_pretrained(model_name, - revision=model_revision) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_name, revision=model_revision + ) return tokenizer, model # type: ignore diff --git a/src/langcheck/metrics/model_manager/_model_management.py b/src/langcheck/metrics/model_manager/_model_management.py index 5995b164..08809890 100644 --- a/src/langcheck/metrics/model_manager/_model_management.py +++ b/src/langcheck/metrics/model_manager/_model_management.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import os from copy import deepcopy from functools import lru_cache -from typing import Optional, Tuple, Union import pandas as pd import requests @@ -36,7 +37,7 @@ VALID_LANGUAGE = ["zh", "en", "ja", "de"] -def check_model_availability(model_name: str, revision: Optional[str]) -> bool: +def check_model_availability(model_name: str, revision: str | None) -> bool: # TODO: add local cached model availability check for offline environment if revision is None or revision == "": url = f"https://huggingface.co/api/models/{model_name}" @@ -88,11 +89,11 @@ def __load_config(self, path: str) -> None: @lru_cache def fetch_model( self, language: str, metric: str - ) -> Union[ - Tuple[AutoTokenizer, AutoModelForSequenceClassification], - Tuple[AutoTokenizer, AutoModelForSeq2SeqLM], - SentenceTransformer, - ]: + ) -> ( + tuple[AutoTokenizer | AutoModelForSequenceClassification] + | tuple[AutoTokenizer | AutoModelForSeq2SeqLM] + | SentenceTransformer + ): """ Return the model (and if applicable, the tokenizer) used for the given metric and language. diff --git a/src/langcheck/metrics/prompts/_utils.py b/src/langcheck/metrics/prompts/_utils.py index c263eff8..3a323b3c 100644 --- a/src/langcheck/metrics/prompts/_utils.py +++ b/src/langcheck/metrics/prompts/_utils.py @@ -28,7 +28,7 @@ def load_few_shot_examples(relative_path: str) -> list[dict[str, str]]: relative_path (str): The relative path of the JSONL file. Returns: - List[str]: The few-shot examples. + list[str]: The few-shot examples. """ cwd = Path(__file__).parent with open(cwd / relative_path) as f: diff --git a/src/langcheck/metrics/reference_based_text_quality.py b/src/langcheck/metrics/reference_based_text_quality.py index 88cb5276..1c6c640e 100644 --- a/src/langcheck/metrics/reference_based_text_quality.py +++ b/src/langcheck/metrics/reference_based_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from langcheck.metrics.metric_inputs import ( get_metric_inputs_with_required_lists, ) @@ -12,9 +10,9 @@ def exact_match( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if the generated outputs exact matches with the reference outputs. This metric takes on binary 0 or 1 values. diff --git a/src/langcheck/metrics/scorer/_base.py b/src/langcheck/metrics/scorer/_base.py index a1f22482..67ccc47f 100644 --- a/src/langcheck/metrics/scorer/_base.py +++ b/src/langcheck/metrics/scorer/_base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar import torch from sentence_transformers import util @@ -27,7 +27,7 @@ def _tokenize(self, inputs: list[str]) -> _TokensType: """ raise NotImplementedError - def _score_tokens(self, tokens: _TokensType) -> list[Optional[float]]: + def _score_tokens(self, tokens: _TokensType) -> list[float | None]: """Score the tokens. The returned list should have the same length as the tokens. Each element in the list should be the score of the token. """ @@ -42,14 +42,14 @@ def _slice_tokens( """ raise NotImplementedError - def score(self, inputs: list[str]) -> list[Optional[float]]: + def score(self, inputs: list[str]) -> list[float | None]: """Score the inputs. Basically subclasses should not override this.""" tokens = self._tokenize(inputs) input_length = len(inputs) - scores: list[Optional[float]] = [] + scores: list[float | None] = [] for i in tqdm_wrapper( range(0, input_length, self.batch_size), total=(input_length + self.batch_size - 1) // self.batch_size, diff --git a/src/langcheck/metrics/scorer/detoxify_models.py b/src/langcheck/metrics/scorer/detoxify_models.py index e20292f1..49ae199e 100644 --- a/src/langcheck/metrics/scorer/detoxify_models.py +++ b/src/langcheck/metrics/scorer/detoxify_models.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional, Tuple, Union - import torch from transformers import ( BatchEncoding, @@ -28,10 +26,10 @@ def load_checkpoint( device: str, lang: str -) -> Tuple[ - Union[BertForSequenceClassification, XLMRobertaForSequenceClassification], - Union[BertTokenizer, XLMRobertaTokenizer], - List[str], +) -> tuple[ + BertForSequenceClassification | XLMRobertaForSequenceClassification, + BertTokenizer | XLMRobertaTokenizer, + list[str], ]: checkpoint_url = _checkpoints[lang] class_model_type, tokenizer_type = _model_types[lang] @@ -72,7 +70,7 @@ def __init__( device: str = "cpu", lang: str = "en", overflow_strategy: str = "truncate", - max_input_length: Optional[int] = None, + max_input_length: int | None = None, ): """ Initialize the scorer with the provided configs. @@ -96,7 +94,7 @@ def __init__( max_input_length or self.tokenizer.model_max_length ) - def _tokenize(self, inputs: list[str]) -> Tuple[BatchEncoding, list[bool]]: + def _tokenize(self, inputs: list[str]) -> tuple[BatchEncoding, list[bool]]: """Tokenize the inputs. It also does the validation on the token length, and return the results as a list of boolean values. If the validation mode is 'raise', it raises an error when the token length is invalid. @@ -143,10 +141,10 @@ def _validate_inputs(self, inputs: list[str]) -> list[bool]: def _slice_tokens( self, - tokens: Tuple[BatchEncoding, list[bool]], + tokens: tuple[BatchEncoding, list[bool]], start_idx: int, end_idx: int, - ) -> Tuple[BatchEncoding, list[bool]]: + ) -> tuple[BatchEncoding, list[bool]]: input_tokens, validation_results = tokens return ( @@ -158,8 +156,8 @@ def _slice_tokens( ) def _score_tokens( - self, tokens: Tuple[BatchEncoding, list[bool]] - ) -> list[Optional[float]]: + self, tokens: tuple[BatchEncoding, list[bool]] + ) -> list[float | None]: input_tokens, validation_results = tokens out = self.model(**input_tokens)[0] scores = torch.sigmoid(out).cpu().detach().numpy() diff --git a/src/langcheck/metrics/scorer/hf_models.py b/src/langcheck/metrics/scorer/hf_models.py index 20cfb1f2..599fcdd3 100644 --- a/src/langcheck/metrics/scorer/hf_models.py +++ b/src/langcheck/metrics/scorer/hf_models.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional, Tuple - import torch from transformers import BatchEncoding @@ -19,7 +17,7 @@ def __init__( metric, class_weights, overflow_strategy: str = "truncate", - max_input_length: Optional[int] = None, + max_input_length: int | None = None, ): """ Initialize the scorer with the provided configs. @@ -49,7 +47,7 @@ def __init__( else: self.max_input_length = self.model.config.max_position_embeddings # type: ignore - def _tokenize(self, inputs: list[str]) -> Tuple[BatchEncoding, list[bool]]: + def _tokenize(self, inputs: list[str]) -> tuple[BatchEncoding, list[bool]]: """Tokenize the inputs. It also does the validation on the token length, and return the results as a list of boolean values. If the validation mode is 'raise', it raises an error when the token length is invalid. @@ -95,13 +93,13 @@ def _validate_inputs(self, inputs: list[str]) -> list[bool]: return validation_results def _score_tokens( - self, tokens: Tuple[BatchEncoding, list[bool]] - ) -> list[Optional[float]]: + self, tokens: tuple[BatchEncoding, list[bool]] + ) -> list[float | None]: """Return the prediction results as scores.""" input_tokens, validation_results = tokens with torch.no_grad(): logits: torch.Tensor = self.model(**input_tokens).logits # type: ignore - scores: list[Optional[float]] = self._logits_to_scores(logits) # type: ignore + scores: list[float | None] = self._logits_to_scores(logits) # type: ignore for i, validation_result in enumerate(validation_results): if not validation_result: @@ -111,10 +109,10 @@ def _score_tokens( def _slice_tokens( self, - tokens: Tuple[BatchEncoding, list[bool]], + tokens: tuple[BatchEncoding, list[bool]], start_idx: int, end_idx: int, - ) -> Tuple[BatchEncoding, list[bool]]: + ) -> tuple[BatchEncoding, list[bool]]: input_tokens, validation_results = tokens return ( diff --git a/src/langcheck/metrics/text_structure.py b/src/langcheck/metrics/text_structure.py index 3315490d..a659efc7 100644 --- a/src/langcheck/metrics/text_structure.py +++ b/src/langcheck/metrics/text_structure.py @@ -2,7 +2,7 @@ import json import re -from typing import Callable, Container, Iterable, List, Optional +from collections.abc import Callable, Container, Iterable from langcheck.metrics.metric_inputs import ( get_metric_inputs_with_required_lists, @@ -12,9 +12,9 @@ def is_int( - generated_outputs: List[str] | str, + generated_outputs: list[str] | str, domain: Iterable[int] | Container[int] | None = None, - prompts: Optional[List[str] | str] = None, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs can be parsed as integers, optionally within a domain of integers like `range(1, 11)` or `{1, 3, 5}`. This metric takes @@ -57,10 +57,10 @@ def is_int( def is_float( - generated_outputs: List[str] | str, - min: Optional[float] = None, - max: Optional[float] = None, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + min: float | None = None, + max: float | None = None, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs can be parsed as floating point numbers, optionally within a min/max range. This metric takes on binary 0 or 1 @@ -109,8 +109,8 @@ def is_float( def is_json_object( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs can be parsed as JSON objects. This metric takes on binary 0 or 1 values. @@ -151,8 +151,8 @@ def is_json_object( def is_json_array( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs can be parsed as JSON arrays. This metric takes on binary 0 or 1 values. @@ -193,9 +193,9 @@ def is_json_array( def matches_regex( - generated_outputs: List[str] | str, + generated_outputs: list[str] | str, regex: str, - prompts: Optional[List[str] | str] = None, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs fully match a given regular expression. This metric takes on binary 0 or 1 values. @@ -233,9 +233,9 @@ def matches_regex( def contains_regex( - generated_outputs: List[str] | str, + generated_outputs: list[str] | str, regex: str, - prompts: Optional[List[str] | str] = None, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs partially contain a given regular expression. This metric takes on binary 0 or 1 values. @@ -273,10 +273,10 @@ def contains_regex( def contains_all_strings( - generated_outputs: List[str] | str, - strings: List[str], + generated_outputs: list[str] | str, + strings: list[str], case_sensitive: bool = False, - prompts: Optional[List[str] | str] = None, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs contain all strings in of a given list. This metric takes on binary 0 or 1 values. @@ -323,10 +323,10 @@ def contains_all_strings( def contains_any_strings( - generated_outputs: List[str] | str, - strings: List[str], + generated_outputs: list[str] | str, + strings: list[str], case_sensitive: bool = False, - prompts: Optional[List[str] | str] = None, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs contain any strings in a given list. This metric takes on binary 0 or 1 values. @@ -374,9 +374,9 @@ def contains_any_strings( def validation_fn( - generated_outputs: List[str] | str, + generated_outputs: list[str] | str, valid_fn: Callable[[str], bool], - prompts: Optional[List[str] | str] = None, + prompts: list[str] | str | None = None, ) -> MetricValue[int]: """Checks if generated outputs are valid according to an arbitrary function. This metric takes on binary 0 or 1 values. diff --git a/src/langcheck/metrics/zh/reference_based_text_quality.py b/src/langcheck/metrics/zh/reference_based_text_quality.py index 2498af09..dbae15b9 100644 --- a/src/langcheck/metrics/zh/reference_based_text_quality.py +++ b/src/langcheck/metrics/zh/reference_based_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - from rouge_score import rouge_scorer from rouge_score.tokenizers import Tokenizer @@ -19,9 +17,9 @@ def semantic_similarity( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", ) -> MetricValue[float]: """ @@ -91,11 +89,11 @@ def semantic_similarity( def rouge1( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, *, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Tokenizer | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-1 scores between the generated (single tokens) between the generated outputs and the reference outputs. @@ -136,11 +134,11 @@ def rouge1( def rouge2( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, *, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Tokenizer | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-2 scores between the generated outputs and the reference outputs. It evaluates the overlap of bigrams @@ -182,11 +180,11 @@ def rouge2( def rougeL( - generated_outputs: List[str] | str, - reference_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + reference_outputs: list[str] | str, + prompts: list[str] | str | None = None, *, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Tokenizer | None = None, ) -> MetricValue[float]: """Calculates the F1 metrics of the ROUGE-L scores between the generated outputs and the reference outputs. It evaluates the longest common @@ -238,12 +236,12 @@ def rougeL( def _rouge( - generated_outputs: List[str], - reference_outputs: List[str], + generated_outputs: list[str], + reference_outputs: list[str], rouge_type: str, *, - tokenizer: Optional[Tokenizer] = None, -) -> List[float]: + tokenizer: Tokenizer | None = None, +) -> list[float]: """Helper function for computing the rouge1, rouge2, and rougeL metrics. This uses Google Research's implementation of ROUGE: https://github.com/google-research/google-research/tree/master/rouge diff --git a/src/langcheck/metrics/zh/reference_free_text_quality.py b/src/langcheck/metrics/zh/reference_free_text_quality.py index e74ee6d3..ed0bd504 100644 --- a/src/langcheck/metrics/zh/reference_free_text_quality.py +++ b/src/langcheck/metrics/zh/reference_free_text_quality.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional - import hanlp from transformers.pipelines import pipeline @@ -21,10 +19,10 @@ def sentiment( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the sentiment scores of generated outputs. This metric takes on float values between [0, 1], where 0 is negative sentiment and 1 is positive sentiment. (NOTE: when using an EvalClient, the sentiment scores @@ -100,11 +98,11 @@ def sentiment( def toxicity( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", eval_prompt_version: str = "v2", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the toxicity scores of generated outputs. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high toxicity. (NOTE: when using an EvalClient, the toxicity scores are in steps of @@ -167,7 +165,7 @@ def toxicity( ) -def _toxicity_local(generated_outputs: List[str]) -> List[float]: +def _toxicity_local(generated_outputs: list[str]) -> list[float]: """Calculates the toxicity scores of generated outputs using a fine-tuned model from `alibaba-pai/pai-bert-base-zh-llm-risk-detection`. This metric takes on float values between [0, 1], where 0 is low toxicity and 1 is high @@ -183,7 +181,7 @@ def _toxicity_local(generated_outputs: List[str]) -> List[float]: A list of scores """ # this pipeline output predict probability for each text on each label. - # the output format is List[List[Dict(str)]] + # the output format is list[list[dict(str)]] from langcheck.metrics.model_manager import manager tokenizer, model = manager.fetch_model(language="zh", metric="toxicity") @@ -210,8 +208,8 @@ def _toxicity_local(generated_outputs: List[str]) -> List[float]: def xuyaochen_report_readability( - generated_outputs: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + prompts: list[str] | str | None = None, ) -> MetricValue[float]: """Calculates the readability scores of generated outputs introduced in "中文年报可读性"(Chinese annual report readability). This metric calculates @@ -262,27 +260,27 @@ def xuyaochen_report_readability( # List[List[List[POS]]] output_pos = list(map(pos_pipeline, generated_outputs)) - def count_tokens(sent_tokens: List[str]) -> int: + def count_tokens(sent_tokens: list[str]) -> int: count = sum([ not hanlp.utils.string_util.ispunct(token) for token in # type: ignore[reportGeneralTypeIssues] sent_tokens ]) return count - def count_postags(sent_poses: List[str]) -> int: + def count_postags(sent_poses: list[str]) -> int: # AD: adverb, CC: coordinating conjunction, # CS: subordinating conjunction count = sum([pos in ["AD", "CC", "CS"] for pos in sent_poses]) return count - def calc_r1(content: List[List[str]]) -> float: + def calc_r1(content: list[list[str]]) -> float: token_count_by_sentence = list(map(count_tokens, content)) if len(token_count_by_sentence) == 0: return 0 else: return sum(token_count_by_sentence) / len(token_count_by_sentence) - def calc_r2(content: List[List[str]]) -> float: + def calc_r2(content: list[list[str]]) -> float: pos_count_by_sentence = list(map(count_postags, content)) if len(pos_count_by_sentence) == 0: return 0 diff --git a/src/langcheck/metrics/zh/source_based_text_quality.py b/src/langcheck/metrics/zh/source_based_text_quality.py index edb84beb..27ae8ba4 100644 --- a/src/langcheck/metrics/zh/source_based_text_quality.py +++ b/src/langcheck/metrics/zh/source_based_text_quality.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, cast +from typing import cast from transformers.pipelines import pipeline @@ -17,11 +17,11 @@ def factual_consistency( - generated_outputs: List[str] | str, - sources: List[str] | str, - prompts: Optional[List[str] | str] = None, + generated_outputs: list[str] | str, + sources: list[str] | str, + prompts: list[str] | str | None = None, eval_model: str | EvalClient = "local", -) -> MetricValue[Optional[float]]: +) -> MetricValue[float | None]: """Calculates the factual consistency between the generated outputs and the sources. This metric takes on float values between [0, 1], where 0 means that the output is not at all consistent with the source text, and 1 diff --git a/src/langcheck/plot/_scatter.py b/src/langcheck/plot/_scatter.py index eb37e8ad..5721325f 100644 --- a/src/langcheck/plot/_scatter.py +++ b/src/langcheck/plot/_scatter.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import math import textwrap from copy import deepcopy -from typing import Optional, Union import plotly.express as px from dash import Dash, Input, Output, dcc, html @@ -14,7 +15,7 @@ def scatter( metric_value: MetricValue, - other_metric_value: Optional[MetricValue] = None, + other_metric_value: MetricValue | None = None, jupyter_mode: str = "inline", ) -> None: """Shows an interactive scatter plot of all data points in an @@ -422,7 +423,7 @@ def update_figure( # Unfortunately it's not possible to make "index" show up at the top of # the tooltip like _scatter_one_metric_value() since Plotly always # displays the x and y values at the top.) - hover_data: dict[str, Union[bool, Index]] = { + hover_data: dict[str, bool | Index] = { col: True for col in filtered_df.columns } hover_data["index"] = filtered_df.index diff --git a/src/langcheck/plot/_utils.py b/src/langcheck/plot/_utils.py index 726bcd05..1ad7db31 100644 --- a/src/langcheck/plot/_utils.py +++ b/src/langcheck/plot/_utils.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from enum import Enum -from typing import Union from plotly.graph_objects import Figure @@ -9,8 +10,9 @@ class Axis(Enum): horizontal = 1 -def _plot_threshold(fig: Figure, threshold_op: str, - threshold: Union[float, int], direction: Axis): +def _plot_threshold( + fig: Figure, threshold_op: str, threshold: float | int, direction: Axis +): """Draw a dashed line on the target figure at the specified threshold value along either the horizontal or vertical axis. @@ -23,15 +25,19 @@ def _plot_threshold(fig: Figure, threshold_op: str, """ threshold_text = f"{threshold_op} {threshold}" if direction == Axis.horizontal: # Draw a horizontal line - fig.add_hline(y=threshold, - line_width=3, - line_dash="dash", - annotation_text=threshold_text, - annotation_position="right") + fig.add_hline( + y=threshold, + line_width=3, + line_dash="dash", + annotation_text=threshold_text, + annotation_position="right", + ) elif direction == Axis.vertical: # Draw a vertical line - fig.add_vline(x=threshold, - line_width=3, - line_dash="dash", - annotation_text=threshold_text, - annotation_position="top") + fig.add_vline( + x=threshold, + line_width=3, + line_dash="dash", + annotation_text=threshold_text, + annotation_position="top", + ) return diff --git a/src/langcheck/utils/progress_bar.py b/src/langcheck/utils/progress_bar.py index bba07e15..a56a1995 100644 --- a/src/langcheck/utils/progress_bar.py +++ b/src/langcheck/utils/progress_bar.py @@ -1,12 +1,17 @@ -from typing import Any, Iterable, Optional +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any from tqdm import tqdm -def tqdm_wrapper(iterable: Iterable[Any], - desc: Optional[str] = None, - total: Optional[int] = None, - unit: str = "it"): +def tqdm_wrapper( + iterable: Iterable[Any], + desc: str | None = None, + total: int | None = None, + unit: str = "it", +): """ Wrapper for tqdm to make it optional """ diff --git a/tests/augment/en/test_change_case.py b/tests/augment/en/test_change_case.py index 127c3ce6..cac143cc 100644 --- a/tests/augment/en/test_change_case.py +++ b/tests/augment/en/test_change_case.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from typing import List import pytest @@ -15,51 +14,118 @@ # To uppercase, single input ######################################################################## ("Hello, world!", 1, 0.9, "uppercase", ["HELLO, WORLD!"]), - ("Hello, world!", 2, 0.9, "uppercase", - ["HELLO, WORLD!", "HELLO, WORLd!"]), + ( + "Hello, world!", + 2, + 0.9, + "uppercase", + ["HELLO, WORLD!", "HELLO, WORLd!"], + ), (["Hello, world!"], 1, 0.9, "uppercase", ["HELLO, WORLD!"]), - (["Hello, world!" - ], 2, 0.9, "uppercase", ["HELLO, WORLD!", "HELLO, WORLd!"]), + ( + ["Hello, world!"], + 2, + 0.9, + "uppercase", + ["HELLO, WORLD!", "HELLO, WORLd!"], + ), ("Hello, world!", 1, 0.1, "uppercase", ["HEllo, WoRld!"]), - ("Hello, world!", 2, 0.1, "uppercase", - ["HEllo, WoRld!", "Hello, world!"]), + ( + "Hello, world!", + 2, + 0.1, + "uppercase", + ["HEllo, WoRld!", "Hello, world!"], + ), (["Hello, world!"], 1, 0.1, "uppercase", ["HEllo, WoRld!"]), - (["Hello, world!" - ], 2, 0.1, "uppercase", ["HEllo, WoRld!", "Hello, world!"]), + ( + ["Hello, world!"], + 2, + 0.1, + "uppercase", + ["HEllo, WoRld!", "Hello, world!"], + ), ######################################################################## # To lowercase, single input ######################################################################## ("HELLO, world!", 1, 0.9, "lowercase", ["hello, world!"]), - ("HELLO, world!", 2, 0.9, "lowercase", - ["hello, world!", "hello, world!"]), + ( + "HELLO, world!", + 2, + 0.9, + "lowercase", + ["hello, world!", "hello, world!"], + ), (["HELLO, world!"], 1, 0.9, "lowercase", ["hello, world!"]), - (["HELLO, world!" - ], 2, 0.9, "lowercase", ["hello, world!", "hello, world!"]), + ( + ["HELLO, world!"], + 2, + 0.9, + "lowercase", + ["hello, world!", "hello, world!"], + ), ("HELLO, world!", 1, 0.1, "lowercase", ["HeLLO, world!"]), - ("HELLO, world!", 2, 0.1, "lowercase", - ["HeLLO, world!", "HELLO, world!"]), + ( + "HELLO, world!", + 2, + 0.1, + "lowercase", + ["HeLLO, world!", "HELLO, world!"], + ), (["HELLO, world!"], 1, 0.1, "lowercase", ["HeLLO, world!"]), - (["HELLO, world!" - ], 2, 0.1, "lowercase", ["HeLLO, world!", "HELLO, world!"]), + ( + ["HELLO, world!"], + 2, + 0.1, + "lowercase", + ["HeLLO, world!", "HELLO, world!"], + ), ######################################################################## # Multiple inputs ######################################################################## - (["HELLO, world!", "I'm hungry" - ], 1, 0.9, "lowercase", ["hello, world!", "i'm hungry"]), - (["HELLO, world!", "I'm hungry"], 2, 0.9, "lowercase", - ["hello, world!", "hello, world!", "i'm hungry", "i'm hungry"]), - (["HELLO, world!", "I'm hungry" - ], 1, 0.1, "uppercase", ["HELLO, WoRld!", "I'm huNgry"]), - (["HELLO, world!", "I'm hungry"], 2, 0.1, "uppercase", - ["HELLO, WoRld!", "HELLO, world!", "I'm hungry", "I'm hUngRy"]) + ( + ["HELLO, world!", "I'm hungry"], + 1, + 0.9, + "lowercase", + ["hello, world!", "i'm hungry"], + ), + ( + ["HELLO, world!", "I'm hungry"], + 2, + 0.9, + "lowercase", + ["hello, world!", "hello, world!", "i'm hungry", "i'm hungry"], + ), + ( + ["HELLO, world!", "I'm hungry"], + 1, + 0.1, + "uppercase", + ["HELLO, WoRld!", "I'm huNgry"], + ), + ( + ["HELLO, world!", "I'm hungry"], + 2, + 0.1, + "uppercase", + ["HELLO, WoRld!", "HELLO, world!", "I'm hungry", "I'm hUngRy"], + ), ], ) -def test_change_case(instances: List[str] | str, num_perturbations: int, - aug_char_p: float, to_case: str, expected: List[str]): +def test_change_case( + instances: list[str] | str, + num_perturbations: int, + aug_char_p: float, + to_case: str, + expected: list[str], +): seed = 42 random.seed(seed) - actual = change_case(instances, - to_case=to_case, - aug_char_p=aug_char_p, - num_perturbations=num_perturbations) + actual = change_case( + instances, + to_case=to_case, + aug_char_p=aug_char_p, + num_perturbations=num_perturbations, + ) assert actual == expected diff --git a/tests/augment/en/test_gender.py b/tests/augment/en/test_gender.py index 70f8d21a..29e7dec0 100644 --- a/tests/augment/en/test_gender.py +++ b/tests/augment/en/test_gender.py @@ -1,5 +1,4 @@ import random -from typing import List, Optional import pytest @@ -21,33 +20,58 @@ def test_invalid_input(): [ (["He cooks by himself.", "This is his dog.", "I gave him a book."]), (["She cooks by herself.", "This is her dog.", "I gave her a book."]), - ([ - "They cooks by themselves.", "This is their dog.", - "I gave them a book." - ]), + ( + [ + "They cooks by themselves.", + "This is their dog.", + "I gave them a book.", + ] + ), + ], +) +@pytest.mark.parametrize( + "to_gender, expected", + [ + ( + None, + [ + "They cooks by themselves.", + "This is their dog.", + "I gave them a book.", + ], + ), + ( + "female", + ["She cooks by herself.", "This is her dog.", "I gave her a book."], + ), + ( + "male", + ["He cooks by himself.", "This is his dog.", "I gave him a book."], + ), + ( + "neutral", + ["Xe cooks by xyrself.", "This is xyr dog.", "I gave xem a book."], + ), + ( + "plural", + [ + "They cooks by themselves.", + "This is their dog.", + "I gave them a book.", + ], + ), ], ) -@pytest.mark.parametrize("to_gender, expected", [ - (None, [ - "They cooks by themselves.", "This is their dog.", "I gave them a book." - ]), - ("female", - ["She cooks by herself.", "This is her dog.", "I gave her a book."]), - ("male", ["He cooks by himself.", "This is his dog.", "I gave him a book." - ]), - ("neutral", - ["Xe cooks by xyrself.", "This is xyr dog.", "I gave xem a book."]), - ("plural", [ - "They cooks by themselves.", "This is their dog.", "I gave them a book." - ]), -]) def test_gender( - texts: List[str], - to_gender: Optional[str], - expected: List[str], + texts: list[str], + to_gender: str | None, + expected: list[str], ): seed = 42 random.seed(seed) - actual = gender(texts) if to_gender is None else gender(texts, - to_gender=to_gender) + actual = ( + gender(texts) + if to_gender is None + else gender(texts, to_gender=to_gender) + ) assert actual == expected diff --git a/tests/augment/en/test_keyboard_typo.py b/tests/augment/en/test_keyboard_typo.py index b740dd3b..8a9cfcfa 100644 --- a/tests/augment/en/test_keyboard_typo.py +++ b/tests/augment/en/test_keyboard_typo.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from typing import List import pytest @@ -16,12 +15,16 @@ (["Hello, world!"], 1, ["HePlo, wLrld!"]), (["Hello, world!"], 2, ["HePlo, wLrld!", "Helll, Aorld!"]), (["Hello, world!", "I'm hungry"], 1, ["HePlo, wLrld!", "I ' m hungrt"]), - (["Hello, world!", "I'm hungry"], 2, - ["HePlo, wLrld!", "Helll, Aorld!", "I ' m hKngry", "I ' m hungGy"]), + ( + ["Hello, world!", "I'm hungry"], + 2, + ["HePlo, wLrld!", "Helll, Aorld!", "I ' m hKngry", "I ' m hungGy"], + ), ], ) -def test_keyboard_typo(instances: List[str] | str, num_perturbations: int, - expected: List[str]): +def test_keyboard_typo( + instances: list[str] | str, num_perturbations: int, expected: list[str] +): seed = 42 random.seed(seed) actual = keyboard_typo(instances, num_perturbations=num_perturbations) diff --git a/tests/augment/en/test_ocr_typo.py b/tests/augment/en/test_ocr_typo.py index 80676eaf..bc0d876a 100644 --- a/tests/augment/en/test_ocr_typo.py +++ b/tests/augment/en/test_ocr_typo.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from typing import List import pytest @@ -16,12 +15,16 @@ (["Hello, world!"], 1, ["Hel1u, world!"]), (["Hello, world!"], 2, ["Hel1u, world!", "Hello, w0r1d!"]), (["Hello, world!", "I'm hungry"], 1, ["Hel1u, world!", "I ' m hungry"]), - (["Hello, world!", "I'm hungry"], 2, - ["Hel1u, world!", "Hello, w0r1d!", "1 ' m hongky", "I ' m hun9ky"]), + ( + ["Hello, world!", "I'm hungry"], + 2, + ["Hel1u, world!", "Hello, w0r1d!", "1 ' m hongky", "I ' m hun9ky"], + ), ], ) -def test_ocr_typo(instances: List[str] | str, num_perturbations: int, - expected: List[str]): +def test_ocr_typo( + instances: list[str] | str, num_perturbations: int, expected: list[str] +): seed = 42 random.seed(seed) actual = ocr_typo(instances, num_perturbations=num_perturbations) diff --git a/tests/augment/en/test_remove_punctuation.py b/tests/augment/en/test_remove_punctuation.py index da2e217f..a1680c57 100644 --- a/tests/augment/en/test_remove_punctuation.py +++ b/tests/augment/en/test_remove_punctuation.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from typing import List import pytest @@ -15,17 +14,29 @@ ("Hello, world...!?", 2, 0.5, ["Hello, world!?", "Hello, world?"]), (["Hello, world...!?"], 1, 0.5, ["Hello, world!?"]), (["Hello, world...!?"], 2, 0.5, ["Hello, world!?", "Hello, world?"]), - (["Hello, world...!?", "!@#$%^&*()_+,./" - ], 1, 0.5, ["Hello, world!?", "!^()+,/"]), - (["Hello, world...!?", "!@#$%^&*()_+,./"], 2, 0.5, - ["Hello, world!?", "Hello, world?", "#$^&(),", "@#$%^&()_+,."]), + ( + ["Hello, world...!?", "!@#$%^&*()_+,./"], + 1, + 0.5, + ["Hello, world!?", "!^()+,/"], + ), + ( + ["Hello, world...!?", "!@#$%^&*()_+,./"], + 2, + 0.5, + ["Hello, world!?", "Hello, world?", "#$^&(),", "@#$%^&()_+,."], + ), ], ) -def test_remove_punctuation(instances: List[str] | str, num_perturbations: int, - aug_char_p: float, expected: List[str]): +def test_remove_punctuation( + instances: list[str] | str, + num_perturbations: int, + aug_char_p: float, + expected: list[str], +): seed = 42 random.seed(seed) - actual = remove_punctuation(instances, - aug_char_p=aug_char_p, - num_perturbations=num_perturbations) + actual = remove_punctuation( + instances, aug_char_p=aug_char_p, num_perturbations=num_perturbations + ) assert actual == expected diff --git a/tests/augment/en/test_to_full_width.py b/tests/augment/en/test_to_full_width.py index 2de533f7..c81032a8 100644 --- a/tests/augment/en/test_to_full_width.py +++ b/tests/augment/en/test_to_full_width.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from typing import List import pytest @@ -82,10 +81,10 @@ ], ) def test_to_ful_width( - instances: List[str] | str, + instances: list[str] | str, num_perturbations: int, aug_char_p: float, - expected: List[str], + expected: list[str], ): seed = 42 random.seed(seed) diff --git a/tests/augment/ja/test_conv_kana.py b/tests/augment/ja/test_conv_kana.py index 8da87bd4..50eabf64 100644 --- a/tests/augment/ja/test_conv_kana.py +++ b/tests/augment/ja/test_conv_kana.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from typing import List import pytest @@ -226,11 +225,11 @@ ], ) def test_change_case( - instances: List[str] | str, + instances: list[str] | str, num_perturbations: int, aug_char_p: float, convert_to: str, - expected: List[str], + expected: list[str], ): seed = 42 random.seed(seed) diff --git a/tests/metrics/de/test_tokenizers.py b/tests/metrics/de/test_tokenizers.py index ef6d5e1c..0762c910 100644 --- a/tests/metrics/de/test_tokenizers.py +++ b/tests/metrics/de/test_tokenizers.py @@ -1,24 +1,48 @@ -from typing import List - import pytest from langcheck.metrics.de import DeTokenizer -@pytest.mark.parametrize("text,expected_tokens", [ - ([ - "Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein.", - [ - "Ich", "habe", "keine", "persönlichen", "Meinungen", ",", - "Emotionen", "oder", "Bewusstsein", "." - ] - ]), - ("Mein Freund. Willkommen in den Karpaten. Ich erwarte dich sehnsüchtig.\n", - [ - "Mein", "Freund", ".", "Willkommen", "in", "den", "Karpaten", ".", - "Ich", "erwarte", "dich", "sehnsüchtig", "." - ]), -]) -def test_de_tokenizer(text: str, expected_tokens: List[str]) -> None: +@pytest.mark.parametrize( + "text,expected_tokens", + [ + ( + [ + "Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein.", + [ + "Ich", + "habe", + "keine", + "persönlichen", + "Meinungen", + ",", + "Emotionen", + "oder", + "Bewusstsein", + ".", + ], + ] + ), + ( + "Mein Freund. Willkommen in den Karpaten. Ich erwarte dich sehnsüchtig.\n", + [ + "Mein", + "Freund", + ".", + "Willkommen", + "in", + "den", + "Karpaten", + ".", + "Ich", + "erwarte", + "dich", + "sehnsüchtig", + ".", + ], + ), + ], +) +def test_de_tokenizer(text: str, expected_tokens: list[str]) -> None: tokenizer = DeTokenizer() # type: ignore[reportGeneralTypeIssues] assert tokenizer.tokenize(text) == expected_tokens diff --git a/tests/metrics/de/test_translation.py b/tests/metrics/de/test_translation.py index 948dd698..2d285809 100644 --- a/tests/metrics/de/test_translation.py +++ b/tests/metrics/de/test_translation.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from langcheck.metrics.de import Translate @@ -8,31 +6,45 @@ @pytest.mark.parametrize( "de_text,en_text", [ - ([ - "Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein.", # noqa: E501 - "I have no personal opinions, emotions or consciousness." - ]), - ([ - "Mein Freund. Willkommen in den Karpaten.", - "My friend, welcome to the Carpathians." - ]), - ([ - "Tokio ist die Hauptstadt von Japan.", - "Tokyo is the capital of Japan." - ]), - ]) + ( + [ + "Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein.", # noqa: E501 + "I have no personal opinions, emotions or consciousness.", + ] + ), + ( + [ + "Mein Freund. Willkommen in den Karpaten.", + "My friend, welcome to the Carpathians.", + ] + ), + ( + [ + "Tokio ist die Hauptstadt von Japan.", + "Tokyo is the capital of Japan.", + ] + ), + ], +) def test_translate_de_en(de_text: str, en_text: str) -> None: translation = Translate("Helsinki-NLP/opus-mt-de-en") assert translation(de_text) == en_text -@pytest.mark.parametrize("en_text,de_text", [ - ("I have no personal opinions, emotions or consciousness.", - "Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein."), - ("My Friend. Welcome to the Carpathians. I am anxiously expecting you.", - "Willkommen bei den Karpaten, ich erwarte Sie."), - ("Tokyo is the capital of Japan.", "Tokio ist die Hauptstadt Japans."), -]) -def test_translate_en_de(en_text: str, de_text: List[str]) -> None: +@pytest.mark.parametrize( + "en_text,de_text", + [ + ( + "I have no personal opinions, emotions or consciousness.", + "Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein.", + ), + ( + "My Friend. Welcome to the Carpathians. I am anxiously expecting you.", + "Willkommen bei den Karpaten, ich erwarte Sie.", + ), + ("Tokyo is the capital of Japan.", "Tokio ist die Hauptstadt Japans."), + ], +) +def test_translate_en_de(en_text: str, de_text: list[str]) -> None: translation = Translate("Helsinki-NLP/opus-mt-en-de") assert translation(en_text) == de_text diff --git a/tests/metrics/ja/test_reference_based_text_quality.py b/tests/metrics/ja/test_reference_based_text_quality.py index f1b5df84..ce781bad 100644 --- a/tests/metrics/ja/test_reference_based_text_quality.py +++ b/tests/metrics/ja/test_reference_based_text_quality.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import os -from typing import Callable, Optional +from collections.abc import Callable from unittest.mock import Mock, patch import pytest @@ -95,9 +97,9 @@ def test_rouge_identical( generated_outputs: str, reference_outputs: str, rouge_function: Callable[ - [str, str, Optional[_JapaneseTokenizer]], MetricValue[float] + [str, str, _JapaneseTokenizer | None], MetricValue[float] ], - tokenizer: Optional[_JapaneseTokenizer], + tokenizer: _JapaneseTokenizer | None, ) -> None: # All ROUGE scores are 1 if the generated and reference outputs are # identical @@ -125,7 +127,7 @@ def test_rouge_no_overlap( generated_outputs: str, reference_outputs: str, rouge_function: Callable[[str, str], MetricValue[float]], - tokenizer: Optional[_JapaneseTokenizer], + tokenizer: _JapaneseTokenizer | None, ) -> None: # All ROUGE scores are 0 if the generated and reference outputs have no # overlapping words @@ -153,7 +155,7 @@ def test_rouge_some_overlap( generated_outputs: str, reference_outputs: str, rouge_function: Callable[[str, str], MetricValue[float]], - tokenizer: Optional[_JapaneseTokenizer], + tokenizer: _JapaneseTokenizer | None, ) -> None: expected_value = { "rouge1": [0.823529411764706], diff --git a/tests/metrics/ja/test_tokenizers.py b/tests/metrics/ja/test_tokenizers.py index 78172df0..ebaa19a7 100644 --- a/tests/metrics/ja/test_tokenizers.py +++ b/tests/metrics/ja/test_tokenizers.py @@ -1,5 +1,4 @@ import pkgutil -from typing import List import pytest @@ -7,22 +6,36 @@ from langcheck.metrics.ja._tokenizers import _JapaneseTokenizer -@pytest.mark.parametrize("text,expected_tokens", [ - (["頭が赤い魚を食べる猫", ["頭", "が", "赤い", "魚", "を", "食べる", "猫"]]), - ("猫が、マットの上に座った。", ["猫", "が", "マット", "の", "上", "に", "座っ", "た"]), -]) +@pytest.mark.parametrize( + "text,expected_tokens", + [ + ( + [ + "頭が赤い魚を食べる猫", + ["頭", "が", "赤い", "魚", "を", "食べる", "猫"], + ] + ), + ( + "猫が、マットの上に座った。", + ["猫", "が", "マット", "の", "上", "に", "座っ", "た"], + ), + ], +) @pytest.mark.parametrize( "tokenizer", - [JanomeTokenizer, - pytest.param(MeCabTokenizer, marks=pytest.mark.optional)]) -def test_janome_tokenizer(text: str, expected_tokens: List[str], - tokenizer: _JapaneseTokenizer) -> None: + [JanomeTokenizer, pytest.param(MeCabTokenizer, marks=pytest.mark.optional)], +) +def test_janome_tokenizer( + text: str, expected_tokens: list[str], tokenizer: _JapaneseTokenizer +) -> None: tokenizer = tokenizer() # type: ignore[reportGeneralTypeIssues] assert tokenizer.tokenize(text) == expected_tokens -@pytest.mark.skipif(pkgutil.find_loader("MeCab") is not None, - reason="MeCab has already been installed.") +@pytest.mark.skipif( + pkgutil.find_loader("MeCab") is not None, + reason="MeCab has already been installed.", +) def test_handle_mecab_not_found() -> None: with pytest.raises(ModuleNotFoundError): MeCabTokenizer() diff --git a/tests/metrics/test_metric_value.py b/tests/metrics/test_metric_value.py index f44dfe06..983ef98f 100644 --- a/tests/metrics/test_metric_value.py +++ b/tests/metrics/test_metric_value.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations import pandas as pd import pytest @@ -69,7 +69,7 @@ def test_optional_metric_values(): }, required_params=["generated_outputs"], ) - metric_value: MetricValue[Optional[float]] = MetricValue( + metric_value: MetricValue[float | None] = MetricValue( metric_name="test", metric_inputs=metric_inputs, explanations=None, diff --git a/tests/metrics/zh/test_reference_based_text_quality.py b/tests/metrics/zh/test_reference_based_text_quality.py index 05c4776f..813bcbd6 100644 --- a/tests/metrics/zh/test_reference_based_text_quality.py +++ b/tests/metrics/zh/test_reference_based_text_quality.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Optional +from collections.abc import Callable from unittest.mock import Mock, patch import pytest @@ -23,81 +23,114 @@ ################################################################################ # Tests ################################################################################ -parametrize_rouge_function = pytest.mark.parametrize("rouge_function", - [rouge1, rouge2, rougeL]) -parametrize_tokenizer = pytest.mark.parametrize("tokenizer", - [None, HanLPTokenizer]) +parametrize_rouge_function = pytest.mark.parametrize( + "rouge_function", [rouge1, rouge2, rougeL] +) +parametrize_tokenizer = pytest.mark.parametrize( + "tokenizer", [None, HanLPTokenizer] +) -@pytest.mark.parametrize("generated_outputs,reference_outputs", - [("宇宙的终极答案是什么?", "宇宙的终极答案是什么。"), - (["宇宙的终极答案是什么。"], ["宇宙的终极答案是什么?"])]) +@pytest.mark.parametrize( + "generated_outputs,reference_outputs", + [ + ("宇宙的终极答案是什么?", "宇宙的终极答案是什么。"), + (["宇宙的终极答案是什么。"], ["宇宙的终极答案是什么?"]), + ], +) @parametrize_rouge_function @parametrize_tokenizer -def test_rouge_identical(generated_outputs: str, reference_outputs: str, - rouge_function: Callable[ - [str, str, Optional[_ChineseTokenizer]], - MetricValue[float]], - tokenizer: Optional[_ChineseTokenizer]) -> None: +def test_rouge_identical( + generated_outputs: str, + reference_outputs: str, + rouge_function: Callable[ + [str, str, _ChineseTokenizer | None], MetricValue[float] + ], + tokenizer: _ChineseTokenizer | None, +) -> None: # All ROUGE scores are 1 if the generated and reference outputs are # identical actual_metric_value = rouge_function( generated_outputs, reference_outputs, tokenizer=tokenizer() # type: ignore[reportGeneralTypeIssues] - if tokenizer else None) - assert actual_metric_value.metric_values == [1.] + if tokenizer + else None, + ) + assert actual_metric_value.metric_values == [1.0] assert actual_metric_value.language == "zh" -@pytest.mark.parametrize("generated_outputs,reference_outputs", - [("这样的姑娘是受不了的。", "您到底有什么事?"), - (["这样的姑娘是受不了的。"], ["您到底有什么事?"])]) +@pytest.mark.parametrize( + "generated_outputs,reference_outputs", + [ + ("这样的姑娘是受不了的。", "您到底有什么事?"), + (["这样的姑娘是受不了的。"], ["您到底有什么事?"]), + ], +) @parametrize_rouge_function @parametrize_tokenizer -def test_rouge_no_overlap(generated_outputs: str, reference_outputs: str, - rouge_function: Callable[[str, str], - MetricValue[float]], - tokenizer: Optional[_ChineseTokenizer]) -> None: +def test_rouge_no_overlap( + generated_outputs: str, + reference_outputs: str, + rouge_function: Callable[[str, str], MetricValue[float]], + tokenizer: _ChineseTokenizer | None, +) -> None: # All ROUGE scores are 0 if the generated and reference outputs have no # overlapping words actual_metric_value = rouge_function( generated_outputs, reference_outputs, tokenizer=tokenizer() # type: ignore[reportGeneralTypeIssues] - if tokenizer else None) - assert actual_metric_value.metric_values == [0.] + if tokenizer + else None, + ) + assert actual_metric_value.metric_values == [0.0] assert actual_metric_value.language == "zh" -@pytest.mark.parametrize("generated_outputs,reference_outputs", - [("床前明月光,下一句是什么?", "床前明月光的下一句是什么?"), - (["床前明月光,下一句是什么?"], ["床前明月光的下一句是什么?"])]) +@pytest.mark.parametrize( + "generated_outputs,reference_outputs", + [ + ("床前明月光,下一句是什么?", "床前明月光的下一句是什么?"), + (["床前明月光,下一句是什么?"], ["床前明月光的下一句是什么?"]), + ], +) @parametrize_rouge_function @parametrize_tokenizer -def test_rouge_some_overlap(generated_outputs: str, reference_outputs: str, - rouge_function: Callable[[str, str], - MetricValue[float]], - tokenizer: Optional[_ChineseTokenizer]) -> None: +def test_rouge_some_overlap( + generated_outputs: str, + reference_outputs: str, + rouge_function: Callable[[str, str], MetricValue[float]], + tokenizer: _ChineseTokenizer | None, +) -> None: expected_value = { "rouge1": [0.941176], "rouge2": [0.8], - "rougeL": [0.941176] + "rougeL": [0.941176], } # The ROUGE-2 score is lower than the ROUGE-1 and ROUGE-L scores actual_metric_value = rouge_function( generated_outputs, reference_outputs, tokenizer=tokenizer() # type: ignore[reportGeneralTypeIssues] - if tokenizer else None) - is_close(actual_metric_value.metric_values, - expected_value[rouge_function.__name__]) + if tokenizer + else None, + ) + is_close( + actual_metric_value.metric_values, + expected_value[rouge_function.__name__], + ) assert actual_metric_value.language == "zh" -@pytest.mark.parametrize("generated_outputs,reference_outputs", - [("那里有一本三体小说。", "那里有一本三体小说。"), - (["那里有一本三体小说。"], ["那里有一本三体小说。"])]) +@pytest.mark.parametrize( + "generated_outputs,reference_outputs", + [ + ("那里有一本三体小说。", "那里有一本三体小说。"), + (["那里有一本三体小说。"], ["那里有一本三体小说。"]), + ], +) def test_semantic_similarity_identical(generated_outputs, reference_outputs): metric_value = semantic_similarity(generated_outputs, reference_outputs) assert 0.99 <= metric_value <= 1 @@ -105,29 +138,44 @@ def test_semantic_similarity_identical(generated_outputs, reference_outputs): @pytest.mark.parametrize( "generated_outputs,reference_outputs", - [("php是世界上最好的语言,学计算机要从娃娃抓起。", "在石家庄,有一支摇滚乐队,他们创作了很多音乐。"), - (["php是世界上最好的语言,学计算机要从娃娃抓起。"], ["在石家庄,有一支摇滚乐队,他们创作了很多音乐。"])]) + [ + ( + "php是世界上最好的语言,学计算机要从娃娃抓起。", + "在石家庄,有一支摇滚乐队,他们创作了很多音乐。", + ), + ( + ["php是世界上最好的语言,学计算机要从娃娃抓起。"], + ["在石家庄,有一支摇滚乐队,他们创作了很多音乐。"], + ), + ], +) def test_semantic_similarity_not_similar(generated_outputs, reference_outputs): metric_value = semantic_similarity(generated_outputs, reference_outputs) assert 0.0 <= metric_value <= 0.5 -@pytest.mark.parametrize("generated_outputs,reference_outputs", - [("学习中文很快乐。", "学习中文很快乐。"), - (["学习中文很快乐。"], ["学习中文很快乐。"])]) +@pytest.mark.parametrize( + "generated_outputs,reference_outputs", + [ + ("学习中文很快乐。", "学习中文很快乐。"), + (["学习中文很快乐。"], ["学习中文很快乐。"]), + ], +) def test_semantic_similarity_openai(generated_outputs, reference_outputs): mock_embedding_response = Mock(spec=CreateEmbeddingResponse) mock_embedding_response.data = [Mock(embedding=[0.1, 0.2, 0.3])] # Calling the openai.Embedding.create method requires an OpenAI API key, so # we mock the return value instead - with patch("openai.resources.Embeddings.create", - Mock(return_value=mock_embedding_response)): + with patch( + "openai.resources.Embeddings.create", + Mock(return_value=mock_embedding_response), + ): # Set the necessary env vars for the 'openai' embedding model type os.environ["OPENAI_API_KEY"] = "dummy_key" openai_client = OpenAIEvalClient() - metric_value = semantic_similarity(generated_outputs, - reference_outputs, - eval_model=openai_client) + metric_value = semantic_similarity( + generated_outputs, reference_outputs, eval_model=openai_client + ) # Since the mock embeddings are the same for the generated and reference # outputs, the semantic similarity should be 1. assert 0.99 <= metric_value <= 1 @@ -137,10 +185,11 @@ def test_semantic_similarity_openai(generated_outputs, reference_outputs): os.environ["OPENAI_API_VERSION"] = "dummy_version" os.environ["AZURE_OPENAI_ENDPOINT"] = "dummy_endpoint" azure_openai_client = AzureOpenAIEvalClient( - embedding_model_name="foo bar") - metric_value = semantic_similarity(generated_outputs, - reference_outputs, - eval_model=azure_openai_client) + embedding_model_name="foo bar" + ) + metric_value = semantic_similarity( + generated_outputs, reference_outputs, eval_model=azure_openai_client + ) # Since the mock embeddings are the same for the generated and reference # outputs, the semantic similarity should be 1. assert 0.99 <= metric_value <= 1 diff --git a/tests/metrics/zh/test_tokenizers.py b/tests/metrics/zh/test_tokenizers.py index 4e5d309a..79b41e65 100644 --- a/tests/metrics/zh/test_tokenizers.py +++ b/tests/metrics/zh/test_tokenizers.py @@ -1,18 +1,36 @@ -from typing import List - import pytest from langcheck.metrics.zh import HanLPTokenizer from langcheck.metrics.zh._tokenizers import _ChineseTokenizer -@pytest.mark.parametrize("text,expected_tokens", [ - ("吃葡萄不吐葡萄皮。不吃葡萄到吐葡萄皮。", - ["吃", "葡萄", "不", "吐", "葡萄", "皮", "不", "吃", "葡萄", "到", "吐", "葡萄", "皮"]), - ("北京是中国的首都", ["北京", "是", "中国", "的", "首都"]), -]) +@pytest.mark.parametrize( + "text,expected_tokens", + [ + ( + "吃葡萄不吐葡萄皮。不吃葡萄到吐葡萄皮。", + [ + "吃", + "葡萄", + "不", + "吐", + "葡萄", + "皮", + "不", + "吃", + "葡萄", + "到", + "吐", + "葡萄", + "皮", + ], + ), + ("北京是中国的首都", ["北京", "是", "中国", "的", "首都"]), + ], +) @pytest.mark.parametrize("tokenizer", [HanLPTokenizer]) -def test_hanlp_tokenizer(text: str, expected_tokens: List[str], - tokenizer: _ChineseTokenizer) -> None: +def test_hanlp_tokenizer( + text: str, expected_tokens: list[str], tokenizer: _ChineseTokenizer +) -> None: tokenizer = tokenizer() # type: ignore[reportGeneralTypeIssues] assert tokenizer.tokenize(text) == expected_tokens diff --git a/tests/utils.py b/tests/utils.py index 7f9a375e..70759b87 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import Iterable, List +from collections.abc import Iterable from langcheck.metrics.eval_clients import EvalClient @@ -20,20 +20,19 @@ def __init__(self, evaluation_result: str | None = None) -> None: self.evaluation_result = evaluation_result def get_text_responses( - self, - prompts: Iterable[str], - *, - tqdm_description: str | None = None) -> list[str | None]: + self, prompts: Iterable[str], *, tqdm_description: str | None = None + ) -> list[str | None]: return [self.evaluation_result] * len(list(prompts)) def get_float_score( - self, - metric_name: str, - language: str, - unstructured_assessment_result: list[str | None], - score_map: dict[str, float], - *, - tqdm_description: str | None = None) -> list[float | None]: + self, + metric_name: str, + language: str, + unstructured_assessment_result: list[str | None], + score_map: dict[str, float], + *, + tqdm_description: str | None = None, + ) -> list[float | None]: eval_results = [] # Assume that the evaluation result is actually structured and it can be # put into the score_map directly @@ -51,13 +50,13 @@ def get_float_score( ################################################################################ -def is_close(a: List, b: List) -> bool: +def is_close(a: list, b: list) -> bool: """Returns True if two lists of numbers are element-wise close.""" assert len(a) == len(b) return all(math.isclose(x, y) for x, y in zip(a, b)) -def lists_are_equal(a: List[str] | str, b: List[str] | str) -> bool: +def lists_are_equal(a: list[str] | str, b: list[str] | str) -> bool: """Returns True if two lists of strings are equal. If either argument is a single string, it's automatically converted to a list. """ From 476a7f667eedb576d6e105cf6eefb65d413ce908 Mon Sep 17 00:00:00 2001 From: Koki Ryu Date: Mon, 28 Oct 2024 07:50:43 +0000 Subject: [PATCH 3/5] Fix types --- src/langcheck/metrics/metric_value.py | 6 ++++-- src/langcheck/metrics/model_manager/_model_management.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/langcheck/metrics/metric_value.py b/src/langcheck/metrics/metric_value.py index 1e41dfd4..d07e0b17 100644 --- a/src/langcheck/metrics/metric_value.py +++ b/src/langcheck/metrics/metric_value.py @@ -4,7 +4,7 @@ import warnings from dataclasses import dataclass, fields from statistics import mean -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Union import pandas as pd @@ -12,7 +12,9 @@ # Metrics take on float or integer values # Some metrics may return `None` values when the score fails to be computed -NumericType = TypeVar("NumericType", float, int, float | None, int | None) +NumericType = TypeVar( + "NumericType", float, int, Union[float, None], Union[int, None] +) @dataclass diff --git a/src/langcheck/metrics/model_manager/_model_management.py b/src/langcheck/metrics/model_manager/_model_management.py index 08809890..b95e0439 100644 --- a/src/langcheck/metrics/model_manager/_model_management.py +++ b/src/langcheck/metrics/model_manager/_model_management.py @@ -90,8 +90,8 @@ def __load_config(self, path: str) -> None: def fetch_model( self, language: str, metric: str ) -> ( - tuple[AutoTokenizer | AutoModelForSequenceClassification] - | tuple[AutoTokenizer | AutoModelForSeq2SeqLM] + tuple[AutoTokenizer, AutoModelForSequenceClassification] + | tuple[AutoTokenizer, AutoModelForSeq2SeqLM] | SentenceTransformer ): """ From d1bfb46387913436b50a5a4cd764f2887225ea52 Mon Sep 17 00:00:00 2001 From: Koki Ryu Date: Mon, 28 Oct 2024 07:57:41 +0000 Subject: [PATCH 4/5] Add __future__ --- tests/augment/en/test_gender.py | 2 ++ tests/metrics/zh/test_reference_based_text_quality.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/augment/en/test_gender.py b/tests/augment/en/test_gender.py index 29e7dec0..17479e9b 100644 --- a/tests/augment/en/test_gender.py +++ b/tests/augment/en/test_gender.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random import pytest diff --git a/tests/metrics/zh/test_reference_based_text_quality.py b/tests/metrics/zh/test_reference_based_text_quality.py index 813bcbd6..460a8791 100644 --- a/tests/metrics/zh/test_reference_based_text_quality.py +++ b/tests/metrics/zh/test_reference_based_text_quality.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from collections.abc import Callable from unittest.mock import Mock, patch From 0a79fc7e1c8d512f521bfa79557a98b586ffcaa7 Mon Sep 17 00:00:00 2001 From: Koki Ryu Date: Mon, 28 Oct 2024 13:34:30 +0000 Subject: [PATCH 5/5] Unpin old transformers --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0e8b0e58..3ec923a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ 'tomli; python_version < "3.11"', 'tokenizers >= 0.13.2; python_version >= "3.11"', # See https://github.com/citadel-ai/langcheck/pull/45 'torch >= 2', - 'transformers >= 4.6, < 4.46', + 'transformers >= 4.6', 'tabulate >= 0.9.0', # For model manager print table 'omegaconf >= 2.3.0' # For model manager print table ]