From 0bf50024e661eaee32b7c61ad2deded0bf32c12d Mon Sep 17 00:00:00 2001 From: awwaawwa <8493196+awwaawwa@users.noreply.github.com> Date: Tue, 24 Dec 2024 02:44:56 +0800 Subject: [PATCH] format --- pdf2zh/high_level.py | 12 +++------ pdf2zh/translator.py | 56 +++++++++++++++++++++-------------------- test/test_translator.py | 12 ++++++--- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/pdf2zh/high_level.py b/pdf2zh/high_level.py index 544063fd..8135caae 100644 --- a/pdf2zh/high_level.py +++ b/pdf2zh/high_level.py @@ -108,7 +108,7 @@ def translate_patch( noto, kwarg.get("envs", {}), kwarg.get("prompt", []), - generate_cache_executor + generate_cache_executor, ) assert device is not None @@ -373,9 +373,7 @@ def translate( if file.startswith(tempfile.gettempdir()): os.unlink(file) generate_cache_start = time.time() - with concurrent.futures.ThreadPoolExecutor( - max_workers=thread - ) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=thread) as executor: translate_stream( s_raw, envs=kwarg.get("envs", {}), @@ -383,10 +381,8 @@ def translate( generate_cache_executor=executor, **locals(), ) - print('Translating... Please wait...') - print( - f"Generate cache time: {time.time() - generate_cache_start:.2f} seconds" - ) + print("Translating... Please wait...") + print(f"Generate cache time: {time.time() - generate_cache_start:.2f} seconds") s_mono, s_dual = translate_stream( s_raw, envs=kwarg.get("envs", {}), diff --git a/pdf2zh/translator.py b/pdf2zh/translator.py index 1b59b15a..59619ac2 100644 --- a/pdf2zh/translator.py +++ b/pdf2zh/translator.py @@ -93,7 +93,9 @@ def __init__(self, lang_in, lang_out, model): def __del__(self): print(f"{self.name} translate call count: {self.translate_call_count}") - print(f"{self.name} translate cache call count: {self.translate_cache_call_count}") + print( + f"{self.name} translate cache call count: {self.translate_cache_call_count}" + ) def set_envs(self, envs): # Detach from self.__class__.envs @@ -342,14 +344,14 @@ class OpenAITranslator(BaseTranslator): CustomPrompt = True def __init__( - self, - lang_in, - lang_out, - model, - base_url=None, - api_key=None, - envs=None, - prompt=None, + self, + lang_in, + lang_out, + model, + base_url=None, + api_key=None, + envs=None, + prompt=None, ): self.set_envs(envs) if not model: @@ -381,14 +383,14 @@ class AzureOpenAITranslator(BaseTranslator): CustomPrompt = True def __init__( - self, - lang_in, - lang_out, - model, - base_url=None, - api_key=None, - envs=None, - prompt=None, + self, + lang_in, + lang_out, + model, + base_url=None, + api_key=None, + envs=None, + prompt=None, ): self.set_envs(envs) base_url = self.envs["AZURE_OPENAI_BASE_URL"] @@ -426,14 +428,14 @@ class ModelScopeTranslator(OpenAITranslator): CustomPrompt = True def __init__( - self, - lang_in, - lang_out, - model, - base_url=None, - api_key=None, - envs=None, - prompt=None, + self, + lang_in, + lang_out, + model, + base_url=None, + api_key=None, + envs=None, + prompt=None, ): self.set_envs(envs) base_url = "https://api-inference.modelscope.cn/v1" @@ -475,8 +477,8 @@ def do_translate(self, text) -> str: ) except openai.BadRequestError as e: if ( - json.loads(response.choices[0].message.content.strip())["error"]["code"] - == "1301" + json.loads(response.choices[0].message.content.strip())["error"]["code"] + == "1301" ): return "IRREPARABLE TRANSLATION ERROR" raise e diff --git a/test/test_translator.py b/test/test_translator.py index ed05b7b0..3deddbc8 100644 --- a/test/test_translator.py +++ b/test/test_translator.py @@ -99,12 +99,18 @@ def task(i): # Verify timing total_time = timestamps[-1] - start_time - self.assertGreaterEqual(total_time, 1.0) # Should take at least 1s for 20 requests at 10 QPS + self.assertGreaterEqual( + total_time, 1.0 + ) # Should take at least 1s for 20 requests at 10 QPS # Check even distribution - intervals = [timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)] + intervals = [ + timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1) + ] avg_interval = sum(intervals) / len(intervals) - self.assertAlmostEqual(avg_interval, 0.1, delta=0.05) # Should be close to 0.1s (1/10 QPS) + self.assertAlmostEqual( + avg_interval, 0.1, delta=0.05 + ) # Should be close to 0.1s (1/10 QPS) def test_burst_handling(self): limiter = RateLimiter(10) # 10 QPS