Skip to content

Commit

Permalink
Fix the IndexError in caused by a race condition in parallel sequenci…
Browse files Browse the repository at this point in the history
…ng batches and pruning old batches. (#83)
  • Loading branch information
AlirezaRoshanzamir authored Feb 15, 2025
1 parent 6c5652f commit 33cac9f
Showing 1 changed file with 54 additions and 21 deletions.
75 changes: 54 additions & 21 deletions common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import itertools
import portion # type: ignore[import-untyped]
from typing import TypedDict
from collections.abc import Iterable

from collections.abc import Iterable, Callable

State = Literal["initialized", "sequenced", "locked", "finalized"]
OperationalState = Literal["sequenced", "locked", "finalized"]
Expand Down Expand Up @@ -312,7 +311,9 @@ def get_global_operational_batches_sequence(

self._append_unique_batches_after_index(
self._filter_operational_batches_sequence(
app_name, self._get_batch_index_interval(app_name, state)
app_name,
self._get_batch_index_interval(app_name, state)
& portion.open(after, portion.inf),
),
after,
batches_sequence,
Expand Down Expand Up @@ -463,9 +464,7 @@ def lock_batches(self, app_name: str, signature_data: SignatureData) -> None:
for batch in self._filter_operational_batches_sequence(
app_name,
self._get_batch_index_interval(app_name, "finalized").complement()
& portion.closed(
lower=self._GLOBAL_FIRST_BATCH_INDEX, upper=signature_data["index"]
),
& portion.closed(self._GLOBAL_FIRST_BATCH_INDEX, signature_data["index"]),
):
batch["state"] = "locked"

Expand Down Expand Up @@ -501,9 +500,7 @@ def finalize_batches(self, app_name: str, signature_data: SignatureData) -> None
snapshot_indexes: list[int] = []
for batch in self._filter_operational_batches_sequence(
app_name,
portion.closed(
lower=self._GLOBAL_FIRST_BATCH_INDEX, upper=signature_data["index"]
),
portion.closed(self._GLOBAL_FIRST_BATCH_INDEX, signature_data["index"]),
):
batch["state"] = "finalized"
if batch["index"] % zconfig.SNAPSHOT_CHUNK == 0:
Expand Down Expand Up @@ -768,22 +765,52 @@ def _filter_operational_batches_sequence(
self,
app_name: str,
index_interval: portion.Interval,
limit: int | None = None,
) -> list[Batch]:
feasible_index_interval = index_interval.intersection(
self._get_batch_index_interval(app_name)
)
relative_index_interval = feasible_index_interval.apply(
if index_interval.empty:
return []

relative_index_calculator = self._generate_relative_index_calculator(app_name)
relative_index_interval = index_interval.apply(
lambda x: x.replace(
lower=self._calculate_relative_index(app_name, x.lower),
upper=self._calculate_relative_index(app_name, x.upper),
lower=(
x.lower
if x.lower == -portion.inf
else relative_index_calculator(x.lower)
),
upper=(
x.upper
if x.upper == portion.inf
else relative_index_calculator(x.upper)
),
)
)
return [
self.apps[app_name]["operational_batches_sequence"][i]
for i in itertools.islice(
portion.iterate(relative_index_interval, step=1), limit

feasible_relative_index_interval = relative_index_interval & portion.closed(
0, portion.inf
)

if not feasible_relative_index_interval.atomic:
raise ValueError(
f"The {feasible_relative_index_interval=} from "
f"{relative_index_interval=} and {index_interval=} is not atomic."
)

closed_lower_relative_index = None
if feasible_relative_index_interval.lower != 0:
if relative_index_interval.left == portion.CLOSED:
closed_lower_relative_index = feasible_relative_index_interval.lower
else:
closed_lower_relative_index = feasible_relative_index_interval.lower + 1

open_upper_relative_index = None
if feasible_relative_index_interval.upper != portion.inf:
if feasible_relative_index_interval.right == portion.CLOSED:
open_upper_relative_index = feasible_relative_index_interval.upper + 1
else:
open_upper_relative_index = feasible_relative_index_interval.upper

return self.apps[app_name]["operational_batches_sequence"][
closed_lower_relative_index:open_upper_relative_index
]

def _get_operational_batch_by_hash(self, app_name: str, batch_hash: str) -> Batch:
Expand All @@ -804,9 +831,15 @@ def _batch_exists(self, app_name: str, batch_hash: str) -> bool:
)

def _calculate_relative_index(self, app_name: str, index: int) -> int:
return index - self._get_first_batch_index(
return self._generate_relative_index_calculator(app_name)(index)

def _generate_relative_index_calculator(
self, app_name: str
) -> Callable[[int], int]:
first_batch_index = self._get_first_batch_index(
app_name, default=self._GLOBAL_FIRST_BATCH_INDEX
)
return lambda index: index - first_batch_index

def _get_batch_index_interval(
self,
Expand Down

0 comments on commit 33cac9f

Please sign in to comment.