Skip to content

Commit

Permalink
feat: build CSM oracle report
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed Apr 24, 2024
1 parent e71ef05 commit 7809314
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 20 deletions.
55 changes: 36 additions & 19 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

from web3.types import BlockIdentifier

from src import variables
from src.metrics.prometheus.business import CONTRACT_ON_PAUSE
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import CheckpointsFactory
from src.modules.csm.typings import FramePerformance, ReportData
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.typings import BlockStamp, ReferenceBlockStamp, SlotNumber, EpochNumber
from src.typings import BlockStamp, ReferenceBlockStamp, SlotNumber, EpochNumber, ValidatorIndex
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
Expand Down Expand Up @@ -66,22 +67,41 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
last_ref_slot = self.w3.csm.get_csm_last_processing_ref_slot(blockstamp)
ref_slot = self.get_current_frame(blockstamp).ref_slot

# Get module's node operators.
_ = self.module_validators_by_node_operators(blockstamp)
# Read performance threshold value from somewhere (hardcoded?).
_ = self.frame_performance.avg_perf * 0.95
# Build the map of the current distribution operators.
# _ = groupby(self.frame_performance.aggr_per_val, operators)
# Exclude validators of operators with stuck keys.
_ = self.w3.csm.get_csm_stuck_node_operators(
threshold = self.frame_performance.avg_perf * self.w3.csm.oracle.perf_threshold(blockstamp.block_hash)
stuck_operators = self.w3.csm.get_csm_stuck_node_operators(
self._slot_to_block_identifier(last_ref_slot),
self._slot_to_block_identifier(ref_slot),
)
# Exclude underperforming validators.

operators = self.module_validators_by_node_operators(blockstamp)
# Build the map of the current distribution operators.
distribution: dict[NodeOperatorId, int] = {}
total = 0

for (_, no_id), validators in operators.items():
if no_id in stuck_operators:
continue

share = len(
[
v
for v in validators
if self.frame_performance.perf(ValidatorIndex(int(v.index))) > threshold
]
)

distribution[no_id] = share
total += share

# Calculate share of each CSM node operator.
_ = self._to_distribute(blockstamp)
shares: tuple[tuple[NodeOperatorId, int]] = tuple() # type: ignore
to_distribute = self.w3.csm.fee_distributor.pending_to_distribute(blockstamp.block_hash)
shares: list[tuple[NodeOperatorId, int]] = []
for no_id, share in distribution.items():
shares.append((no_id, to_distribute * share // total))

distributed = sum((s for (_, s) in shares))
if not distributed:
... # TODO: The code expects the report built, but it doesn't make sense.

# Load the previous tree if any.
cid = self.w3.csm.get_csm_tree_cid(blockstamp)
Expand All @@ -98,7 +118,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
blockstamp.ref_slot,
tree_root=b"", # type: ignore
tree_cid="",
distributed=sum((s for (_, s) in shares)),
distributed=distributed
).as_tuple()

def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
Expand All @@ -107,7 +127,7 @@ def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
return last_ref_slot == ref_slot

def is_contract_reportable(self, blockstamp: BlockStamp) -> bool:
return not self.is_main_data_submitted(blockstamp)
return not self.is_main_data_submitted(blockstamp) and not self.w3.csm.module.is_paused()

def is_reporting_allowed(self, blockstamp: ReferenceBlockStamp) -> bool:
on_pause = self._is_paused(blockstamp)
Expand All @@ -120,10 +140,10 @@ def module(self) -> StakingModule:
modules: list[StakingModule] = self.w3.lido_validators.get_staking_modules(self._receive_last_finalized_slot())

for mod in modules:
if mod.name == "": # FIXME
if mod.staking_module_address == variables.CSM_MODULE_ADDRESS:
return mod

raise ValueError("No CSM module found")
raise ValueError("No CSM module found. Wrong address?")

@lru_cache(maxsize=1)
def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> ValidatorsByNodeOperator:
Expand Down Expand Up @@ -182,9 +202,6 @@ def _print_collect_result(self):
logger.info({"msg": f"Assigned: {assigned}"})
logger.info({"msg": f"Included: {inc}"})

def _to_distribute(self, blockstamp: ReferenceBlockStamp) -> int:
return self.w3.csm.fee_distributor.pending_to_distribute(blockstamp.block_hash)

def _slot_to_block_identifier(self, slot: SlotNumber) -> BlockIdentifier:
block = self.w3.cc.get_block_details(slot)

Expand Down
4 changes: 4 additions & 0 deletions src/modules/csm/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def avg_perf(self) -> float:
"""Returns average performance of all validators in the cache."""
return mean((a.perf for a in self.aggr_per_val.values()))

def perf(self, index: ValidatorIndex) -> float:
"""Returns performance of a validator by its index."""
return self.aggr_per_val[index].perf

def dump(self, epoch: EpochNumber, committees: dict, roots: set[BlockRoot]) -> None:
"""Used to persist the current state of the structure."""
# TODO: persist the data. no sense to keep it in memory (except of `processed` ?)
Expand Down
13 changes: 13 additions & 0 deletions src/providers/execution/contracts/CSFeeOracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ def fee_distributor(self, block: BlockIdentifier = "latest") -> Address:
)
return cast(Address, resp)

def perf_threshold(self, block: BlockIdentifier = "latest") -> float:
"""Performance threshold used to determine underperforming validators"""

resp = self.functions.perfThresholdBP().call(block_identifier=block)
logger.debug(
{
"msg": "Call to perfThresholdBP()",
"value": resp,
"block_identifier": repr(block),
}
)
return resp / 10_000 # Convert from basis points

# TODO: Inherit the method from the BaseOracle class.
def get_last_processing_ref_slot(self, block: BlockIdentifier = "latest") -> SlotNumber:
resp = self.functions.getLastProcessingRefSlot().call(block_identifier=block)
Expand Down
11 changes: 11 additions & 0 deletions src/providers/execution/contracts/CSModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ def get_stuck_keys_events(self, block: BlockIdentifier) -> Iterable[EventData]:

def is_deployed(self, block: BlockIdentifier) -> bool:
return self.w3.eth.get_code(self.address, block_identifier=block) != b""

def is_paused(self, block: BlockIdentifier = "latest") -> bool:
resp = self.functions.isPaused().call(block_identifier=block)
logger.debug(
{
"msg": "Call to isPaused()",
"value": resp,
"block_identifier": repr(block),
}
)
return resp
4 changes: 3 additions & 1 deletion src/web3py/extensions/lido_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def get_lido_validators_by_node_operators(self, blockstamp: BlockStamp) -> Valid
@lru_cache(maxsize=1)
def get_module_validators_by_node_operators(self, module_id: StakingModuleId, blockstamp: BlockStamp) -> ValidatorsByNodeOperator:
"""Get module validators by querying the KeysAPI for the module keys"""
raise NotImplementedError()
# TODO: Re-evaluate the implementation, sub-optimal solution.
all = self.get_lido_validators_by_node_operators(blockstamp)
return {k: v for (k, v) in all.items() if k[0] == module_id}

@lru_cache(maxsize=1)
def get_lido_node_operators(self, blockstamp: BlockStamp) -> list[NodeOperator]:
Expand Down

0 comments on commit 7809314

Please sign in to comment.