Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add a get_next_txn method to StreamIdGenerator to match `MultiWri…
Browse files Browse the repository at this point in the history
…terIdGenerator` (#15191
  • Loading branch information
anoadragon453 authored Mar 2, 2023
1 parent ecbe0dd commit 1eea662
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog.d/15191.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `get_next_txn` method to `StreamIdGenerator` to match `MultiWriterIdGenerator`.
11 changes: 2 additions & 9 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
Expand All @@ -64,14 +63,12 @@ def __init__(
):
super().__init__(database, db_conn, hs)

# `_can_write_to_account_data` indicates whether the current worker is allowed
# to write account data. A value of `True` implies that `_account_data_id_gen`
# is an `AbstractStreamIdGenerator` and not just a tracker.
self._account_data_id_gen: AbstractStreamIdTracker
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)

self._account_data_id_gen: AbstractStreamIdGenerator

if isinstance(database.engine, PostgresEngine):
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
Expand Down Expand Up @@ -558,7 +555,6 @@ async def add_account_data_to_room(
The maximum stream ID.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

content_json = json_encoder.encode(content)

Expand Down Expand Up @@ -598,7 +594,6 @@ async def remove_account_data_for_room(
data to delete.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

def _remove_account_data_for_room_txn(
txn: LoggingTransaction, next_id: int
Expand Down Expand Up @@ -663,7 +658,6 @@ async def add_account_data_for_user(
The maximum stream ID.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
Expand Down Expand Up @@ -770,7 +764,6 @@ async def remove_account_data_for_user(
to delete.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

def _remove_account_data_for_user_txn(
txn: LoggingTransaction, next_id: int
Expand Down
45 changes: 44 additions & 1 deletion synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
raise NotImplementedError()

@abc.abstractmethod
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Usage:
stream_id_gen.get_next_txn(txn)
# ... persist events ...
"""
raise NotImplementedError()


class StreamIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with a single writer.
Expand Down Expand Up @@ -263,6 +272,40 @@ def manager() -> Generator[Sequence[int], None, None]:

return _AsyncCtxManagerWrapper(manager())

def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Retrieve the next stream ID from within a database transaction.
Clean-up functions will be called when the transaction finishes.
Args:
txn: The database transaction object.
Returns:
The next stream ID.
"""
if not self._is_writer:
raise Exception("Tried to allocate stream ID on non-writer")

# Get the next stream ID.
with self._lock:
self._current += self._step
next_id = self._current

self._unfinished_ids[next_id] = next_id

def clear_unfinished_id(id_to_clear: int) -> None:
"""A function to mark processing this ID as finished"""
with self._lock:
self._unfinished_ids.pop(id_to_clear)

# Mark this ID as finished once the database transaction itself finishes.
txn.call_after(clear_unfinished_id, next_id)
txn.call_on_exception(clear_unfinished_id, next_id)

# Return the new ID.
return next_id

def get_current_token(self) -> int:
if not self._is_writer:
return self._current
Expand Down Expand Up @@ -568,7 +611,7 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Usage:
stream_id = stream_id_gen.get_next(txn)
stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ...
"""

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/util/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(self, get_first_callback: GetFirstCallbackType):
"""
Args:
get_first_callback: a callback which is called on the first call to
get_next_id_txn; should return the curreent maximum id
get_next_id_txn; should return the current maximum id
"""
# the callback. this is cleared after it is called, so that it can be GCed.
self._callback: Optional[GetFirstCallbackType] = get_first_callback
Expand Down

0 comments on commit 1eea662

Please sign in to comment.