Skip to content

Commit

Permalink
Merge pull request #622 from lidofinance/feat/update-ejector-predicate
Browse files Browse the repository at this point in the history
feat: Update ejector network penetraction logic
  • Loading branch information
F4ever authored Feb 12, 2025
2 parents 24a2219 + 3490475 commit 4ed5fa4
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 44 deletions.
1 change: 1 addition & 0 deletions src/modules/ejector/ejector.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_validators_to_eject(self, blockstamp: ReferenceBlockStamp) -> list[tuple
chain_config = self.get_chain_config(blockstamp)
validators_iterator = iter(ValidatorExitIterator(
w3=self.w3,
consensus_version=self.get_consensus_version(blockstamp),
blockstamp=blockstamp,
seconds_per_slot=chain_config.seconds_per_slot
))
Expand Down
56 changes: 46 additions & 10 deletions src/services/exit_order_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

from more_itertools import ilen

from src.constants import TOTAL_BASIS_POINTS
from src.constants import TOTAL_BASIS_POINTS, LIDO_DEPOSIT_AMOUNT
from src.metrics.prometheus.duration_meter import duration_meter
from src.providers.consensus.types import Validator
from src.services.validator_state import LidoValidatorStateService
from src.types import ReferenceBlockStamp, NodeOperatorGlobalIndex, StakingModuleId
from src.types import ReferenceBlockStamp, NodeOperatorGlobalIndex, StakingModuleId, Gwei
from src.utils.validator_state import is_on_exit, get_validator_age
from src.web3py.extensions.lido_validators import LidoValidator, StakingModule, NodeOperator, NodeOperatorLimitMode
from src.web3py.types import Web3


logger = logging.getLogger(__name__)


Expand All @@ -28,6 +28,7 @@ class NodeOperatorStats:
module_stats: StakingModuleStats

predictable_validators: int = 0
predictable_effective_balance: Gwei = Gwei(0)
delayed_validators: int = 0
total_age: int = 0
force_exit_to: int | None = None
Expand Down Expand Up @@ -60,10 +61,19 @@ class ValidatorExitIterator:

max_validators_to_exit: int = 0
no_penetration_threshold: float = 0
eth_validators_count: int = 0

def __init__(self, w3: Web3, blockstamp: ReferenceBlockStamp, seconds_per_slot: int):
eth_validators_count: int = 0
eth_validators_effective_balance: Gwei = Gwei(0)

def __init__(
self,
w3: Web3,
consensus_version: int,
blockstamp: ReferenceBlockStamp,
seconds_per_slot: int,
):
self.w3 = w3
self.consensus_version = consensus_version
self.blockstamp = blockstamp
self.seconds_per_slot = seconds_per_slot

Expand Down Expand Up @@ -138,6 +148,9 @@ def _calculate_lido_stats(self):
self.total_lido_validators += no_predictable_validators
self.module_stats[gid[0]].predictable_validators += no_predictable_validators
self.node_operators_stats[gid].predictable_validators = no_predictable_validators
self.node_operators_stats[gid].predictable_effective_balance = (
self._calculate_effective_balance_non_exiting_validators(validators) + transient_validators_count * LIDO_DEPOSIT_AMOUNT
)

self.node_operators_stats[gid].delayed_validators = delayed_validators[gid]
self.node_operators_stats[gid].total_age = self.calculate_validators_age(validators)
Expand All @@ -159,6 +172,18 @@ def _load_blockchain_state(self):

self.eth_validators_count = ilen(v for v in self.w3.cc.get_validators(self.blockstamp) if not is_on_exit(v))

self.eth_validators_effective_balance = self._calculate_effective_balance_non_exiting_validators(self.w3.cc.get_validators(self.blockstamp))

@staticmethod
def _calculate_effective_balance_non_exiting_validators(validators: list[Validator]) -> Gwei:
return sum(
(
v.validator.effective_balance for v in validators
if not is_on_exit(v)
),
Gwei(0),
)

def get_filter_non_exitable_validators(self, gid: NodeOperatorGlobalIndex):
"""Validators that are presented but not yet activated on CL can be requested to exit in advance."""
indexes = self.lvs.get_operators_with_last_exited_validator_indexes(self.blockstamp)
Expand Down Expand Up @@ -199,29 +224,32 @@ def calculate_validators_age(self, validators: list[LidoValidator]) -> int:
return result

def _eject_validator(self, gid: NodeOperatorGlobalIndex) -> LidoValidator:
validator = self.exitable_validators[gid].pop(0)
lido_validator = self.exitable_validators[gid].pop(0)

# Total validators
self.eth_validators_count -= 1
self.eth_validators_effective_balance -= lido_validator.validator.effective_balance # type: ignore
# Change lido total
self.total_lido_validators -= 1
# Change module total
self.module_stats[gid[0]].predictable_validators -= 1
# Change node operator stats
self.node_operators_stats[gid].predictable_validators -= 1
self.node_operators_stats[gid].total_age -= get_validator_age(validator, self.blockstamp.ref_epoch)
self.node_operators_stats[gid].predictable_effective_balance -= lido_validator.validator.effective_balance # type: ignore
self.node_operators_stats[gid].total_age -= get_validator_age(lido_validator, self.blockstamp.ref_epoch)

logger.debug({
'msg': 'Iterator state change. Eject validator.',
'eth_validators_count': self.eth_validators_count,
'eth_validators_effective_balance': self.eth_validators_effective_balance,
'total_lido_validators': self.total_lido_validators,
'no_gid': gid[0],
'module_stats': self.module_stats[gid[0]].predictable_validators,
'no_stats_exitable_validators': self.node_operators_stats[gid].predictable_validators,
'no_stats_total_age': self.node_operators_stats[gid].total_age,
})

return validator
return lido_validator

def _no_predicate(self, node_operator: NodeOperatorStats) -> tuple:
return (
Expand All @@ -232,7 +260,9 @@ def _no_predicate(self, node_operator: NodeOperatorStats) -> tuple:
- self._stake_weight_coefficient_predicate(
node_operator,
self.eth_validators_count,
self.eth_validators_effective_balance,
self.no_penetration_threshold,
self.consensus_version > 2 and self.w3.cc.is_electra_activated(self.blockstamp.ref_epoch),
),
- node_operator.predictable_validators,
self._lowest_validator_index_predicate(node_operator),
Expand Down Expand Up @@ -272,13 +302,19 @@ def _max_share_rate_coefficient_predicate(self, node_operator: NodeOperatorStats
def _stake_weight_coefficient_predicate(
node_operator: NodeOperatorStats,
total_validators: int,
total_effective_balance: Gwei,
no_penetration: float,
is_post_pectra: bool,
) -> int:
"""
The higher coefficient the higher priority to eject validator
"""
if total_validators * no_penetration < node_operator.predictable_validators:
return node_operator.total_age
if is_post_pectra:
if total_effective_balance * no_penetration < node_operator.predictable_effective_balance:
return node_operator.total_age
else:
if total_validators * no_penetration < node_operator.predictable_validators:
return node_operator.total_age

return 0

Expand Down
7 changes: 6 additions & 1 deletion tests/factory/no_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
FAR_FUTURE_EPOCH,
MAX_EFFECTIVE_BALANCE,
MIN_ACTIVATION_BALANCE,
GWEI_TO_WEI,
)
from src.providers.consensus.types import PendingDeposit, Validator, ValidatorState
from src.providers.keys.types import LidoKey
from src.types import Gwei
from src.web3py.extensions.lido_validators import LidoValidator, NodeOperator, StakingModule
from tests.factory.web3_factory import Web3Factory

Expand Down Expand Up @@ -78,7 +80,10 @@ class LidoValidatorFactory(Web3Factory):
@classmethod
def build_with_activation_epoch_bound(cls, max_value: int, **kwargs: Any):
return cls.build(
validator=ValidatorStateFactory.build(activation_epoch=faker.pyint(max_value=max_value - 1)), **kwargs
validator=ValidatorStateFactory.build(
activation_epoch=faker.pyint(max_value=max_value - 1), effective_balance=Gwei(32 * 10**9)
),
**kwargs,
)

@classmethod
Expand Down
95 changes: 62 additions & 33 deletions tests/modules/ejector/test_validator_exit_order_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from src.services.exit_order_iterator import ValidatorExitIterator, NodeOperatorStats, StakingModuleStats
from src.types import Gwei
from src.web3py.extensions.lido_validators import NodeOperatorLimitMode
from tests.factory.blockstamp import ReferenceBlockStampFactory
from tests.factory.no_registry import (
Expand All @@ -29,6 +30,7 @@ class NodeOperatorStatsFactory(Web3Factory):
def iterator(web3, contracts, lido_validators):
return ValidatorExitIterator(
web3,
2,
ReferenceBlockStampFactory.build(),
12,
)
Expand Down Expand Up @@ -161,7 +163,7 @@ def test_eject_validator(iterator):
assert iterator.module_stats[1].predictable_validators == 5
assert iterator.module_stats[2].predictable_validators == 2
assert iterator.node_operators_stats[(1, 1)].predictable_validators == 3
assert iterator.node_operators_stats[(1, 1)].delayed_validators == 1
assert iterator.node_operators_stats[(1, 1)].predictable_effective_balance == 3 * 32 * 10**9
assert iterator.node_operators_stats[(1, 1)].delayed_validators == 1
assert iterator.node_operators_stats[(1, 2)].soft_exit_to is not None
assert iterator.node_operators_stats[(2, 1)].force_exit_to is not None
Expand All @@ -175,11 +177,12 @@ def test_eject_validator(iterator):
assert iterator.total_lido_validators == 6
assert iterator.module_stats[1].predictable_validators == 4
assert iterator.node_operators_stats[(1, 1)].predictable_validators == 2
assert iterator.node_operators_stats[(1, 1)].predictable_effective_balance == 2 * 32 * 10**9
assert iterator.node_operators_stats[(1, 1)].total_age < prev_total_age

iterator.max_validators_to_exit = 3
iterator.no_penetration_threshold = 0.1
iterator.eth_validators_count = 1000
iterator.eth_validators_effective_balance = Gwei(1000 * 32 * 10**9)
iterator._load_blockchain_state = Mock()

validators_to_eject = list(iterator)
Expand All @@ -198,6 +201,7 @@ def test_eject_validator(iterator):

@pytest.mark.unit
def test_no_predicate(iterator):
iterator.eth_validators_effective_balance = Gwei(1000 * 32 * 10**9)
iterator.total_lido_validators = 1000
iterator.no_penetration_threshold = 0.1
iterator.eth_validators_count = 10000
Expand All @@ -207,38 +211,57 @@ def test_no_predicate(iterator):
(2, 2): [LidoValidatorFactory.build(index=20)],
}

result = iterator._no_predicate(
NodeOperatorStatsFactory.build(
predictable_validators=100,
delayed_validators=1,
total_age=1000,
force_exit_to=50,
soft_exit_to=25,
node_operator=NodeOperatorFactory.build(id=1, staking_module=StakingModuleFactory.build(id=1)),
module_stats=ModuleStatsFactory.build(
predictable_validators=200,
staking_module=StakingModuleFactory.build(priority_exit_share_threshold=0.15 * 1000),
),
)
node_operator_1 = NodeOperatorStatsFactory.build(
predictable_validators=100,
predictable_effective_balance=Gwei(2000 * 32 * 10**9),
delayed_validators=1,
total_age=1000,
force_exit_to=50,
soft_exit_to=25,
node_operator=NodeOperatorFactory.build(id=1, staking_module=StakingModuleFactory.build(id=1)),
module_stats=ModuleStatsFactory.build(
predictable_validators=200,
staking_module=StakingModuleFactory.build(priority_exit_share_threshold=0.15 * 1000),
),
)

result = iterator._no_predicate(node_operator_1)

assert result == (1, -50, -75, -185, 0, -100, 10)

result = iterator._no_predicate(
NodeOperatorStatsFactory.build(
predictable_validators=2000,
delayed_validators=0,
total_age=1000,
force_exit_to=50,
soft_exit_to=25,
node_operator=NodeOperatorFactory.build(id=2, staking_module=StakingModuleFactory.build(id=2)),
module_stats=ModuleStatsFactory.build(
predictable_validators=200,
staking_module=StakingModuleFactory.build(priority_exit_share_threshold=0.15 * 1000),
),
)
node_operator_2 = NodeOperatorStatsFactory.build(
predictable_validators=2000,
predictable_effective_balance=Gwei(100 * 32 * 10**9),
delayed_validators=0,
total_age=1000,
force_exit_to=50,
soft_exit_to=25,
node_operator=NodeOperatorFactory.build(id=2, staking_module=StakingModuleFactory.build(id=2)),
module_stats=ModuleStatsFactory.build(
predictable_validators=200,
staking_module=StakingModuleFactory.build(priority_exit_share_threshold=0.15 * 1000),
),
)

result = iterator._no_predicate(node_operator_2)
assert result == (0, -1950, -1975, -185, -1000, -2000, 20)

iterator.consensus_version = 3
iterator.w3.cc.is_electra_activated = Mock(return_value=False)

# Check works with old alg before pectra
result = iterator._no_predicate(node_operator_2)
assert result == (0, -1950, -1975, -185, -1000, -2000, 20)

iterator.w3.cc.is_electra_activated = Mock(return_value=True)

# Check after pectra
result = iterator._no_predicate(node_operator_2)
assert result == (0, -1950, -1975, -185, 0, -2000, 20)

result = iterator._no_predicate(node_operator_1)
assert result == (1, -50, -75, -185, -1000, -100, 10)


@pytest.mark.unit
def test_no_force_and_soft_predicate(iterator):
Expand Down Expand Up @@ -315,25 +338,31 @@ def test_stake_weight_coefficient_predicate(iterator):
nos = [
NodeOperatorStatsFactory.build(
predictable_validators=900,
predictable_effective_balance=900 * 32 * 10**9,
total_age=3000,
),
NodeOperatorStatsFactory.build(
predictable_validators=1010,
predictable_effective_balance=1010 * 32 * 10**9,
total_age=2000,
),
NodeOperatorStatsFactory.build(
predictable_validators=2010,
predictable_effective_balance=2010 * 32 * 10**9,
total_age=1000,
),
]

sorted_nos = sorted(
nos,
key=lambda x: -iterator._stake_weight_coefficient_predicate(
x,
10000,
0.1,
),
key=lambda x: -iterator._stake_weight_coefficient_predicate(x, 10000, 10000 * 32 * 10**9, 0.1, False),
)

assert [nos[1], nos[2], nos[0]] == sorted_nos

sorted_nos = sorted(
nos,
key=lambda x: -iterator._stake_weight_coefficient_predicate(x, 10000, 10000 * 32 * 10**9, 0.1, True),
)

assert [nos[1], nos[2], nos[0]] == sorted_nos
Expand Down

0 comments on commit 4ed5fa4

Please sign in to comment.