Skip to content

Commit

Permalink
Optimize translation failure logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ionic-bond committed Dec 24, 2024
1 parent 0b82d1b commit 7d5b9c7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
1 change: 1 addition & 0 deletions stream_translator_gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, audio: np.array, time_range: tuple[float, float]):
self.translated_text = None
self.time_range = time_range
self.start_time = None
self.translation_failed = False


def _auto_args(func, kwargs):
Expand Down
73 changes: 46 additions & 27 deletions stream_translator_gpt/llm_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


# The double quotes in the values of JSON have not been escaped, so manual escaping is necessary.
def escape_specific_quotes(input_string):
def _escape_specific_quotes(input_string):
quote_positions = [i for i, char in enumerate(input_string) if char == '"']

if len(quote_positions) <= 4:
Expand All @@ -37,7 +37,7 @@ def _parse_json_completion(completion):
return completion

json_str = json_match.group(0)
json_str = escape_specific_quotes(json_str)
json_str = _escape_specific_quotes(json_str)

try:
json_obj = json.loads(json_str)
Expand All @@ -49,6 +49,10 @@ def _parse_json_completion(completion):
return completion


def _is_task_timeout(task: TranslationTask, timeout: float) -> bool:
return datetime.utcnow() - task.start_time > timedelta(seconds=timeout)


class LLMClint():

class LLM_TYPE:
Expand Down Expand Up @@ -106,6 +110,7 @@ def _translate_by_gpt(self, translation_task: TranslationTask):
if self.use_json_result:
translation_task.translated_text = _parse_json_completion(translation_task.translated_text)
except (APITimeoutError, APIConnectionError) as e:
translation_task.translation_failed = True
print(e)
return
if self.history_size:
Expand Down Expand Up @@ -142,6 +147,7 @@ def _translate_by_gemini(self, translation_task: TranslationTask):
if self.use_json_result:
translation_task.translated_text = _parse_json_completion(translation_task.translated_text)
except (ValueError, InternalServerError, ResourceExhausted, TooManyRequests) as e:
translation_task.translation_failed = True
print(e)
return
if self.history_size:
Expand All @@ -165,36 +171,44 @@ def __init__(self, llm_client: LLMClint, timeout: int, retry_if_translation_fail
self.retry_if_translation_fails = retry_if_translation_fails
self.processing_queue = deque()

def trigger(self, translation_task: TranslationTask):
self.processing_queue.append(translation_task)
translation_task.start_time = datetime.utcnow()
def _trigger(self, translation_task: TranslationTask):
if not translation_task.start_time:
translation_task.start_time = datetime.utcnow()
translation_task.translation_failed = False
thread = threading.Thread(target=self.llm_client.translate, args=(translation_task,))
thread.daemon = True
thread.start()

def get_results(self):

def _retrigger_failed_tasks(self):
for task in self.processing_queue:
if task.translation_failed and not _is_task_timeout(task, self.timeout):
self._trigger(task)
print('Translation failed: {}'.format(task.transcribed_text))
time.sleep(1)

def _get_results(self):
results = []
while self.processing_queue and (self.processing_queue[0].translated_text or datetime.utcnow() -
self.processing_queue[0].start_time > timedelta(seconds=self.timeout)):
while self.processing_queue and (self.processing_queue[0].translated_text or _is_task_timeout(self.processing_queue[0], self.timeout) or (self.processing_queue[0].translation_failed and not self.retry_if_translation_fails)):
task = self.processing_queue.popleft()
if task.translated_text:
results.append(task)
else:
if self.retry_if_translation_fails:
self.trigger(task)
if not task.translated_text:
if _is_task_timeout(task, self.timeout):
print('Translation timeout: {}'.format(task.transcribed_text))
else:
results.append(task)
print('Translation timeout or failed: {}'.format(task.transcribed_text))
print('Translation failed: {}'.format(task.transcribed_text))
results.append(task)
return results

def loop(self, input_queue: queue.SimpleQueue[TranslationTask], output_queue: queue.SimpleQueue[TranslationTask]):
while True:
if not input_queue.empty() and len(self.processing_queue) < self.PARALLEL_MAX_NUMBER:
task = input_queue.get()
self.trigger(task)
finished_tasks = self.get_results()
self.processing_queue.append(task)
self._trigger(task)
finished_tasks = self._get_results()
for task in finished_tasks:
output_queue.put(task)
if self.retry_if_translation_fails:
self._retrigger_failed_tasks()
time.sleep(0.1)


Expand All @@ -205,8 +219,10 @@ def __init__(self, llm_client: LLMClint, timeout: int, retry_if_translation_fail
self.timeout = timeout
self.retry_if_translation_fails = retry_if_translation_fails

def trigger(self, translation_task: TranslationTask):
translation_task.start_time = datetime.utcnow()
def _trigger(self, translation_task: TranslationTask):
if not translation_task.start_time:
translation_task.start_time = datetime.utcnow()
translation_task.translation_failed = False
thread = threading.Thread(target=self.llm_client.translate, args=(translation_task,))
thread.daemon = True
thread.start()
Expand All @@ -215,17 +231,20 @@ def loop(self, input_queue: queue.SimpleQueue[TranslationTask], output_queue: qu
current_task = None
while True:
if current_task:
if (current_task.translated_text or
datetime.utcnow() - current_task.start_time > timedelta(seconds=self.timeout)):
if (current_task.translated_text or current_task.translation_failed or _is_task_timeout(current_task, self.timeout)):
if not current_task.translated_text:
if self.retry_if_translation_fails:
self.trigger(current_task)
continue
print('Translation timeout or failed: {}'.format(current_task.transcribed_text))
if _is_task_timeout(current_task, self.timeout):
print('Translation timeout: {}'.format(current_task.transcribed_text))
else:
print('Translation failed: {}'.format(current_task.transcribed_text))
if self.retry_if_translation_fails:
self._trigger(current_task)
time.sleep(1)
continue
output_queue.put(current_task)
current_task = None

if current_task is None and not input_queue.empty():
current_task = input_queue.get()
self.trigger(current_task)
self._trigger(current_task)
time.sleep(0.1)

0 comments on commit 7d5b9c7

Please sign in to comment.