From 1aa722885ffbda44f09d8417864f62a307c27355 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Thu, 13 Feb 2025 12:43:14 +0100 Subject: [PATCH] refactor: `State` and tests --- src/modules/csm/checkpoint.py | 4 +- src/modules/csm/state.py | 135 ++++---- tests/modules/csm/test_state.py | 592 ++++++++++++++++++++------------ 3 files changed, 445 insertions(+), 286 deletions(-) diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index b111fe197..69d0a79dd 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -205,12 +205,12 @@ def _check_duty( for root in block_roots: attestations = self.cc.get_block_attestations(root) process_attestations(attestations, committees, self.eip7549_supported) - + frame = self.state.find_frame(duty_epoch) with lock: for committee in committees.values(): for validator_duty in committee: self.state.increment_duty( - duty_epoch, + frame, validator_duty.index, included=validator_duty.included, ) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index fd27a8d62..c269b7fcb 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -13,8 +13,6 @@ logger = logging.getLogger(__name__) -type Frame = tuple[EpochNumber, EpochNumber] - class InvalidState(ValueError): """State has data considered as invalid for a report""" @@ -36,6 +34,10 @@ def add_duty(self, included: bool) -> None: self.included += 1 if included else 0 +type Frame = tuple[EpochNumber, EpochNumber] +type StateData = dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]] + + class State: """ Processing state of a CSM performance oracle frame. @@ -46,7 +48,7 @@ class State: The state can be migrated to be used for another frame's report by calling the `migrate` method. """ - data: dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]] + data: StateData _epochs_to_process: tuple[EpochNumber, ...] _processed_epochs: set[EpochNumber] @@ -54,10 +56,8 @@ class State: _consensus_version: int = 1 - def __init__(self, data: dict[Frame, dict[ValidatorIndex, AttestationsAccumulator]] | None = None) -> None: - self.data = { - frame: defaultdict(AttestationsAccumulator, validators) for frame, validators in (data or {}).items() - } + def __init__(self) -> None: + self.data = {} self._epochs_to_process = tuple() self._processed_epochs = set() self._epochs_per_frame = 0 @@ -110,6 +110,16 @@ def unprocessed_epochs(self) -> set[EpochNumber]: def is_fulfilled(self) -> bool: return not self.unprocessed_epochs + @staticmethod + def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: + """Split epochs to process into frames of `epochs_per_frame` length""" + frames = [] + for frame_epochs in batched(epochs_to_process, epochs_per_frame): + if len(frame_epochs) < epochs_per_frame: + raise ValueError("Insufficient epochs to form a frame") + frames.append((frame_epochs[0], frame_epochs[-1])) + return frames + def clear(self) -> None: self.data = {} self._epochs_to_process = tuple() @@ -123,9 +133,10 @@ def find_frame(self, epoch: EpochNumber) -> Frame: return epoch_range raise ValueError(f"Epoch {epoch} is out of frames range: {frames}") - def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None: - epoch_range = self.find_frame(epoch) - self.data[epoch_range][val_index].add_duty(included) + def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None: + if frame not in self.data: + raise ValueError(f"Frame {frame} is not found in the state") + self.data[frame][val_index].add_duty(included) def add_processed_epoch(self, epoch: EpochNumber) -> None: self._processed_epochs.add(epoch) @@ -133,7 +144,9 @@ def add_processed_epoch(self, epoch: EpochNumber) -> None: def log_progress(self) -> None: logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"}) - def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int) -> None: + def init_or_migrate( + self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int + ) -> None: if consensus_version != self._consensus_version: logger.warning( { @@ -143,59 +156,55 @@ def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per ) self.clear() + frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames} + if not self.is_empty: - invalidated = self._migrate_or_invalidate(l_epoch, r_epoch, epochs_per_frame) - if invalidated: - self.clear() + cached_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) + if cached_frames == frames: + logger.info({"msg": "No need to migrate duties data cache"}) + return + + frames_data, migration_status = self._migrate_frames_data(cached_frames, frames) + + for current_frame, migrated in migration_status.items(): + if not migrated: + logger.warning({"msg": f"Invalidating frame duties data cache: {current_frame}"}) + for epoch in sequence(*current_frame): + self._processed_epochs.discard(epoch) - self._fill_frames(l_epoch, r_epoch, epochs_per_frame) + self.data = frames_data self._epochs_per_frame = epochs_per_frame self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version self.commit() - def _fill_frames(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> None: - frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) - for frame in frames: - self.data.setdefault(frame, defaultdict(AttestationsAccumulator)) - - def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> bool: - current_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) - new_frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) - inv_msg = f"Discarding invalid state cache because of frames change. {current_frames=}, {new_frames=}" - - if self._invalidate_on_epoch_range_change(l_epoch, r_epoch): - logger.warning({"msg": inv_msg}) - return True - - frame_expanded = epochs_per_frame > self._epochs_per_frame - frame_shrunk = epochs_per_frame < self._epochs_per_frame - - has_single_frame = len(current_frames) == len(new_frames) == 1 - - if has_single_frame and frame_expanded: - current_frame, *_ = current_frames - new_frame, *_ = new_frames - self.data[new_frame] = self.data.pop(current_frame) - logger.info({"msg": f"Migrated state cache to a new frame. {current_frame=}, {new_frame=}"}) - return False - - if has_single_frame and frame_shrunk: - logger.warning({"msg": inv_msg}) - return True - - if not has_single_frame and frame_expanded or frame_shrunk: - logger.warning({"msg": inv_msg}) - return True - - return False - - def _invalidate_on_epoch_range_change(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> bool: - """Check if the epoch range has been invalidated.""" - for epoch_set in (self._epochs_to_process, self._processed_epochs): - if any(epoch < l_epoch or epoch > r_epoch for epoch in epoch_set): - return True - return False + def _migrate_frames_data( + self, current_frames: list[Frame], new_frames: list[Frame] + ) -> tuple[StateData, dict[Frame, bool]]: + migration_status = {frame: False for frame in current_frames} + new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames} + + logger.info({"msg": f"Migrating duties data cache: {current_frames=} -> {new_frames=}"}) + + for current_frame in current_frames: + curr_frame_l_epoch, curr_frame_r_epoch = current_frame + for new_frame in new_frames: + if current_frame == new_frame: + new_data[new_frame] = self.data[current_frame] + migration_status[current_frame] = True + break + + new_frame_l_epoch, new_frame_r_epoch = new_frame + if curr_frame_l_epoch >= new_frame_l_epoch and curr_frame_r_epoch <= new_frame_r_epoch: + logger.info({"msg": f"Migrating frame duties data cache: {current_frame=} -> {new_frame=}"}) + for val in self.data[current_frame]: + new_data[new_frame][val].assigned += self.data[current_frame][val].assigned + new_data[new_frame][val].included += self.data[current_frame][val].included + migration_status[current_frame] = True + break + + return new_data, migration_status def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if not self.is_fulfilled: @@ -209,21 +218,11 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if epoch not in self._processed_epochs: raise InvalidState(f"Epoch {epoch} missing in processed epochs") - @staticmethod - def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: - """Split epochs to process into frames of `epochs_per_frame` length""" - frames = [] - for frame_epochs in batched(epochs_to_process, epochs_per_frame): - if len(frame_epochs) < epochs_per_frame: - raise ValueError("Insufficient epochs to form a frame") - frames.append((frame_epochs[0], frame_epochs[-1])) - return frames - def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator: # TODO: exclude `active_slashed` validators from the calculation included = assigned = 0 frame_data = self.data.get(frame) - if not frame_data: + if frame_data is None: raise ValueError(f"No data for frame {frame} to calculate network aggregate") for validator, acc in frame_data.items(): if acc.included > acc.assigned: diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index d781522e2..b5d8f8808 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -1,258 +1,418 @@ +import os +import pickle +from collections import defaultdict from pathlib import Path from unittest.mock import Mock import pytest -from src.modules.csm.state import AttestationsAccumulator, State -from src.types import EpochNumber, ValidatorIndex +from src import variables +from src.modules.csm.state import AttestationsAccumulator, State, InvalidState +from src.types import ValidatorIndex from src.utils.range import sequence -@pytest.fixture() -def state_file_path(tmp_path: Path) -> Path: - return (tmp_path / "mock").with_suffix(State.EXTENSION) +@pytest.fixture(autouse=True) +def remove_state_files(): + state_file = Path("/tmp/state.pkl") + state_buf = Path("/tmp/state.buf") + state_file.unlink(missing_ok=True) + state_buf.unlink(missing_ok=True) + yield + state_file.unlink(missing_ok=True) + state_buf.unlink(missing_ok=True) + + +def test_load_restores_state_from_file(monkeypatch): + monkeypatch.setattr("src.modules.csm.state.State.file", lambda _=None: Path("/tmp/state.pkl")) + state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state.commit() + loaded_state = State.load() + assert loaded_state.data == state.data -@pytest.fixture(autouse=True) -def mock_state_file(state_file_path: Path): - State.file = Mock(return_value=state_file_path) +def test_load_returns_new_instance_if_file_not_found(monkeypatch): + monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/non/existent/path")) + state = State.load() + assert state.is_empty -def test_attestation_aggregate_perf(): - aggr = AttestationsAccumulator(included=333, assigned=777) - assert aggr.perf == pytest.approx(0.4285, abs=1e-4) +def test_load_returns_new_instance_if_empty_object(monkeypatch, tmp_path): + with open('/tmp/state.pkl', "wb") as f: + pickle.dump(None, f) + monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/tmp/state.pkl")) + state = State.load() + assert state.is_empty + + +def test_commit_saves_state_to_file(monkeypatch): + state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + monkeypatch.setattr("src.modules.csm.state.State.file", lambda _: Path("/tmp/state.pkl")) + monkeypatch.setattr("os.replace", Mock(side_effect=os.replace)) + state.commit() + with open("/tmp/state.pkl", "rb") as f: + loaded_state = pickle.load(f) + assert loaded_state.data == state.data + os.replace.assert_called_once_with(Path("/tmp/state.buf"), Path("/tmp/state.pkl")) + + +def test_file_returns_correct_path(monkeypatch): + monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp")) + assert State.file() == Path("/tmp/cache.pkl") + + +def test_buffer_returns_correct_path(monkeypatch): + monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp")) + state = State() + assert state.buffer == Path("/tmp/cache.buf") + + +def test_is_empty_returns_true_for_empty_state(): + state = State() + assert state.is_empty + + +def test_is_empty_returns_false_for_non_empty_state(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + assert not state.is_empty + + +def test_unprocessed_epochs_raises_error_if_epochs_not_set(): + state = State() + with pytest.raises(ValueError, match="Epochs to process are not set"): + state.unprocessed_epochs + + +def test_unprocessed_epochs_returns_correct_set(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 63)) + assert state.unprocessed_epochs == set(sequence(64, 95)) + + +def test_is_fulfilled_returns_true_if_no_unprocessed_epochs(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + assert state.is_fulfilled + + +def test_is_fulfilled_returns_false_if_unprocessed_epochs_exist(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 63)) + assert not state.is_fulfilled + + +def test_calculate_frames_handles_exact_frame_size(): + epochs = tuple(range(10)) + frames = State.calculate_frames(epochs, 5) + assert frames == [(0, 4), (5, 9)] + + +def test_calculate_frames_raises_error_for_insufficient_epochs(): + epochs = tuple(range(8)) + with pytest.raises(ValueError, match="Insufficient epochs to form a frame"): + State.calculate_frames(epochs, 5) + + +def test_clear_resets_state_to_empty(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)})} + state.clear() + assert state.is_empty + + +def test_find_frame_returns_correct_frame(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + assert state.find_frame(15) == (0, 31) -def test_state_avg_perf(): +def test_find_frame_raises_error_for_out_of_range_epoch(): state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + with pytest.raises(ValueError, match="Epoch 32 is out of frames range"): + state.find_frame(32) - frame = (0, 999) - with pytest.raises(ValueError): - state.get_network_aggr(frame) +def test_increment_duty_adds_duty_correctly(): + state = State() + frame = (0, 31) + state.data = { + frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state.increment_duty(frame, ValidatorIndex(1), True) + assert state.data[frame][ValidatorIndex(1)].assigned == 11 + assert state.data[frame][ValidatorIndex(1)].included == 6 + +def test_increment_duty_creates_new_validator_entry(): state = State() - state.init_or_migrate(*frame, 1000, 1) + frame = (0, 31) state.data = { - frame: { - ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), - ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=0), - } + frame: defaultdict(AttestationsAccumulator), } + state.increment_duty(frame, ValidatorIndex(2), True) + assert state.data[frame][ValidatorIndex(2)].assigned == 1 + assert state.data[frame][ValidatorIndex(2)].included == 1 - assert state.get_network_aggr(frame).perf == 0 +def test_increment_duty_handles_non_included_duty(): + state = State() + frame = (0, 31) state.data = { - frame: { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } + frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), } + state.increment_duty(frame, ValidatorIndex(1), False) + assert state.data[frame][ValidatorIndex(1)].assigned == 11 + assert state.data[frame][ValidatorIndex(1)].included == 5 - assert state.get_network_aggr(frame).perf == 0.5 +def test_increment_duty_raises_error_for_out_of_range_epoch(): + state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator), + } + with pytest.raises(ValueError, match="is not found in the state"): + state.increment_duty((0, 32), ValidatorIndex(1), True) -def test_state_attestations(): - state = State( - { - (0, 999): { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - } - ) - network_aggr = state.get_network_aggr((0, 999)) +def test_add_processed_epoch_adds_epoch_to_processed_set(): + state = State() + state.add_processed_epoch(5) + assert 5 in state._processed_epochs - assert network_aggr.assigned == 1000 - assert network_aggr.included == 500 +def test_add_processed_epoch_does_not_duplicate_epochs(): + state = State() + state.add_processed_epoch(5) + state.add_processed_epoch(5) + assert len(state._processed_epochs) == 1 -def test_state_load(): - orig = State( - { - (0, 999): { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - } - ) - orig.commit() - copy = State.load() - assert copy.data == orig.data +def test_init_or_migrate_discards_data_on_version_change(): + state = State() + state._consensus_version = 1 + state.clear = Mock() + state.commit = Mock() + state.init_or_migrate(0, 63, 32, 2) + state.clear.assert_called_once() + state.commit.assert_called_once() -def test_state_clear(): - state = State( - { - (0, 999): { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - } - ) +def test_init_or_migrate_no_migration_needed(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 32 + state.data = { + (0, 31): defaultdict(AttestationsAccumulator), + (32, 63): defaultdict(AttestationsAccumulator), + } + state.commit = Mock() + state.init_or_migrate(0, 63, 32, 1) + state.commit.assert_not_called() - state._epochs_to_process = (EpochNumber(1), EpochNumber(33)) - state._processed_epochs = {EpochNumber(42), EpochNumber(17)} - state.clear() - assert state.is_empty - assert not state.data +def test_init_or_migrate_migrates_data(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 32 + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + state.commit = Mock() + state.init_or_migrate(0, 63, 64, 1) + assert state.data == { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + state.commit.assert_called_once() + + +def test_init_or_migrate_invalidates_unmigrated_frames(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 64 + state.data = { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + state.commit = Mock() + state.init_or_migrate(0, 31, 32, 1) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator), + } + assert state._processed_epochs == set() + state.commit.assert_called_once() + + +def test_init_or_migrate_discards_unmigrated_frame(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 95)) + state._epochs_per_frame = 32 + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + (64, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 25)}), + } + state._processed_epochs = set(sequence(0, 95)) + state.commit = Mock() + state.init_or_migrate(0, 63, 32, 1) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + assert state._processed_epochs == set(sequence(0, 63)) + state.commit.assert_called_once() + + +def test_migrate_frames_data_creates_new_data_correctly(): + state = State() + current_frames = [(0, 31), (32, 63)] + new_frames = [(0, 63)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}) + } + assert migration_status == {(0, 31): True, (32, 63): True} + + +def test_migrate_frames_data_handles_no_migration(): + state = State() + current_frames = [(0, 31)] + new_frames = [(0, 31)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}) + } + assert migration_status == {(0, 31): True} + + +def test_migrate_frames_data_handles_partial_migration(): + state = State() + current_frames = [(0, 31), (32, 63)] + new_frames = [(0, 31), (32, 95)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + assert migration_status == {(0, 31): True, (32, 63): True} + + +def test_migrate_frames_data_handles_no_data(): + state = State() + current_frames = [(0, 31)] + new_frames = [(0, 31)] + state.data = {frame: defaultdict(AttestationsAccumulator) for frame in current_frames} + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == {(0, 31): defaultdict(AttestationsAccumulator)} + assert migration_status == {(0, 31): True} + + +def test_migrate_frames_data_handles_wider_old_frame(): + state = State() + current_frames = [(0, 63)] + new_frames = [(0, 31), (32, 63)] + state.data = { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 31): defaultdict(AttestationsAccumulator), + (32, 63): defaultdict(AttestationsAccumulator), + } + assert migration_status == {(0, 63): False} + + +def test_validate_raises_error_if_state_not_fulfilled(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 94)) + with pytest.raises(InvalidState, match="State is not fulfilled"): + state.validate(0, 95) + + +def test_validate_raises_error_if_processed_epoch_out_of_range(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + state._processed_epochs.add(96) + with pytest.raises(InvalidState, match="Processed epoch 96 is out of range"): + state.validate(0, 95) + + +def test_validate_raises_error_if_epoch_missing_in_processed_epochs(): + state = State() + state._epochs_to_process = tuple(sequence(0, 94)) + state._processed_epochs = set(sequence(0, 94)) + with pytest.raises(InvalidState, match="Epoch 95 missing in processed epochs"): + state.validate(0, 95) -def test_state_add_processed_epoch(): +def test_validate_passes_for_fulfilled_state(): state = State() - state.add_processed_epoch(EpochNumber(42)) - state.add_processed_epoch(EpochNumber(17)) - assert state._processed_epochs == {EpochNumber(42), EpochNumber(17)} + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + state.validate(0, 95) -def test_state_inc(): - - frame_0 = (0, 999) - frame_1 = (1000, 1999) - - state = State( - { - frame_0: { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - }, - frame_1: { - ValidatorIndex(0): AttestationsAccumulator(included=1, assigned=1), - ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=1), - }, - } - ) - - state.increment_duty(999, ValidatorIndex(0), True) - state.increment_duty(999, ValidatorIndex(0), False) - state.increment_duty(999, ValidatorIndex(1), True) - state.increment_duty(999, ValidatorIndex(1), True) - state.increment_duty(999, ValidatorIndex(1), False) - state.increment_duty(999, ValidatorIndex(2), True) - - state.increment_duty(1000, ValidatorIndex(2), False) - - assert tuple(state.data[frame_0].values()) == ( - AttestationsAccumulator(included=334, assigned=779), - AttestationsAccumulator(included=169, assigned=226), - AttestationsAccumulator(included=1, assigned=1), - ) - - assert tuple(state.data[frame_1].values()) == ( - AttestationsAccumulator(included=1, assigned=1), - AttestationsAccumulator(included=0, assigned=1), - AttestationsAccumulator(included=0, assigned=1), - ) - - -def test_state_file_is_path(): - assert isinstance(State.file(), Path) - - -class TestStateTransition: - """Tests for State's transition for different l_epoch, r_epoch values""" - - @pytest.fixture(autouse=True) - def no_commit(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(State, "commit", Mock()) - - def test_empty_to_new_frame(self): - state = State() - assert state.is_empty - - l_epoch = EpochNumber(1) - r_epoch = EpochNumber(255) - - state.init_or_migrate(l_epoch, r_epoch, 255, 1) - - assert not state.is_empty - assert state.unprocessed_epochs == set(sequence(l_epoch, r_epoch)) - - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new"), - [ - pytest.param(1, 255, 256, 510, id="Migrate a..bA..B"), - pytest.param(1, 255, 32, 510, id="Migrate a..A..b..B"), - pytest.param(32, 510, 1, 255, id="Migrate: A..a..B..b"), - ], - ) - def test_new_frame_requires_discarding_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): - state = State() - state.clear = Mock(side_effect=state.clear) - state.init_or_migrate(l_epoch_old, r_epoch_old, r_epoch_old - l_epoch_old + 1, 1) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch_new, r_epoch_new, r_epoch_new - l_epoch_new + 1, 1) - state.clear.assert_called_once() - - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame"), - [ - pytest.param(1, 255, 1, 510, 255, id="Migrate Aa..b..B"), - ], - ) - def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new, epochs_per_frame): - state = State() - state.clear = Mock(side_effect=state.clear) - - state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame, 1) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame, 1) - state.clear.assert_not_called() - - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - assert len(state.data) == 2 - assert list(state.data.keys()) == [(l_epoch_old, r_epoch_old), (r_epoch_old + 1, r_epoch_new)] - assert state.calculate_frames(state._epochs_to_process, epochs_per_frame) == [ - (l_epoch_old, r_epoch_old), - (r_epoch_old + 1, r_epoch_new), - ] - - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "epochs_per_frame_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame_new"), - [ - pytest.param(32, 510, 479, 1, 510, 510, id="Migrate: A..a..b..B"), - ], - ) - def test_new_frame_extends_old_state_with_single_frame( - self, l_epoch_old, r_epoch_old, epochs_per_frame_old, l_epoch_new, r_epoch_new, epochs_per_frame_new - ): - state = State() - state.clear = Mock(side_effect=state.clear) - - state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame_old, 1) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame_new, 1) - state.clear.assert_not_called() - - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - assert len(state.data) == 1 - assert list(state.data.keys())[0] == (l_epoch_new, r_epoch_new) - assert state.calculate_frames(state._epochs_to_process, epochs_per_frame_new) == [(l_epoch_new, r_epoch_new)] - - @pytest.mark.parametrize( - ("old_version", "new_version"), - [ - pytest.param(2, 3, id="Increase consensus version"), - pytest.param(3, 2, id="Decrease consensus version"), - ], - ) - def test_consensus_version_change(self, old_version, new_version): - state = State() - state.clear = Mock(side_effect=state.clear) - state._consensus_version = old_version - - l_epoch = r_epoch = EpochNumber(255) - - state.init_or_migrate(l_epoch, r_epoch, 1, old_version) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch, r_epoch, 1, new_version) - state.clear.assert_called_once() +def test_attestation_aggregate_perf(): + aggr = AttestationsAccumulator(included=333, assigned=777) + assert aggr.perf == pytest.approx(0.4285, abs=1e-4) + + +def test_get_network_aggr_computes_correctly(): + state = State() + state.data = { + (0, 31): defaultdict( + AttestationsAccumulator, + {ValidatorIndex(1): AttestationsAccumulator(10, 5), ValidatorIndex(2): AttestationsAccumulator(20, 15)}, + ) + } + aggr = state.get_network_aggr((0, 31)) + assert aggr.assigned == 30 + assert aggr.included == 20 + + +def test_get_network_aggr_raises_error_for_invalid_accumulator(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 15)})} + with pytest.raises(ValueError, match="Invalid accumulator"): + state.get_network_aggr((0, 31)) + + +def test_get_network_aggr_raises_error_for_missing_frame_data(): + state = State() + with pytest.raises(ValueError, match="No data for frame"): + state.get_network_aggr((0, 31)) + + +def test_get_network_aggr_handles_empty_frame_data(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + aggr = state.get_network_aggr((0, 31)) + assert aggr.assigned == 0 + assert aggr.included == 0