diff --git a/changelog.d/15617.feature b/changelog.d/15617.feature new file mode 100644 index 000000000000..092d5f483147 --- /dev/null +++ b/changelog.d/15617.feature @@ -0,0 +1 @@ +Make `/messages` faster by efficiently grabbing state out of database whenever we have to backfill and process new events. diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 097dea51828c..6826e9676187 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -89,6 +89,18 @@ def _get_state_groups_from_groups_txn( groups: List[int], state_filter: Optional[StateFilter] = None, ) -> Mapping[int, StateMap[str]]: + """ + Given a number of state groups, fetch the latest state for each group. + + Args: + txn: The transaction object. + groups: The given state groups that you want to fetch the latest state for. + state_filter: The state filter to apply the state we fetch state from the database. + + Returns: + Map from state_group to a StateMap at that point. + """ + state_filter = state_filter or StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} @@ -98,24 +110,49 @@ def _get_state_groups_from_groups_txn( # a temporary hack until we can add the right indices in txn.execute("SET LOCAL enable_seqscan=off") - # The below query walks the state_group tree so that the "state" + # The query below walks the state_group tree so that the "state" # table includes all state_groups in the tree. It then joins # against `state_groups_state` to fetch the latest state. # It assumes that previous state groups are always numerically # lesser. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. sql = """ - WITH RECURSIVE sgs(state_group) AS ( - VALUES(?::bigint) + WITH RECURSIVE sgs(state_group, state_group_reached) AS ( + VALUES(?::bigint, NULL::bigint) UNION ALL - SELECT prev_state_group FROM state_group_edges e, sgs s - WHERE s.state_group = e.state_group + SELECT + prev_state_group, + CASE + /* Specify state_groups we have already done the work for */ + WHEN @prev_state_group IN (%s /* state_groups_we_have_already_fetched_string */) THEN prev_state_group + ELSE NULL + END AS state_group_reached + FROM + state_group_edges e, sgs s + WHERE + s.state_group = e.state_group + /* Stop when we connect up to another state_group that we already did the work for */ + AND s.state_group_reached IS NULL ) - %s + %s /* overall_select_clause */ """ overall_select_query_args: List[Union[int, str]] = [] + # Make sure we always have a row that tells us if we linked up to another + # state_group chain that we already processed (indicated by + # `state_group_reached`) regardless of whether we find any state according + # to the state_filter. + # + # We use a `UNION ALL` to make sure it is always the first row returned. + # `UNION` will merge and sort in with the rows from the next query + # otherwise. + overall_select_clause = """ + ( + SELECT NULL, NULL, NULL, state_group_reached + FROM sgs + ORDER BY state_group ASC + LIMIT 1 + ) UNION ALL (%s /* main_select_clause */) + """ # This is an optimization to create a select clause per-condition. This # makes the query planner a lot smarter on what rows should pull out in the @@ -154,7 +191,7 @@ def _get_state_groups_from_groups_txn( f""" ( SELECT DISTINCT ON (type, state_key) - type, state_key, event_id + type, state_key, event_id, state_group FROM state_groups_state INNER JOIN sgs USING (state_group) WHERE {where_clause} @@ -163,7 +200,7 @@ def _get_state_groups_from_groups_txn( """ ) - overall_select_clause = " UNION ".join(select_clause_list) + main_select_clause = " UNION ".join(select_clause_list) else: where_clause, where_args = state_filter.make_sql_filter_clause() # Unless the filter clause is empty, we're going to append it after an @@ -173,9 +210,9 @@ def _get_state_groups_from_groups_txn( overall_select_query_args.extend(where_args) - overall_select_clause = f""" + main_select_clause = f""" SELECT DISTINCT ON (type, state_key) - type, state_key, event_id + type, state_key, event_id, state_group FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM sgs @@ -183,15 +220,73 @@ def _get_state_groups_from_groups_txn( ORDER BY type, state_key, state_group DESC """ - for group in groups: + # We can sort from least to greatest state_group and re-use the work from a + # lesser state_group for a greater one if we see that the edge chain links + # up. + # + # What this means in practice is that if we fetch the latest state for + # `state_group = 20`, and then we want `state_group = 30`, it will traverse + # down the edge chain to `20`, see that we linked up to `20` and bail out + # early and re-use the work we did for `20`. This can have massive savings + # in rooms like Matrix HQ where the edge chain is 88k events long and + # fetching the mostly-same chain over and over isn't very efficient. + sorted_groups = sorted(groups) + state_groups_we_have_already_fetched: Set[int] = { + # We default to `[-1]` just to fill in the query with something that + # will have no effect but not bork our query when it would be empty + # otherwise + -1 + } + for group in sorted_groups: args: List[Union[int, str]] = [group] + args.extend(state_groups_we_have_already_fetched) args.extend(overall_select_query_args) - txn.execute(sql % (overall_select_clause,), args) + state_groups_we_have_already_fetched_string = ", ".join( + ["?::bigint"] * len(state_groups_we_have_already_fetched) + ) + + txn.execute( + sql + % ( + state_groups_we_have_already_fetched_string, + overall_select_clause % (main_select_clause,), + ), + args, + ) + + # The first row is always our special `state_group_reached` row which + # tells us if we linked up to any other existing state_group that we + # already fetched and if so, which one we linked up to (see the `UNION + # ALL` above which drives this special row) + first_row = txn.fetchone() + if first_row: + _, _, _, state_group_reached = first_row + + partial_state_map_for_state_group: MutableStateMap[str] = {} for row in txn: - typ, state_key, event_id = row + typ, state_key, event_id, _state_group = row key = (intern_string(typ), intern_string(state_key)) - results[group][key] = event_id + partial_state_map_for_state_group[key] = event_id + + # If we see a state_group edge link to a previous state_group that we + # already fetched from the database, link up the base state to the + # partial state we retrieved from the database to build on top of. + if state_group_reached in results: + resultant_state_map = dict(results[state_group_reached]) + resultant_state_map.update(partial_state_map_for_state_group) + + results[group] = resultant_state_map + else: + # It's also completely normal for us not to have a previous + # state_group to build on top of if this is the first group being + # processed or we are processing a bunch of groups from different + # rooms which of course will never link together (competely + # different DAGs). + results[group] = partial_state_map_for_state_group + + state_groups_we_have_already_fetched.add(group) + else: max_entries_returned = state_filter.max_entries_returned() @@ -201,8 +296,9 @@ def _get_state_groups_from_groups_txn( if where_clause: where_clause = " AND (%s)" % (where_clause,) - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + # XXX: We could `WITH RECURSIVE` here since it's supported on SQLite 3.8.3 + # or higher and our minimum supported version is greater than that. We just + # haven't put in the time to refactor this. for group in groups: next_group: Optional[int] = group