diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index b16f35f0f..480518aac 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -1,14 +1,15 @@ import logging import time -from threading import Thread +from threading import Thread, Lock from typing import cast -from src.modules.csm.typings import FramePerformance, AttestationsAggregate +from src.modules.csm.typings import FramePerformance from src.providers.consensus.client import ConsensusClient -from src.typings import EpochNumber, BlockRoot, SlotNumber, BlockStamp, ValidatorIndex +from src.typings import EpochNumber, BlockRoot, SlotNumber, BlockStamp from src.utils.web3converter import Web3Converter logger = logging.getLogger(__name__) +lock = Lock() class CheckpointsFactory: @@ -115,7 +116,7 @@ def _select_roots_to_check( roots_to_check = [] slots = range( self.converter.get_epoch_first_slot(duty_epoch), - self.converter.get_epoch_last_slot(EpochNumber(duty_epoch + 1)) + self.converter.get_epoch_last_slot(EpochNumber(duty_epoch + 1)) + 1 ) for slot in slots: # TODO: get the magic number from the CL spec @@ -145,22 +146,19 @@ def _process_epoch( continue slot_data = self.cc.get_block_details_raw(BlockRoot(root)) self._process_attestations(slot_data, committees) - - self.frame_performance.processed.add(EpochNumber(duty_epoch)) - self.frame_performance.dump() + with lock: + self.frame_performance.dump(duty_epoch, committees) logger.info({"msg": f"Epoch {duty_epoch} processed in {time.time() - start:.2f} seconds"}) def _prepare_committees(self, last_finalized_blockstamp: BlockStamp, epoch: int) -> dict: start = time.time() committees = {} for committee in self.cc.get_attestation_committees(last_finalized_blockstamp, EpochNumber(epoch)): - committees[f"{committee.slot}{committee.index}"] = committee.validators + validators = [] for validator in committee.validators: - val = self.frame_performance.aggr_per_val.get( - ValidatorIndex(int(validator)), AttestationsAggregate(0, 0) - ) - val.assigned += 1 - self.frame_performance.aggr_per_val[ValidatorIndex(int(validator))] = val + data = {"index": validator, "included": False} + validators.append(data) + committees[f"{committee.slot}{committee.index}"] = validators logger.info({"msg": f"Committees for epoch {epoch} processed in {time.time() - start:.2f} seconds"}) return committees @@ -178,12 +176,11 @@ def to_bits(aggregation_bits: str): att_bits = to_bits(attestation['aggregation_bits']) if not committee: continue - for index, validator in enumerate(committee): - if validator is None: + for index_in_committee, validator in enumerate(committee): + if validator['included']: # validator has already fulfilled its duties continue - attested = att_bits[index] + attested = att_bits[index_in_committee] if attested: - self.frame_performance.aggr_per_val[ValidatorIndex(int(validator))].included += 1 - # duty is fulfilled, so we can remove validator from the committee - committees[committee_id][index] = None + validator['included'] = True + committees[committee_id][index_in_committee] = validator diff --git a/src/modules/csm/typings.py b/src/modules/csm/typings.py index 608ae775e..9f363ab93 100644 --- a/src/modules/csm/typings.py +++ b/src/modules/csm/typings.py @@ -51,14 +51,21 @@ class FramePerformance: to_distribute: int = 0 last_cid: str | None = None - @property def avg_perf(self) -> float: """Returns average performance of all validators in the cache.""" return mean((a.perf for a in self.aggr_per_val.values())) - def dump(self) -> None: + def dump(self, epoch: EpochNumber, committees: dict) -> None: """Used to persist the current state of the structure.""" + # TODO: persist the data. no sense to keep it in memory (except of `processed` ?) + self.processed.add(epoch) + for committee in committees.values(): + for validator in committee: + perf_data = self.aggr_per_val.get(validator['index'], AttestationsAggregate(0, 0)) + perf_data.assigned += 1 + perf_data.included += validator['included'] + self.aggr_per_val[validator['index']] = perf_data @classmethod def try_read(cls, ref_slot: SlotNumber) -> Self | None: