Skip to content

Commit

Permalink
refactor: use fork epoch instead of querying a fork version
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed Jan 8, 2025
1 parent 9f060a5 commit 43f853b
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 75 deletions.
4 changes: 2 additions & 2 deletions src/modules/ejector/ejector.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def _get_predicted_withdrawable_epoch(
"""
Returns epoch when all validators in queue and validators_to_eject will be withdrawn.
"""
fork = self.fork(blockstamp)
spec = self.w3.cc.get_config_spec()

if fork < fork.ELECTRA:
if blockstamp.ref_epoch < int(spec.ELECTRA_FORK_EPOCH):
return self._get_predicted_withdrawable_epoch_pre_electra(blockstamp, validators_to_eject)

return self._get_predicted_withdrawable_epoch_post_electra(blockstamp, validators_to_eject)
Expand Down
4 changes: 0 additions & 4 deletions src/modules/submodules/consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def _get_consensus_contract_members(self, blockstamp: BlockStamp):
def consensus_version(self, blockstamp: BlockStamp):
return self.report_contract.get_consensus_version(blockstamp.block_hash)

@lru_cache(maxsize=1)
def fork(self, blockstamp: BlockStamp):
return self.w3.cc.get_fork(blockstamp.slot_number)

@lru_cache(maxsize=1)
def get_chain_config(self, blockstamp: BlockStamp) -> ChainConfig:
consensus_contract = self._get_consensus_contract(blockstamp)
Expand Down
35 changes: 2 additions & 33 deletions src/providers/consensus/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import StrEnum
from http import HTTPStatus
from typing import Literal, cast

Expand All @@ -18,10 +17,9 @@
SlotAttestationCommittee, BlockAttestation,
)
from src.providers.http_provider import HTTPProvider, NotOkResponse
from src.types import BlockRoot, BlockStamp, Fork, SlotNumber, EpochNumber, StateRoot
from src.types import BlockRoot, BlockStamp, SlotNumber, EpochNumber, StateRoot
from src.utils.dataclass import list_of_dataclasses
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.types import is_4bytes_hex

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,40 +54,11 @@ class ConsensusClient(HTTPProvider):
API_GET_FORK = '/eth/v1/beacon/states/{}/fork'

def get_config_spec(self) -> BeaconSpecResponse:
data = self._get_raw_spec()
return BeaconSpecResponse.from_response(**data)

def get_fork(self, state_id: LiteralState | SlotNumber) -> Fork:
data, _ = self._get(
self.API_GET_FORK,
path_params=(state_id,),
)
if not isinstance(data, dict):
raise ValueError("Expected mapping response from getFork")

current_version = data["current_version"]
return self._forks()(current_version) # type: ignore[operator]

def _forks(self) -> Fork:
spec = self._get_raw_spec()

versions = {}
for k, v in spec.items():
if k.endswith("FORK_VERSION"):
if not is_4bytes_hex(v):
raise ValueError(f"Got invalid fork version {v}")
versions[k.split("_")[0].upper()] = v

if not versions:
raise ValueError("No forks defined in the spec")
return cast(Fork, StrEnum("Fork", versions.items()))

def _get_raw_spec(self) -> dict[str, str]:
"""Spec: https://ethereum.github.io/beacon-APIs/#/Config/getSpec"""
data, _ = self._get(self.API_GET_SPEC)
if not isinstance(data, dict):
raise ValueError("Expected mapping response from getSpec")
return data
return BeaconSpecResponse.from_response(**data)

def get_genesis(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/providers/consensus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from src.types import BlockHash, BlockRoot, Gwei, SlotNumber, StateRoot
from src.utils.dataclass import Nested, FromResponse
from src.constants import FAR_FUTURE_EPOCH


@dataclass
Expand All @@ -12,6 +13,7 @@ class BeaconSpecResponse(FromResponse):
SECONDS_PER_SLOT: str
DEPOSIT_CONTRACT_ADDRESS: str
SLOTS_PER_HISTORICAL_ROOT: str
ELECTRA_FORK_EPOCH: str = str(FAR_FUTURE_EPOCH)


@dataclass
Expand Down
16 changes: 1 addition & 15 deletions src/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from enum import StrEnum, auto
from enum import StrEnum
from typing import NewType

from eth_typing import BlockNumber, ChecksumAddress, HexStr
Expand Down Expand Up @@ -38,20 +38,6 @@ class OracleModule(StrEnum):
type OperatorsValidatorCount = dict[NodeOperatorGlobalIndex, int]


class _Fork(StrEnum):
"""We store fork versions as an enum of hex encoded 4 bytes, so the values are comparable as strings"""

GENESIS = auto()
ALTAIR = auto()
BELLATRIX = auto()
CAPELLA = auto()
DENEB = auto()
ELECTRA = auto()


type Fork = _Fork


@dataclass(frozen=True)
class BlockStamp:
state_root: StateRoot
Expand Down
9 changes: 5 additions & 4 deletions tests/modules/ejector/test_ejector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from src.modules.submodules.oracle_module import ModuleExecuteDelay
from src.modules.submodules.types import ChainConfig, CurrentFrame
from src.providers.consensus.types import BeaconStateView
from src.types import BlockStamp, Gwei, ReferenceBlockStamp, SlotNumber
from src.types import BlockStamp, Gwei, ReferenceBlockStamp
from src.types import _Fork as Fork
from src.utils import validator_state
from src.web3py.extensions.contracts import LidoContracts
from src.web3py.extensions.lido_validators import LidoValidator, NodeOperatorId, StakingModuleId
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModuleId
from src.web3py.types import Web3
from tests.factory.base_oracle import EjectorProcessingStateFactory
from tests.factory.blockstamp import BlockStampFactory, ReferenceBlockStampFactory
Expand Down Expand Up @@ -233,7 +233,8 @@ def test_is_contract_reportable(ejector: Ejector, blockstamp: BlockStamp) -> Non

@pytest.mark.unit
def test_get_predicted_withdrawable_epoch_pre_electra(ejector: Ejector) -> None:
ejector.fork = Mock(return_value=Fork.CAPELLA)
ejector.w3.cc = Mock()
ejector.w3.cc.get_config_spec = Mock(return_value=Mock(ELECTRA_FORK_EPOCH=FAR_FUTURE_EPOCH))
ejector._get_latest_exit_epoch = Mock(return_value=[1, 32])
ejector._get_churn_limit = Mock(return_value=2)
ref_blockstamp = ReferenceBlockStampFactory.build(ref_epoch=3546)
Expand Down Expand Up @@ -318,8 +319,8 @@ def test_exit_exceeds_churn_limit(self, ejector: Ejector, ref_blockstamp: Refere

@pytest.fixture(autouse=True)
def _patch_ejector(self, ejector: Ejector):
ejector.fork = Mock(return_value=Fork.ELECTRA)
ejector.w3.cc = Mock()
ejector.w3.cc.get_config_spec = Mock(return_value=Mock(ELECTRA_FORK_EPOCH=0))


@pytest.mark.unit
Expand Down
21 changes: 4 additions & 17 deletions tests/providers/consensus/test_consensus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.providers.consensus.client import ConsensusClient
from src.providers.consensus.types import Validator
from src.types import SlotNumber, _Fork
from src.types import SlotNumber
from src.utils.blockstamp import build_blockstamp
from src.variables import CONSENSUS_CLIENT_URI
from tests.factory.blockstamp import BlockStampFactory
Expand Down Expand Up @@ -71,20 +71,15 @@ def test_get_validators(consensus_client: ConsensusClient):
assert validator_by_pub_key[0] == validator


@pytest.mark.integration
def test_get_fork(consensus_client: ConsensusClient):
fork = consensus_client.get_fork("head")
assert fork >= fork.GENESIS


@pytest.mark.integration
@pytest.mark.skip(reason="Too long to complete in CI")
def test_get_state_view(consensus_client: ConsensusClient):
state_view = consensus_client.get_state_view("head")
assert state_view.slot > 0

fork = consensus_client.get_fork(state_view.slot)
if fork >= fork.ELECTRA:
spec = consensus_client.get_config_spec()
epoch = state_view.slot // 32
if epoch >= int(spec.ELECTRA_FORK_EPOCH):
assert state_view.earliest_exit_epoch != 0
assert state_view.exit_balance_to_consume >= 0

Expand Down Expand Up @@ -119,11 +114,3 @@ def test_get_returns_nor_dict_nor_list(consensus_client: ConsensusClient):

with raises:
consensus_client._get_chain_id_with_provider(0)


@pytest.mark.unit
def test_unknown_fork(consensus_client: ConsensusClient):
consensus_client._get = Mock(return_value=[{"current_version": "UNKNOWN"}, None])
consensus_client._forks = Mock(return_value=_Fork)
with pytest.raises(ValueError, match="'UNKNOWN' is not a valid _Fork"):
consensus_client.get_fork("head")

0 comments on commit 43f853b

Please sign in to comment.