Skip to content

Commit

Permalink
feat: ao-multi-third-phase-partial
Browse files Browse the repository at this point in the history
  • Loading branch information
F4ever committed Jun 6, 2024
1 parent 1222ca1 commit 2778776
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 78 deletions.
10 changes: 5 additions & 5 deletions src/modules/accounting/accounting.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def process_extra_data(self, blockstamp: ReferenceBlockStamp):
self._submit_extra_data(blockstamp)

def _submit_extra_data(self, blockstamp: ReferenceBlockStamp) -> None:
extra_data = self.get_extra_data(blockstamp, self.get_chain_config(blockstamp))
extra_data = self.get_extra_data(blockstamp)

if extra_data.format == FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY:
tx = self.report_contract.submit_report_extra_data_empty()
return self.w3.transaction.check_and_send_transaction(tx, variables.ACCOUNT)

for tx_data in extra_data.extra_data_list:
tx = self.report_contract.submit_report_extra_data_list(tx_data)
self.w3.transaction.check_and_send_transaction(tx, variables.ACCOUNT)
else:
for tx_data in extra_data.extra_data_list:
tx = self.report_contract.submit_report_extra_data_list(tx_data)
self.w3.transaction.check_and_send_transaction(tx, variables.ACCOUNT)

@lru_cache(maxsize=1)
@duration_meter()
Expand Down
18 changes: 14 additions & 4 deletions src/modules/accounting/third_phase/extra_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@

from hexbytes import HexBytes

from src.modules.accounting.third_phase.types import ItemType, ExtraData, FormatList, ItemPayload, ExtraDataLengths
from src.modules.accounting.third_phase.types import ItemType, ExtraData, FormatList, ExtraDataLengths
from src.modules.submodules.types import ZERO_HASH
from src.types import NodeOperatorGlobalIndex
from src.web3py.types import Web3


@dataclass
class ItemPayload:
module_id: bytes
node_ops_count: bytes
node_operator_ids: bytes
vals_counts: bytes


@dataclass
class ExtraDataItem:
item_index: bytes
Expand Down Expand Up @@ -46,16 +54,18 @@ def collect(
extra_data_bytes = cls.to_bytes(extra_data)

if extra_data:
extra_data_list = [extra_data_bytes]
data_format = FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
data_hash = Web3.keccak(extra_data_bytes)
else:
extra_data_list = []
data_format = FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY
data_hash = ZERO_HASH
data_hash = HexBytes(ZERO_HASH)

return ExtraData(
extra_data_list=[extra_data_bytes],
extra_data_list=extra_data_list,
data_hash=data_hash,
format=data_format.value,
format=data_format,
items_count=len(extra_data),
)

Expand Down
74 changes: 50 additions & 24 deletions src/modules/accounting/third_phase/extra_data_v2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
import itertools
from dataclasses import dataclass
from itertools import groupby

from src.modules.accounting.third_phase.types import ExtraData, ItemType, ItemPayload, ExtraDataLengths, FormatList
from src.modules.accounting.third_phase.types import ExtraData, ItemType, ExtraDataLengths, FormatList
from src.modules.submodules.types import ZERO_HASH
from src.types import NodeOperatorGlobalIndex
from src.web3py.types import Web3


def batch(iterable: list, n: int):
"""
ToDo: Replace with batched from itertools when python 3.12
"""
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]


@dataclass
class ItemPayload:
module_id: int
node_operator_ids: list[int]
vals_counts: list[int]


class ExtraDataServiceV2:
"""
Service that encodes extra data into bytes in correct order.
Expand All @@ -31,7 +48,8 @@ def collect(
) -> ExtraData:
stuck_payloads = cls.build_validators_payloads(stuck_validators, max_no_in_payload_count)
exited_payloads = cls.build_validators_payloads(exited_validators, max_no_in_payload_count)
items_count, first_hash, txs = cls.build_extra_data_transactions_data(stuck_payloads, exited_payloads, max_items_count)
items_count, txs = cls.build_extra_transactions_data(stuck_payloads, exited_payloads, max_items_count)
first_hash, hashed_txs = cls.add_hashes_to_transactions(txs)

if items_count:
extra_data_format = FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
Expand All @@ -40,7 +58,7 @@ def collect(

return ExtraData(
items_count=items_count,
extra_data_list=txs,
extra_data_list=hashed_txs,
data_hash=first_hash,
format=extra_data_format,
)
Expand All @@ -55,33 +73,32 @@ def build_validators_payloads(

payloads = []

for module_id, operators_by_module in itertools.groupby(operator_validators, key=lambda x: x[0][0]):
operator_ids = []
vals_count = []
for module_id, operators_by_module in groupby(operator_validators, key=lambda x: x[0][0]):
for nos_in_batch in batch(list(operators_by_module), max_no_in_payload_count):
operator_ids = []
vals_count = []

for nos_in_batch in itertools.batched(operators_by_module, max_no_in_payload_count):
for ((_, no_id), validators_count) in nos_in_batch:
operator_ids.append(no_id.to_bytes(ExtraDataLengths.NODE_OPERATOR_IDS))
vals_count.append(validators_count.to_bytes(ExtraDataLengths.STUCK_OR_EXITED_VALS_COUNT))
operator_ids.append(no_id)
vals_count.append(validators_count)

payloads.append(
ItemPayload(
module_id=module_id.to_bytes(ExtraDataLengths.MODULE_ID),
node_ops_count=len(operator_ids).to_bytes(ExtraDataLengths.NODE_OPS_COUNT),
node_operator_ids=b"".join(operator_ids),
vals_counts=b"".join(vals_count),
module_id=module_id,
node_operator_ids=operator_ids,
vals_counts=vals_count,
)
)

return payloads

@classmethod
def build_extra_data_transactions_data(
def build_extra_transactions_data(
cls,
stuck_payloads: list[ItemPayload],
exited_payloads: list[ItemPayload],
max_items_count: int,
) -> tuple[int, bytes, list[bytes]]:
) -> tuple[int, list[bytes]]:
all_payloads = [
*[(ItemType.EXTRA_DATA_TYPE_STUCK_VALIDATORS, payload) for payload in stuck_payloads],
*[(ItemType.EXTRA_DATA_TYPE_EXITED_VALIDATORS, payload) for payload in exited_payloads],
Expand All @@ -90,20 +107,27 @@ def build_extra_data_transactions_data(
index = 0
result = []

for batch in itertools.batched(all_payloads, max_items_count):
for payload_batch in batch(all_payloads, max_items_count):
tx = b''
for item_type, payload in batch:
for item_type, payload in payload_batch:
tx += index.to_bytes(ExtraDataLengths.ITEM_INDEX)
tx += item_type.value.to_bytes(ExtraDataLengths.ITEM_TYPE)
tx += payload.module_id.to_bytes(ExtraDataLengths.MODULE_ID)
tx += payload.node_ops_count.to_bytes(ExtraDataLengths.NODE_OPS_COUNT)
tx += payload.node_operator_ids
tx += payload.vals_counts
tx += len(payload.node_operator_ids).to_bytes(ExtraDataLengths.NODE_OPS_COUNT)
tx += b''.join(
no_id.to_bytes(ExtraDataLengths.NODE_OPERATOR_IDS)
for no_id in payload.node_operator_ids
)
tx += b''.join(
count.to_bytes(ExtraDataLengths.STUCK_OR_EXITED_VALS_COUNT)
for count in payload.vals_counts
)

index += 1

index += 1
result.append(tx)

return index, cls.add_hashes_to_transactions(result)[0], cls.add_hashes_to_transactions(result)[1]
return index, result

@staticmethod
def add_hashes_to_transactions(txs_data: list[bytes]) -> tuple[bytes, list[bytes]]:
Expand All @@ -115,6 +139,8 @@ def add_hashes_to_transactions(txs_data: list[bytes]) -> tuple[bytes, list[bytes
for tx in txs_data:
full_tx_data = next_hash + tx
txs_with_hashes.append(full_tx_data)
next_hash = Web3.keccak(next_hash)
next_hash = Web3.keccak(full_tx_data)

txs_with_hashes.reverse()

return next_hash, txs_with_hashes
10 changes: 0 additions & 10 deletions src/modules/accounting/third_phase/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from dataclasses import dataclass
from enum import Enum

from hexbytes import HexBytes


class ItemType(Enum):
EXTRA_DATA_TYPE_STUCK_VALIDATORS = 1
Expand All @@ -22,14 +20,6 @@ class ExtraData:
items_count: int


@dataclass
class ItemPayload:
module_id: bytes
node_ops_count: bytes
node_operator_ids: bytes
vals_counts: bytes


class ExtraDataLengths:
NEXT_HASH = 32
ITEM_INDEX = 3
Expand Down
4 changes: 2 additions & 2 deletions src/modules/accounting/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ class OracleReportLimits:
max_node_operators_per_extra_data_item_count: int
request_timestamp_margin: int
max_positive_token_rebase: int
appeared_validators_per_day_limit: Optional[int]
appeared_validators_per_day_limit: Optional[int] = None

@classmethod
def from_response(cls, **kwargs) -> Self:
# churn_validators_per_day_limit was renamed in new version
# Unpack structure by order
return cls(*kwargs.values())
return cls(*kwargs.values()) # pylint: disable=no-value-for-parameter


@dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_oracle_report_limits(self, block_identifier: BlockIdentifier = 'latest')
Returns the limits list for the Lido's oracle report sanity checks
"""
response = self.functions.getOracleReportLimits().call(block_identifier=block_identifier)
response = named_tuple_to_dataclass(response, OracleReportLimits)
response = named_tuple_to_dataclass(response, OracleReportLimits.from_response)

logger.info({
'msg': 'Call `getOracleReportLimits()`.',
Expand Down
1 change: 0 additions & 1 deletion src/services/validator_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ACCOUNTING_EXITED_VALIDATORS,
ACCOUNTING_DELAYED_VALIDATORS,
)
from src.modules.accounting.third_phase import ExtraDataService, ExtraData
from src.modules.submodules.types import ChainConfig
from src.types import BlockStamp, ReferenceBlockStamp, EpochNumber
from src.utils.events import get_events_in_past
Expand Down
27 changes: 6 additions & 21 deletions tests/modules/accounting/test_accounting_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import asdict
from typing import Any, Iterable, cast
from unittest.mock import Mock, patch

Expand All @@ -9,6 +8,7 @@
from src.modules.accounting import accounting as accounting_module
from src.modules.accounting.accounting import Accounting
from src.modules.accounting.accounting import logger as accounting_logger
from src.modules.accounting.third_phase.types import FormatList
from src.modules.accounting.types import LidoReportRebase
from src.modules.submodules.oracle_module import ModuleExecuteDelay
from src.modules.submodules.types import ChainConfig, FrameConfig
Expand Down Expand Up @@ -245,38 +245,24 @@ def test_submit_extra_data_non_empty(
):
extra_data = bytes(32)

accounting.get_chain_config = Mock(return_value=chain_config)
accounting.lido_validator_state_service.get_extra_data = Mock(return_value=Mock(extra_data=extra_data))
accounting.w3.lido_contracts.accounting_oracle.get_consensus_version = Mock(return_value=1)
accounting.get_extra_data = Mock(return_value=Mock(extra_data_list=[extra_data]))
accounting.report_contract.submit_report_extra_data_list = Mock() # type: ignore
accounting.w3.transaction = Mock()

accounting._submit_extra_data(ref_bs)

accounting.report_contract.submit_report_extra_data_list.assert_called_once_with(extra_data)
accounting.lido_validator_state_service.get_extra_data.assert_called_once_with(ref_bs, chain_config)
accounting.get_chain_config.assert_called_once_with(ref_bs)
accounting.get_extra_data.assert_called_once_with(ref_bs)

@pytest.mark.unit
@pytest.mark.parametrize(
("third_phase",),
[
(None,),
(bytes(0),),
([],),
(b'',),
('',),
(False,),
],
)
def test_submit_extra_data_empty(
self,
accounting: Accounting,
ref_bs: ReferenceBlockStamp,
chain_config: ChainConfig,
extra_data: Any,
):
accounting.get_chain_config = Mock(return_value=chain_config)
accounting.lido_validator_state_service.get_extra_data = Mock(return_value=Mock(extra_data=extra_data))
accounting.get_extra_data = Mock(return_value=Mock(format=FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY))
accounting.report_contract.submit_report_extra_data_list = Mock() # type: ignore
accounting.report_contract.submit_report_extra_data_empty = Mock() # type: ignore
accounting.w3.transaction = Mock()
Expand All @@ -285,8 +271,7 @@ def test_submit_extra_data_empty(

accounting.report_contract.submit_report_extra_data_empty.assert_called_once()
accounting.report_contract.submit_report_extra_data_list.assert_not_called()
accounting.lido_validator_state_service.get_extra_data.assert_called_once_with(ref_bs, chain_config)
accounting.get_chain_config.assert_called_once_with(ref_bs)
accounting.get_extra_data.assert_called_once_with(ref_bs)


@pytest.mark.unit
Expand Down
25 changes: 15 additions & 10 deletions tests/modules/accounting/test_extra_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from hexbytes import HexBytes

from src.modules.accounting.third_phase import ExtraDataService, ExtraData, FormatList
from src.web3py.extensions.lido_validators import NodeOperatorGlobalIndex, LidoValidator
from src.modules.accounting.third_phase.extra_data import ExtraDataService
from src.modules.accounting.third_phase.types import FormatList, ExtraData
from src.web3py.extensions.lido_validators import NodeOperatorGlobalIndex


pytestmark = pytest.mark.unit
Expand All @@ -21,8 +22,8 @@ class TestBuildValidators:
def test_collect_zero(self, extra_data_service, contracts):
extra_data = extra_data_service.collect({}, {}, 10, 10)
assert isinstance(extra_data, ExtraData)
assert extra_data.format == FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY.value
assert extra_data.extra_data == b''
assert extra_data.format == FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY
assert not extra_data.extra_data_list
assert extra_data.data_hash == HexBytes('0x0000000000000000000000000000000000000000000000000000000000000000')

def test_collect_non_zero(self, extra_data_service):
Expand All @@ -34,9 +35,10 @@ def test_collect_non_zero(self, extra_data_service):
}
extra_data = extra_data_service.collect(vals_stuck_non_zero, vals_exited_non_zero, 10, 10)
assert isinstance(extra_data, ExtraData)
assert extra_data.format == FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY.value
assert extra_data.format == FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
assert len(extra_data.extra_data_list) == 1
assert (
extra_data.extra_data
extra_data.extra_data_list[0]
== b'\x00\x00\x00\x00\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x00\x02\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02'
)
assert extra_data.data_hash == HexBytes(
Expand Down Expand Up @@ -67,9 +69,10 @@ def test_collect_stuck_vals_in_cap(self, extra_data_service):
}
extra_data = extra_data_service.collect(vals_stuck_non_zero, vals_exited_non_zero, 1, 2)
assert isinstance(extra_data, ExtraData)
assert extra_data.format == FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY.value
assert extra_data.format == FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
assert len(extra_data.extra_data_list) == 1
assert (
extra_data.extra_data
extra_data.extra_data_list[0]
== b'\x00\x00\x00\x00\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01'
)
assert extra_data.data_hash == HexBytes(
Expand All @@ -82,10 +85,12 @@ def test_collect_stuck_vals_in_cap(self, extra_data_service):
item_length = 3 + 2 + 3 + 8
no_payload_length = 8 + 16
# Expecting one module
assert len(extra_data.extra_data) == item_length + no_payload_length * 2
assert len(extra_data.extra_data_list[0]) == item_length + no_payload_length * 2
# Expecting two modules
extra_data = extra_data_service.collect(vals_stuck_non_zero, vals_exited_non_zero, 2, 2)
assert len(extra_data.extra_data) == item_length + no_payload_length * 2 + item_length + no_payload_length
assert (
len(extra_data.extra_data_list[0]) == item_length + no_payload_length * 2 + item_length + no_payload_length
)

def test_order(self, extra_data_service, monkeypatch):
vals_order = {
Expand Down
Loading

0 comments on commit 2778776

Please sign in to comment.