From cabde40fb8edaf7ea4d39044dc1d09b19dc667a5 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Mon, 6 May 2024 12:09:52 +0400 Subject: [PATCH] feat: handle thread exception in processing --- src/modules/csm/checkpoint.py | 49 ++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index 84a0d27a5..46ab2b9fd 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -1,6 +1,8 @@ import logging import time + from threading import Thread, Lock +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Iterable, cast from src.modules.csm.state import State @@ -91,6 +93,7 @@ class Checkpoint: converter: Web3Converter threads: list[Thread] + thread_exception: Exception state: State slot: SlotNumber # last slot of the epoch @@ -118,30 +121,35 @@ def free_threads(self): return self.MAX_THREADS - len(self.threads) def process(self, last_finalized_blockstamp: BlockStamp): - for duty_epoch in self.duty_epochs: - if duty_epoch in self.state.processed_epochs: - continue - if not self.block_roots: - self._get_block_roots() - roots_to_check = self._select_roots_to_check(duty_epoch) - if not self.free_threads: - self._await_oldest_thread() - # TODO: handle error in the thread. wait all, then raise - thread = Thread( - target=self._process_epoch, args=(last_finalized_blockstamp, duty_epoch, roots_to_check) - ) - thread.start() - self.threads.append(thread) - self._await_all_threads() + def _unprocessed(): + for _epoch in self.duty_epochs: + if _epoch in self.state.processed_epochs: + continue + if not self.block_roots: + self._get_block_roots() + yield _epoch + + with ThreadPoolExecutor() as ext: + futures = { + ext.submit(self._process_epoch, last_finalized_blockstamp, duty_epoch): duty_epoch + for duty_epoch in _unprocessed() + } + for future in as_completed(futures): + duty_epoch = futures[future] + try: + future.result() + except Exception as e: + logger.error({"msg": f"Error processing epoch {duty_epoch} in thread", "error": str(e)}) + raise e + + def _await_all_threads(self): + while self.threads: + self._await_oldest_thread() def _await_oldest_thread(self): old = self.threads.pop(0) old.join() - def _await_all_threads(self): - for thread in self.threads: - thread.join() - def _select_roots_to_check( self, duty_epoch: EpochNumber ) -> list[BlockRoot | None]: @@ -171,12 +179,11 @@ def _process_epoch( self, last_finalized_blockstamp: BlockStamp, duty_epoch: EpochNumber, - roots_to_check: list[BlockRoot] ): logger.info({"msg": f"Process epoch {duty_epoch}"}) start = time.time() committees = self._prepare_committees(last_finalized_blockstamp, EpochNumber(duty_epoch)) - for root in roots_to_check: + for root in self._select_roots_to_check(duty_epoch): if root is None: continue slot_data = self.cc.get_block_details_raw(BlockRoot(root))