Skip to content

Commit

Permalink
feat: collect data for report building
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Apr 22, 2024
1 parent 28f056d commit 3a73f87
Show file tree
Hide file tree
Showing 11 changed files with 406 additions and 97 deletions.
37 changes: 7 additions & 30 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

189 changes: 189 additions & 0 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import logging
import time
from threading import Thread
from typing import cast

from src.modules.csm.typings import FramePerformance, AttestationsAggregate
from src.providers.consensus.client import ConsensusClient
from src.typings import EpochNumber, BlockRoot, SlotNumber, BlockStamp, ValidatorIndex
from src.utils.web3converter import Web3Converter

logger = logging.getLogger(__name__)


class CheckpointsFactory:
cc: ConsensusClient
converter: Web3Converter
frame_performance: FramePerformance

def __init__(self, cc: ConsensusClient, converter: Web3Converter, frame_performance: FramePerformance):
self.cc = cc
self.converter = converter
self.frame_performance = frame_performance

def prepare_checkpoints(
self,
l_epoch: EpochNumber,
r_epoch: EpochNumber,
finalized_epoch: EpochNumber
):
def _prepare_checkpoint(_slot: SlotNumber, _duty_epochs: list[EpochNumber]):
return Checkpoint(self.cc, self.converter, self.frame_performance, _slot, _duty_epochs)

processing_delay = finalized_epoch - (max(self.frame_performance.processed, default=0) or l_epoch)
# - max checkpoint step is 255 because it should be less than
# the state block roots size (8192 blocks = 256 epochs) to check 64 roots per committee from one state
# - min checkpoint step is 10 because it's a reasonable number of epochs to process at once (~1 hour)
checkpoint_step = min(255, max(processing_delay, 10))
duty_epochs = cast(list[EpochNumber], list(range(l_epoch, r_epoch + 1)))

checkpoints: list[Checkpoint] = []
for index, epoch in enumerate(duty_epochs, 1):
if index % checkpoint_step != 0 and epoch != r_epoch:
continue
slot = self.converter.get_epoch_last_slot(EpochNumber(epoch + 1))
if epoch == r_epoch:
checkpoints.append(_prepare_checkpoint(slot, duty_epochs[index - index % checkpoint_step: index]))
else:
checkpoints.append(_prepare_checkpoint(slot, duty_epochs[index - checkpoint_step: index]))
return checkpoints


class Checkpoint:
# TODO: should be configurable or calculated based on the system resources
MAX_THREADS: int = 4

cc: ConsensusClient
converter: Web3Converter

threads: list[Thread]
frame_performance: FramePerformance

slot: SlotNumber # last slot of the epoch
duty_epochs: list[EpochNumber] # max 255 elements
block_roots: list[BlockRoot] # max 8192 elements

def __init__(
self,
cc: ConsensusClient,
converter: Web3Converter,
frame_performance: FramePerformance,
slot: SlotNumber,
duty_epochs: list[EpochNumber]
):
self.cc = cc
self.converter = converter
self.slot = slot
self.duty_epochs = duty_epochs
self.block_roots = []
self.threads = []
self.frame_performance = frame_performance

@property
def free_threads(self):
return self.MAX_THREADS - len(self.threads)

def process(self, last_finalized_blockstamp: BlockStamp):
for duty_epoch in self.duty_epochs:
if duty_epoch in self.frame_performance.processed:
continue
if not self.block_roots:
self._get_block_roots()
roots_to_check = self._select_roots_to_check(duty_epoch)
if not self.free_threads:
self._await_oldest_thread()
thread = Thread(
target=self._process_epoch, args=(last_finalized_blockstamp, duty_epoch, roots_to_check)
)
thread.start()
self.threads.append(thread)
self._await_all_threads()

def _await_oldest_thread(self):
old = self.threads.pop(0)
old.join()

def _await_all_threads(self):
for thread in self.threads:
thread.join()

def _select_roots_to_check(
self, duty_epoch: EpochNumber
) -> list[BlockRoot]:
# copy of
# https://github.com/ethereum/consensus-specs/blob/dev/specs/phase0/beacon-chain.md#get_block_root_at_slot
roots_to_check = []
slots = range(
self.converter.get_epoch_first_slot(duty_epoch),
self.converter.get_epoch_last_slot(EpochNumber(duty_epoch + 1))
)
for slot in slots:
# TODO: get the magic number from the CL spec
if slot + 8192 < self.slot < slot:
raise ValueError("Slot is out of the state block roots range")
roots_to_check.append(self.block_roots[slot % 8192])
return roots_to_check

def _get_block_roots(self):
logger.info({"msg": f"Get block roots for slot {self.slot}"})
# checkpoint for us like a time point, that's why we use slot, not root
br = self.cc.get_state_block_roots(self.slot)
# replace duplicated roots to None to mark missed slots
self.block_roots = [None if br[i] == br[i - 1] else br[i] for i in range(len(br))]

def _process_epoch(
self,
last_finalized_blockstamp: BlockStamp,
duty_epoch: EpochNumber,
roots_to_check: list[BlockRoot]
):
logger.info({"msg": f"Process epoch {duty_epoch}"})
start = time.time()
committees = self._prepare_committees(last_finalized_blockstamp, EpochNumber(duty_epoch))
for root in roots_to_check:
if root is None:
continue
slot_data = self.cc.get_block_details_raw(BlockRoot(root))
self._process_attestations(slot_data, committees)

self.frame_performance.processed.add(EpochNumber(duty_epoch))
self.frame_performance.dump()
logger.info({"msg": f"Epoch {duty_epoch} processed in {time.time() - start:.2f} seconds"})

def _prepare_committees(self, last_finalized_blockstamp: BlockStamp, epoch: int) -> dict:
start = time.time()
committees = {}
for committee in self.cc.get_attestation_committees(last_finalized_blockstamp, EpochNumber(epoch)):
committees[f"{committee.slot}{committee.index}"] = committee.validators
for validator in committee.validators:
val = self.frame_performance.aggr_per_val.get(
ValidatorIndex(int(validator)), AttestationsAggregate(0, 0)
)
val.assigned += 1
self.frame_performance.aggr_per_val[ValidatorIndex(int(validator))] = val
logger.info({"msg": f"Committees for epoch {epoch} processed in {time.time() - start:.2f} seconds"})
return committees

def _process_attestations(self, slot_data: dict, committees: dict) -> None:
def to_bits(aggregation_bits: str):
# copied from https://github.com/ethereum/py-ssz/blob/main/ssz/sedes/bitvector.py#L66
att_bytes = bytes.fromhex(aggregation_bits[2:])
return [
bool((att_bytes[bit_index // 8] >> bit_index % 8) % 2) for bit_index in range(len(att_bytes) * 8)
]

for attestation in slot_data['message']['body']['attestations']:
committee_id = f"{attestation['data']['slot']}{attestation['data']['index']}"
committee = committees.get(committee_id)
att_bits = to_bits(attestation['aggregation_bits'])
if not committee:
continue
for index, validator in enumerate(committee):
if validator is None:
# validator has already fulfilled its duties
continue
attested = att_bits[index]
if attested:
self.frame_performance.aggr_per_val[ValidatorIndex(int(validator))].included += 1
# duty is fulfilled, so we can remove validator from the committee
committees[committee_id][index] = None
70 changes: 54 additions & 16 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import time
from functools import cached_property
import logging

from web3.types import BlockIdentifier

from src.metrics.prometheus.business import CONTRACT_ON_PAUSE
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import CheckpointsFactory
from src.modules.csm.typings import FramePerformance, ReportData
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.typings import BlockStamp, ReferenceBlockStamp, SlotNumber
from src.typings import BlockStamp, ReferenceBlockStamp, SlotNumber, EpochNumber
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
from src.web3py.typings import Web3

Expand All @@ -33,19 +36,22 @@ class CSFeeOracle(BaseModule, ConsensusModule):
def __init__(self, w3: Web3):
self.report_contract = w3.csm.oracle
super().__init__(w3)
self.frame_performance: FramePerformance | None
self.frame_performance: FramePerformance | None = None
# TODO: Feed the cache with the data about the attestations observed so far.

def refresh_contracts(self):
self.report_contract = self.w3.csm.oracle

def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecuteDelay:
report_blockstamp = self.get_blockstamp_for_report(last_finalized_blockstamp)

collected = self._collect_data(last_finalized_blockstamp)

if not collected:
# The data is not fully collected yet, wait for the next epoch
return ModuleExecuteDelay.NEXT_FINALIZED_EPOCH

if not report_blockstamp:
# TODO: To get ref_slot and if it's in the finalized epoch, wait one more epoch.
# Feed the cache with the data about the attestations observed so far.
self._collect_data(self._get_latest_blockstamp())
return ModuleExecuteDelay.NEXT_FINALIZED_EPOCH

self.process_report(report_blockstamp)
Expand Down Expand Up @@ -124,25 +130,57 @@ def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> Validat
def _is_paused(self, blockstamp: ReferenceBlockStamp) -> bool:
return self.report_contract.functions.isPaused().call(block_identifier=blockstamp.block_hash)

def _collect_data(self, blockstamp: BlockStamp) -> None:
last_ref_slot = self.w3.csm.get_csm_last_processing_ref_slot(blockstamp)
ref_slot = self.get_current_frame(blockstamp).ref_slot
def _collect_data(self, last_finalized_blockstamp: BlockStamp) -> bool:
"""Ongoing report data collection before the report ref slot and it's submission"""
converter = Web3Converter(
self.get_chain_config(last_finalized_blockstamp), self.get_frame_config(last_finalized_blockstamp)
)

l_ref_slot = self.w3.csm.get_csm_last_processing_ref_slot(last_finalized_blockstamp)
r_ref_slot = self.get_current_frame(last_finalized_blockstamp).ref_slot

# TODO: To think about the proper cache invalidation conditions.
if self.frame_performance:
if self.frame_performance.l_slot < last_ref_slot:
if self.frame_performance.l_slot < l_ref_slot:
self.frame_performance = None

if not self.frame_performance:
self.frame_performance = FramePerformance.try_read(ref_slot) or FramePerformance(
l_slot=last_ref_slot, r_slot=ref_slot
self.frame_performance = FramePerformance.try_read(r_ref_slot) or FramePerformance(
l_slot=l_ref_slot, r_slot=r_ref_slot
)

# Get the network validators from the 'finalized' state.
# Starting the min(r_slot, finalized) slot follow the parent block roots to collect the attestations data back to the l_slot.
# TODO: 1 epoch boundaries to get all the attestations.

self.frame_performance.dump()
# Finalized slot is the first slot of justifying epoch, so we need to take the previous
finalized_epoch = EpochNumber(converter.get_epoch_by_slot(last_finalized_blockstamp.slot_number) - 1)

l_epoch = EpochNumber(converter.get_epoch_by_slot(l_ref_slot) + 1)
if l_epoch > finalized_epoch:
return False
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

factory = CheckpointsFactory(self.w3.cc, converter, self.frame_performance)
checkpoints = factory.prepare_checkpoints(l_epoch, r_epoch, finalized_epoch)

start = time.time()
for checkpoint in checkpoints:
if converter.get_epoch_by_slot(checkpoint.slot) > finalized_epoch:
# checkpoint isn't finalized yet, can't be processed
break
checkpoint.process(last_finalized_blockstamp)
delay = time.time() - start
logger.info({"msg": f"All epochs processed in {delay:.2f} seconds"})

self._print_result()
return self.frame_performance.is_coherent

def _print_result(self):
assigned = 0
inc = 0
for _, aggr in self.frame_performance.aggr_per_val.items():
assigned += aggr.assigned
inc += aggr.included

logger.info({"msg": f"Assigned: {assigned}"})
logger.info({"msg": f"Included: {inc}"})

def _to_distribute(self, blockstamp: ReferenceBlockStamp) -> int:
return self.w3.csm.fee_distributor.pending_to_distribute(blockstamp.block_hash)
Expand Down
2 changes: 2 additions & 0 deletions src/modules/csm/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def as_tuple(self):
self.distributed,
)


@dataclass
class AttestationsAggregate:
assigned: int
Expand All @@ -33,6 +34,7 @@ class AttestationsAggregate:
def perf(self) -> float:
return self.assigned / self.included


@dataclass
class FramePerformance:
"""Data structure to store required data for performance calculation within the given frame."""
Expand Down
Loading

0 comments on commit 3a73f87

Please sign in to comment.