Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: epochs_to_process #445

Merged
merged 8 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from threading import Lock
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Iterable, cast
from typing import cast

from timeout_decorator import TimeoutError as DecoratorTimeoutError

Expand Down Expand Up @@ -42,17 +42,7 @@ def prepare_checkpoints(
def _prepare_checkpoint(_slot: SlotNumber, _duty_epochs: list[EpochNumber]):
return Checkpoint(self.cc, self.converter, self.state, _slot, _duty_epochs)

def _max_in_seq(items: Iterable[Any]) -> Any:
sorted_ = sorted(items)
assert sorted_
item = sorted_[0]
for curr in sorted_:
if curr - item > 1:
break
item = curr
return item

l_epoch = _max_in_seq((l_epoch, *self.state.processed_epochs))
l_epoch = min(self.state.unprocessed_epochs) or l_epoch
if l_epoch == r_epoch:
logger.info({"msg": "All epochs processed. No checkpoint required."})
return []
Expand Down Expand Up @@ -114,11 +104,10 @@ def __init__(
def process(self, last_finalized_blockstamp: BlockStamp):
def _unprocessed():
for _epoch in self.duty_epochs:
if _epoch in self.state.processed_epochs:
continue
if not self.block_roots:
self._get_block_roots()
yield _epoch
if _epoch in self.state.unprocessed_epochs:
if not self.block_roots:
self._get_block_roots()
yield _epoch

with ThreadPoolExecutor() as ext:
try:
Expand Down Expand Up @@ -188,7 +177,9 @@ def _process_epoch(
ValidatorIndex(int(validator['index'])),
included=validator['included'],
)
self.state.processed_epochs.add(duty_epoch)
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
self.state.add_processed_epoch(duty_epoch)
self.state.commit()
self.state.status()

Expand Down
59 changes: 16 additions & 43 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from src.metrics.prometheus.business import CONTRACT_ON_PAUSE
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import CheckpointsFactory
from src.modules.csm.state import State
from src.modules.csm.state import State, InvalidState
from src.modules.csm.tree import Tree
from src.modules.csm.types import ReportData
from src.modules.submodules.consensus import ConsensusModule
Expand All @@ -17,7 +17,6 @@
from src.providers.execution.contracts.CSFeeOracle import CSFeeOracle
from src.typings import BlockStamp, EpochNumber, ReferenceBlockStamp, SlotNumber, ValidatorIndex
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.range import sequence
from src.utils.slot import get_first_non_missed_slot
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
Expand All @@ -26,10 +25,6 @@
logger = logging.getLogger(__name__)


class InvalidState(Exception):
...


class CSOracle(BaseModule, ConsensusModule):
"""
CSM performance module collects performance of CSM node operators and creates a Merkle tree of the resulting
Expand Down Expand Up @@ -72,18 +67,22 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute
@lru_cache(maxsize=1)
@duration_meter()
def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
# pylint: disable=too-many-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches,too-many-statements
assert self.state
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
converter = self.converter(blockstamp)
l_epoch = EpochNumber(converter.get_epoch_by_slot(l_ref_slot) + 1)
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

try:
self.validate_state(blockstamp, full=True)
except InvalidState as ex:
raise ValueError("Unable to build report") from ex
self.state.validate_for_report(l_epoch, r_epoch)
except InvalidState as e:
raise ValueError("State is not valid for the report") from e

self.state.status()

threshold = self.state.avg_perf * self.w3.csm.oracle.perf_threshold(blockstamp.block_hash)
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)

# NOTE: r_block is guaranteed to be <= ref_slot, and the check
# in the inner frames assures the l_block <= r_block.
stuck_operators = self.w3.csm.get_csm_stuck_node_operators(
Expand Down Expand Up @@ -204,43 +203,13 @@ def module(self) -> StakingModule:
def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> ValidatorsByNodeOperator:
return self.w3.lido_validators.get_module_validators_by_node_operators(self.module.id, blockstamp)

def validate_state(self, blockstamp: BlockStamp, full: bool = False) -> None:
assert self.state
converter = self.converter(blockstamp)
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
l_epoch = EpochNumber(converter.get_epoch_by_slot(l_ref_slot) + 1)
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

for epoch in self.state.processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
logger.info({"msg": f"Invalid state: processed {epoch=}, but range is [{l_epoch};{r_epoch}]"})
raise InvalidState()

if full:
for epoch in sequence(l_epoch, r_epoch):
if epoch not in self.state.processed_epochs:
logger.info({"msg": f"Invalid state: {epoch} was not processed"})
raise InvalidState()

def collect_data(self, blockstamp: BlockStamp) -> bool:
"""Ongoing report data collection before the report ref slot and it's submission"""
logger.info({"msg": "Collecting data for the report"})

l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
logger.info({"msg": f"Frame for performance data collect: ({l_ref_slot};{r_ref_slot}]"})

self.state = self.state or State.load()

try:
self.validate_state(blockstamp)
except InvalidState:
logger.info({"msg": "Discarding invalidated state cache"})
self.state.clear()
self.state.commit()

self.state.status()

converter = self.converter(blockstamp)
# Finalized slot is the first slot of justifying epoch, so we need to take the previous
finalized_epoch = EpochNumber(converter.get_epoch_by_slot(blockstamp.slot_number) - 1)
Expand All @@ -250,6 +219,10 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
return False
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

self.state = self.state or State.load()
self.state.validate_for_collect(l_epoch, r_epoch)
self.state.status()

factory = CheckpointsFactory(self.w3.cc, converter, self.state)
checkpoints = factory.prepare_checkpoints(l_epoch, r_epoch, finalized_epoch)

Expand All @@ -265,7 +238,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
if checkpoints:
logger.info({"msg": f"All epochs were processed in {time.time() - start:.2f} seconds"})

return all(epoch in self.state.processed_epochs for epoch in sequence(l_epoch, r_epoch))
return self.state.is_fulfilled

@lru_cache(maxsize=1)
def current_frame_range(self, blockstamp: BlockStamp) -> tuple[SlotNumber, SlotNumber]:
Expand Down
70 changes: 66 additions & 4 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
from pathlib import Path

from src.typings import EpochNumber, ValidatorIndex
from src.utils.range import sequence

logger = logging.getLogger(__name__)


class InvalidState(Exception):
...


@dataclass
class AttestationsAggregate:
assigned: int
Expand All @@ -26,9 +31,9 @@ def perf(self) -> float:
@dataclass
class State(UserDict[ValidatorIndex, AttestationsAggregate]):
"""Tracks processing state of CSM performance oracle frame"""

data: dict[ValidatorIndex, AttestationsAggregate] = field(default_factory=dict)
processed_epochs: set[EpochNumber] = field(default_factory=set)
_epochs_to_process: set[EpochNumber] = field(default_factory=set)
_processed_epochs: set[EpochNumber] = field(default_factory=set)

EXTENSION = ".pkl"

Expand Down Expand Up @@ -60,28 +65,85 @@ def commit(self) -> None:

def clear(self) -> None:
self.data = {}
self.processed_epochs.clear()
self._epochs_to_process.clear()
self._processed_epochs.clear()
assert self.is_empty

def inc(self, key: ValidatorIndex, included: bool) -> None:
perf = self.get(key, AttestationsAggregate(0, 0))
perf.assigned += 1
perf.included += 1 if included else 0
self[key] = perf

def add_processed_epoch(self, epoch: EpochNumber) -> None:
self._processed_epochs.add(epoch)

def status(self) -> None:
assigned, included = reduce(
lambda acc, aggr: (acc[0] + aggr.assigned, acc[1] + aggr.included), self.values(), (0, 0)
)

logger.info(
{
"msg": f"Processed {len(self.processed_epochs)} epochs",
"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs",
"assigned": assigned,
"included": included,
"avg_perf": self.avg_perf,
}
)

def validate_for_report(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
if not self.is_fulfilled:
raise InvalidState()

for epoch in self._processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
raise InvalidState()

for epoch in sequence(l_epoch, r_epoch):
if epoch not in self._processed_epochs:
raise InvalidState()

def validate_for_collect(self, l_epoch: EpochNumber, r_epoch: EpochNumber):

invalidated = False

for epoch in self._epochs_to_process:
if l_epoch <= epoch <= r_epoch:
continue
invalidated = True
break

for epoch in self._processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
invalidated = True
break

if invalidated:
logger.warning({"msg": "Discarding invalidated state cache"})
self.clear()
self.commit()

if self.is_empty or r_epoch > max(self._epochs_to_process):
self._epochs_to_process.update(sequence(l_epoch, r_epoch))
self.commit()

@property
def is_empty(self) -> bool:
return not self.data and not self._epochs_to_process and not self._processed_epochs

@property
def unprocessed_epochs(self) -> set[EpochNumber]:
if not self._epochs_to_process:
raise ValueError("Epochs to process are not set")
return self._epochs_to_process - self._processed_epochs

@property
def is_fulfilled(self) -> bool:
return not self.unprocessed_epochs

@property
def avg_perf(self) -> float:
"""Returns average performance of all validators in the cache"""
Expand Down
Loading