From 797192a46d4bbcf4eb82aa931f1edd04cf89112e Mon Sep 17 00:00:00 2001 From: awwaawwa <8493196+awwaawwa@users.noreply.github.com> Date: Tue, 24 Dec 2024 00:55:06 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(translator):=20add=20initial?= =?UTF-8?q?=20asynchronous=20translation=20support=20and=20remove=20BaseTr?= =?UTF-8?q?anslator.translate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pdf2zh/converter.py | 3 ++- pdf2zh/translator.py | 36 +++++---------------------------- test/test_translator.py | 44 ++++++++++++----------------------------- 3 files changed, 20 insertions(+), 63 deletions(-) diff --git a/pdf2zh/converter.py b/pdf2zh/converter.py index ea623536..49b2873c 100644 --- a/pdf2zh/converter.py +++ b/pdf2zh/converter.py @@ -16,6 +16,7 @@ import concurrent.futures import numpy as np import unicodedata +import asyncio from tenacity import retry, wait_fixed from pdf2zh.translator import ( AzureOpenAITranslator, @@ -333,7 +334,7 @@ def worker(s: str): # 多线程翻译 if not s.strip() or re.match(r"^\{v\d+\}$", s): # 空白和公式不翻译 return s try: - new = self.translator.translate(s) + new = asyncio.run(self.translator.translate_async(s)) return new except BaseException as e: if log.isEnabledFor(logging.DEBUG): diff --git a/pdf2zh/translator.py b/pdf2zh/translator.py index bdc32b00..2556a3a9 100644 --- a/pdf2zh/translator.py +++ b/pdf2zh/translator.py @@ -67,32 +67,6 @@ def add_cache_impact_parameters(self, k: str, v): """ self.cache.add_params(k, v) - def translate(self, text, ignore_cache=False): - """ - Translate the text, and the other part should call this method. - - Don't call this method in asyncio since we use asyncio.run() - asyncio.run() will raise RuntimeError if the event loop is not running - :param text: text to translate - :return: translated text - """ - if not (self.ignore_cache or ignore_cache): - cache = self.cache.get(text) - if cache is not None: - return cache - - try: - translation = self.do_translate(text) - except NotImplementedError: - # asyncio.run() will raise RuntimeError if the event loop is not running - if asyncio.get_running_loop() is not None: - raise NotImplementedError - - translation = asyncio.run(self.do_translate_async(text)) - if not (self.ignore_cache or ignore_cache): - self.cache.set(text, translation) - return translation - async def translate_async(self, text, ignore_cache=False): """ Translate the text, and the other part should call this method. @@ -339,14 +313,14 @@ def __init__( model = self.envs["OPENAI_MODEL"] super().__init__(lang_in, lang_out, model) self.options = {"temperature": 0} # 随机采样可能会打断公式标记 - self.client = openai.OpenAI(base_url=base_url, api_key=api_key) + self.client = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) self.prompttext = prompt self.add_cache_impact_parameters("temperature", self.options["temperature"]) if prompt: self.add_cache_impact_parameters("prompt", prompt) - def do_translate(self, text) -> str: - response = self.client.chat.completions.create( + async def do_translate_async(self, text) -> str: + response = await self.client.chat.completions.create( model=self.model, **self.options, messages=self.prompt(text, self.prompttext), @@ -449,9 +423,9 @@ def __init__(self, lang_in, lang_out, model, envs=None, prompt=None): if prompt: self.add_cache_impact_parameters("prompt", prompt) - def do_translate(self, text) -> str: + async def do_translate_async(self, text) -> str: try: - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model, **self.options, messages=self.prompt(text, self.prompttext), diff --git a/test/test_translator.py b/test/test_translator.py index 8c0aefbe..e0a865ec 100644 --- a/test/test_translator.py +++ b/test/test_translator.py @@ -28,52 +28,52 @@ def setUp(self): def tearDown(self): cache.clean_test_db(self.test_db) - def test_cache(self): + async def test_cache(self): translator = AutoIncreaseTranslator("en", "zh", "test") # First translation should be cached text = "Hello World" - first_result = translator.translate(text) + first_result = await translator.translate_async(text) # Second translation should return the same result from cache - second_result = translator.translate(text) + second_result = await translator.translate_async(text) self.assertEqual(first_result, second_result) # Different input should give different result different_text = "Different Text" - different_result = translator.translate(different_text) + different_result = await translator.translate_async(different_text) self.assertNotEqual(first_result, different_result) # Test cache with ignore_cache=True translator.ignore_cache = True - no_cache_result = translator.translate(text) + no_cache_result = await translator.translate_async(text) self.assertNotEqual(first_result, no_cache_result) - def test_add_cache_impact_parameters(self): + async def test_add_cache_impact_parameters(self): translator = AutoIncreaseTranslator("en", "zh", "test") # Test cache with added parameters text = "Hello World" - first_result = translator.translate(text) + first_result = await translator.translate_async(text) translator.add_cache_impact_parameters("test", "value") - second_result = translator.translate(text) + second_result = await translator.translate_async(text) self.assertNotEqual(first_result, second_result) # Test cache with ignore_cache=True - no_cache_result = translator.translate(text, ignore_cache=True) + no_cache_result = await translator.translate_async(text, ignore_cache=True) self.assertNotEqual(first_result, no_cache_result) translator.ignore_cache = True - no_cache_result = translator.translate(text) + no_cache_result = await translator.translate_async(text) self.assertNotEqual(first_result, no_cache_result) # Test cache with ignore_cache=False translator.ignore_cache = False - cache_result = translator.translate(text) + cache_result = await translator.translate_async(text) self.assertEqual(second_result, cache_result) # Test cache with another parameter translator.add_cache_impact_parameters("test2", "value2") - another_result = translator.translate(text) + another_result = await translator.translate_async(text) self.assertNotEqual(second_result, another_result) async def test_base_translator_throw(self): @@ -83,30 +83,12 @@ async def test_base_translator_throw(self): with self.assertRaises(NotImplementedError): await translator.do_translate_async("Hello World") - async def test_async_and_sync_translator(self): - async_translator = AutoIncreaseAsyncTranslator("en", "zh", "test") + async def test_call_sync_from_async(self): sync_translator = AutoIncreaseTranslator("en", "zh", "test") - # call async from async - self.assertEqual(await async_translator.translate_async("Hello World"), "1") - # call sync from async self.assertEqual(await sync_translator.translate_async("Hello World"), "1") - # call sync from sync - self.assertEqual(sync_translator.translate("Hello World"), "1") - - # call async from sync - with self.assertRaises(NotImplementedError): - self.assertEqual( - async_translator.translate("Hello World", ignore_cache=True), "1" - ) - - async def test_call_async_from_sync_inside_running_loop(self): - translator = AutoIncreaseAsyncTranslator("en", "zh", "test") - with self.assertRaises(NotImplementedError): - translator.translate("Hello World") - if __name__ == "__main__": unittest.main()