Skip to content

Commit

Permalink
Merge pull request #162 from citadel-ai/rephrase-evalclient
Browse files Browse the repository at this point in the history
Use `EvalClient` in `langcheck.augment.rephrase`
  • Loading branch information
taniokay authored Oct 26, 2024
2 parents 71fb7fa + 8cd8c9a commit eb799fd
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 144 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
'tomli; python_version < "3.11"',
'tokenizers >= 0.13.2; python_version >= "3.11"', # See https://github.com/citadel-ai/langcheck/pull/45
'torch >= 2',
'transformers >= 4.6',
'transformers >= 4.6, < 4.46',
'tabulate >= 0.9.0', # For model manager print table
'omegaconf >= 2.3.0' # For model manager print table
]
Expand Down
102 changes: 21 additions & 81 deletions src/langcheck/augment/en/_rephrase.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,41 @@
from __future__ import annotations

import os
from typing import Optional

from openai import AzureOpenAI, OpenAI
from langcheck.metrics.eval_clients import (
EvalClient,
)


def rephrase(
instances: list[str] | str,
*,
num_perturbations: int = 1,
model_type: str = "openai",
openai_client: Optional[OpenAI] = None,
openai_args: Optional[dict[str, str]] = None) -> list[Optional[str]]:
'''Rephrases each string in instances (usually a list of prompts) without
instances: list[str] | str,
*,
num_perturbations: int = 1,
eval_client: EvalClient,
) -> list[str | None]:
"""Rephrases each string in instances (usually a list of prompts) without
changing their meaning. We use a modified version of the prompt presented
in `"Rethinking Benchmark and Contamination for Language Models with
Rephrased Samples" <https://arxiv.org/abs/2311.04850>`__ to make an LLM
rephrase the given text.
We currently support two model types:
1. The 'openai' type, where we use OpenAI's 'gpt-turbo-3.5' model
by default.
2. The 'azure_openai' type. Essentially the same as the 'openai' type,
except that it uses the AzureOpenAI client. Note that you must specify the
model to use in ``openai_args``, e.g.
``openai_args={'model': 'YOUR_DEPLOYMENT_NAME'}``
Args:
instances: A single string or a list of strings to be augmented.
num_perturbations: The number of perturbed instances to generate for
each string in instances
model_type: The type of model to use ('openai' or 'azure_openai'),
default 'openai'
openai_client: OpenAI or AzureOpenAI client, default None. If this is
None, we will attempt to create a default client.
openai_args: Dict of additional args to pass in to the
``client.chat.completions.create`` function, default None
eval_model: The type of model to use.
Returns:
A list of rephrased instances.
'''
# Initialize the openai object if openai_client is None
# TODO: Refactor this into OpenAIEvalClient?
if openai_client is None:
if model_type == "openai":
openai_client = OpenAI()
elif model_type == "azure_openai":
if not openai_args:
raise AssertionError(
'The model deployment must be specified in `openai_args` '
'for the azure_openai type, e.g. '
'`openai_args={"model": "YOUR_DEPLOYMENT_NAME"}`')
openai_client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_KEY"),
api_version=os.getenv("OPENAI_API_VERSION"),
azure_endpoint=os.getenv(
"AZURE_OPENAI_ENDPOINT")) # type: ignore
else:
raise AssertionError(f'Unexpected model type "{model_type}"')
"""

prompt_template = eval_client.load_prompt_template(
language="en", metric_name="rephrase"
)

instances = [instances] if isinstance(instances, str) else instances
rephrased_instances = []
for instance in instances:
for i in range(num_perturbations):
prompt = f"""
Please rephrase the following prompt without altering its meaning,
ensuring you adjust the word order appropriately.
Ensure that no more than five consecutive words are repeated
and try to use similar words as substitutes where possible.
[BEGIN DATA]
************
[Prompt]: {instance}
************
[END DATA]
"""
messages = [{"role": "user", "content": prompt}]
chat_completions = openai_client.chat.completions
try:
if openai_args is None:
response = chat_completions.create(
model="gpt-3.5-turbo",
messages=messages, # type: ignore
seed=i)
else:
response = chat_completions.create( # type: ignore
messages=messages, # type: ignore
seed=i,
**openai_args, # type: ignore
)
rephrased_instance = response.choices[0].message.content
rephrased_instances.append(rephrased_instance)
except Exception as e:
print(f"OpenAI failed to return a rephrased prompt: {e}")
print(f"Prompt that triggered the failure is:\n{prompt}")
rephrased_instances.append(None)
prompt_template_inputs = [{"instance": instance} for instance in instances]

return rephrased_instances
return eval_client.repeat_requests_from_template(
prompt_template_inputs=prompt_template_inputs,
template=prompt_template,
num_perturbations=num_perturbations,
)
2 changes: 1 addition & 1 deletion src/langcheck/metrics/eval_clients/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def _call_async_api() -> list[Any]:
def get_text_responses(
self, prompts: Iterable[str], *, tqdm_description: str | None = None
) -> list[str | None]:
"""The function that gets resonses to the given prompt texts.
"""The function that gets responses to the given prompt texts.
We use Anthropic's 'claude-3-haiku-20240307' model by default, but you
can configure it by passing the 'model' parameter in the anthropic_args.
Expand Down
37 changes: 36 additions & 1 deletion src/langcheck/metrics/eval_clients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def load_prompt_template(
)

def get_text_responses(
self, prompts: Iterable[str], *, tqdm_description: str | None = None
self,
prompts: Iterable[str],
*,
tqdm_description: str | None = None,
) -> list[str | None]:
"""The function that gets responses to the given prompt texts. Each
concrete subclass needs to define the concrete implementation of this
Expand Down Expand Up @@ -222,3 +225,35 @@ def compute_metric_values_from_template(
metric_values=scores,
language=language,
)

def repeat_requests_from_template(
self,
prompt_template_inputs: list[dict[str, str]],
template: Template,
num_perturbations: int = 1,
) -> list[str | None]:
"""Repeats the request using the given Jinja template for
`num_perturbations` times. Note that every EvalClient subclass is
expected to implement `get_text_responses` method to get different
responses for the same input.
Args:
instances: A single string or a list of strings to be augmented.
template: The Jinja template ready to be rendered.
num_perturbations: The number of perturbed instances to generate
for each string in instances.
Returns:
A list of responses for each input. If `num_pertuations` is > 1, the
multiple responses for the same input are included consecutively.
"""

populated_prompts = [
template.render(prompt_template_input)
for prompt_template_input in prompt_template_inputs
for _ in range(num_perturbations)
]

responses = self.get_text_responses(populated_prompts)

return responses
2 changes: 1 addition & 1 deletion src/langcheck/metrics/eval_clients/_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _call_api_with_exception_filter(prompt: str) -> Any:
def get_text_responses(
self, prompts: Iterable[str], *, tqdm_description: str | None = None
) -> list[str | None]:
"""The function that gets resonses to the given prompt texts.
"""The function that gets responses to the given prompt texts.
Args:
prompts: The prompts you want to get the responses for.
Expand Down
2 changes: 1 addition & 1 deletion src/langcheck/metrics/eval_clients/_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_text_responses(
prompts: Iterable[str],
language: str,
) -> list[str | None]:
"""The function that generates resonses to the given prompt texts.
"""The function that generates responses to the given prompt texts.
Args:
prompts: The prompts you want to get the responses for.
Expand Down
60 changes: 22 additions & 38 deletions src/langcheck/metrics/eval_clients/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
Args:
openai_client: (Optional) The OpenAI client to use.
openai_args: (Optional) dict of additional args to pass in to the
``client.chat.completions.create`` function
``client.chat.completions.create`` function.
use_async: (Optional) If True, the async client will be used.
"""
if openai_client:
Expand Down Expand Up @@ -61,9 +61,14 @@ def _call_api_with_exception_filter(model_input: dict[str, Any]) -> Any:
except Exception as e:
return e

# Call API with different seed values for each prompt.
model_inputs = [
{"messages": [{"role": "user", "content": prompt}], **config}
for prompt in prompts
{
"messages": [{"role": "user", "content": prompt}],
"seed": i,
**config,
}
for i, prompt in enumerate(prompts)
]

if self._use_async:
Expand Down Expand Up @@ -101,7 +106,10 @@ async def _call_async_api() -> list[Any]:
return responses

def get_text_responses(
self, prompts: Iterable[str], *, tqdm_description: str | None = None
self,
prompts: Iterable[str],
*,
tqdm_description: str | None = None,
) -> list[str | None]:
"""The function that gets responses to the given prompt texts.
We use OpenAI's 'gpt-turbo-3.5' model by default, but you can configure
Expand All @@ -114,11 +122,13 @@ def get_text_responses(
A list of responses to the prompts. The responses can be None if the
evaluation fails.
"""
config = {"model": "gpt-3.5-turbo", "seed": 123}
config = {"model": "gpt-3.5-turbo"}
config.update(self._openai_args or {})
tqdm_description = tqdm_description or "Intermediate assessments (1/2)"
responses = self._call_api(
prompts=prompts, config=config, tqdm_description=tqdm_description
prompts=prompts,
config=config,
tqdm_description=tqdm_description,
)
response_texts = [
response.choices[0].message.content if response else None
Expand Down Expand Up @@ -151,7 +161,7 @@ def get_text_responses_with_log_likelihood(
output text and the list of tuples of the output tokens and the log
probabilities. The responses can be None if the evaluation fails.
"""
config = {"model": "gpt-3.5-turbo", "seed": 123, "logprobs": True}
config = {"model": "gpt-3.5-turbo", "logprobs": True}
if top_logprobs:
config["top_logprobs"] = top_logprobs
config.update(self._openai_args or {})
Expand Down Expand Up @@ -256,7 +266,6 @@ def get_float_score(
]

config_structured_assessments = {
"seed": 123,
"functions": functions,
"function_call": {
"name": "save_assessment",
Expand Down Expand Up @@ -354,40 +363,14 @@ def __init__(
else:
self._client = AzureOpenAI(**kargs) # type: ignore

self._openai_args = openai_args or {}

self._text_model_name = text_model_name
self._embedding_model_name = embedding_model_name
self._openai_args = openai_args or {}

self._use_async = use_async
if self._text_model_name is not None:
self._openai_args["model"] = self._text_model_name

def get_score(
self,
metric_name: str,
language: str,
prompts: str | Iterable[str],
score_map: dict[str, float],
*,
intermediate_tqdm_description: str | None = None,
score_tqdm_description: str | None = None,
) -> tuple[list[float | None], list[str | None]]:
"""This method does the sanity check for the text_model_name and then
calls the parent class's get_score method with the additional "model"
parameter. See the parent class for the detailed documentation.
"""
assert self._text_model_name is not None, (
"You need to specify the text_model_name to get the score for this "
"metric."
)
self._openai_args["model"] = self._text_model_name
return super().get_score(
metric_name,
language,
prompts,
score_map,
intermediate_tqdm_description=intermediate_tqdm_description,
score_tqdm_description=score_tqdm_description,
)
self._use_async = use_async

def similarity_scorer(self) -> OpenAISimilarityScorer:
"""This method does the sanity check for the embedding_model_name and
Expand Down Expand Up @@ -431,6 +414,7 @@ def _embed(self, inputs: list[str]) -> torch.Tensor:
# TODO: Fix that this async call could be much slower than the sync
# version. https://github.com/citadel-ai/langcheck/issues/160
if self._use_async:

async def _call_async_api() -> Any:
assert isinstance(self.openai_client, AsyncOpenAI)
if self.openai_args:
Expand Down
2 changes: 1 addition & 1 deletion src/langcheck/metrics/eval_clients/_prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def load_prompt_template(
)

def get_text_responses(self, prompts: Iterable[str]) -> list[str | None]:
"""The function that generates resonses to the given prompt texts.
"""The function that generates responses to the given prompt texts.
Args:
prompts: The prompts you want to get the responses for.
Expand Down
7 changes: 7 additions & 0 deletions src/langcheck/metrics/prompts/en/metrics/rephrase.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Please rephrase the following prompt without altering its meaning, ensuring you adjust the word order appropriately.
Ensure that no more than five consecutive words are repeated and try to use similar words as substitutes where possible.
[BEGIN DATA]
************
[Prompt]: {{ instance }}
************
[END DATA]
Loading

0 comments on commit eb799fd

Please sign in to comment.