diff --git a/changelog.d/0000.misc b/changelog.d/0000.misc new file mode 100644 index 000000000000..942ba1217bf0 --- /dev/null +++ b/changelog.d/0000.misc @@ -0,0 +1 @@ +Implemented cancellation support in `EventsWorkerStore._get_events_from_cache_or_db`. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6d6e146ff160..c31fc00eaace 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -75,7 +75,7 @@ from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -640,42 +640,57 @@ async def _get_events_from_cache_or_db( missing_events_ids.difference_update(already_fetching_ids) if missing_events_ids: - log_ctx = current_context() - log_ctx.record_event_fetch(len(missing_events_ids)) - - # Add entries to `self._current_event_fetches` for each event we're - # going to pull from the DB. We use a single deferred that resolves - # to all the events we pulled from the DB (this will result in this - # function returning more events than requested, but that can happen - # already due to `_get_events_from_db`). - fetching_deferred: ObservableDeferred[ - Dict[str, EventCacheEntry] - ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) - for event_id in missing_events_ids: - self._current_event_fetches[event_id] = fetching_deferred - - # Note that _get_events_from_db is also responsible for turning db rows - # into FrozenEvents (via _get_event_from_row), which involves seeing if - # the events have been redacted, and if so pulling the redaction event out - # of the database to check it. - # - try: - missing_events = await self._get_events_from_db( - missing_events_ids, - ) - event_entry_map.update(missing_events) - except Exception as e: - with PreserveLoggingContext(): - fetching_deferred.errback(e) - raise e - finally: - # Ensure that we mark these events as no longer being fetched. + async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]: + """Fetches the events in `missing_event_ids` from the database. + + Also creates entries in `self._current_event_fetches` to allow + concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch. + """ + log_ctx = current_context() + log_ctx.record_event_fetch(len(missing_events_ids)) + + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, EventCacheEntry] + ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) for event_id in missing_events_ids: - self._current_event_fetches.pop(event_id, None) + self._current_event_fetches[event_id] = fetching_deferred - with PreserveLoggingContext(): - fetching_deferred.callback(missing_events) + # Note that _get_events_from_db is also responsible for turning db rows + # into FrozenEvents (via _get_event_from_row), which involves seeing if + # the events have been redacted, and if so pulling the redaction event + # out of the database to check it. + # + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + return missing_events + + # We must allow the database fetch to complete in the presence of + # cancellations, since multiple `_get_events_from_cache_or_db` calls can + # reuse the same fetch. + missing_events: Dict[str, EventCacheEntry] = await delay_cancellation( + get_missing_events_from_db() + ) + event_entry_map.update(missing_events) if already_fetching_deferreds: # Wait for the other event requests to finish and add their results diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 1f6a9eb07bfc..bf6374f93d52 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -13,10 +13,11 @@ # limitations under the License. import json from contextlib import contextmanager -from typing import Generator +from typing import Generator, Tuple +from unittest import mock from twisted.enterprise.adbapi import ConnectionPool -from twisted.internet.defer import ensureDeferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import EventFormatVersions, RoomVersions @@ -281,3 +282,119 @@ def test_recovery(self) -> None: # This next event fetch should succeed self.get_success(self.store.get_event(self.event_ids[0])) + + +class GetEventCancellationTestCase(unittest.HomeserverTestCase): + """Test cancellation of `get_event` calls.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store: EventsWorkerStore = hs.get_datastores().main + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + @contextmanager + def blocking_get_event_calls( + self, + ) -> Generator[ + Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None + ]: + """Starts two concurrent `get_event` calls for the same event. + + Both `get_event` calls will use the same database fetch, which will be blocked + at the time this function returns. + + Returns: + A tuple containing: + * A `Deferred` that unblocks the database fetch. + * A cancellable `Deferred` for the first `get_event` call. + * A cancellable `Deferred` for the second `get_event` call. + """ + # Patch `DatabasePool.runWithConnection` to block. + unblock: "Deferred[None]" = Deferred() + original_runWithConnection = self.store.db_pool.runWithConnection + + async def runWithConnection(*args, **kwargs): + await unblock + return await original_runWithConnection(*args, **kwargs) + + with mock.patch.object( + self.store.db_pool, + "runWithConnection", + new=runWithConnection, + ): + ctx1 = LoggingContext("get_event1") + ctx2 = LoggingContext("get_event2") + + async def get_event(ctx: LoggingContext) -> None: + with ctx: + await self.store.get_event(self.event_id) + + get_event1 = ensureDeferred(get_event(ctx1)) + get_event2 = ensureDeferred(get_event(ctx2)) + + # Both `get_event` calls ought to be blocked. + self.assertNoResult(get_event1) + self.assertNoResult(get_event2) + + yield unblock, get_event1, get_event2 + + # Confirm that the two `get_event` calls shared the same database fetch. + self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1) + self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0) + + def test_first_get_event_cancelled(self): + """Test cancellation of the first `get_event` call sharing a database fetch. + + The first `get_event` call is the one which initiates the fetch. We expect the + fetch to complete despite the cancellation. Furthermore, the first `get_event` + call must not abort before the fetch is complete, otherwise the fetch will be + using a finished logging context. + """ + with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): + # Cancel the first `get_event` call. + get_event1.cancel() + # The first `get_event` call must not abort immediately, otherwise its + # logging context will be finished while it is still in use by the database + # fetch. + self.assertNoResult(get_event1) + # The second `get_event` call must not be cancelled. + self.assertNoResult(get_event2) + + # Unblock the database fetch. + unblock.callback(None) + # A `CancelledError` should be raised out of the first `get_event` call. + exc = self.get_failure(get_event1, CancelledError).value + self.assertIsInstance(exc, CancelledError) + # The second `get_event` call should complete successfully. + self.get_success(get_event2) + + def test_second_get_event_cancelled(self): + """Test cancellation of the second `get_event` call sharing a database fetch.""" + with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): + # Cancel the second `get_event` call. + get_event2.cancel() + # The first `get_event` call must not be cancelled. + self.assertNoResult(get_event1) + # The second `get_event` call gets cancelled immediately. + exc = self.get_failure(get_event2, CancelledError).value + self.assertIsInstance(exc, CancelledError) + + # Unblock the database fetch. + unblock.callback(None) + # The first `get_event` call should complete successfully. + self.get_success(get_event1)