Skip to content

Commit

Permalink
refactor: State and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Feb 13, 2025
1 parent 8fcfe96 commit 1aa7228
Show file tree
Hide file tree
Showing 3 changed files with 445 additions and 286 deletions.
4 changes: 2 additions & 2 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ def _check_duty(
for root in block_roots:
attestations = self.cc.get_block_attestations(root)
process_attestations(attestations, committees, self.eip7549_supported)

frame = self.state.find_frame(duty_epoch)
with lock:
for committee in committees.values():
for validator_duty in committee:
self.state.increment_duty(
duty_epoch,
frame,
validator_duty.index,
included=validator_duty.included,
)
Expand Down
135 changes: 67 additions & 68 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

logger = logging.getLogger(__name__)

type Frame = tuple[EpochNumber, EpochNumber]


class InvalidState(ValueError):
"""State has data considered as invalid for a report"""
Expand All @@ -36,6 +34,10 @@ def add_duty(self, included: bool) -> None:
self.included += 1 if included else 0


type Frame = tuple[EpochNumber, EpochNumber]
type StateData = dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]]


class State:
"""
Processing state of a CSM performance oracle frame.
Expand All @@ -46,18 +48,16 @@ class State:
The state can be migrated to be used for another frame's report by calling the `migrate` method.
"""
data: dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]]
data: StateData

_epochs_to_process: tuple[EpochNumber, ...]
_processed_epochs: set[EpochNumber]
_epochs_per_frame: int

_consensus_version: int = 1

def __init__(self, data: dict[Frame, dict[ValidatorIndex, AttestationsAccumulator]] | None = None) -> None:
self.data = {
frame: defaultdict(AttestationsAccumulator, validators) for frame, validators in (data or {}).items()
}
def __init__(self) -> None:
self.data = {}
self._epochs_to_process = tuple()
self._processed_epochs = set()
self._epochs_per_frame = 0
Expand Down Expand Up @@ -110,6 +110,16 @@ def unprocessed_epochs(self) -> set[EpochNumber]:
def is_fulfilled(self) -> bool:
return not self.unprocessed_epochs

@staticmethod
def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
"""Split epochs to process into frames of `epochs_per_frame` length"""
frames = []
for frame_epochs in batched(epochs_to_process, epochs_per_frame):
if len(frame_epochs) < epochs_per_frame:
raise ValueError("Insufficient epochs to form a frame")
frames.append((frame_epochs[0], frame_epochs[-1]))
return frames

def clear(self) -> None:
self.data = {}
self._epochs_to_process = tuple()
Expand All @@ -123,17 +133,20 @@ def find_frame(self, epoch: EpochNumber) -> Frame:
return epoch_range
raise ValueError(f"Epoch {epoch} is out of frames range: {frames}")

def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None:
epoch_range = self.find_frame(epoch)
self.data[epoch_range][val_index].add_duty(included)
def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None:
if frame not in self.data:
raise ValueError(f"Frame {frame} is not found in the state")
self.data[frame][val_index].add_duty(included)

def add_processed_epoch(self, epoch: EpochNumber) -> None:
self._processed_epochs.add(epoch)

def log_progress(self) -> None:
logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"})

def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int) -> None:
def init_or_migrate(
self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int
) -> None:
if consensus_version != self._consensus_version:
logger.warning(
{
Expand All @@ -143,59 +156,55 @@ def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per
)
self.clear()

frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames}

if not self.is_empty:
invalidated = self._migrate_or_invalidate(l_epoch, r_epoch, epochs_per_frame)
if invalidated:
self.clear()
cached_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
if cached_frames == frames:
logger.info({"msg": "No need to migrate duties data cache"})
return

frames_data, migration_status = self._migrate_frames_data(cached_frames, frames)

for current_frame, migrated in migration_status.items():
if not migrated:
logger.warning({"msg": f"Invalidating frame duties data cache: {current_frame}"})
for epoch in sequence(*current_frame):
self._processed_epochs.discard(epoch)

self._fill_frames(l_epoch, r_epoch, epochs_per_frame)
self.data = frames_data
self._epochs_per_frame = epochs_per_frame
self._epochs_to_process = tuple(sequence(l_epoch, r_epoch))
self._consensus_version = consensus_version
self.commit()

def _fill_frames(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> None:
frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
for frame in frames:
self.data.setdefault(frame, defaultdict(AttestationsAccumulator))

def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> bool:
current_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
new_frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
inv_msg = f"Discarding invalid state cache because of frames change. {current_frames=}, {new_frames=}"

if self._invalidate_on_epoch_range_change(l_epoch, r_epoch):
logger.warning({"msg": inv_msg})
return True

frame_expanded = epochs_per_frame > self._epochs_per_frame
frame_shrunk = epochs_per_frame < self._epochs_per_frame

has_single_frame = len(current_frames) == len(new_frames) == 1

if has_single_frame and frame_expanded:
current_frame, *_ = current_frames
new_frame, *_ = new_frames
self.data[new_frame] = self.data.pop(current_frame)
logger.info({"msg": f"Migrated state cache to a new frame. {current_frame=}, {new_frame=}"})
return False

if has_single_frame and frame_shrunk:
logger.warning({"msg": inv_msg})
return True

if not has_single_frame and frame_expanded or frame_shrunk:
logger.warning({"msg": inv_msg})
return True

return False

def _invalidate_on_epoch_range_change(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> bool:
"""Check if the epoch range has been invalidated."""
for epoch_set in (self._epochs_to_process, self._processed_epochs):
if any(epoch < l_epoch or epoch > r_epoch for epoch in epoch_set):
return True
return False
def _migrate_frames_data(
self, current_frames: list[Frame], new_frames: list[Frame]
) -> tuple[StateData, dict[Frame, bool]]:
migration_status = {frame: False for frame in current_frames}
new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames}

logger.info({"msg": f"Migrating duties data cache: {current_frames=} -> {new_frames=}"})

for current_frame in current_frames:
curr_frame_l_epoch, curr_frame_r_epoch = current_frame
for new_frame in new_frames:
if current_frame == new_frame:
new_data[new_frame] = self.data[current_frame]
migration_status[current_frame] = True
break

new_frame_l_epoch, new_frame_r_epoch = new_frame
if curr_frame_l_epoch >= new_frame_l_epoch and curr_frame_r_epoch <= new_frame_r_epoch:
logger.info({"msg": f"Migrating frame duties data cache: {current_frame=} -> {new_frame=}"})
for val in self.data[current_frame]:
new_data[new_frame][val].assigned += self.data[current_frame][val].assigned
new_data[new_frame][val].included += self.data[current_frame][val].included
migration_status[current_frame] = True
break

return new_data, migration_status

def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
if not self.is_fulfilled:
Expand All @@ -209,21 +218,11 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
if epoch not in self._processed_epochs:
raise InvalidState(f"Epoch {epoch} missing in processed epochs")

@staticmethod
def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
"""Split epochs to process into frames of `epochs_per_frame` length"""
frames = []
for frame_epochs in batched(epochs_to_process, epochs_per_frame):
if len(frame_epochs) < epochs_per_frame:
raise ValueError("Insufficient epochs to form a frame")
frames.append((frame_epochs[0], frame_epochs[-1]))
return frames

def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator:
# TODO: exclude `active_slashed` validators from the calculation
included = assigned = 0
frame_data = self.data.get(frame)
if not frame_data:
if frame_data is None:
raise ValueError(f"No data for frame {frame} to calculate network aggregate")
for validator, acc in frame_data.items():
if acc.included > acc.assigned:
Expand Down
Loading

0 comments on commit 1aa7228

Please sign in to comment.