From 49de631413e83b05a0d122f5dafb32f9bf9a6308 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Sep 2021 15:16:19 -0400 Subject: [PATCH 1/7] Convert _GetStateGroupDelta to attrs. --- synapse/storage/databases/state/store.py | 18 +++++++++--------- synapse/storage/state.py | 3 ++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f839c0c24f1c..c9549edc94c4 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,9 +13,10 @@ # limitations under the License. import logging -from collections import namedtuple from typing import Dict, Iterable, List, Optional, Set, Tuple +import attr + from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -33,16 +34,15 @@ MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _GetStateGroupDelta: """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching + us use the iterable flag when caching """ + prev_group: Optional[int] + delta_ids: Optional[StateMap[str]] - __slots__ = [] - - def __len__(self): + def __len__(self) -> int: return len(self.delta_ids) if self.delta_ids else 0 @@ -110,7 +110,7 @@ async def get_state_group_delta(self, state_group): the old and the new. Returns: - (prev_group, delta_ids), where both may be None. + _GetStateGroupDelta containing prev_group and delta_ids, where both may be None. """ def _get_state_group_delta_txn(txn): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c76529cb5710..5e86befde430 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -377,7 +377,8 @@ async def get_state_group_delta( make up the delta between the old and new state groups. """ - return await self.stores.state.get_state_group_delta(state_group) + state_group_delta = await self.stores.state.get_state_group_delta(state_group) + return state_group_delta.prev_group, state_group_delta.delta_ids async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] From 926d76f612b73698787799c79cf19b15e6f4f6f8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Sep 2021 15:18:53 -0400 Subject: [PATCH 2/7] Add type hints to state store. --- synapse/storage/databases/state/store.py | 114 +++++++++++++++-------- synapse/util/caches/dictionary_cache.py | 4 +- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index c9549edc94c4..33685c43f3db 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,21 +13,28 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple import attr from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateKey, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -39,6 +46,7 @@ class _GetStateGroupDelta: """Return type of get_state_group_delta that implements __len__, which lets us use the iterable flag when caching """ + prev_group: Optional[int] delta_ids: Optional[StateMap[str]] @@ -49,7 +57,12 @@ def __len__(self) -> int: class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups.""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the @@ -81,19 +94,21 @@ def __init__(self, database: DatabasePool, db_conn, hs): # We size the non-members cache to be smaller than the members cache as the # vast majority of state in Matrix (today) is member events. - self._state_group_cache = DictionaryCache( + self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet 50000, ) - self._state_group_members_cache = DictionaryCache( + self._state_group_members_cache: DictionaryCache[ + int, StateKey, str + ] = DictionaryCache( "*stateGroupMembersCache*", 500000, ) - def get_max_state_group_txn(txn: Cursor): + def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - return txn.fetchone()[0] + return txn.fetchone()[0] # type: ignore self._state_group_seq_gen = build_sequence_generator( db_conn, @@ -105,7 +120,7 @@ def get_max_state_group_txn(txn: Cursor): ) @cached(max_entries=10000, iterable=True) - async def get_state_group_delta(self, state_group): + async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta: """Given a state group try to return a previous group and a delta between the old and the new. @@ -113,7 +128,7 @@ async def get_state_group_delta(self, state_group): _GetStateGroupDelta containing prev_group and delta_ids, where both may be None. """ - def _get_state_group_delta_txn(txn): + def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta: prev_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", @@ -168,19 +183,24 @@ async def _get_state_groups_from_groups( return results - def _get_state_for_group_using_cache(self, cache, group, state_filter): + def _get_state_for_group_using_cache( + self, + cache: DictionaryCache[int, StateKey, str], + group: int, + state_filter: StateFilter, + ) -> Tuple[MutableStateMap[str], bool]: """Checks if group is in cache. See `_get_state_for_groups` Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. + cache: the state group cache to use + group: The state group to lookup + state_filter: The state filter used to fetch state from the database. - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. + Returns: + 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. """ cache_entry = cache.get(group) state_dict_ids = cache_entry.value @@ -277,8 +297,11 @@ async def _get_state_for_groups( return state def _get_state_for_groups_using_cache( - self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter - ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: + self, + groups: Iterable[int], + cache: DictionaryCache[int, StateKey, str], + state_filter: StateFilter, + ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. @@ -310,21 +333,21 @@ def _get_state_for_groups_using_cache( def _insert_into_cache( self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): + group_to_state_dict: Dict[int, StateMap[str]], + state_filter: StateFilter, + cache_seq_num_members: int, + cache_seq_num_non_members: int, + ) -> None: """Inserts results from querying the database into the relevant cache. Args: - group_to_state_dict (dict): The new entries pulled from database. + group_to_state_dict: The new entries pulled from database. Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state + state_filter: The state filter used to fetch state from the database. - cache_seq_num_members (int): Sequence number of member cache since + cache_seq_num_members: Sequence number of member cache since last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since + cache_seq_num_non_members: Sequence number of member cache since last lookup in cache """ @@ -395,7 +418,7 @@ async def store_state_group( The state group ID """ - def _store_state_group_txn(txn): + def _store_state_group_txn(txn: LoggingTransaction) -> int: if current_state_ids is None: # AFAIK, this can never happen raise Exception("current_state_ids cannot be None") @@ -426,6 +449,8 @@ def _store_state_group_txn(txn): potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + assert delta_ids is not None + self.db_pool.simple_insert_txn( txn, table="state_group_edges", @@ -498,7 +523,7 @@ def _store_state_group_txn(txn): ) async def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete + self, room_id: str, state_groups_to_delete: Collection[int] ) -> None: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -506,8 +531,7 @@ async def purge_unreferenced_state_groups( Args: room_id: The room the state groups belong to (must all be in the same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. + state_groups_to_delete: Set of all state groups to delete. """ await self.db_pool.runInteraction( @@ -517,7 +541,12 @@ async def purge_unreferenced_state_groups( state_groups_to_delete, ) - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + def _purge_unreferenced_state_groups( + self, + txn: LoggingTransaction, + room_id: str, + state_groups_to_delete: Collection[int], + ) -> None: logger.info( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) @@ -605,12 +634,14 @@ async def get_previous_state_groups( return {row["state_group"]: row["prev_state_group"] for row in rows} - async def purge_room_state(self, room_id, state_groups_to_delete): + async def purge_room_state( + self, room_id: str, state_groups_to_delete: Collection[int] + ) -> None: """Deletes all record of a room from state tables Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete + room_id: + state_groups_to_delete: State groups to delete """ await self.db_pool.runInteraction( @@ -620,7 +651,12 @@ async def purge_room_state(self, room_id, state_groups_to_delete): state_groups_to_delete, ) - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + def _purge_room_state_txn( + self, + txn: LoggingTransaction, + room_id: str, + state_groups_to_delete: Collection[int], + ) -> None: # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index ade088aae2ef..485ddb189373 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -130,7 +130,7 @@ def update( sequence: int, key: KT, value: Dict[DKT, DV], - fetched_keys: Optional[Set[DKT]] = None, + fetched_keys: Optional[Iterable[DKT]] = None, ) -> None: """Updates the entry in the cache @@ -155,7 +155,7 @@ def update( self._update_or_insert(key, value, fetched_keys) def _update_or_insert( - self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] + self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT] ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed From 203985540fab551c835317059553c94edc76cab9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Sep 2021 15:35:07 -0400 Subject: [PATCH 3/7] Add type hints to state bg updates. --- mypy.ini | 1 + synapse/storage/databases/state/bg_updates.py | 52 +++++++++++++------ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/mypy.ini b/mypy.ini index 60dadc478144..78600e8aebb6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -60,6 +60,7 @@ files = synapse/storage/databases/main/session.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, + synapse/storage/databases/state, synapse/storage/database.py, synapse/storage/engines, synapse/storage/keys.py, diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index c2891cb07f11..d7c7344fb612 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -13,13 +13,20 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -31,7 +38,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): updates. """ - def _count_state_group_hops_txn(self, txn, state_group): + def _count_state_group_hops_txn( + self, txn: LoggingTransaction, state_group: int + ) -> int: """Given a state group, count how many hops there are in the tree. This is used to ensure the delta chains don't get too long. @@ -56,7 +65,7 @@ def _count_state_group_hops_txn(self, txn, state_group): else: # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group + next_group: Optional[int] = state_group count = 0 while next_group: @@ -73,11 +82,14 @@ def _count_state_group_hops_txn(self, txn, state_group): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter: Optional[StateFilter] = None + self, + txn: LoggingTransaction, + groups: List[int], + state_filter: Optional[StateFilter] = None, ): state_filter = state_filter or StateFilter.all() - results = {group: {} for group in groups} + results: Dict[int, Dict[Tuple[str, str], str]] = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() @@ -117,7 +129,7 @@ def _get_state_groups_from_groups_txn( """ for group in groups: - args = [group] + args: List[Union[int, str]] = [group] args.extend(where_args) txn.execute(sql % (where_clause,), args) @@ -131,7 +143,7 @@ def _get_state_groups_from_groups_txn( # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: - next_group = group + next_group: Optional[int] = group while next_group: # We did this before by getting the list of group ids, and @@ -182,7 +194,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, @@ -198,7 +215,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): columns=["room_id"], ) - async def _background_deduplicate_state(self, progress, batch_size): + async def _background_deduplicate_state( + self, progress: dict, batch_size: int + ) -> int: """This background update will slowly deduplicate state by reencoding them as deltas. """ @@ -218,7 +237,7 @@ async def _background_deduplicate_state(self, progress, batch_size): ) max_group = rows[0][0] - def reindex_txn(txn): + def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]: new_last_state_group = last_state_group for count in range(batch_size): txn.execute( @@ -251,7 +270,8 @@ def reindex_txn(txn): " WHERE id < ? AND room_id = ?", (state_group, room_id), ) - (prev_group,) = txn.fetchone() + # There will be a result due to the coalesce. + (prev_group,) = txn.fetchone() # type: ignore new_last_state_group = state_group if prev_group: @@ -340,14 +360,14 @@ def reindex_txn(txn): return result * BATCH_SIZE_SCALE_FACTOR - async def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): + async def _background_index_state(self, progress: dict, batch_size: int) -> int: + def reindex_txn(conn: LoggingDatabaseConnection) -> None: conn.rollback() if isinstance(self.database_engine, PostgresEngine): # postgres insists on autocommit for the index conn.set_session(autocommit=True) try: - txn = conn.cursor() + txn = conn.LoggingTransaction() txn.execute( "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" @@ -356,7 +376,7 @@ def reindex_txn(conn): finally: conn.set_session(autocommit=False) else: - txn = conn.cursor() + txn = conn.LoggingTransaction() txn.execute( "CREATE INDEX state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" From e9f66f28a30453de4bacab7f9cdcf4796d37e86d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Sep 2021 15:40:04 -0400 Subject: [PATCH 4/7] Newsfragment --- changelog.d/10823.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10823.misc diff --git a/changelog.d/10823.misc b/changelog.d/10823.misc new file mode 100644 index 000000000000..0532969900f8 --- /dev/null +++ b/changelog.d/10823.misc @@ -0,0 +1 @@ +Add type hints to the state database. From 8fd777c92256df93a8806ded02a4431d8ff0f8ef Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Sep 2021 08:04:17 -0400 Subject: [PATCH 5/7] Fix calls to cursor. --- synapse/storage/databases/state/bg_updates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index d7c7344fb612..7a7799ae0c78 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -367,7 +367,7 @@ def reindex_txn(conn: LoggingDatabaseConnection) -> None: # postgres insists on autocommit for the index conn.set_session(autocommit=True) try: - txn = conn.LoggingTransaction() + txn = conn.cursor() txn.execute( "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" @@ -376,7 +376,7 @@ def reindex_txn(conn: LoggingDatabaseConnection) -> None: finally: conn.set_session(autocommit=False) else: - txn = conn.LoggingTransaction() + txn = conn.cursor() txn.execute( "CREATE INDEX state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" From d2cbda85e98335414b2f4d6b50d439c3eb6edf16 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Sep 2021 08:24:01 -0400 Subject: [PATCH 6/7] Add a missing return type. --- synapse/storage/databases/state/bg_updates.py | 18 ++++++++++-------- synapse/storage/databases/state/store.py | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 7a7799ae0c78..5bcef5b08607 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -23,6 +23,7 @@ ) from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter +from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: from synapse.server import HomeServer @@ -86,10 +87,10 @@ def _get_state_groups_from_groups_txn( txn: LoggingTransaction, groups: List[int], state_filter: Optional[StateFilter] = None, - ): + ) -> Dict[int, StateMap[str]]: state_filter = state_filter or StateFilter.all() - results: Dict[int, Dict[Tuple[str, str], str]] = {group: {} for group in groups} + results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() @@ -185,7 +186,8 @@ def _get_state_groups_from_groups_txn( allow_none=True, ) - return results + # The results shouldn't be considered mutable. + return cast(Dict[int, StateMap[str]], results) class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): @@ -281,15 +283,15 @@ def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]: # otherwise read performance degrades. continue - prev_state = self._get_state_groups_from_groups_txn( + prev_state_by_group = self._get_state_groups_from_groups_txn( txn, [prev_group] ) - prev_state = prev_state[prev_group] + prev_state = prev_state_by_group[prev_group] - curr_state = self._get_state_groups_from_groups_txn( + curr_state_by_group = self._get_state_groups_from_groups_txn( txn, [state_group] ) - curr_state = curr_state[state_group] + curr_state = curr_state_by_group[state_group] if not set(prev_state.keys()) - set(curr_state.keys()): # We can only do a delta if the current has a strict super set diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 33685c43f3db..c59e7e331e2c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -575,8 +575,8 @@ def _purge_unreferenced_state_groups( # groups to non delta versions. for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] + curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state_by_group[sg] self.db_pool.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} From ae9ef6625a1a9ee9f88188dcf5670606afbef7b9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Sep 2021 09:22:10 -0400 Subject: [PATCH 7/7] Remove cast. --- synapse/storage/databases/state/bg_updates.py | 6 +++--- synapse/storage/databases/state/store.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 5bcef5b08607..eb1118d2cb20 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -87,7 +87,7 @@ def _get_state_groups_from_groups_txn( txn: LoggingTransaction, groups: List[int], state_filter: Optional[StateFilter] = None, - ) -> Dict[int, StateMap[str]]: + ) -> Mapping[int, StateMap[str]]: state_filter = state_filter or StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} @@ -187,7 +187,7 @@ def _get_state_groups_from_groups_txn( ) # The results shouldn't be considered mutable. - return cast(Dict[int, StateMap[str]], results) + return results class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index c59e7e331e2c..f1e3a27e637e 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -169,7 +169,7 @@ async def _get_state_groups_from_groups( Returns: Dict of state group to state map. """ - results = {} + results: Dict[int, StateMap[str]] = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: