diff --git a/src/langcheck/metrics/eval_clients/_openai.py b/src/langcheck/metrics/eval_clients/_openai.py index 50c75c03..8b5b2210 100644 --- a/src/langcheck/metrics/eval_clients/_openai.py +++ b/src/langcheck/metrics/eval_clients/_openai.py @@ -409,26 +409,25 @@ def __init__( self.openai_args = openai_args self._use_async = use_async + async def _async_embed(self, inputs: list[str]) -> torch.Tensor: + if self.openai_args: + responses = await self.openai_client.embeddings.create( + input=inputs, **self.openai_args + ) + else: + responses = await self.openai_client.embeddings.create( + input=inputs, model="text-embedding-3-small" + ) + return responses + def _embed(self, inputs: list[str]) -> torch.Tensor: """Embed the inputs using the OpenAI API.""" # TODO: Fix that this async call could be much slower than the sync # version. https://github.com/citadel-ai/langcheck/issues/160 if self._use_async: - - async def _call_async_api() -> Any: - assert isinstance(self.openai_client, AsyncOpenAI) - if self.openai_args: - responses = await self.openai_client.embeddings.create( - input=inputs, **self.openai_args - ) - else: - responses = await self.openai_client.embeddings.create( - input=inputs, model="text-embedding-3-small" - ) - return responses - - embed_response = asyncio.run(_call_async_api()) + assert isinstance(self.openai_client, AsyncOpenAI) + embed_response = asyncio.get_event_loop().run_until_complete(self._async_embed(inputs)) embeddings = [item.embedding for item in embed_response.data] else: