Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
awwaawwa committed Dec 23, 2024
1 parent b4a1084 commit 0bf5002
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 38 deletions.
12 changes: 4 additions & 8 deletions pdf2zh/high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def translate_patch(
noto,
kwarg.get("envs", {}),
kwarg.get("prompt", []),
generate_cache_executor
generate_cache_executor,
)

assert device is not None
Expand Down Expand Up @@ -373,20 +373,16 @@ def translate(
if file.startswith(tempfile.gettempdir()):
os.unlink(file)
generate_cache_start = time.time()
with concurrent.futures.ThreadPoolExecutor(
max_workers=thread
) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=thread) as executor:
translate_stream(
s_raw,
envs=kwarg.get("envs", {}),
prompt=kwarg.get("prompt", []),
generate_cache_executor=executor,
**locals(),
)
print('Translating... Please wait...')
print(
f"Generate cache time: {time.time() - generate_cache_start:.2f} seconds"
)
print("Translating... Please wait...")
print(f"Generate cache time: {time.time() - generate_cache_start:.2f} seconds")
s_mono, s_dual = translate_stream(
s_raw,
envs=kwarg.get("envs", {}),
Expand Down
56 changes: 29 additions & 27 deletions pdf2zh/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def __init__(self, lang_in, lang_out, model):

def __del__(self):
print(f"{self.name} translate call count: {self.translate_call_count}")
print(f"{self.name} translate cache call count: {self.translate_cache_call_count}")
print(
f"{self.name} translate cache call count: {self.translate_cache_call_count}"
)

def set_envs(self, envs):
# Detach from self.__class__.envs
Expand Down Expand Up @@ -342,14 +344,14 @@ class OpenAITranslator(BaseTranslator):
CustomPrompt = True

def __init__(
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
):
self.set_envs(envs)
if not model:
Expand Down Expand Up @@ -381,14 +383,14 @@ class AzureOpenAITranslator(BaseTranslator):
CustomPrompt = True

def __init__(
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
):
self.set_envs(envs)
base_url = self.envs["AZURE_OPENAI_BASE_URL"]
Expand Down Expand Up @@ -426,14 +428,14 @@ class ModelScopeTranslator(OpenAITranslator):
CustomPrompt = True

def __init__(
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
):
self.set_envs(envs)
base_url = "https://api-inference.modelscope.cn/v1"
Expand Down Expand Up @@ -475,8 +477,8 @@ def do_translate(self, text) -> str:
)
except openai.BadRequestError as e:
if (
json.loads(response.choices[0].message.content.strip())["error"]["code"]
== "1301"
json.loads(response.choices[0].message.content.strip())["error"]["code"]
== "1301"
):
return "IRREPARABLE TRANSLATION ERROR"
raise e
Expand Down
12 changes: 9 additions & 3 deletions test/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,18 @@ def task(i):

# Verify timing
total_time = timestamps[-1] - start_time
self.assertGreaterEqual(total_time, 1.0) # Should take at least 1s for 20 requests at 10 QPS
self.assertGreaterEqual(
total_time, 1.0
) # Should take at least 1s for 20 requests at 10 QPS

# Check even distribution
intervals = [timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)]
intervals = [
timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)
]
avg_interval = sum(intervals) / len(intervals)
self.assertAlmostEqual(avg_interval, 0.1, delta=0.05) # Should be close to 0.1s (1/10 QPS)
self.assertAlmostEqual(
avg_interval, 0.1, delta=0.05
) # Should be close to 0.1s (1/10 QPS)

def test_burst_handling(self):
limiter = RateLimiter(10) # 10 QPS
Expand Down

0 comments on commit 0bf5002

Please sign in to comment.