Skip to content

Commit

Permalink
feat: epochs_to_process (#445)
Browse files Browse the repository at this point in the history
* feat: `epochs_to_process`

* feat: two sets

* feat: advanced validation

* fix: use `unprocessed_epochs`

* fix: private state properties

* revert: `_data` -> ``data`
  • Loading branch information
vgorkavenko authored May 9, 2024
1 parent a3ed61f commit a7e6a78
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 65 deletions.
27 changes: 9 additions & 18 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from threading import Lock
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Iterable, cast
from typing import cast

from timeout_decorator import TimeoutError as DecoratorTimeoutError

Expand Down Expand Up @@ -42,17 +42,7 @@ def prepare_checkpoints(
def _prepare_checkpoint(_slot: SlotNumber, _duty_epochs: list[EpochNumber]):
return Checkpoint(self.cc, self.converter, self.state, _slot, _duty_epochs)

def _max_in_seq(items: Iterable[Any]) -> Any:
sorted_ = sorted(items)
assert sorted_
item = sorted_[0]
for curr in sorted_:
if curr - item > 1:
break
item = curr
return item

l_epoch = _max_in_seq((l_epoch, *self.state.processed_epochs))
l_epoch = min(self.state.unprocessed_epochs) or l_epoch
if l_epoch == r_epoch:
logger.info({"msg": "All epochs processed. No checkpoint required."})
return []
Expand Down Expand Up @@ -114,11 +104,10 @@ def __init__(
def process(self, last_finalized_blockstamp: BlockStamp):
def _unprocessed():
for _epoch in self.duty_epochs:
if _epoch in self.state.processed_epochs:
continue
if not self.block_roots:
self._get_block_roots()
yield _epoch
if _epoch in self.state.unprocessed_epochs:
if not self.block_roots:
self._get_block_roots()
yield _epoch

with ThreadPoolExecutor() as ext:
try:
Expand Down Expand Up @@ -188,7 +177,9 @@ def _process_epoch(
ValidatorIndex(int(validator['index'])),
included=validator['included'],
)
self.state.processed_epochs.add(duty_epoch)
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
self.state.add_processed_epoch(duty_epoch)
self.state.commit()
self.state.status()

Expand Down
59 changes: 16 additions & 43 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from src.metrics.prometheus.business import CONTRACT_ON_PAUSE
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import CheckpointsFactory
from src.modules.csm.state import State
from src.modules.csm.state import State, InvalidState
from src.modules.csm.tree import Tree
from src.modules.csm.types import ReportData
from src.modules.submodules.consensus import ConsensusModule
Expand All @@ -17,7 +17,6 @@
from src.providers.execution.contracts.CSFeeOracle import CSFeeOracle
from src.typings import BlockStamp, EpochNumber, ReferenceBlockStamp, SlotNumber, ValidatorIndex
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.range import sequence
from src.utils.slot import get_first_non_missed_slot
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
Expand All @@ -26,10 +25,6 @@
logger = logging.getLogger(__name__)


class InvalidState(Exception):
...


class CSOracle(BaseModule, ConsensusModule):
"""
CSM performance module collects performance of CSM node operators and creates a Merkle tree of the resulting
Expand Down Expand Up @@ -72,18 +67,22 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute
@lru_cache(maxsize=1)
@duration_meter()
def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
# pylint: disable=too-many-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches,too-many-statements
assert self.state
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
converter = self.converter(blockstamp)
l_epoch = EpochNumber(converter.get_epoch_by_slot(l_ref_slot) + 1)
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

try:
self.validate_state(blockstamp, full=True)
except InvalidState as ex:
raise ValueError("Unable to build report") from ex
self.state.validate_for_report(l_epoch, r_epoch)
except InvalidState as e:
raise ValueError("State is not valid for the report") from e

self.state.status()

threshold = self.state.avg_perf * self.w3.csm.oracle.perf_threshold(blockstamp.block_hash)
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)

# NOTE: r_block is guaranteed to be <= ref_slot, and the check
# in the inner frames assures the l_block <= r_block.
stuck_operators = self.w3.csm.get_csm_stuck_node_operators(
Expand Down Expand Up @@ -204,43 +203,13 @@ def module(self) -> StakingModule:
def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> ValidatorsByNodeOperator:
return self.w3.lido_validators.get_module_validators_by_node_operators(self.module.id, blockstamp)

def validate_state(self, blockstamp: BlockStamp, full: bool = False) -> None:
assert self.state
converter = self.converter(blockstamp)
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
l_epoch = EpochNumber(converter.get_epoch_by_slot(l_ref_slot) + 1)
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

for epoch in self.state.processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
logger.info({"msg": f"Invalid state: processed {epoch=}, but range is [{l_epoch};{r_epoch}]"})
raise InvalidState()

if full:
for epoch in sequence(l_epoch, r_epoch):
if epoch not in self.state.processed_epochs:
logger.info({"msg": f"Invalid state: {epoch} was not processed"})
raise InvalidState()

def collect_data(self, blockstamp: BlockStamp) -> bool:
"""Ongoing report data collection before the report ref slot and it's submission"""
logger.info({"msg": "Collecting data for the report"})

l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
logger.info({"msg": f"Frame for performance data collect: ({l_ref_slot};{r_ref_slot}]"})

self.state = self.state or State.load()

try:
self.validate_state(blockstamp)
except InvalidState:
logger.info({"msg": "Discarding invalidated state cache"})
self.state.clear()
self.state.commit()

self.state.status()

converter = self.converter(blockstamp)
# Finalized slot is the first slot of justifying epoch, so we need to take the previous
finalized_epoch = EpochNumber(converter.get_epoch_by_slot(blockstamp.slot_number) - 1)
Expand All @@ -250,6 +219,10 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
return False
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

self.state = self.state or State.load()
self.state.validate_for_collect(l_epoch, r_epoch)
self.state.status()

factory = CheckpointsFactory(self.w3.cc, converter, self.state)
checkpoints = factory.prepare_checkpoints(l_epoch, r_epoch, finalized_epoch)

Expand All @@ -265,7 +238,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
if checkpoints:
logger.info({"msg": f"All epochs were processed in {time.time() - start:.2f} seconds"})

return all(epoch in self.state.processed_epochs for epoch in sequence(l_epoch, r_epoch))
return self.state.is_fulfilled

@lru_cache(maxsize=1)
def current_frame_range(self, blockstamp: BlockStamp) -> tuple[SlotNumber, SlotNumber]:
Expand Down
70 changes: 66 additions & 4 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
from pathlib import Path

from src.typings import EpochNumber, ValidatorIndex
from src.utils.range import sequence

logger = logging.getLogger(__name__)


class InvalidState(Exception):
...


@dataclass
class AttestationsAggregate:
assigned: int
Expand All @@ -26,9 +31,9 @@ def perf(self) -> float:
@dataclass
class State(UserDict[ValidatorIndex, AttestationsAggregate]):
"""Tracks processing state of CSM performance oracle frame"""

data: dict[ValidatorIndex, AttestationsAggregate] = field(default_factory=dict)
processed_epochs: set[EpochNumber] = field(default_factory=set)
_epochs_to_process: set[EpochNumber] = field(default_factory=set)
_processed_epochs: set[EpochNumber] = field(default_factory=set)

EXTENSION = ".pkl"

Expand Down Expand Up @@ -60,28 +65,85 @@ def commit(self) -> None:

def clear(self) -> None:
self.data = {}
self.processed_epochs.clear()
self._epochs_to_process.clear()
self._processed_epochs.clear()
assert self.is_empty

def inc(self, key: ValidatorIndex, included: bool) -> None:
perf = self.get(key, AttestationsAggregate(0, 0))
perf.assigned += 1
perf.included += 1 if included else 0
self[key] = perf

def add_processed_epoch(self, epoch: EpochNumber) -> None:
self._processed_epochs.add(epoch)

def status(self) -> None:
assigned, included = reduce(
lambda acc, aggr: (acc[0] + aggr.assigned, acc[1] + aggr.included), self.values(), (0, 0)
)

logger.info(
{
"msg": f"Processed {len(self.processed_epochs)} epochs",
"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs",
"assigned": assigned,
"included": included,
"avg_perf": self.avg_perf,
}
)

def validate_for_report(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
if not self.is_fulfilled:
raise InvalidState()

for epoch in self._processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
raise InvalidState()

for epoch in sequence(l_epoch, r_epoch):
if epoch not in self._processed_epochs:
raise InvalidState()

def validate_for_collect(self, l_epoch: EpochNumber, r_epoch: EpochNumber):

invalidated = False

for epoch in self._epochs_to_process:
if l_epoch <= epoch <= r_epoch:
continue
invalidated = True
break

for epoch in self._processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
invalidated = True
break

if invalidated:
logger.warning({"msg": "Discarding invalidated state cache"})
self.clear()
self.commit()

if self.is_empty or r_epoch > max(self._epochs_to_process):
self._epochs_to_process.update(sequence(l_epoch, r_epoch))
self.commit()

@property
def is_empty(self) -> bool:
return not self.data and not self._epochs_to_process and not self._processed_epochs

@property
def unprocessed_epochs(self) -> set[EpochNumber]:
if not self._epochs_to_process:
raise ValueError("Epochs to process are not set")
return self._epochs_to_process - self._processed_epochs

@property
def is_fulfilled(self) -> bool:
return not self.unprocessed_epochs

@property
def avg_perf(self) -> float:
"""Returns average performance of all validators in the cache"""
Expand Down

0 comments on commit a7e6a78

Please sign in to comment.