Skip to content

Commit

Permalink
refactor: build_report
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed Apr 24, 2024
1 parent 7e746b8 commit 94fbec1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
33 changes: 15 additions & 18 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from collections import defaultdict
import logging
import time
from collections import defaultdict
from functools import cached_property
import logging

from hexbytes import HexBytes
from web3.types import BlockIdentifier

from src import variables
from src.metrics.prometheus.business import CONTRACT_ON_PAUSE
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import CheckpointsFactory
from src.modules.csm.tree import Tree
from src.modules.csm.typings import FramePerformance, ReportData
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.typings import BlockStamp, ReferenceBlockStamp, SlotNumber, EpochNumber, ValidatorIndex
from src.typings import BlockStamp, EpochNumber, ReferenceBlockStamp, SlotNumber, ValidatorIndex
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
Expand All @@ -41,7 +39,6 @@ def __init__(self, w3: Web3):
self.report_contract = w3.csm.oracle
super().__init__(w3)
self.frame_performance: FramePerformance | None = None
# TODO: Feed the cache with the data about the attestations observed so far.

def refresh_contracts(self):
self.report_contract = self.w3.csm.oracle
Expand Down Expand Up @@ -80,30 +77,29 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple | None:
operators = self.module_validators_by_node_operators(blockstamp)
# Build the map of the current distribution operators.
distribution: dict[NodeOperatorId, int] = {}
total = 0

for (_, no_id), validators in operators.items():
if no_id in stuck_operators:
continue

share = len(
portion = len(
[
v
for v in validators
if self.frame_performance.perf(ValidatorIndex(int(v.index))) > threshold
]
)

distribution[no_id] = share
total += share
distribution[no_id] = portion

# Calculate share of each CSM node operator.
to_distribute = self.w3.csm.fee_distributor.pending_to_distribute(blockstamp.block_hash)
shares: dict[NodeOperatorId, int] = defaultdict(int)
for no_id, share in distribution.items():
shares[no_id] = to_distribute * share // total
total = sum(p for p in distribution.values())
for no_id, portion in distribution.items():
shares[no_id] = to_distribute * portion // total

distributed = sum((s for s in shares.values()))
distributed = sum(s for s in shares.values())
assert distributed <= to_distribute
if not distributed:
logger.info({"msg": "No shares distributed"})
return
Expand All @@ -117,9 +113,10 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple | None:
root = self.w3.csm.get_csm_tree_root(blockstamp)
logger.info({"msg": "Restored tree from IPFS dump", "root": root})

if tree.root != root: # TODO: Is the `root` 0x-prefixed?
if tree.root.hex() != root: # TODO: Is the `root` 0x-prefixed?
raise ValueError("Unexpected tree root got from IPFS dump")

# Update cumulative amount of shares for all operators.
for v in tree.tree.values:
no_id, amount = v["value"]
shares[no_id] += amount
Expand All @@ -132,9 +129,9 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple | None:
return ReportData(
self.CONSENSUS_VERSION,
blockstamp.ref_slot,
tree_root=HexBytes(tree.tree.root),
tree_root=tree.root,
tree_cid=cid,
distributed=distributed
distributed=distributed,
).as_tuple()

def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
Expand All @@ -156,7 +153,7 @@ def module(self) -> StakingModule:
modules: list[StakingModule] = self.w3.lido_validators.get_staking_modules(self._receive_last_finalized_slot())

for mod in modules:
if mod.staking_module_address == variables.CSM_MODULE_ADDRESS:
if mod.staking_module_address == self.w3.csm.module.address:
return mod

raise ValueError("No CSM module found. Wrong address?")
Expand Down
6 changes: 3 additions & 3 deletions src/modules/csm/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from dataclasses import dataclass
from typing import Self, Sequence, TypeAlias

from hexbytes import HexBytes
from oz_merkle_tree import StandardMerkleTree

from src.web3py.extensions.lido_validators import NodeOperatorId


Leaf: TypeAlias = tuple[NodeOperatorId, int]


Expand All @@ -17,8 +17,8 @@ class Tree:
tree: StandardMerkleTree[Leaf]

@property
def root(self) -> str:
return self.tree.root.hex()
def root(self) -> HexBytes:
return HexBytes(self.tree.root)

@classmethod
def decode(cls, content: bytes) -> Self:
Expand Down

0 comments on commit 94fbec1

Please sign in to comment.