Skip to content

Commit

Permalink
✨ feat(translator): add initial asynchronous translation support and …
Browse files Browse the repository at this point in the history
…remove BaseTranslator.translate
  • Loading branch information
awwaawwa committed Dec 23, 2024
1 parent d699e38 commit 797192a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 63 deletions.
3 changes: 2 additions & 1 deletion pdf2zh/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import concurrent.futures
import numpy as np
import unicodedata
import asyncio
from tenacity import retry, wait_fixed
from pdf2zh.translator import (
AzureOpenAITranslator,
Expand Down Expand Up @@ -333,7 +334,7 @@ def worker(s: str): # 多线程翻译
if not s.strip() or re.match(r"^\{v\d+\}$", s): # 空白和公式不翻译
return s
try:
new = self.translator.translate(s)
new = asyncio.run(self.translator.translate_async(s))
return new
except BaseException as e:
if log.isEnabledFor(logging.DEBUG):
Expand Down
36 changes: 5 additions & 31 deletions pdf2zh/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,6 @@ def add_cache_impact_parameters(self, k: str, v):
"""
self.cache.add_params(k, v)

def translate(self, text, ignore_cache=False):
"""
Translate the text, and the other part should call this method.
Don't call this method in asyncio since we use asyncio.run()
asyncio.run() will raise RuntimeError if the event loop is not running
:param text: text to translate
:return: translated text
"""
if not (self.ignore_cache or ignore_cache):
cache = self.cache.get(text)
if cache is not None:
return cache

try:
translation = self.do_translate(text)
except NotImplementedError:
# asyncio.run() will raise RuntimeError if the event loop is not running
if asyncio.get_running_loop() is not None:
raise NotImplementedError

translation = asyncio.run(self.do_translate_async(text))
if not (self.ignore_cache or ignore_cache):
self.cache.set(text, translation)
return translation

async def translate_async(self, text, ignore_cache=False):
"""
Translate the text, and the other part should call this method.
Expand Down Expand Up @@ -339,14 +313,14 @@ def __init__(
model = self.envs["OPENAI_MODEL"]
super().__init__(lang_in, lang_out, model)
self.options = {"temperature": 0} # 随机采样可能会打断公式标记
self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
self.client = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
self.prompttext = prompt
self.add_cache_impact_parameters("temperature", self.options["temperature"])
if prompt:
self.add_cache_impact_parameters("prompt", prompt)

def do_translate(self, text) -> str:
response = self.client.chat.completions.create(
async def do_translate_async(self, text) -> str:
response = await self.client.chat.completions.create(
model=self.model,
**self.options,
messages=self.prompt(text, self.prompttext),
Expand Down Expand Up @@ -449,9 +423,9 @@ def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
if prompt:
self.add_cache_impact_parameters("prompt", prompt)

def do_translate(self, text) -> str:
async def do_translate_async(self, text) -> str:
try:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
model=self.model,
**self.options,
messages=self.prompt(text, self.prompttext),
Expand Down
44 changes: 13 additions & 31 deletions test/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,52 @@ def setUp(self):
def tearDown(self):
cache.clean_test_db(self.test_db)

def test_cache(self):
async def test_cache(self):
translator = AutoIncreaseTranslator("en", "zh", "test")
# First translation should be cached
text = "Hello World"
first_result = translator.translate(text)
first_result = await translator.translate_async(text)

# Second translation should return the same result from cache
second_result = translator.translate(text)
second_result = await translator.translate_async(text)
self.assertEqual(first_result, second_result)

# Different input should give different result
different_text = "Different Text"
different_result = translator.translate(different_text)
different_result = await translator.translate_async(different_text)
self.assertNotEqual(first_result, different_result)

# Test cache with ignore_cache=True
translator.ignore_cache = True
no_cache_result = translator.translate(text)
no_cache_result = await translator.translate_async(text)
self.assertNotEqual(first_result, no_cache_result)

def test_add_cache_impact_parameters(self):
async def test_add_cache_impact_parameters(self):
translator = AutoIncreaseTranslator("en", "zh", "test")

# Test cache with added parameters
text = "Hello World"
first_result = translator.translate(text)
first_result = await translator.translate_async(text)
translator.add_cache_impact_parameters("test", "value")
second_result = translator.translate(text)
second_result = await translator.translate_async(text)
self.assertNotEqual(first_result, second_result)

# Test cache with ignore_cache=True
no_cache_result = translator.translate(text, ignore_cache=True)
no_cache_result = await translator.translate_async(text, ignore_cache=True)
self.assertNotEqual(first_result, no_cache_result)

translator.ignore_cache = True
no_cache_result = translator.translate(text)
no_cache_result = await translator.translate_async(text)
self.assertNotEqual(first_result, no_cache_result)

# Test cache with ignore_cache=False
translator.ignore_cache = False
cache_result = translator.translate(text)
cache_result = await translator.translate_async(text)
self.assertEqual(second_result, cache_result)

# Test cache with another parameter
translator.add_cache_impact_parameters("test2", "value2")
another_result = translator.translate(text)
another_result = await translator.translate_async(text)
self.assertNotEqual(second_result, another_result)

async def test_base_translator_throw(self):
Expand All @@ -83,30 +83,12 @@ async def test_base_translator_throw(self):
with self.assertRaises(NotImplementedError):
await translator.do_translate_async("Hello World")

async def test_async_and_sync_translator(self):
async_translator = AutoIncreaseAsyncTranslator("en", "zh", "test")
async def test_call_sync_from_async(self):
sync_translator = AutoIncreaseTranslator("en", "zh", "test")

# call async from async
self.assertEqual(await async_translator.translate_async("Hello World"), "1")

# call sync from async
self.assertEqual(await sync_translator.translate_async("Hello World"), "1")

# call sync from sync
self.assertEqual(sync_translator.translate("Hello World"), "1")

# call async from sync
with self.assertRaises(NotImplementedError):
self.assertEqual(
async_translator.translate("Hello World", ignore_cache=True), "1"
)

async def test_call_async_from_sync_inside_running_loop(self):
translator = AutoIncreaseAsyncTranslator("en", "zh", "test")
with self.assertRaises(NotImplementedError):
translator.translate("Hello World")


if __name__ == "__main__":
unittest.main()

0 comments on commit 797192a

Please sign in to comment.