Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Optimise _get_state_after_missing_prev_event: use /state #12040

Merged
merged 3 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12040.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimise fetching large quantities of missing room state over federation.
43 changes: 39 additions & 4 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,24 @@ async def _get_state_after_missing_prev_event(
logger.debug("We are also missing %i auth events", len(missing_auth_events))

missing_events = missing_desired_events | missing_auth_events
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=missing_events
)

# Making an individual request for each of 1000s of events has a lot of
# overhead. On the other hand, we don't really want to fetch all of the events
# if we already have most of them.
#
# As an arbitrary heuristic, if we are missing more than 10% of the events, then
# we fetch the whole state.
#
# TODO: might it be better to have an API which lets us do an aggregate event
# request
if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
logger.debug("Requesting complete state from remote")
await self._get_state_and_persist(destination, room_id, event_id)
else:
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=missing_events
)

# we need to make sure we re-load from the database to get the rejected
# state correct.
Expand Down Expand Up @@ -959,6 +973,27 @@ async def _get_state_after_missing_prev_event(

return remote_state

async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
) -> None:
"""Get the complete room state at a given event, and persist any new events
as outliers"""
room_version = await self._store.get_room_version(room_id)
auth_events, state_events = await self._federation_client.get_room_state(
destination, room_id, event_id=event_id, room_version=room_version
)
logger.info("/state returned %i events", len(auth_events) + len(state_events))

await self._auth_and_persist_outliers(
room_id, itertools.chain(auth_events, state_events)
)

# we also need the event itself.
if not await self._store.have_seen_events(room_id, event_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want

Suggested change
if not await self._store.have_seen_events(room_id, event_id):
if not await self._store.have_seen_events(room_id, (event_id,)):

Mypy can't spot this because x: str implies x: Iterable[str]. @squahtx noted this pain point a few weeks ago I think.

(Given that CI passed, are we missing some kind of test here?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, thank you.

I've fixed this by implementing have_seen_event and using that instead, which seems to have a lower potential for foot-shooting.

I've also added a UT which catches this.

The failure mode here was relatively edge-casey (it requires us to have the event, but not its state), and benign (it resulted in a redundant call to _matrix/federation/event), which was why this wasn't being picked up in the tests before.

await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=(event_id,)
)

async def _process_received_pdu(
self,
origin: str,
Expand Down