Skip to content

Commit

Permalink
Merge pull request #166 from Vela-zz/Vela-zz/issue160
Browse files Browse the repository at this point in the history
Async call of OpenAISimilarityScorer is slower than sync version
  • Loading branch information
liwii authored Nov 20, 2024
2 parents 81a72ef + 6168eab commit 7f7e9ca
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions src/langcheck/metrics/eval_clients/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.create_embedding_response import CreateEmbeddingResponse

from langcheck.utils.progress_bar import tqdm_wrapper

Expand Down Expand Up @@ -409,28 +410,32 @@ def __init__(
self.openai_args = openai_args
self._use_async = use_async

async def _async_embed(self, inputs: list[str]) -> CreateEmbeddingResponse:
"""Embed the inputs using the OpenAI API in async mode."""
assert isinstance(self.openai_client, AsyncOpenAI)
if self.openai_args:
responses = await self.openai_client.embeddings.create(
input=inputs, **self.openai_args
)
else:
responses = await self.openai_client.embeddings.create(
input=inputs, model="text-embedding-3-small"
)
return responses

def _embed(self, inputs: list[str]) -> torch.Tensor:
"""Embed the inputs using the OpenAI API."""

# TODO: Fix that this async call could be much slower than the sync
# version. https://github.com/citadel-ai/langcheck/issues/160
if self._use_async:

async def _call_async_api() -> Any:
assert isinstance(self.openai_client, AsyncOpenAI)
if self.openai_args:
responses = await self.openai_client.embeddings.create(
input=inputs, **self.openai_args
)
else:
responses = await self.openai_client.embeddings.create(
input=inputs, model="text-embedding-3-small"
)
return responses

embed_response = asyncio.run(_call_async_api())
embeddings = [item.embedding for item in embed_response.data]

try:
loop = asyncio.get_event_loop()
except RuntimeError: # pragma: py-lt-310
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
embed_response = loop.run_until_complete(self._async_embed(inputs))
embeddings = [item.embedding for item in embed_response.data]
else:
assert isinstance(self.openai_client, OpenAI)

Expand Down

0 comments on commit 7f7e9ca

Please sign in to comment.