Skip to content

Commit

Permalink
feat(translate): Initial implementation of two-phase translation
Browse files Browse the repository at this point in the history
- First phase generates translation cache
- Second phase reformats using cache
  • Loading branch information
awwaawwa committed Dec 23, 2024
1 parent 11bbd55 commit 261cac6
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 33 deletions.
13 changes: 9 additions & 4 deletions pdf2zh/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
noto: Font = None,
envs: Dict = None,
prompt: List = None,
generate_cache_executor=None,
) -> None:
super().__init__(rsrcmgr)
self.vfont = vfont
Expand All @@ -154,6 +155,7 @@ def __init__(
self.translator = translator(lang_in, lang_out, service_model, envs=envs, prompt=prompt)
if not self.translator:
raise ValueError("Unsupported translation service")
self.generate_cache_executor = generate_cache_executor

def receive_layout(self, ltpage: LTPage):
# 段落
Expand Down Expand Up @@ -341,10 +343,13 @@ def worker(s: str): # 多线程翻译
else:
log.exception(e, exc_info=False)
raise e
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.thread
) as executor:
news = list(executor.map(worker, sstk))

if self.generate_cache_executor is not None:
self.generate_cache_executor.map(worker, sstk)
return
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=self.thread) as executor:
news = list(executor.map(worker, sstk))

############################################################
# C. 新文档排版
Expand Down
19 changes: 19 additions & 0 deletions pdf2zh/high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from asyncio import CancelledError
from pathlib import Path
from typing import Any, BinaryIO, List, Optional
import concurrent.futures

import numpy as np
import requests
Expand Down Expand Up @@ -88,6 +89,7 @@ def translate_patch(
noto: Font = None,
callback: object = None,
cancellation_event: asyncio.Event = None,
generate_cache_executor=None,
**kwarg: Any,
) -> None:
rsrcmgr = PDFResourceManager()
Expand All @@ -105,6 +107,7 @@ def translate_patch(
noto,
kwarg.get("envs", {}),
kwarg.get("prompt", []),
generate_cache_executor
)

assert device is not None
Expand Down Expand Up @@ -179,6 +182,7 @@ def translate_stream(
vchar: str = "",
callback: object = None,
cancellation_event: asyncio.Event = None,
generate_cache_executor=None,
**kwarg: Any,
):
font_list = [("tiro", None)]
Expand Down Expand Up @@ -235,6 +239,9 @@ def translate_stream(
fp = io.BytesIO()
doc_zh.save(fp)
obj_patch: dict = translate_patch(fp, prompt=kwarg["prompt"], **locals())
if generate_cache_executor:
return
print()

for obj_id, ops_new in obj_patch.items():
# ops_old=doc_en.xref_stream(obj_id)
Expand Down Expand Up @@ -365,10 +372,22 @@ def translate(
if file.startswith(tempfile.gettempdir()):
os.unlink(file)

with concurrent.futures.ThreadPoolExecutor(
max_workers=thread
) as executor:
translate_stream(
s_raw,
envs=kwarg.get("envs", {}),
prompt=kwarg.get("prompt", []),
generate_cache_executor=executor,
**locals(),
)
print('Translating... Please wait...')
s_mono, s_dual = translate_stream(
s_raw,
envs=kwarg.get("envs", {}),
prompt=kwarg.get("prompt", []),
generate_cache_executor=None,
**locals(),
)
file_mono = Path(output) / f"{filename}-mono.pdf"
Expand Down
70 changes: 41 additions & 29 deletions pdf2zh/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def __init__(self, lang_in, lang_out, model):
},
)

self.translate_call_count = 0
self.translate_cache_call_count = 0

def __del__(self):
print(f"{self.name} translate call count: {self.translate_call_count}")
print(f"{self.name} translate cache call count: {self.translate_cache_call_count}")

def set_envs(self, envs):
# Detach from self.__class__.envs
# Cannot use self.envs = copy(self.__class__.envs)
Expand All @@ -114,9 +121,11 @@ def translate(self, text, ignore_cache=False):
:param text: text to translate
:return: translated text
"""
self.translate_call_count += 1
if not (self.ignore_cache or ignore_cache):
cache = self.cache.get(text)
if cache is not None:
self.translate_cache_call_count += 1
return cache
_translate_rate_limiter.wait()
translation = self.do_translate(text)
Expand Down Expand Up @@ -148,7 +157,8 @@ def prompt(self, text, prompt):
},
{
"role": "user",
"content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation {{v*}} unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:", # noqa: E501
"content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation {{v*}} unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",
# noqa: E501
},
]

Expand All @@ -165,7 +175,8 @@ def __init__(self, lang_in, lang_out, model, **kwargs):
self.session = requests.Session()
self.endpoint = "http://translate.google.com/m"
self.headers = {
"User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)" # noqa: E501
"User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)"
# noqa: E501
}

def do_translate(self, text):
Expand Down Expand Up @@ -196,7 +207,8 @@ def __init__(self, lang_in, lang_out, model, **kwargs):
self.session = requests.Session()
self.endpoint = "https://www.bing.com/translator"
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36 Edg/131.0.0.0", # noqa: E501
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36 Edg/131.0.0.0",
# noqa: E501
}

def find_sid(self):
Expand Down Expand Up @@ -330,14 +342,14 @@ class OpenAITranslator(BaseTranslator):
CustomPrompt = True

def __init__(
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
):
self.set_envs(envs)
if not model:
Expand Down Expand Up @@ -369,14 +381,14 @@ class AzureOpenAITranslator(BaseTranslator):
CustomPrompt = True

def __init__(
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
):
self.set_envs(envs)
base_url = self.envs["AZURE_OPENAI_BASE_URL"]
Expand Down Expand Up @@ -414,14 +426,14 @@ class ModelScopeTranslator(OpenAITranslator):
CustomPrompt = True

def __init__(
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
self,
lang_in,
lang_out,
model,
base_url=None,
api_key=None,
envs=None,
prompt=None,
):
self.set_envs(envs)
base_url = "https://api-inference.modelscope.cn/v1"
Expand Down Expand Up @@ -463,8 +475,8 @@ def do_translate(self, text) -> str:
)
except openai.BadRequestError as e:
if (
json.loads(response.choices[0].message.content.strip())["error"]["code"]
== "1301"
json.loads(response.choices[0].message.content.strip())["error"]["code"]
== "1301"
):
return "IRREPARABLE TRANSLATION ERROR"
raise e
Expand Down

0 comments on commit 261cac6

Please sign in to comment.