From 98125bba7a63f34bf623fdef3902f2e4ab7c1231 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Aug 2020 09:59:37 -0400 Subject: [PATCH 01/39] Allow running mypy directly. (#8175) --- changelog.d/8175.misc | 1 + mypy.ini | 49 ++++++++++++++++++++++++++++++++++++++++ tox.ini | 52 +------------------------------------------ 3 files changed, 51 insertions(+), 51 deletions(-) create mode 100644 changelog.d/8175.misc diff --git a/changelog.d/8175.misc b/changelog.d/8175.misc new file mode 100644 index 000000000000..28af294dcf6c --- /dev/null +++ b/changelog.d/8175.misc @@ -0,0 +1 @@ +Standardize the mypy configuration. diff --git a/mypy.ini b/mypy.ini index c69cb5dc4064..4213e31b0320 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,6 +6,55 @@ check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs +files = + synapse/api, + synapse/appservice, + synapse/config, + synapse/event_auth.py, + synapse/events/builder.py, + synapse/events/spamcheck.py, + synapse/federation, + synapse/handlers/auth.py, + synapse/handlers/cas_handler.py, + synapse/handlers/directory.py, + synapse/handlers/federation.py, + synapse/handlers/identity.py, + synapse/handlers/message.py, + synapse/handlers/oidc_handler.py, + synapse/handlers/presence.py, + synapse/handlers/room.py, + synapse/handlers/room_member.py, + synapse/handlers/room_member_worker.py, + synapse/handlers/saml_handler.py, + synapse/handlers/sync.py, + synapse/handlers/ui_auth, + synapse/http/server.py, + synapse/http/site.py, + synapse/logging/, + synapse/metrics, + synapse/module_api, + synapse/notifier.py, + synapse/push/pusherpool.py, + synapse/push/push_rule_evaluator.py, + synapse/replication, + synapse/rest, + synapse/server.py, + synapse/server_notices, + synapse/spam_checker_api, + synapse/state, + synapse/storage/databases/main/ui_auth.py, + synapse/storage/database.py, + synapse/storage/engines, + synapse/storage/state.py, + synapse/storage/util, + synapse/streams, + synapse/types.py, + synapse/util/caches/stream_change_cache.py, + synapse/util/metrics.py, + tests/replication, + tests/test_utils, + tests/rest/client/v2_alpha/test_auth.py, + tests/util/test_stream_change_cache.py [mypy-pymacaroons.*] ignore_missing_imports = True diff --git a/tox.ini b/tox.ini index edeb757f7b60..df473bd234a6 100644 --- a/tox.ini +++ b/tox.ini @@ -171,58 +171,8 @@ deps = {[base]deps} mypy==0.782 mypy-zope -env = - MYPYPATH = stubs/ extras = all -commands = mypy \ - synapse/api \ - synapse/appservice \ - synapse/config \ - synapse/event_auth.py \ - synapse/events/builder.py \ - synapse/events/spamcheck.py \ - synapse/federation \ - synapse/handlers/auth.py \ - synapse/handlers/cas_handler.py \ - synapse/handlers/directory.py \ - synapse/handlers/federation.py \ - synapse/handlers/identity.py \ - synapse/handlers/message.py \ - synapse/handlers/oidc_handler.py \ - synapse/handlers/presence.py \ - synapse/handlers/room.py \ - synapse/handlers/room_member.py \ - synapse/handlers/room_member_worker.py \ - synapse/handlers/saml_handler.py \ - synapse/handlers/sync.py \ - synapse/handlers/ui_auth \ - synapse/http/server.py \ - synapse/http/site.py \ - synapse/logging/ \ - synapse/metrics \ - synapse/module_api \ - synapse/notifier.py \ - synapse/push/pusherpool.py \ - synapse/push/push_rule_evaluator.py \ - synapse/replication \ - synapse/rest \ - synapse/server.py \ - synapse/server_notices \ - synapse/spam_checker_api \ - synapse/state \ - synapse/storage/databases/main/ui_auth.py \ - synapse/storage/database.py \ - synapse/storage/engines \ - synapse/storage/state.py \ - synapse/storage/util \ - synapse/streams \ - synapse/types.py \ - synapse/util/caches/stream_change_cache.py \ - synapse/util/metrics.py \ - tests/replication \ - tests/test_utils \ - tests/rest/client/v2_alpha/test_auth.py \ - tests/util/test_stream_change_cache.py +commands = mypy # To find all folders that pass mypy you run: # From ed18f32e1b7bf734303e040400a2da2e27501154 Mon Sep 17 00:00:00 2001 From: Christopher May-Townsend Date: Wed, 26 Aug 2020 15:03:20 +0100 Subject: [PATCH 02/39] Add required Debian dependencies to allow docker builds on the arm platform (#8144) Signed-off-by: Christopher May-Townsend --- changelog.d/8144.docker | 1 + docker/Dockerfile | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8144.docker diff --git a/changelog.d/8144.docker b/changelog.d/8144.docker new file mode 100644 index 000000000000..9bb5881fa8f3 --- /dev/null +++ b/changelog.d/8144.docker @@ -0,0 +1 @@ +Fix builds of the Docker image on non-x86 platforms. diff --git a/docker/Dockerfile b/docker/Dockerfile index 432d56a8ee11..27512f860092 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -19,11 +19,16 @@ ARG PYTHON_VERSION=3.7 FROM docker.io/python:${PYTHON_VERSION}-slim as builder # install the OS build deps - - RUN apt-get update && apt-get install -y \ build-essential \ + libffi-dev \ + libjpeg-dev \ libpq-dev \ + libssl-dev \ + libwebp-dev \ + libxml++2.6-dev \ + libxslt1-dev \ + zlib1g-dev \ && rm -rf /var/lib/apt/lists/* # Build dependencies that are not available as wheels, to speed up rebuilds @@ -56,9 +61,11 @@ FROM docker.io/python:${PYTHON_VERSION}-slim RUN apt-get update && apt-get install -y \ curl \ + gosu \ + libjpeg62-turbo \ libpq5 \ + libwebp6 \ xmlsec1 \ - gosu \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local From e0d6244beb0165417a2817f8b36c828ad22f8dbd Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 26 Aug 2020 15:07:35 +0100 Subject: [PATCH 03/39] Remove unused parameter from, and add safeguard in, get_room_data (#8174) Small cleanup PR. * Removed the unused `is_guest` argument * Added a safeguard to a (currently) impossible code path, fixing static checking at the same time. --- changelog.d/8174.misc | 1 + synapse/handlers/message.py | 20 ++++++++++++-------- synapse/rest/client/v1/room.py | 1 - tests/rest/client/test_retention.py | 2 +- 4 files changed, 14 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8174.misc diff --git a/changelog.d/8174.misc b/changelog.d/8174.misc new file mode 100644 index 000000000000..a39e9eab46cd --- /dev/null +++ b/changelog.d/8174.misc @@ -0,0 +1 @@ +Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 02d624268bee..9d0c38f4df02 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -95,12 +95,7 @@ def __init__(self, hs): ) async def get_room_data( - self, - user_id: str, - room_id: str, - event_type: str, - state_key: str, - is_guest: bool, + self, user_id: str, room_id: str, event_type: str, state_key: str, ) -> dict: """ Get data from a room. @@ -109,11 +104,10 @@ async def get_room_data( room_id event_type state_key - is_guest Returns: The path data content. Raises: - SynapseError if something went wrong. + SynapseError or AuthError if the user is not in the room """ ( membership, @@ -130,6 +124,16 @@ async def get_room_data( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) + else: + # check_user_in_room_or_world_readable, if it doesn't raise an AuthError, should + # only ever return a Membership.JOIN/LEAVE object + # + # Safeguard in case it returned something else + logger.error( + "Attempted to retrieve data from a room for a user that has never been in it. " + "This should not have happened." + ) + raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN) return data diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 11da8bc0371f..88245fc177dd 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -171,7 +171,6 @@ async def on_GET(self, request, room_id, event_type, state_key): room_id=room_id, event_type=event_type, state_key=state_key, - is_guest=requester.is_guest, ) if not data: diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index d4e7fa129334..7d3773ff7835 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -178,7 +178,7 @@ def _test_retention_event_purged(self, room_id: str, increment: float): message_handler = self.hs.get_message_handler() create_event = self.get_success( message_handler.get_room_data( - self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False + self.user_id, room_id, EventTypes.Create, state_key="" ) ) From 6fe12c9512aa480a17d477c46f748a1b63beb539 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Aug 2020 12:05:36 -0400 Subject: [PATCH 04/39] Do not propagate typing notifications from shadow-banned users. (#8176) --- changelog.d/8176.feature | 1 + synapse/handlers/typing.py | 21 ++++++-- synapse/rest/client/v1/room.py | 26 +++++----- tests/handlers/test_typing.py | 26 +++++++--- .../test_federation_sender_shard.py | 4 +- tests/rest/client/test_shadow_banned.py | 48 +++++++++++++++++++ 6 files changed, 102 insertions(+), 24 deletions(-) create mode 100644 changelog.d/8176.feature diff --git a/changelog.d/8176.feature b/changelog.d/8176.feature new file mode 100644 index 000000000000..813e6d0903d9 --- /dev/null +++ b/changelog.d/8176.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a86ac0150e05..1d828bd7be16 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -14,10 +14,11 @@ # limitations under the License. import logging +import random from collections import namedtuple from typing import TYPE_CHECKING, List, Set, Tuple -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream from synapse.types import UserID, get_domain_from_id @@ -227,9 +228,9 @@ def _handle_timeout_for_member(self, now: int, member: RoomMember): self._stopped_typing(member) return - async def started_typing(self, target_user, auth_user, room_id, timeout): + async def started_typing(self, target_user, requester, room_id, timeout): target_user_id = target_user.to_string() - auth_user_id = auth_user.to_string() + auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") @@ -237,6 +238,11 @@ async def started_typing(self, target_user, auth_user, room_id, timeout): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -256,9 +262,9 @@ async def started_typing(self, target_user, auth_user, room_id, timeout): self._push_update(member=member, typing=True) - async def stopped_typing(self, target_user, auth_user, room_id): + async def stopped_typing(self, target_user, requester, room_id): target_user_id = target_user.to_string() - auth_user_id = auth_user.to_string() + auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") @@ -266,6 +272,11 @@ async def stopped_typing(self, target_user, auth_user, room_id): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 88245fc177dd..84baf3d59bca 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -868,17 +868,21 @@ async def on_PUT(self, request, room_id, user_id): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) - if content["typing"]: - await self.typing_handler.started_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, - timeout=timeout, - ) - else: - await self.typing_handler.stopped_typing( - target_user=target_user, auth_user=requester.user, room_id=room_id - ) + try: + if content["typing"]: + await self.typing_handler.started_typing( + target_user=target_user, + requester=requester, + room_id=room_id, + timeout=timeout, + ) + else: + await self.typing_handler.stopped_typing( + target_user=target_user, requester=requester, room_id=room_id + ) + except ShadowBanError: + # Pretend this worked without error. + pass return 200, {} diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 834b4a0af62b..81c1839637eb 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import make_awaitable @@ -167,7 +167,10 @@ def test_started_typing_local(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -194,7 +197,10 @@ def test_started_typing_remote_send(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -269,7 +275,9 @@ def test_stopped_typing(self): self.get_success( self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, ) ) @@ -309,7 +317,10 @@ def test_typing_timeout(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) @@ -348,7 +359,10 @@ def test_typing_timeout(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 83f9aa291c67..8b4982ecb160 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -20,7 +20,7 @@ from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.test_utils import make_awaitable @@ -175,7 +175,7 @@ def test_send_typing_sharded(self): self.get_success( typing_handler.started_typing( target_user=UserID.from_string(user), - auth_user=UserID.from_string(user), + requester=create_requester(user), room_id=room, timeout=20000, ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 3eb9aeaa9eab..0c48a9fd5e66 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -179,6 +179,54 @@ def test_upgrade(self): # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) + def test_typing(self): + """Typing notifications should not be propagated into the room.""" + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.banned_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # There should be no typing events. + event_source = self.hs.get_event_sources().sources["typing"] + self.assertEquals(event_source.get_current_key(), 0) + + # The other user can join and send typing events. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.other_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.other_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # These appear in the room. + self.assertEquals(event_source.get_current_key(), 1) + events = self.get_success( + event_source.get_new_events(from_key=0, room_ids=[room_id]) + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": room_id, + "content": {"user_ids": [self.other_user_id]}, + } + ], + ) + # To avoid the tests timing out don't add a delay to "annoy the requester". @patch("random.randint", new=lambda a, b: 0) From b8f20e4276ea23dabcc3882dcee5773f856c39d0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 26 Aug 2020 17:26:56 +0100 Subject: [PATCH 05/39] Remove remaining is_guest argument uses from get_room_data calls (#8181) #8174 removed the `is_guest` parameter from `get_room_data`, at the same time that #8157 was merged using it, colliding together to break unit tests on develop. This PR removes the `is_guest` parameter from the call in the broken test. Uses the same changelog as #8174. --- changelog.d/8181.misc | 1 + tests/rest/client/test_shadow_banned.py | 12 ++---------- 2 files changed, 3 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8181.misc diff --git a/changelog.d/8181.misc b/changelog.d/8181.misc new file mode 100644 index 000000000000..a39e9eab46cd --- /dev/null +++ b/changelog.d/8181.misc @@ -0,0 +1 @@ +Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`. \ No newline at end of file diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 0c48a9fd5e66..dfe4bf7762e3 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -271,11 +271,7 @@ def test_displayname(self): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, - room_id, - "m.room.member", - self.banned_user_id, - False, + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, ) ) self.assertEqual( @@ -308,11 +304,7 @@ def test_room_displayname(self): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, - room_id, - "m.room.member", - self.banned_user_id, - False, + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, ) ) self.assertEqual( From cf2f6c3d22811314ecc9082574e6b6bef49ab696 Mon Sep 17 00:00:00 2001 From: Dexter Chua Date: Thu, 27 Aug 2020 17:39:13 +0800 Subject: [PATCH 06/39] Update debian systemd service to use Type=notify (#8169) This ensures systemctl start matrix-synapse returns only after synapse is actually started, which is very useful for automated deployments. Fixes #5761 Signed-off-by: Dexter Chua --- debian/changelog | 6 ++++++ debian/matrix-synapse.service | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/debian/changelog b/debian/changelog index bdaf59e9b948..d20ca06b5bc5 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.19.0ubuntu1) UNRELEASED; urgency=medium + + * Use Type=notify in systemd service + + -- Dexter Chua Wed, 26 Aug 2020 12:41:36 +0000 + matrix-synapse-py3 (1.19.0) stable; urgency=medium [ Synapse Packaging team ] diff --git a/debian/matrix-synapse.service b/debian/matrix-synapse.service index b0a8d72e6d25..553babf5492d 100644 --- a/debian/matrix-synapse.service +++ b/debian/matrix-synapse.service @@ -2,7 +2,7 @@ Description=Synapse Matrix homeserver [Service] -Type=simple +Type=notify User=matrix-synapse WorkingDirectory=/var/lib/matrix-synapse EnvironmentFile=/etc/default/matrix-synapse From eadfda3ebc165317c634948826d50794b11dad46 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 27 Aug 2020 10:50:39 +0100 Subject: [PATCH 07/39] 1.19.1 --- CHANGES.md | 6 ++++++ debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 688f824da1b4..d859baa9ff56 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse 1.19.1 (2020-08-27) +=========================== + +No significant changes. + + Synapse 1.19.1rc1 (2020-08-25) ============================== diff --git a/debian/changelog b/debian/changelog index bdaf59e9b948..6676706dea12 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.19.1) stable; urgency=medium + + * New synapse release 1.19.1. + + -- Synapse Packaging team Thu, 27 Aug 2020 10:50:19 +0100 + matrix-synapse-py3 (1.19.0) stable; urgency=medium [ Synapse Packaging team ] diff --git a/synapse/__init__.py b/synapse/__init__.py index 2195723613c4..1282d19b3c74 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ except ImportError: pass -__version__ = "1.19.1rc1" +__version__ = "1.19.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when From a466b6797219203cbf539c5c0bbdfe8ab5a41fa4 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 27 Aug 2020 11:39:53 +0100 Subject: [PATCH 08/39] Reduce run-times of tests by advancing the reactor less (#7757) --- changelog.d/7757.misc | 1 + tests/handlers/test_stats.py | 2 +- tests/http/test_fedclient.py | 2 +- tests/push/test_email.py | 2 +- tests/storage/test_cleanup_extrems.py | 2 +- tests/storage/test_roommember.py | 6 +++--- tests/util/test_retryutils.py | 2 +- 7 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7757.misc diff --git a/changelog.d/7757.misc b/changelog.d/7757.misc new file mode 100644 index 000000000000..091f40382e68 --- /dev/null +++ b/changelog.d/7757.misc @@ -0,0 +1 @@ +Reduce run times of some unit tests by advancing the reactor a fewer number of times. \ No newline at end of file diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 0e666492f629..88b05c23a051 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -256,7 +256,7 @@ def test_initial_earliest_token(self): # self.handler.notify_new_event() # We need to let the delta processor advanceā€¦ - self.pump(10 * 60) + self.reactor.advance(10 * 60) # Get the slices! There should be two -- day 1, and day 2. r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ac598249e405..7f1dfb2128d7 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -508,6 +508,6 @@ def test_closes_connection(self): self.assertFalse(conn.disconnecting) # wait for a while - self.pump(120) + self.reactor.advance(120) self.assertTrue(conn.disconnecting) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 83032cc9eab6..227b0d32d047 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -170,7 +170,7 @@ def _check_for_mail(self): last_stream_ordering = pushers[0]["last_stream_ordering"] # Advance time a bit, so the pusher will register something has happened - self.pump(100) + self.pump(10) # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 43639ca28615..080761d1d2dc 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -271,7 +271,7 @@ def test_send_dummy_event(self): # Pump the reactor repeatedly so that the background updates have a # chance to run. - self.pump(10 * 60) + self.pump(20) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index d98fe8754dab..12ccc1f53e99 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -87,7 +87,7 @@ def test_count_known_servers_stat_counter_disabled(self): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump() self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) @@ -101,7 +101,7 @@ def test_count_known_servers_stat_counter_enabled(self): # Initialises to 1 -- itself self.assertEqual(self.store._known_servers_count, 1) - self.pump(20) + self.pump() # No rooms have been joined, so technically the SQL returns 0, but it # will still say it knows about itself. @@ -111,7 +111,7 @@ def test_count_known_servers_stat_counter_enabled(self): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump(1) # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index bc42ffce880c..5f46ed0cefd9 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -91,7 +91,7 @@ def test_limiter(self): # # one more go, with success # - self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) + self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) From 4a739c73b404284253a548f60197e70c6c385645 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 07:08:38 -0400 Subject: [PATCH 09/39] Convert simple_update* and simple_select* to async (#8173) --- changelog.d/8173.misc | 1 + synapse/handlers/room.py | 6 +- synapse/storage/database.py | 29 +++++----- synapse/storage/databases/main/__init__.py | 8 +-- synapse/storage/databases/main/directory.py | 6 +- .../storage/databases/main/e2e_room_keys.py | 26 +++++---- .../databases/main/event_federation.py | 4 +- .../storage/databases/main/group_server.py | 55 +++++++++++-------- .../databases/main/media_repository.py | 16 +++--- synapse/storage/databases/main/profile.py | 18 ++++-- synapse/storage/databases/main/receipts.py | 8 ++- .../storage/databases/main/registration.py | 22 ++++---- synapse/storage/databases/main/room.py | 4 +- .../storage/databases/main/user_directory.py | 4 +- tests/handlers/test_stats.py | 4 +- tests/storage/test_base.py | 26 +++++---- tests/storage/test_directory.py | 6 +- tests/storage/test_main.py | 4 +- tests/test_federation.py | 50 +++++++---------- 19 files changed, 164 insertions(+), 133 deletions(-) create mode 100644 changelog.d/8173.misc diff --git a/changelog.d/8173.misc b/changelog.d/8173.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8173.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e4788ef86b96..236a37f777c2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,7 +51,7 @@ create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer, maybe_awaitable +from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -1329,9 +1329,7 @@ async def shutdown_room( ratelimit=False, ) - aliases_for_room = await maybe_awaitable( - self.store.get_aliases_for_room(room_id) - ) + aliases_for_room = await self.store.get_aliases_for_room(room_id) await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 181c3ec24994..38010af60041 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1132,13 +1132,13 @@ def simple_select_onecol_txn( return [r[0] for r in txn] - def simple_select_onecol( + async def simple_select_onecol( self, table: str, keyvalues: Optional[Dict[str, Any]], retcol: str, desc: str = "simple_select_onecol", - ) -> defer.Deferred: + ) -> List[Any]: """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -1148,19 +1148,19 @@ def simple_select_onecol( retcol: column whos value we wish to retrieve. Returns: - Deferred: Results in a list + Results in a list """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def simple_select_list( + async def simple_select_list( self, table: str, keyvalues: Optional[Dict[str, Any]], retcols: Iterable[str], desc: str = "simple_select_list", - ) -> defer.Deferred: + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1170,10 +1170,11 @@ def simple_select_list( column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_list_txn, table, keyvalues, retcols ) @@ -1299,14 +1300,14 @@ def simple_select_many_txn( txn.execute(sql, values) return cls.cursor_to_dict(txn) - def simple_update( + async def simple_update( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str, - ) -> defer.Deferred: - return self.runInteraction( + ) -> int: + return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @@ -1332,13 +1333,13 @@ def simple_update_txn( return txn.rowcount - def simple_update_one( + async def simple_update_one( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str = "simple_update_one", - ) -> defer.Deferred: + ) -> None: """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1347,7 +1348,7 @@ def simple_update_one( keyvalues: dict of column names and values to select the row with updatevalues: dict giving column names and values to update """ - return self.runInteraction( + await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0934ae276c2e..8b9b6eb47241 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,6 +18,7 @@ import calendar import logging import time +from typing import Any, Dict, List from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -476,14 +477,13 @@ def _generate_user_daily_visits(txn): "generate_user_daily_visits", _generate_user_daily_visits ) - def get_users(self): + async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. - Args: Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries representing users. """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 301d5d845ac8..405b5eafa579 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import namedtuple -from typing import Iterable, Optional +from typing import Iterable, List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -68,8 +68,8 @@ async def get_room_alias_creator(self, room_alias: str) -> str: ) @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_aliases_for_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 46c3e33cc667..82f9d870fd06 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -368,18 +370,22 @@ def _create_e2e_room_keys_version_txn(txn): ) @trace - def update_e2e_room_keys_version( - self, user_id, version, info=None, version_etag=None - ): + async def update_e2e_room_keys_version( + self, + user_id: str, + version: str, + info: Optional[dict] = None, + version_etag: Optional[int] = None, + ) -> None: """Update a given backup version Args: - user_id(str): the user whose backup version we're updating - version(str): the version ID of the backup version we're updating - info (dict): the new backup version info to store. If None, then - the backup version info is not updated - version_etag (Optional[int]): etag of the keys in the backup. If - None, then the etag is not updated + user_id: the user whose backup version we're updating + version: the version ID of the backup version we're updating + info: the new backup version info to store. If None, then the backup + version info is not updated. + version_etag: etag of the keys in the backup. If None, then the etag + is not updated. """ updatevalues = {} @@ -389,7 +395,7 @@ def update_e2e_room_keys_version( updatevalues["etag"] = version_etag if updatevalues: - return self.db_pool.simple_update( + await self.db_pool.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index e6a97b018cfe..6e5761c7b75a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -368,8 +368,8 @@ def _get_rooms_with_many_extremities_txn(txn): ) @cached(max_entries=5000, iterable=True) - def get_latest_event_ids_in_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index c39864f59f8d..e3ead71853ea 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -44,24 +44,26 @@ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: desc="get_group", ) - def get_users_in_group(self, group_id, include_private=False): + async def get_users_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Any]]: # TODO: Pagination keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), desc="get_users_in_group", ) - def get_invited_users_in_group(self, group_id): + async def get_invited_users_in_group(self, group_id: str) -> List[str]: # TODO: Pagination - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -265,15 +267,14 @@ async def get_group_role(self, group_id, role_id): return role - def get_local_groups_for_room(self, room_id): + async def get_local_groups_for_room(self, room_id: str) -> List[str]: """Get all of the local group that contain a given room Args: - room_id (str): The ID of a room + room_id: The ID of a room Returns: - Deferred[list[str]]: A twisted.Deferred containing a list of group ids - containing this room + A list of group ids containing this room """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -422,10 +423,10 @@ def _get_users_membership_in_group_txn(txn): "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) - def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: """Get all groups a user is publicising """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -466,8 +467,8 @@ async def get_remote_attestation(self, group_id, user_id): return None - def get_joined_groups(self, user_id): - return self.db_pool.simple_select_onecol( + async def get_joined_groups(self, user_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -585,14 +586,14 @@ def _get_all_groups_changes_txn(txn): class GroupServerStore(GroupServerWorkerStore): - def set_group_join_policy(self, group_id, join_policy): + async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: """Set the join policy of a group. join_policy can be one of: * "invite" * "open" """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -1050,8 +1051,10 @@ def add_room_to_group(self, group_id, room_id, is_public): desc="add_room_to_group", ) - def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.db_pool.simple_update( + async def update_room_in_group_visibility( + self, group_id: str, room_id: str, is_public: bool + ) -> int: + return await self.db_pool.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -1076,10 +1079,12 @@ def _remove_room_from_group_txn(txn): "remove_room_from_group", _remove_room_from_group_txn ) - def update_group_publicity(self, group_id, user_id, publicise): + async def update_group_publicity( + self, group_id: str, user_id: str, publicise: bool + ) -> None: """Update whether the user is publicising their membership of the group """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -1218,20 +1223,24 @@ async def update_group_profile(self, group_id, profile): desc="update_group_profile", ) - def update_attestation_renewal(self, group_id, user_id, attestation): + async def update_attestation_renewal( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that we have renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, desc="update_attestation_renewal", ) - def update_remote_attestion(self, group_id, user_id, attestation): + async def update_remote_attestion( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that a remote has renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4ae255ebd8f5..fc223f5a2ab3 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -84,9 +84,9 @@ def store_local_media( desc="store_local_media", ) - def mark_local_media_as_safe(self, media_id: str): + async def mark_local_media_as_safe(self, media_id: str) -> None: """Mark a local media as safe from quarantining.""" - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_media_repository", keyvalues={"media_id": media_id}, updatevalues={"safe_from_quarantine": True}, @@ -158,8 +158,8 @@ def store_url_cache( desc="store_url_cache", ) - def get_local_media_thumbnails(self, media_id): - return self.db_pool.simple_select_list( + async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -271,8 +271,10 @@ def update_cache_txn(txn): "update_cached_last_access_time", update_cache_txn ) - def get_remote_media_thumbnails(self, origin, media_id): - return self.db_pool.simple_select_list( + async def get_remote_media_thumbnails( + self, origin: str, media_id: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8233c4848ae..858fd92420b5 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -71,16 +71,20 @@ def create_profile(self, user_localpart): table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self.db_pool.simple_update_one( + async def set_profile_displayname( + self, user_localpart: str, new_displayname: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, desc="set_profile_displayname", ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db_pool.simple_update_one( + async def set_profile_avatar_url( + self, user_localpart: str, new_avatar_url: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -106,8 +110,10 @@ def add_remote_profile_cache(self, user_id, displayname, avatar_url): desc="add_remote_profile_cache", ) - def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db_pool.simple_update( + async def update_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> int: + return await self.db_pool.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index cea5ac9a6862..436f22ad2dbd 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -16,7 +16,7 @@ import abc import logging -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer @@ -62,8 +62,10 @@ async def get_users_with_read_receipts_in_room(self, room_id): return {r["user_id"] for r in receipts} @cached(num_args=2) - def get_receipts_for_room(self, room_id, receipt_type): - return self.db_pool.simple_select_list( + async def get_receipts_for_room( + self, room_id: str, receipt_type: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eced53d470ce..48bda66f3ec3 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -578,20 +578,20 @@ def add_user_bound_threepid(self, user_id, medium, address, id_server): desc="add_user_bound_threepid", ) - def user_get_bound_threepids(self, user_id): + async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. Args: - user_id (str): The ID of the user to retrieve threepids for + user_id: The ID of the user to retrieve threepids for Returns: - Deferred[list[dict]]: List of dictionaries containing the following: + List of dictionaries containing the following keys: medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -623,19 +623,21 @@ def remove_user_bound_threepid(self, user_id, medium, address, id_server): desc="remove_user_bound_threepid", ) - def get_id_servers_user_bound(self, user_id, medium, address): + async def get_id_servers_user_bound( + self, user_id: str, medium: str, address: str + ) -> List[str]: """Get the list of identity servers that the server proxied bind requests to for given user and threepid Args: - user_id (str) - medium (str) - address (str) + user_id: The user to query for identity servers. + medium: The medium to query for identity servers. + address: The address to query for identity servers. Returns: - Deferred[list[str]]: Resolves to a list of identity servers + A list of identity servers """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 97ecdb16e4ec..66d713541389 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -125,8 +125,8 @@ def get_room_with_stats_txn(txn, room_id): "get_room_with_stats", get_room_with_stats_txn, room_id ) - def get_public_room_ids(self): - return self.db_pool.simple_select_onecol( + async def get_public_room_ids(self) -> List[str]: + return await self.db_pool.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 20cbcd851c04..a9f2e9361482 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -537,8 +537,8 @@ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: desc="get_user_in_directory", ) - def update_user_directory_stream_pos(self, stream_id): - return self.db_pool.simple_update_one( + async def update_user_directory_stream_pos(self, stream_id: str) -> None: + await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 88b05c23a051..a609f148c080 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -81,8 +81,8 @@ def _add_background_updates(self): ) ) - def get_all_room_state(self): - return self.store.db_pool.simple_select_list( + async def get_all_room_state(self): + return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index bf22540d9905..64abe8cc49d1 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -148,8 +148,10 @@ def test_select_list(self): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db_pool.simple_select_list( - table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_list( + table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ) ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) @@ -161,10 +163,12 @@ def test_select_list(self): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"}, + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"}, + ) ) self.mock_txn.execute.assert_called_with( @@ -176,10 +180,12 @@ def test_update_one_1col(self): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index daac947cb2b8..da93ca398039 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -42,7 +42,11 @@ def test_room_to_alias(self): self.assertEquals( ["#my-room:test"], - (yield self.store.get_aliases_for_room(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_aliases_for_room(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index fbf8af940acc..954338a59276 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -36,7 +36,9 @@ def setUp(self): def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.user.localpart, self.displayname) + ) users, total = yield self.store.get_users_paginate( 0, 10, name="bc", guests=False diff --git a/tests/test_federation.py b/tests/test_federation.py index 4a4548433f97..27a7fc9ed71a 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -15,8 +15,9 @@ from mock import Mock -from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed +from twisted.internet.defer import succeed +from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID @@ -44,22 +45,17 @@ def setUp(self): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room_deferred = ensureDeferred( + self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - ) - self.reactor.advance(0.1) - self.room_id = self.successResultOf(room_deferred)[0]["room_id"] + )[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + most_recent = self.get_success( + self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -89,19 +85,18 @@ def setUp(self): ) # Send the join, it should return None (which is not an error) - d = ensureDeferred( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) + self.assertEqual( + self.get_success( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + ), + None, ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( - self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - )[0], + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", ) @@ -119,8 +114,8 @@ async def post_json(destination, path, data, headers=None, timeout=0): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) + most_recent = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) )[0] # Now lie about an event @@ -140,17 +135,14 @@ async def post_json(destination, path, data, headers=None, timeout=0): ) with LoggingContext(request="lying_event"): - d = ensureDeferred( + failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True - ) + ), + FederationError, ) - # Step the reactor, so the database fetches come back - self.reactor.advance(1) - # on_receive_pdu should throw an error - failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ( @@ -160,8 +152,8 @@ async def post_json(destination, path, data, headers=None, timeout=0): ) # Make sure the invalid event isn't there - extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + self.assertEqual(extrem[0], "$join:test.serv") def test_retry_device_list_resync(self): """Tests that device lists are marked as stale if they couldn't be synced, and From 30426c7063a7e5567ac21cd10267651ef1935360 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 07:41:01 -0400 Subject: [PATCH 10/39] Convert additional database methods to async (select list, search, insert_many, delete_*) (#8168) --- changelog.d/8168.misc | 1 + synapse/storage/background_updates.py | 9 +- synapse/storage/database.py | 99 +++++++------------ synapse/storage/databases/main/__init__.py | 12 +-- .../storage/databases/main/end_to_end_keys.py | 15 ++- .../databases/main/media_repository.py | 4 +- tests/storage/test_background_update.py | 5 +- tests/storage/test_base.py | 6 +- 8 files changed, 67 insertions(+), 84 deletions(-) create mode 100644 changelog.d/8168.misc diff --git a/changelog.d/8168.misc b/changelog.d/8168.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8168.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 56818f4df883..0db900fa0e0b 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -414,13 +414,14 @@ async def updater(progress, batch_size): self.register_background_update_handler(update_name, updater) - def _end_background_update(self, update_name): + async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. Args: - update_name(str): The name of the completed task to remove + update_name:: The name of the completed task to remove + Returns: - A deferred that completes once the task is removed. + None, completes once the task is removed. """ if update_name != self._current_background_update: raise Exception( @@ -428,7 +429,7 @@ def _end_background_update(self, update_name): % update_name ) self._current_background_update = None - return self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 38010af60041..2f6f49a4bf84 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -605,7 +605,13 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: results = [dict(zip(col_headers, row)) for row in cursor] return results - def execute(self, desc: str, decoder: Callable, query: str, *args: Any): + async def execute( + self, + desc: str, + decoder: Optional[Callable[[Cursor], R]], + query: str, + *args: Any + ) -> R: """Runs a single query for a result set. Args: @@ -614,7 +620,7 @@ def execute(self, desc: str, decoder: Callable, query: str, *args: Any): query - The query string to execute *args - Query args. Returns: - Deferred which results to the result of decoder(results) + The result of decoder(results) """ def interaction(txn): @@ -624,7 +630,7 @@ def interaction(txn): else: return txn.fetchall() - return self.runInteraction(desc, interaction) + return await self.runInteraction(desc, interaction) # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. @@ -673,15 +679,30 @@ def simple_insert_txn( txn.execute(sql, vals) - def simple_insert_many( + async def simple_insert_many( self, table: str, values: List[Dict[str, Any]], desc: str - ) -> defer.Deferred: - return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + ) -> None: + """Executes an INSERT query on the named table. + + Args: + table: string giving the table name + values: dict of new column names and values for them + desc: string giving a description of the transaction + """ + await self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod def simple_insert_many_txn( txn: LoggingTransaction, table: str, values: List[Dict[str, Any]] ) -> None: + """Executes an INSERT query on the named table. + + Args: + txn: The transaction to use. + table: string giving the table name + values: dict of new column names and values for them + desc: string giving a description of the transaction + """ if not values: return @@ -1397,9 +1418,9 @@ def simple_select_one_txn( return dict(zip(retcols, row)) - def simple_delete_one( + async def simple_delete_one( self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" - ) -> defer.Deferred: + ) -> None: """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1407,7 +1428,7 @@ def simple_delete_one( table: string giving the table name keyvalues: dict of column names and values to select the row with """ - return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod def simple_delete_one_txn( @@ -1446,15 +1467,15 @@ def simple_delete_txn( txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def simple_delete_many( + async def simple_delete_many( self, table: str, column: str, iterable: Iterable[Any], keyvalues: Dict[str, Any], desc: str, - ) -> defer.Deferred: - return self.runInteraction( + ) -> int: + return await self.runInteraction( desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @@ -1537,52 +1558,6 @@ def get_cache_dict( return cache, min_val - def simple_select_list_paginate( - self, - table: str, - orderby: str, - start: int, - limit: int, - retcols: Iterable[str], - filters: Optional[Dict[str, Any]] = None, - keyvalues: Optional[Dict[str, Any]] = None, - order_direction: str = "ASC", - desc: str = "simple_select_list_paginate", - ) -> defer.Deferred: - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table: the table name - orderby: Column to order the results by. - start: Index to begin the query at. - limit: Number of results to return. - retcols: the names of the columns to return - filters: - column names and values to filter the rows with, or None to not - apply a WHERE ? LIKE ? clause. - keyvalues: - column names and values to select the rows with, or None to not - apply a WHERE clause. - order_direction: Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self.simple_select_list_paginate_txn, - table, - orderby, - start, - limit, - retcols, - filters=filters, - keyvalues=keyvalues, - order_direction=order_direction, - ) - @classmethod def simple_select_list_paginate_txn( cls, @@ -1647,14 +1622,14 @@ def simple_select_list_paginate_txn( return cls.cursor_to_dict(txn) - def simple_search_list( + async def simple_search_list( self, table: str, term: Optional[str], col: str, retcols: Iterable[str], desc="simple_search_list", - ): + ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1665,10 +1640,10 @@ def simple_search_list( retcols: the names of the columns to return Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None + A list of dictionaries or None. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_search_list_txn, table, term, col, retcols ) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 8b9b6eb47241..70cf15dd7f46 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ import calendar import logging import time -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -559,17 +559,17 @@ def get_users_paginate_txn(txn): "get_users_paginate_txn", get_users_paginate_txn ) - def search_users(self, term): + async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: """Function to search users list for one or more users with the matched term. Args: - term (str): search term - col (str): column to query term should be matched to + term: search term + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries or None. """ - return self.db_pool.simple_search_list( + return await self.db_pool.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 385868bdab3f..af0b85e2c92c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -27,6 +27,9 @@ from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.handlers.e2e_keys import SignatureListItem + class EndToEndKeyWorkerStore(SQLBaseStore): @trace @@ -730,14 +733,16 @@ async def set_e2e_cross_signing_key(self, user_id, key_type, key): stream_id, ) - def store_e2e_cross_signing_signatures(self, user_id, signatures): + async def store_e2e_cross_signing_signatures( + self, user_id: str, signatures: "Iterable[SignatureListItem]" + ) -> None: """Stores cross-signing signatures. Args: - user_id (str): the user who made the signatures - signatures (iterable[SignatureListItem]): signatures to add + user_id: the user who made the signatures + signatures: signatures to add """ - return self.db_pool.simple_insert_many( + await self.db_pool.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index fc223f5a2ab3..8361dd63d964 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -314,14 +314,14 @@ def store_remote_media_thumbnail( desc="store_remote_media_thumbnail", ) - def get_remote_media_before(self, before_ts): + async def get_remote_media_before(self, before_ts): sql = ( "SELECT media_origin, media_id, filesystem_id" " FROM remote_media_cache" " WHERE last_access_ts < ?" ) - return self.db_pool.execute( + return await self.db_pool.execute( "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 2efbc97c2e62..1a1c59256cfc 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -67,13 +67,12 @@ def update(progress, count): # second step: complete the update # we should now get run with a much bigger number of items to update - @defer.inlineCallbacks - def update(progress, count): + async def update(progress, count): self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, target_background_update_duration_ms / duration_ms, places=0, ) - yield self.updates._end_background_update("test_update") + await self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 64abe8cc49d1..40ba652248ce 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -197,8 +197,10 @@ def test_update_one_4cols(self): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_delete_one( - table="tablename", keyvalues={"keycol": "Go away"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_one( + table="tablename", keyvalues={"keycol": "Go away"} + ) ) self.mock_txn.execute.assert_called_with( From 5649b7f3d05f8e550fba8349433eefd09aacce27 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 27 Aug 2020 13:20:34 +0100 Subject: [PATCH 11/39] Fix missing _add_persisted_position (#8179) This was forgotten in #8164. --- changelog.d/8179.misc | 1 + synapse/storage/util/id_generators.py | 2 ++ tests/storage/test_id_generators.py | 52 +++++++++++++++++++++++++-- 3 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8179.misc diff --git a/changelog.d/8179.misc b/changelog.d/8179.misc new file mode 100644 index 000000000000..55bc079cdba8 --- /dev/null +++ b/changelog.d/8179.misc @@ -0,0 +1 @@ +Add functions to `MultiWriterIdGen` used by events stream. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5b0784777392..b27a4843d010 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -343,6 +343,8 @@ def _mark_id_as_finished(self, next_id: int): curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, next_id) + self._add_persisted_position(next_id) + def get_current_token(self) -> int: """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 9b9a183e7f2b..14ce21c7869d 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -58,6 +58,10 @@ def _create(conn): return self.get_success(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + def _insert(txn): for _ in range(number): txn.execute( @@ -65,7 +69,20 @@ def _insert(txn): (instance_name,), ) - self.get_success(self.db_pool.runInteraction("test_single_instance", _insert)) + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) + + def _insert_row_with_id(self, instance_name: str, stream_id: int): + """Insert one row as the given instance with given stream_id, updating + the postgres sequence position to match. + """ + + def _insert(txn): + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + ) + txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + + self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) def test_empty(self): """Test an ID generator against an empty database gives sensible @@ -188,11 +205,17 @@ def test_get_persisted_upto_position(self): positions. """ - self._insert_rows("first", 3) - self._insert_rows("second", 5) + # The following tests are a bit cheeky in that we notify about new + # positions via `advance` without *actually* advancing the postgres + # sequence. + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) id_gen = self._create_id_generator("first") + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + # Min is 3 and there is a gap between 5, so we expect it to be 3. self.assertEqual(id_gen.get_persisted_upto_position(), 3) @@ -218,3 +241,26 @@ def test_get_persisted_upto_position(self): id_gen.advance("first", 11) id_gen.advance("second", 15) self.assertEqual(id_gen.get_persisted_upto_position(), 11) + + def test_get_persisted_upto_position_get_next(self): + """Test that `get_persisted_upto_position` correctly tracks updates to + positions when `get_next` is called. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + id_gen = self._create_id_generator("first") + + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + with self.get_success(id_gen.get_next()) as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.assertEqual(id_gen.get_persisted_upto_position(), 6) + + # We assume that so long as `get_next` does correctly advance the + # `persisted_upto_position` in this case, then it will be correct in the + # other cases that are tested above (since they'll hit the same code). From c9fa696ea2ad5bc32430aeb1bc555df537a71a59 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 12:07:13 -0400 Subject: [PATCH 12/39] simple_search_list_txn should return None, not 0. (#8187) --- changelog.d/8187.misc | 1 + synapse/storage/database.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 changelog.d/8187.misc diff --git a/changelog.d/8187.misc b/changelog.d/8187.misc new file mode 100644 index 000000000000..cb557122aaa1 --- /dev/null +++ b/changelog.d/8187.misc @@ -0,0 +1 @@ +Add type hints to `synapse.storage.database`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 2f6f49a4bf84..ba4c0c9af6d1 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -28,7 +28,6 @@ Optional, Tuple, TypeVar, - Union, overload, ) @@ -1655,7 +1654,7 @@ def simple_search_list_txn( term: Optional[str], col: str, retcols: Iterable[str], - ) -> Union[List[Dict[str, Any]], int]: + ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1667,14 +1666,14 @@ def simple_search_list_txn( retcols: the names of the columns to return Returns: - 0 if no term is given, otherwise a list of dictionaries. + None if no term is given, otherwise a list of dictionaries. """ if term: sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) termvalues = ["%%" + term + "%%"] txn.execute(sql, termvalues) else: - return 0 + return None return cls.cursor_to_dict(txn) From 9b7ac03af3e7ceae7d1933db566ee407cfdef72d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 13:38:41 -0400 Subject: [PATCH 13/39] Convert calls of async database methods to async (#8166) --- changelog.d/8166.misc | 1 + synapse/federation/persistence.py | 16 +++++----- synapse/federation/units.py | 4 +-- synapse/storage/databases/main/appservice.py | 6 ++-- synapse/storage/databases/main/devices.py | 4 +-- .../storage/databases/main/group_server.py | 30 ++++++++++++++----- synapse/storage/databases/main/keys.py | 26 +++++++++------- .../databases/main/media_repository.py | 22 +++++++------- synapse/storage/databases/main/openid.py | 6 ++-- synapse/storage/databases/main/profile.py | 10 ++++--- .../storage/databases/main/registration.py | 29 +++++++++--------- synapse/storage/databases/main/room.py | 16 ++++++---- synapse/storage/databases/main/stats.py | 10 +++---- .../storage/databases/main/transactions.py | 18 ++++++----- 14 files changed, 114 insertions(+), 84 deletions(-) create mode 100644 changelog.d/8166.misc diff --git a/changelog.d/8166.misc b/changelog.d/8166.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8166.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index d68b4bd670f6..769cd5de28ab 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -21,7 +21,9 @@ import logging +from synapse.federation.units import Transaction from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -49,15 +51,15 @@ def have_responded(self, origin, transaction): return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function - def set_response(self, origin, transaction, code, response): + async def set_response( + self, origin: str, transaction: Transaction, code: int, response: JsonDict + ) -> None: """ Persist how we responded to a transaction. - - Returns: - Deferred """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.set_received_txn_response( - transaction.transaction_id, origin, code, response + await self.store.set_received_txn_response( + transaction_id, origin, code, response ) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 6b32e0dcbfb8..64d98fc8f675 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -107,9 +107,7 @@ def __init__(self, transaction_id=None, pdus=[], **kwargs): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super(Transaction, self).__init__( - transaction_id=transaction_id, pdus=pdus, **kwargs - ) + super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 77723f7d4dc7..92f56f1602a2 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -161,16 +161,14 @@ async def get_appservice_state(self, service): return result.get("state") return None - def set_appservice_state(self, service, state): + async def set_appservice_state(self, service, state) -> None: """Set the application service state. Args: service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. - Returns: - An Awaitable which resolves when the state was set successfully. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a811a39eb524..ecd3f3b3108b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -716,11 +716,11 @@ async def get_user_ids_requiring_device_list_resync( return {row["user_id"] for row in rows} - def mark_remote_user_device_cache_as_stale(self, user_id: str): + async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index e3ead71853ea..8acf254bf351 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -742,7 +742,13 @@ def remove_room_from_summary(self, group_id, room_id, category_id): desc="remove_room_from_summary", ) - def upsert_group_category(self, group_id, category_id, profile, is_public): + async def upsert_group_category( + self, + group_id: str, + category_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/update room category for group """ insertion_values = {} @@ -758,7 +764,7 @@ def upsert_group_category(self, group_id, category_id, profile, is_public): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -773,7 +779,13 @@ def remove_group_category(self, group_id, category_id): desc="remove_group_category", ) - def upsert_group_role(self, group_id, role_id, profile, is_public): + async def upsert_group_role( + self, + group_id: str, + role_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/remove user role """ insertion_values = {} @@ -789,7 +801,7 @@ def upsert_group_role(self, group_id, role_id, profile, is_public): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -938,10 +950,10 @@ def remove_user_from_summary(self, group_id, user_id, role_id): desc="remove_user_from_summary", ) - def add_group_invite(self, group_id, user_id): + async def add_group_invite(self, group_id: str, user_id: str) -> None: """Record that the group server has invited a user """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -1044,8 +1056,10 @@ def _remove_user_from_group_txn(txn): "remove_user_from_group", _remove_user_from_group_txn ) - def add_room_to_group(self, group_id, room_id, is_public): - return self.db_pool.simple_insert( + async def add_room_to_group( + self, group_id: str, room_id: str, is_public: bool + ) -> None: + await self.db_pool.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index fadcad51e7a1..1c0a049c5548 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -140,22 +140,28 @@ async def store_server_verify_keys( for i in invalidations: invalidate((i,)) - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): + async def store_server_keys_json( + self, + server_name: str, + key_id: str, + from_server: str, + ts_now_ms: int, + ts_expires_ms: int, + key_json_bytes: bytes, + ) -> None: """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the (server_name, key_id, from_server) triplet if one already existed. Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. + server_name: The name of the server. + key_id: The identifer of the key this JSON is for. + from_server: The server this JSON was fetched from. + ts_now_ms: The time now in milliseconds. + ts_valid_until_ms: The time when this json stops being valid. + key_json_bytes: The encoded JSON. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 8361dd63d964..3919ecad69a8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -60,7 +60,7 @@ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: desc="get_local_media", ) - def store_local_media( + async def store_local_media( self, media_id, media_type, @@ -69,8 +69,8 @@ def store_local_media( media_length, user_id, url_cache=None, - ): - return self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, @@ -141,10 +141,10 @@ def get_url_cache_txn(txn): return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) - def store_url_cache( + async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -172,7 +172,7 @@ async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any] desc="get_local_media_thumbnails", ) - def store_local_thumbnail( + async def store_local_thumbnail( self, media_id, thumbnail_width, @@ -181,7 +181,7 @@ def store_local_thumbnail( thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -212,7 +212,7 @@ async def get_cached_remote_media( desc="get_cached_remote_media", ) - def store_cached_remote_media( + async def store_cached_remote_media( self, origin, media_id, @@ -222,7 +222,7 @@ def store_cached_remote_media( upload_name, filesystem_id, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -288,7 +288,7 @@ async def get_remote_media_thumbnails( desc="get_remote_media_thumbnails", ) - def store_remote_media_thumbnail( + async def store_remote_media_thumbnail( self, origin, media_id, @@ -299,7 +299,7 @@ def store_remote_media_thumbnail( thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index dcd1ff911a20..4db8949da767 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -2,8 +2,10 @@ class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db_pool.simple_insert( + async def insert_open_id_token( + self, token: str, ts_valid_until_ms: int, user_id: str + ) -> None: + await self.db_pool.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 858fd92420b5..301875a6725a 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -66,8 +66,8 @@ async def get_from_remote_profile_cache( desc="get_from_remote_profile_cache", ) - def create_profile(self, user_localpart): - return self.db_pool.simple_insert( + async def create_profile(self, user_localpart: str) -> None: + await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) @@ -93,13 +93,15 @@ async def set_profile_avatar_url( class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: """Ensure we are caching the remote user's profiles. This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 48bda66f3ec3..28f7ae043092 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Awaitable, Dict, List, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -549,23 +549,22 @@ def user_delete_threepids(self, user_id: str): desc="user_delete_threepids", ) - def add_user_bound_threepid(self, user_id, medium, address, id_server): + async def add_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ): """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Awaitable + user_id + medium + address + id_server """ # We need to use an upsert, in case they user had already bound the # threepid - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1083,9 +1082,9 @@ def _register_user( self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - def record_user_external_id( + async def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Awaitable: + ) -> None: """Record a mapping from an external user id to a mxid Args: @@ -1093,7 +1092,7 @@ def record_user_external_id( external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1237,12 +1236,12 @@ async def is_guest(self, user_id: str) -> bool: return res if res else False - def add_user_pending_deactivation(self, user_id): + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 66d713541389..a92641c33927 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -27,7 +27,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -1296,11 +1296,17 @@ def f(txn): return self.db_pool.runInteraction("get_rooms", f) - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): + async def add_event_report( + self, + room_id: str, + event_id: str, + user_id: str, + reason: str, + content: JsonDict, + received_ts: int, + ) -> None: next_id = self._event_reports_id_gen.get_next() - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="event_reports", values={ "id": next_id, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9fe97af56adb..7af2608ca486 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from itertools import chain -from typing import Tuple +from typing import Any, Dict, Tuple from twisted.internet.defer import DeferredLock @@ -222,11 +222,11 @@ async def get_stats_positions(self) -> int: desc="stats_incremental_position", ) - def update_room_state(self, room_id, fields): + async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: """ Args: - room_id (str) - fields (dict[str:Any]) + room_id + fields """ # For whatever reason some of the fields may contain null bytes, which @@ -244,7 +244,7 @@ def update_room_state(self, room_id, fields): if field and "\0" in field: fields[col] = None - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 52668dbdf9cf..2efcc0dc66f4 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -21,6 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache db_binary_type = memoryview @@ -98,20 +99,21 @@ def _get_received_txn_response(self, txn, transaction_id, origin): else: return None - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and + async def set_received_txn_response( + self, transaction_id: str, origin: str, code: int, response_dict: JsonDict + ) -> None: + """Persist the response we returned for an incoming transaction, and should return for subsequent transactions with the same transaction_id and origin. Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) + transaction_id: The incoming transaction ID. + origin: The origin server. + code: The response code. + response_dict: The response, to be encoded into JSON. """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, From b71d4a094c4370d0229fbd4545b85c049364ecf3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 14:16:41 -0400 Subject: [PATCH 14/39] Convert simple_delete to async/await. (#8191) --- changelog.d/8191.misc | 1 + synapse/storage/database.py | 63 ++++++++++++++++--- .../storage/databases/main/group_server.py | 28 +++++---- .../storage/databases/main/registration.py | 29 +++++---- tests/storage/test_event_push_actions.py | 6 +- 5 files changed, 90 insertions(+), 37 deletions(-) create mode 100644 changelog.d/8191.misc diff --git a/changelog.d/8191.misc b/changelog.d/8191.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8191.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ba4c0c9af6d1..7ab370efef15 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -614,6 +614,7 @@ async def execute( """Runs a single query for a result set. Args: + desc: description of the transaction, for logging and metrics decoder - The function which can resolve the cursor results to something meaningful. query - The query string to execute @@ -649,7 +650,7 @@ async def simple_insert( or_ignore: bool stating whether an exception should be raised when a conflicting row already exists. If True, False will be returned by the function instead - desc: string giving a description of the transaction + desc: description of the transaction, for logging and metrics Returns: Whether the row was inserted or not. Only useful when `or_ignore` is True @@ -686,7 +687,7 @@ async def simple_insert_many( Args: table: string giving the table name values: dict of new column names and values for them - desc: string giving a description of the transaction + desc: description of the transaction, for logging and metrics """ await self.runInteraction(desc, self.simple_insert_many_txn, table, values) @@ -700,7 +701,6 @@ def simple_insert_many_txn( txn: The transaction to use. table: string giving the table name values: dict of new column names and values for them - desc: string giving a description of the transaction """ if not values: return @@ -755,6 +755,7 @@ async def simple_upsert( keyvalues: The unique key columns and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + desc: description of the transaction, for logging and metrics lock: True to lock the table when doing the upsert. Returns: Native upserts always return None. Emulated upserts return True if a @@ -1081,6 +1082,7 @@ async def simple_select_one( retcols: list of strings giving the names of the columns to return allow_none: If true, return None instead of failing if the SELECT statement returns no rows + desc: description of the transaction, for logging and metrics """ return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none @@ -1166,6 +1168,7 @@ async def simple_select_onecol( table: table name keyvalues: column names and values to select the rows with retcol: column whos value we wish to retrieve. + desc: description of the transaction, for logging and metrics Returns: Results in a list @@ -1190,6 +1193,7 @@ async def simple_select_list( column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + desc: description of the transaction, for logging and metrics Returns: A list of dictionaries. @@ -1243,14 +1247,16 @@ async def simple_select_many_batch( """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: table: string giving the table name column: column name to test for inclusion against `iterable` iterable: list - keyvalues: dict of column names and values to select the rows with retcols: list of strings giving the names of the columns to return + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + batch_size: the number of rows for each select query """ results = [] # type: List[Dict[str, Any]] @@ -1291,7 +1297,7 @@ def simple_select_many_txn( """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: txn: Transaction object @@ -1367,6 +1373,7 @@ async def simple_update_one( table: string giving the table name keyvalues: dict of column names and values to select the row with updatevalues: dict giving column names and values to update + desc: description of the transaction, for logging and metrics """ await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues @@ -1426,6 +1433,7 @@ async def simple_delete_one( Args: table: string giving the table name keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics """ await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @@ -1451,13 +1459,38 @@ def simple_delete_one_txn( if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str): - return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + async def simple_delete( + self, table: str, keyvalues: Dict[str, Any], desc: str + ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics + + Returns: + The number of deleted rows. + """ + return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod def simple_delete_txn( txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + + Returns: + The number of deleted rows. + """ sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1474,6 +1507,20 @@ async def simple_delete_many( keyvalues: Dict[str, Any], desc: str, ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + + Returns: + Number rows deleted + """ return await self.runInteraction( desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 8acf254bf351..6c6017188845 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -728,11 +728,13 @@ def _add_room_to_summary_txn( }, ) - def remove_room_from_summary(self, group_id, room_id, category_id): + async def remove_room_from_summary( + self, group_id: str, room_id: str, category_id: str + ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -772,8 +774,8 @@ async def upsert_group_category( desc="upsert_group_category", ) - def remove_group_category(self, group_id, category_id): - return self.db_pool.simple_delete( + async def remove_group_category(self, group_id: str, category_id: str) -> int: + return await self.db_pool.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -809,8 +811,8 @@ async def upsert_group_role( desc="upsert_group_role", ) - def remove_group_role(self, group_id, role_id): - return self.db_pool.simple_delete( + async def remove_group_role(self, group_id: str, role_id: str) -> int: + return await self.db_pool.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", @@ -940,11 +942,13 @@ def _add_user_to_summary_txn( }, ) - def remove_user_from_summary(self, group_id, user_id, role_id): + async def remove_user_from_summary( + self, group_id: str, user_id: str, role_id: str + ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -1264,16 +1268,16 @@ async def update_remote_attestion( desc="update_remote_attestion", ) - def remove_attestation_renewal(self, group_id, user_id): + async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: """Remove an attestation that we thought we should renew, but actually shouldn't. Ideally this would never get called as we would never incorrectly try and do attestations for local users on local groups. Args: - group_id (str) - user_id (str) + group_id + user_id """ - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 28f7ae043092..12689f430868 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -529,21 +529,21 @@ async def user_get_threepids(self, user_id): "user_get_threepids", ) - def user_delete_threepid(self, user_id, medium, address): - return self.db_pool.simple_delete( + async def user_delete_threepid(self, user_id, medium, address) -> None: + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", ) - def user_delete_threepids(self, user_id: str): + async def user_delete_threepids(self, user_id: str) -> None: """Delete all threepid this user has bound Args: user_id: The user id to delete all threepids of """ - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -597,21 +597,20 @@ async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: desc="user_get_bound_threepids", ) - def remove_user_bound_threepid(self, user_id, medium, address, id_server): + async def remove_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ) -> None: """The server proxied an unbind request to the given identity server on behalf of the given user, so we remove the mapping of threepid to identity server. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred + user_id + medium + address + id_server """ - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1247,14 +1246,14 @@ async def add_user_pending_deactivation(self, user_id: str) -> None: desc="add_user_pending_deactivation", ) - def del_user_pending_deactivation(self, user_id): + async def del_user_pending_deactivation(self, user_id: str) -> None: """ Removes the given user to the table of users who need to be parted from all the rooms they're in, effectively marking that user as fully deactivated. """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 238bad5b4513..0e7427e57abe 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -123,8 +123,10 @@ def _mark_read(stream, depth): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.db_pool.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" + yield defer.ensureDeferred( + self.store.db_pool.simple_delete( + table="event_push_actions", keyvalues={"1": 1}, desc="" + ) ) yield _assert_counts(1, 0) From b49a5b9307fbbc9032e28e532e9036db07555d3d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 17:24:37 -0400 Subject: [PATCH 15/39] Convert stats and related calls to async/await (#8192) --- changelog.d/8192.misc | 1 + .../databases/main/monthly_active_users.py | 15 ++-- synapse/storage/databases/main/stats.py | 82 ++++++++++--------- tests/handlers/test_auth.py | 21 ++--- tests/handlers/test_register.py | 12 ++- tests/rest/admin/test_user.py | 9 +- .../test_resource_limits_server_notices.py | 10 ++- tests/storage/test_client_ips.py | 5 +- 8 files changed, 78 insertions(+), 77 deletions(-) create mode 100644 changelog.d/8192.misc diff --git a/changelog.d/8192.misc b/changelog.d/8192.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8192.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index fe30552c08ef..1d793d3debdf 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import Dict, List from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause @@ -33,11 +33,11 @@ def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs @cached(num_args=0) - def get_monthly_active_count(self): + async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users Returns: - Defered[int]: Number of current monthly active users + Number of current monthly active users """ def _count_users(txn): @@ -46,10 +46,10 @@ def _count_users(txn): (count,) = txn.fetchone() return count - return self.db_pool.runInteraction("count_users", _count_users) + return await self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) - def get_monthly_active_count_by_service(self): + async def get_monthly_active_count_by_service(self) -> Dict[str, int]: """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table @@ -57,8 +57,7 @@ def get_monthly_active_count_by_service(self): method to return anything other than native matrix users. Returns: - Deferred[dict]: dict that includes a mapping between app_service_id - and the number of occurrences. + A mapping between app_service_id and the number of occurrences. """ @@ -74,7 +73,7 @@ def _count_users_by_service(txn): result = txn.fetchall() return dict(result) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_users_by_service", _count_users_by_service ) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 7af2608ca486..9b9bc304a857 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -15,8 +15,9 @@ # limitations under the License. import logging +from collections import Counter from itertools import chain -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet.defer import DeferredLock @@ -251,21 +252,23 @@ async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: desc="update_room_state", ) - def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): + async def get_statistics_for_subject( + self, stats_type: str, stats_id: str, start: str, size: int = 100 + ) -> List[dict]: """ Get statistics for a given subject. Args: - stats_type (str): The type of subject - stats_id (str): The ID of the subject (e.g. room_id or user_id) - start (int): Pagination start. Number of entries, not timestamp. - size (int): How many entries to return. + stats_type: The type of subject + stats_id: The ID of the subject (e.g. room_id or user_id) + start: Pagination start. Number of entries, not timestamp. + size: How many entries to return. Returns: - Deferred[list[dict]], where the dict has the keys of + A list of dicts, where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -319,18 +322,17 @@ async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: allow_none=True, ) - def bulk_update_stats_delta(self, ts, updates, stream_id): + async def bulk_update_stats_delta( + self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int + ) -> None: """Bulk update stats tables for a given stream_id and updates the stats incremental position. Args: - ts (int): Current timestamp in ms - updates(dict[str, dict[str, dict[str, Counter]]]): The updates to - commit as a mapping stats_type -> stats_id -> field -> delta. - stream_id (int): Current position. - - Returns: - Deferred + ts: Current timestamp in ms + updates: The updates to commit as a mapping of + stats_type -> stats_id -> field -> delta. + stream_id: Current position. """ def _bulk_update_stats_delta_txn(txn): @@ -355,38 +357,37 @@ def _bulk_update_stats_delta_txn(txn): updatevalues={"stream_id": stream_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) - def update_stats_delta( + async def update_stats_delta( self, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): + ts: int, + stats_type: str, + stats_id: str, + fields: Dict[str, int], + complete_with_stream_id: Optional[int], + absolute_field_overrides: Optional[Dict[str, int]] = None, + ) -> None: """ Updates the statistics for a subject, with a delta (difference/relative change). Args: - ts (int): timestamp of the change - stats_type (str): "room" or "user" ā€“ the kind of subject - stats_id (str): the subject's ID (room ID or user ID) - fields (dict[str, int]): Deltas of stats values. - complete_with_stream_id (int, optional): + ts: timestamp of the change + stats_type: "room" or "user" ā€“ the kind of subject + stats_id: the subject's ID (room ID or user ID) + fields: Deltas of stats values. + complete_with_stream_id: If supplied, converts an incomplete row into a complete row, with the supplied stream_id marked as the stream_id where the row was completed. - absolute_field_overrides (dict[str, int]): Current stats values - (i.e. not deltas) of absolute fields. - Does not work with per-slice fields. + absolute_field_overrides: Current stats values (i.e. not deltas) of + absolute fields. Does not work with per-slice fields. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -646,19 +647,20 @@ def _upsert_copy_from_table_with_additive_relatives_txn( txn, into_table, all_dest_keyvalues, src_row ) - def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): + async def get_changes_room_total_events_and_bytes( + self, min_pos: int, max_pos: int + ) -> Dict[str, Dict[str, int]]: """Fetches the counts of events in the given range of stream IDs. Args: - min_pos (int) - max_pos (int) + min_pos + max_pos Returns: - Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field - changes. + Mapping of room ID to field changes. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index c01b04e1dcad..4b3fb018b1dc 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -24,6 +24,7 @@ from synapse.handlers.auth import AuthHandler from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -142,7 +143,7 @@ def test_mau_limits_disabled(self): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.large_number_of_users) + side_effect=lambda: make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): @@ -153,7 +154,7 @@ def test_mau_limits_exceeded_large(self): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.large_number_of_users) + side_effect=lambda: make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -168,7 +169,7 @@ def test_mau_limits_parity(self): # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -178,7 +179,7 @@ def test_mau_limits_parity(self): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -188,10 +189,10 @@ def test_mau_limits_parity(self): ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -199,10 +200,10 @@ def test_mau_limits_parity(self): ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -215,7 +216,7 @@ def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.small_number_of_users) + side_effect=lambda: make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception yield defer.ensureDeferred( @@ -225,7 +226,7 @@ def test_mau_limits_not_exceeded(self): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.small_number_of_users) + side_effect=lambda: make_awaitable(self.small_number_of_users) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 5c92d0e8c9a2..eddf5e2498be 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -15,8 +15,6 @@ from mock import Mock -from twisted.internet import defer - from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError @@ -102,7 +100,7 @@ def test_mau_limits_when_disabled(self): def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value - 1) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @@ -110,7 +108,7 @@ def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.lots_of_users) + side_effect=lambda: make_awaitable(self.lots_of_users) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -118,7 +116,7 @@ def test_get_or_create_user_mau_blocked(self): ) self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -128,14 +126,14 @@ def test_get_or_create_user_mau_blocked(self): def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.lots_of_users) + side_effect=lambda: make_awaitable(self.lots_of_users) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 17d0aae2e9b2..160c63023555 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -20,8 +20,6 @@ from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import HttpResponseException, ResourceLimitError @@ -29,6 +27,7 @@ from synapse.rest.client.v2_alpha import sync from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -338,7 +337,7 @@ def test_register_mau_limit_reached(self): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -592,7 +591,7 @@ def test_create_user_mau_limit_reached_active_admin(self): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -632,7 +631,7 @@ def test_create_user_mau_limit_reached_passive_admin(self): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 23db821fb7a0..973338ea71ad 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -67,7 +67,7 @@ def prepare(self, reactor, clock, hs): raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(1000) + side_effect=lambda user_id: make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( return_value=defer.succeed(Mock()) @@ -158,7 +158,7 @@ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): """ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(None) + side_effect=lambda user_id: make_awaitable(None) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -261,10 +261,12 @@ def prepare(self, reactor, clock, hs): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) + self.store.get_monthly_active_count = Mock( + side_effect=lambda: make_awaitable(1000) + ) self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(1000) + side_effect=lambda user_id: make_awaitable(1000) ) # Call the function multiple times to ensure we only send the notice once diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 224ea6fd79d3..370c247e16c4 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -16,13 +16,12 @@ from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -155,7 +154,7 @@ def test_adding_monthly_active_user_when_full(self): user_id = "@user:server" self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(lots_of_users) + side_effect=lambda: make_awaitable(lots_of_users) ) self.get_success( self.store.insert_client_ip( From e00816ad98a1165b67238f9711cb1b0e7135f25f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 17:24:46 -0400 Subject: [PATCH 16/39] Do not yield on awaitables in tests. (#8193) --- changelog.d/8193.misc | 1 + tests/api/test_filtering.py | 36 ++++++---- tests/crypto/test_keyring.py | 4 +- tests/federation/test_complexity.py | 6 +- tests/handlers/test_typing.py | 4 +- tests/rest/client/v2_alpha/test_filter.py | 8 ++- tests/storage/test_appservice.py | 24 +++++-- tests/storage/test_background_update.py | 9 +-- tests/storage/test_end_to_end_keys.py | 32 ++++++--- tests/storage/test_event_push_actions.py | 78 +++++++++++++------- tests/storage/test_main.py | 8 ++- tests/storage/test_registration.py | 44 +++++++----- tests/storage/test_user_directory.py | 16 +++-- tests/test_state.py | 86 ++++++++++++----------- tests/test_visibility.py | 5 +- 15 files changed, 230 insertions(+), 131 deletions(-) create mode 100644 changelog.d/8193.misc diff --git a/changelog.d/8193.misc b/changelog.d/8193.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8193.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 1fab1d6b6902..d2d535d23c20 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -369,8 +369,10 @@ def test_filter_not_labels(self): @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] @@ -388,8 +390,10 @@ def test_filter_presence_match(self): def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart + "2", user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart + "2", user_filter=user_filter_json + ) ) event = MockEvent( event_id="$asdasd:localhost", @@ -410,8 +414,10 @@ def test_filter_presence_no_match(self): @defer.inlineCallbacks def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] @@ -428,8 +434,10 @@ def test_filter_room_state_match(self): @defer.inlineCallbacks def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent( sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" @@ -465,8 +473,10 @@ def test_filter_rooms(self): def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.filtering.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) self.assertEquals(filter_id, 0) @@ -485,8 +495,10 @@ def test_add_filter(self): def test_get_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) filter = yield defer.ensureDeferred( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 0d4b05304b2c..d264653e74e9 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -190,7 +190,7 @@ def test_verify_json_for_server(self): # should fail immediately on an unsigned object d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") - self.failureResultOf(d, SynapseError) + self.get_failure(d, SynapseError) # should succeed on a signed object d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") @@ -221,7 +221,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self): # should fail immediately on an unsigned object d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") - self.failureResultOf(d, SynapseError) + self.get_failure(d, SynapseError) # should fail on a signed object with a non-zero minimum_valid_until_ms, # as it tries to refetch the keys and fails. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9bd515080c7d..3d880c499d8b 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -15,8 +15,6 @@ from mock import Mock -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -60,7 +58,7 @@ def test_complexity_simple(self): # Artificially raise the complexity store = self.hs.get_datastore() - store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23) + store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value request, channel = self.make_request( @@ -160,7 +158,7 @@ def test_join_too_large_once_joined(self): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( + self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable( 600 ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 81c1839637eb..7bf15c4ba93f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -155,7 +155,9 @@ def get_users_in_room(room_id): self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( ([], 0) ) - self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( + None + ) self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index e0e9e94fbf41..de00350580f4 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.api.errors import Codes from synapse.rest.client.v2_alpha import filter @@ -73,8 +75,10 @@ def test_add_filter_non_local_user(self): self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): - filter_id = self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER + filter_id = defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER + ) ) self.reactor.advance(1) filter_id = filter_id.result diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 17fbde284aea..cb808d4de4d7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -243,7 +243,9 @@ def test_set_appservices_state_multiple_up(self): def test_create_appservice_txn_first(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -255,7 +257,9 @@ def test_create_appservice_txn_older_last_txn(self): yield self._set_last_txn(service.id, 9643) # AS is falling behind yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9645, events) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -265,7 +269,9 @@ def test_create_appservice_txn_up_to_date_last_txn(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] yield self._set_last_txn(service.id, 9643) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -286,7 +292,9 @@ def test_create_appservice_txn_up_fuzzing(self): yield self._insert_txn(self.as_list[2]["id"], 10, events) yield self._insert_txn(self.as_list[3]["id"], 9643, events) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -298,7 +306,9 @@ def test_complete_appservice_txn_first_txn(self): txn_id = 1 yield self._insert_txn(service.id, txn_id, events) - yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) + yield defer.ensureDeferred( + self.store.complete_appservice_txn(txn_id=txn_id, service=service) + ) res = yield self.db_pool.runQuery( self.engine.convert_param_style( @@ -324,7 +334,9 @@ def test_complete_appservice_txn_existing_in_state_table(self): txn_id = 5 yield self._set_last_txn(service.id, 4) yield self._insert_txn(service.id, txn_id, events) - yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) + yield defer.ensureDeferred( + self.store.complete_appservice_txn(txn_id=txn_id, service=service) + ) res = yield self.db_pool.runQuery( self.engine.convert_param_style( diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 1a1c59256cfc..02aae1c13d21 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,7 +1,5 @@ from mock import Mock -from twisted.internet import defer - from synapse.storage.background_updates import BackgroundUpdater from tests import unittest @@ -38,11 +36,10 @@ def test_do_background_update(self): ) # first step: make a bit of progress - @defer.inlineCallbacks - def update(progress, count): - yield self.clock.sleep((count * duration_ms) / 1000) + async def update(progress, count): + await self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield store.db_pool.runInteraction( + await store.db_pool.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index d57cdffd8ba4..261bf5b08b1b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -32,7 +32,9 @@ def test_key_without_device_name(self): yield defer.ensureDeferred(self.store.store_device("user", "device", None)) - yield self.store.set_e2e_device_keys("user", "device", now, json) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) res = yield defer.ensureDeferred( self.store.get_e2e_device_keys((("user", "device"),)) @@ -49,12 +51,16 @@ def test_reupload_key(self): yield defer.ensureDeferred(self.store.store_device("user", "device", None)) - changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + changed = yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) self.assertTrue(changed) # If we try to upload the same key then we should be told nothing # changed - changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + changed = yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) self.assertFalse(changed) @defer.inlineCallbacks @@ -62,7 +68,9 @@ def test_get_key_with_device_name(self): now = 1470174257070 json = {"key": "value"} - yield self.store.set_e2e_device_keys("user", "device", now, json) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) yield defer.ensureDeferred( self.store.store_device("user", "device", "display_name") ) @@ -86,10 +94,18 @@ def test_multiple_devices(self): yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) - yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) - yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) - yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) - yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) + ) res = yield defer.ensureDeferred( self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 0e7427e57abe..cdfd2634aa9f 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -60,8 +60,10 @@ def test_count_aggregation(self): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.db_pool.runInteraction( - "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + counts = yield defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + ) ) self.assertEquals( counts, @@ -81,25 +83,31 @@ def _inject_actions(stream, action): event.event_id, {user_id: action} ) ) - yield self.store.db_pool.runInteraction( - "", - self.persist_events_store._set_push_actions_for_event_and_users_txn, - [(event, None)], - [(event, None)], + yield defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", + self.persist_events_store._set_push_actions_for_event_and_users_txn, + [(event, None)], + [(event, None)], + ) ) def _rotate(stream): - return self.store.db_pool.runInteraction( - "", self.store._rotate_notifs_before_txn, stream + return defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", self.store._rotate_notifs_before_txn, stream + ) ) def _mark_read(stream, depth): - return self.store.db_pool.runInteraction( - "", - self.store._remove_old_push_actions_before_txn, - room_id, - user_id, - stream, + return defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", + self.store._remove_old_push_actions_before_txn, + room_id, + user_id, + stream, + ) ) yield _assert_counts(0, 0) @@ -163,16 +171,24 @@ def add_event(so, ts): ) # start with the base case where there are no events in the table - r = yield self.store.find_first_stream_ordering_after_ts(11) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(11) + ) self.assertEqual(r, 0) # now with one event yield add_event(2, 10) - r = yield self.store.find_first_stream_ordering_after_ts(9) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(9) + ) self.assertEqual(r, 2) - r = yield self.store.find_first_stream_ordering_after_ts(10) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(10) + ) self.assertEqual(r, 2) - r = yield self.store.find_first_stream_ordering_after_ts(11) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(11) + ) self.assertEqual(r, 3) # add a bunch of dummy events to the events table @@ -185,25 +201,37 @@ def add_event(so, ts): ): yield add_event(stream_ordering, ts) - r = yield self.store.find_first_stream_ordering_after_ts(110) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(110) + ) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) # 4 and 5 are both after 120: we want 4 rather than 5 - r = yield self.store.find_first_stream_ordering_after_ts(120) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(120) + ) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) - r = yield self.store.find_first_stream_ordering_after_ts(129) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(129) + ) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) # check we can get the last event - r = yield self.store.find_first_stream_ordering_after_ts(140) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(140) + ) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) # off the end - r = yield self.store.find_first_stream_ordering_after_ts(160) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(160) + ) self.assertEqual(r, 21) # check we can find an event at ordering zero yield add_event(0, 5) - r = yield self.store.find_first_stream_ordering_after_ts(1) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(1) + ) self.assertEqual(r, 0) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 954338a59276..7e7f1286d961 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -34,14 +34,16 @@ def setUp(self): @defer.inlineCallbacks def test_get_users_paginate(self): - yield self.store.register_user(self.user.to_string(), "pass") + yield defer.ensureDeferred( + self.store.register_user(self.user.to_string(), "pass") + ) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield defer.ensureDeferred( self.store.set_profile_displayname(self.user.localpart, self.displayname) ) - users, total = yield self.store.get_users_paginate( - 0, 10, name="bc", guests=False + users, total = yield defer.ensureDeferred( + self.store.get_users_paginate(0, 10, name="bc", guests=False) ) self.assertEquals(1, total) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 70c55cd65040..6b582771feaf 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -37,7 +37,7 @@ def setUp(self): @defer.inlineCallbacks def test_register(self): - yield self.store.register_user(self.user_id, self.pwhash) + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.assertEquals( { @@ -58,14 +58,16 @@ def test_register(self): @defer.inlineCallbacks def test_add_tokens(self): - yield self.store.register_user(self.user_id, self.pwhash) + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) yield defer.ensureDeferred( self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) ) - result = yield self.store.get_user_by_access_token(self.tokens[1]) + result = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[1]) + ) self.assertDictContainsSubset( {"name": self.user_id, "device_id": self.device_id}, result @@ -76,7 +78,7 @@ def test_add_tokens(self): @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - yield self.store.register_user(self.user_id, self.pwhash) + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) yield defer.ensureDeferred( self.store.add_access_token_to_user( self.user_id, self.tokens[0], device_id=None, valid_until_ms=None @@ -89,22 +91,28 @@ def test_user_delete_access_tokens(self): ) # now delete some - yield self.store.user_delete_access_tokens( - self.user_id, device_id=self.device_id + yield defer.ensureDeferred( + self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id) ) # check they were deleted - user = yield self.store.get_user_by_access_token(self.tokens[1]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[1]) + ) self.assertIsNone(user, "access token was not deleted by device_id") # check the one not associated with the device was not deleted - user = yield self.store.get_user_by_access_token(self.tokens[0]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[0]) + ) self.assertEqual(self.user_id, user["name"]) # now delete the rest - yield self.store.user_delete_access_tokens(self.user_id) + yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) - user = yield self.store.get_user_by_access_token(self.tokens[0]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[0]) + ) self.assertIsNone(user, "access token was not deleted without device_id") @defer.inlineCallbacks @@ -112,16 +120,20 @@ def test_is_support_user(self): TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = yield self.store.is_support_user(None) + res = yield defer.ensureDeferred(self.store.is_support_user(None)) self.assertFalse(res) - yield self.store.register_user(user_id=TEST_USER, password_hash=None) - res = yield self.store.is_support_user(TEST_USER) + yield defer.ensureDeferred( + self.store.register_user(user_id=TEST_USER, password_hash=None) + ) + res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER)) self.assertFalse(res) - yield self.store.register_user( - user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT + yield defer.ensureDeferred( + self.store.register_user( + user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT + ) ) - res = yield self.store.is_support_user(SUPPORT_USER) + res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) @defer.inlineCallbacks diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index ecfafe68a965..738e91246829 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -31,10 +31,18 @@ def setUp(self): # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. - yield self.store.update_profile_in_user_dir(ALICE, "alice", None) - yield self.store.update_profile_in_user_dir(BOB, "bob", None) - yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(ALICE, "alice", None) + ) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BOB, "bob", None) + ) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BOBBY, "bobby", None) + ) + yield defer.ensureDeferred( + self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) + ) @defer.inlineCallbacks def test_search_user_dir(self): diff --git a/tests/test_state.py b/tests/test_state.py index b5c3667d2a8c..56ba0fecf550 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -80,16 +80,16 @@ def __init__(self): self._next_group = 1 - def get_state_groups_ids(self, room_id, event_ids): + async def get_state_groups_ids(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) if group: groups[group] = self._group_to_state[group] - return defer.succeed(groups) + return groups - def store_state_group( + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids ): state_group = self._next_group @@ -97,19 +97,17 @@ def store_state_group( self._group_to_state[state_group] = dict(current_state_ids) - return defer.succeed(state_group) + return state_group - def get_events(self, event_ids, **kwargs): - return defer.succeed( - { - e_id: self._event_id_to_event[e_id] - for e_id in event_ids - if e_id in self._event_id_to_event - } - ) + async def get_events(self, event_ids, **kwargs): + return { + e_id: self._event_id_to_event[e_id] + for e_id in event_ids + if e_id in self._event_id_to_event + } - def get_state_group_delta(self, name): - return defer.succeed((None, None)) + async def get_state_group_delta(self, name): + return (None, None) def register_events(self, events): for e in events: @@ -121,8 +119,8 @@ def register_event_context(self, event, context): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group - def get_room_version_id(self, room_id): - return defer.succeed(RoomVersions.V1.identifier) + async def get_room_version_id(self, room_id): + return RoomVersions.V1.identifier class DictObj(dict): @@ -476,12 +474,14 @@ def test_trivial_annotate_message(self): create_event(type="test2", state_key=""), ] - group_name = yield self.store.store_state_group( - prev_event_id, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state}, + group_name = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) ) self.store.register_event_id_state_group(prev_event_id, group_name) @@ -508,12 +508,14 @@ def test_trivial_annotate_state(self): create_event(type="test2", state_key=""), ] - group_name = yield self.store.store_state_group( - prev_event_id, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state}, + group_name = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) ) self.store.register_event_id_state_group(prev_event_id, group_name) @@ -691,21 +693,25 @@ def test_standard_depth_conflict(self): def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = yield self.store.store_state_group( - prev_event_id_1, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state_1}, + sg1 = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id_1, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state_1}, + ) ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = yield self.store.store_state_group( - prev_event_id_2, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state_2}, + sg2 = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id_2, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state_2}, + ) ) self.store.register_event_id_state_group(prev_event_id_2, sg2) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 531a9b9118b6..4a4483ba120e 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -37,7 +37,6 @@ def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -99,7 +98,9 @@ def test_erased_user(self): events_to_filter.append(evt) # the erasey user gets erased - yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") + yield defer.ensureDeferred( + self.hs.get_datastore().mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. filtered = yield defer.ensureDeferred( From 2c2e649be29390061d2bc7d7a7aea1daa32e68f6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 28 Aug 2020 09:58:17 +0100 Subject: [PATCH 17/39] Move and refactor LoginRestServlet helper methods (#8182) This is split out from https://github.com/matrix-org/synapse/pull/7438, which had gotten rather large. `LoginRestServlet` has a couple helper methods, `login_submission_legacy_convert` and `login_id_thirdparty_from_phone`. They're primarily used for converting legacy user login submissions to "identifier" dicts ([see spec](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login)). Identifying information such as usernames or 3PID information used to be top-level in the login body. They're now supposed to be put inside an [identifier](https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types) parameter instead. #7438's purpose is to allow using the new identifier parameter during User-Interactive Authentication, which is currently handled in AuthHandler. That's why I've moved these helper methods there. I also moved the refactoring of these method from #7438 as they're relevant. --- changelog.d/8182.misc | 1 + synapse/handlers/auth.py | 88 ++++++++++++++++++++++++++++++++- synapse/rest/client/v1/login.py | 60 +++------------------- 3 files changed, 94 insertions(+), 55 deletions(-) create mode 100644 changelog.d/8182.misc diff --git a/changelog.d/8182.misc b/changelog.d/8182.misc new file mode 100644 index 000000000000..4fcdf1c45213 --- /dev/null +++ b/changelog.d/8182.misc @@ -0,0 +1 @@ +Refactor some of `LoginRestServlet`'s helper methods, and move them to `AuthHandler` for easier reuse. \ No newline at end of file diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 654f58ddaefe..f0b0a4d76ab7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -42,8 +42,9 @@ from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils +from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email from ._base import BaseHandler @@ -51,6 +52,91 @@ logger = logging.getLogger(__name__) +def convert_client_dict_legacy_fields_to_identifier( + submission: JsonDict, +) -> Dict[str, str]: + """ + Convert a legacy-formatted login submission to an identifier dict. + + Legacy login submissions (used in both login and user-interactive authentication) + provide user-identifying information at the top-level instead. + + These are now deprecated and replaced with identifiers: + https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types + + Args: + submission: The client dict to convert + + Returns: + The matching identifier dict + + Raises: + SynapseError: If the format of the client dict is invalid + """ + identifier = submission.get("identifier", {}) + + # Generate an m.id.user identifier if "user" parameter is present + user = submission.get("user") + if user: + identifier = {"type": "m.id.user", "user": user} + + # Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present + medium = submission.get("medium") + address = submission.get("address") + if medium and address: + identifier = { + "type": "m.id.thirdparty", + "medium": medium, + "address": address, + } + + # We've converted valid, legacy login submissions to an identifier. If the + # submission still doesn't have an identifier, it's invalid + if not identifier: + raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM) + + # Ensure the identifier has a type + if "type" not in identifier: + raise SynapseError( + 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, + ) + + return identifier + + +def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: + """ + Convert a phone login identifier type to a generic threepid identifier. + + Args: + identifier: Login identifier dict of type 'm.id.phone' + + Returns: + An equivalent m.id.thirdparty identifier dict + """ + if "country" not in identifier or ( + # The specification requires a "phone" field, while Synapse used to require a "number" + # field. Accept both for backwards compatibility. + "phone" not in identifier + and "number" not in identifier + ): + raise SynapseError( + 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM + ) + + # Accept both "phone" and "number" as valid keys in m.id.phone + phone_number = identifier.get("phone", identifier["number"]) + + # Convert user-provided phone number to a consistent representation + msisdn = phone_number_to_msisdn(identifier["country"], phone_number) + + return { + "type": "m.id.thirdparty", + "medium": "msisdn", + "address": msisdn, + } + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 379f668d6f8a..a14618ac84fb 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,6 +18,10 @@ from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.handlers.auth import ( + convert_client_dict_legacy_fields_to_identifier, + login_id_phone_to_thirdparty, +) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -28,56 +32,11 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID -from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) -def login_submission_legacy_convert(submission): - """ - If the input login submission is an old style object - (ie. with top-level user / medium / address) convert it - to a typed object. - """ - if "user" in submission: - submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} - del submission["user"] - - if "medium" in submission and "address" in submission: - submission["identifier"] = { - "type": "m.id.thirdparty", - "medium": submission["medium"], - "address": submission["address"], - } - del submission["medium"] - del submission["address"] - - -def login_id_thirdparty_from_phone(identifier): - """ - Convert a phone login identifier type to a generic threepid identifier - Args: - identifier(dict): Login identifier dict of type 'm.id.phone' - - Returns: Login identifier dict of type 'm.id.threepid' - """ - if "country" not in identifier or ( - # The specification requires a "phone" field, while Synapse used to require a "number" - # field. Accept both for backwards compatibility. - "phone" not in identifier - and "number" not in identifier - ): - raise SynapseError(400, "Invalid phone-type identifier") - - # Accept both "phone" and "number" as valid keys in m.id.phone - phone_number = identifier.get("phone", identifier["number"]) - - msisdn = phone_number_to_msisdn(identifier["country"], phone_number) - - return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} - - class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -194,18 +153,11 @@ async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: login_submission.get("address"), login_submission.get("user"), ) - login_submission_legacy_convert(login_submission) - - if "identifier" not in login_submission: - raise SynapseError(400, "Missing param: identifier") - - identifier = login_submission["identifier"] - if "type" not in identifier: - raise SynapseError(400, "Login identifier has no type") + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) # convert phone type identifiers to generic threepids if identifier["type"] == "m.id.phone": - identifier = login_id_thirdparty_from_phone(identifier) + identifier = login_id_phone_to_thirdparty(identifier) # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": From d5e73cb6aa56fbd267ca957e64ad893a9ef28708 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 07:28:53 -0400 Subject: [PATCH 18/39] Define StateMap as immutable and add a MutableStateMap type. (#8183) --- changelog.d/8183.misc | 1 + synapse/handlers/federation.py | 20 ++++++++++++++------ synapse/handlers/room.py | 3 ++- synapse/handlers/sync.py | 5 +++-- synapse/state/__init__.py | 32 ++++++++++++++++++++------------ synapse/state/v1.py | 10 +++++----- synapse/state/v2.py | 6 +++--- synapse/types.py | 7 ++++--- 8 files changed, 52 insertions(+), 32 deletions(-) create mode 100644 changelog.d/8183.misc diff --git a/changelog.d/8183.misc b/changelog.d/8183.misc new file mode 100644 index 000000000000..78d8834328a5 --- /dev/null +++ b/changelog.d/8183.misc @@ -0,0 +1 @@ +Add type hints to `synapse.state`. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f8b234cee21a..155d0874137d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,7 +72,13 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + MutableStateMap, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -96,7 +102,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) + auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) class FederationHandler(BaseHandler): @@ -2053,7 +2059,7 @@ async def _prep_event( origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[StateMap[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], backfilled: bool, ) -> EventContext: context = await self.state_handler.compute_event_context(event, old_state=state) @@ -2137,7 +2143,9 @@ async def _check_for_soft_fail( current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = {k: e.event_id for k, e in current_states.items()} + current_state_ids = { + k: e.event_id for k, e in current_states.items() + } # type: StateMap[str] else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2223,7 +2231,7 @@ async def do_auth( origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """ @@ -2274,7 +2282,7 @@ async def _update_auth_events_and_context_for_auth( origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """Helper for do_auth. See there for docs. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 236a37f777c2..1419d72e9429 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -41,6 +41,7 @@ from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, + MutableStateMap, Requester, RoomAlias, RoomID, @@ -814,7 +815,7 @@ async def _send_events_for_new_room( room_id: str, preset_config: str, invite_list: List[str], - initial_state: StateMap, + initial_state: MutableStateMap, creation_content: JsonDict, room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c42dac18f5f3..9a86eb01c975 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,6 +31,7 @@ from synapse.types import ( Collection, JsonDict, + MutableStateMap, RoomStreamToken, StateMap, StreamToken, @@ -588,7 +589,7 @@ async def compute_summary( room_id: str, sync_config: SyncConfig, batch: TimelineBatch, - state: StateMap[EventBase], + state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: """ Works out a room summary block for this room, summarising the number @@ -736,7 +737,7 @@ async def compute_state_delta( since_token: Optional[StreamToken], now_token: StreamToken, full_state: bool, - ) -> StateMap[EventBase]: + ) -> MutableStateMap[EventBase]: """ Works out the difference in state between the start of the timeline and the previous sync. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a601303fa34e..9bf2ec368f0d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -25,6 +25,7 @@ Sequence, Set, Union, + cast, overload, ) @@ -41,7 +42,7 @@ from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import Collection, StateMap +from synapse.types import Collection, MutableStateMap, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -205,7 +206,7 @@ async def get_current_state_ids( logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return dict(ret.state) + return ret.state async def get_current_users_in_room( self, room_id: str, latest_event_ids: Optional[List[str]] = None @@ -302,7 +303,7 @@ async def compute_event_context( # if we're given the state before the event, then we use that state_ids_before_event = { (s.type, s.state_key): s.event_id for s in old_state - } + } # type: StateMap[str] state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None @@ -315,7 +316,7 @@ async def compute_event_context( event.room_id, event.prev_event_ids() ) - state_ids_before_event = dict(entry.state) + state_ids_before_event = entry.state state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids @@ -540,7 +541,7 @@ async def resolve_state_groups( # # XXX: is this actually worthwhile, or should we just let # resolve_events_with_store do it? - new_state = {} + new_state = {} # type: MutableStateMap[str] conflicted_state = False for st in state_groups_ids.values(): for key, e_id in st.items(): @@ -554,13 +555,20 @@ async def resolve_state_groups( if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, - room_id, - room_version, - list(state_groups_ids.values()), - event_map=event_map, - state_res_store=state_res_store, + # resolve_events_with_store returns a StateMap, but we can + # treat it as a MutableStateMap as it is above. It isn't + # actually mutated anymore (and is frozen in + # _make_state_cache_entry below). + new_state = cast( + MutableStateMap, + await resolve_events_with_store( + self.clock, + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ), ) # if the new state matches any of the input state groups, we can diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 0eb7fdd9e5d3..a493279cbd2e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -32,7 +32,7 @@ from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -131,7 +131,7 @@ async def resolve_events_with_store( def _seperate( state_sets: Iterable[StateMap[str]], -) -> Tuple[StateMap[str], StateMap[Set[str]]]: +) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]: """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. @@ -152,7 +152,7 @@ def _seperate( """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} # type: StateMap[Set[str]] + conflicted_state = {} # type: MutableStateMap[Set[str]] for state_set in state_set_iterator: for key, value in state_set.items(): @@ -208,7 +208,7 @@ def _create_auth_events_from_maps( def _resolve_with_state( - unconflicted_state_ids: StateMap[str], + unconflicted_state_ids: MutableStateMap[str], conflicted_state_ids: StateMap[Set[str]], auth_event_ids: StateMap[str], state_map: Dict[str, EventBase], @@ -241,7 +241,7 @@ def _resolve_with_state( def _resolve_state_events( - conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase] + conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] ) -> StateMap[EventBase]: """ This is where we actually decide which of the conflicted state to use. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 0e9ffbd6e623..edf94e7ad683 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -38,7 +38,7 @@ from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) @@ -414,7 +414,7 @@ async def _iterative_auth_checks( base_state: StateMap[str], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", -) -> StateMap[str]: +) -> MutableStateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -430,7 +430,7 @@ async def _iterative_auth_checks( Returns: Returns the final updated state """ - resolved_state = base_state.copy() + resolved_state = dict(base_state) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for idx, event_id in enumerate(event_ids, start=1): diff --git a/synapse/types.py b/synapse/types.py index bc36cdde308c..f8b9b0385007 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,7 @@ import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -41,8 +41,9 @@ class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") -StateMap = Dict[Tuple[str, str], T] - +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] # the type of a JSON-serialisable dict. This could be made stronger, but it will # do for now. From 5c03134d0f8dd157ea1800ce1a4bcddbdb73ddf1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 07:54:27 -0400 Subject: [PATCH 19/39] Convert additional database code to async/await. (#8195) --- changelog.d/8195.misc | 1 + synapse/appservice/__init__.py | 19 +- synapse/federation/persistence.py | 19 +- synapse/handlers/federation.py | 4 +- synapse/storage/databases/main/appservice.py | 15 +- synapse/storage/databases/main/deviceinbox.py | 12 +- .../storage/databases/main/e2e_room_keys.py | 30 +-- .../databases/main/event_federation.py | 71 +++---- .../storage/databases/main/group_server.py | 187 +++++++++++------- synapse/storage/databases/main/keys.py | 24 +-- .../storage/databases/main/transactions.py | 39 ++-- 11 files changed, 246 insertions(+), 175 deletions(-) create mode 100644 changelog.d/8195.misc diff --git a/changelog.d/8195.misc b/changelog.d/8195.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8195.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1ffdc1ed9591..69a7182ef4a2 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,11 +14,16 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes +from synapse.appservice.api import ApplicationServiceApi from synapse.types import GroupID, get_domain_from_id from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) @@ -35,19 +40,19 @@ def __init__(self, service, id, events): self.id = id self.events = events - def send(self, as_api): + async def send(self, as_api: ApplicationServiceApi) -> bool: """Sends this transaction using the provided AS API interface. Args: - as_api(ApplicationServiceApi): The API to use to send. + as_api: The API to use to send. Returns: - An Awaitable which resolves to True if the transaction was sent. + True if the transaction was sent. """ - return as_api.push_bulk( + return await as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id ) - def complete(self, store): + async def complete(self, store: "DataStore") -> None: """Completes this transaction as successful. Marks this transaction ID on the application service and removes the @@ -55,10 +60,8 @@ def complete(self, store): Args: store: The database store to operate on. - Returns: - A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn(service=self.service, txn_id=self.id) + await store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 769cd5de28ab..de1fe7da3865 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -20,6 +20,7 @@ """ import logging +from typing import Optional, Tuple from synapse.federation.units import Transaction from synapse.logging.utils import log_function @@ -36,25 +37,27 @@ def __init__(self, datastore): self.store = datastore @log_function - def have_responded(self, origin, transaction): - """ Have we already responded to a transaction with the same id and + async def have_responded( + self, origin: str, transaction: Transaction + ) -> Optional[Tuple[int, JsonDict]]: + """Have we already responded to a transaction with the same id and origin? Returns: - Deferred: Results in `None` if we have not previously responded to - this transaction or a 2-tuple of `(int, dict)` representing the - response code and response body. + `None` if we have not previously responded to this transaction or a + 2-tuple of `(int, dict)` representing the response code and response body. """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.get_received_txn_response(transaction.transaction_id, origin) + return await self.store.get_received_txn_response(transaction_id, origin) @log_function async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: - """ Persist how we responded to a transaction. + """Persist how we responded to a transaction. """ transaction_id = transaction.transaction_id # type: ignore if not transaction_id: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 155d0874137d..16389a0dca4c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1879,8 +1879,8 @@ async def get_persisted_pdu( else: return None - def get_min_depth_for_context(self, context): - return self.store.get_min_depth(context) + async def get_min_depth_for_context(self, context): + return await self.store.get_min_depth(context) async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 92f56f1602a2..454c0bc50cb7 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -172,7 +172,7 @@ async def set_appservice_state(self, service, state) -> None: "application_services_state", {"as_id": service.id}, {"state": state} ) - def create_appservice_txn(self, service, events): + async def create_appservice_txn(self, service, events): """Atomically creates a new transaction for this application service with the given list of events. @@ -209,20 +209,17 @@ def _create_appservice_txn(txn): ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - def complete_appservice_txn(self, txn_id, service): + async def complete_appservice_txn(self, txn_id, service) -> None: """Completes an application service transaction. Args: txn_id(str): The transaction ID being completed. service(ApplicationService): The application service which was sent this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. """ txn_id = int(txn_id) @@ -258,7 +255,7 @@ def _complete_appservice_txn(txn): {"txn_id": txn_id, "as_id": service.id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "complete_appservice_txn", _complete_appservice_txn ) @@ -312,13 +309,13 @@ def _get_last_txn(self, txn, service_id): else: return int(last_txn_id[0]) # select 'last_txn' col - def set_appservice_last_pos(self, pos): + async def set_appservice_last_pos(self, pos) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index bb85637a95e3..00444331102e 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -190,15 +190,15 @@ def get_new_messages_for_remote_destination_txn(txn): ) @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + async def delete_device_msgs_for_remote( + self, destination: str, up_to_stream_id: int + ) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. + destination: The destination server_name + up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): @@ -209,7 +209,7 @@ def delete_messages_for_remote_destination_txn(txn): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 82f9d870fd06..12cecceec2a4 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -151,7 +151,7 @@ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=Non return sessions - def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi(self, user_id, version, room_keys): """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -166,10 +166,10 @@ def get_e2e_room_keys_multi(self, user_id, version, room_keys): that we want to query Returns: - Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -283,7 +283,7 @@ def _get_current_version(txn, user_id): raise StoreError(404, "No current backup version") return row[0] - def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info(self, user_id, version=None): """Get info metadata about a version of our room_keys backup. Args: @@ -293,7 +293,7 @@ def get_e2e_room_keys_version_info(self, user_id, version=None): Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present Returns: - A deferred dict giving the info metadata for this backup version, with + A dict giving the info metadata for this backup version, with fields including: version(str) algorithm(str) @@ -324,12 +324,12 @@ def _get_e2e_room_keys_version_info_txn(txn): result["etag"] = 0 return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @trace - def create_e2e_room_keys_version(self, user_id, info): + async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. @@ -338,7 +338,7 @@ def create_e2e_room_keys_version(self, user_id, info): info(dict): the info about the backup version to be created Returns: - A deferred string for the newly created version ID + The newly created version ID """ def _create_e2e_room_keys_version_txn(txn): @@ -365,7 +365,7 @@ def _create_e2e_room_keys_version_txn(txn): return new_version - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -403,13 +403,15 @@ async def update_e2e_room_keys_version( ) @trace - def delete_e2e_room_keys_version(self, user_id, version=None): + async def delete_e2e_room_keys_version( + self, user_id: str, version: Optional[str] = None + ) -> None: """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting + user_id: the user whose backup version we're deleting + version: Optional. the version ID of the backup version we're deleting If missing, we delete the current backup version info. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present, @@ -430,13 +432,13 @@ def _delete_e2e_room_keys_version_txn(txn): keyvalues={"user_id": user_id, "version": this_version}, ) - return self.db_pool.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6e5761c7b75a..0b69aa6a940a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -59,7 +59,7 @@ async def get_auth_chain_ids( include_given: include the given events in result Returns: - list of event_ids + An awaitable which resolve to a list of event_ids """ return await self.db_pool.runInteraction( "get_auth_chain_ids", @@ -95,7 +95,7 @@ def _get_auth_chain_ids_txn( return list(results) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -104,10 +104,10 @@ def get_auth_chain_difference(self, state_sets: List[Set[str]]): chain. Returns: - Deferred[Set[str]] + The set of the difference in auth chains. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, @@ -252,8 +252,8 @@ def _get_auth_chain_difference_txn( # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_with_depth_in_room(self, room_id): - return self.db_pool.runInteraction( + async def get_oldest_events_with_depth_in_room(self, room_id): + return await self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -293,7 +293,7 @@ async def get_max_depth_of(self, event_ids: List[str]) -> int: else: return max(row["depth"] for row in rows) - def get_prev_events_for_room(self, room_id: str): + async def get_prev_events_for_room(self, room_id: str) -> List[str]: """ Gets a subset of the current forward extremities in the given room. @@ -301,14 +301,14 @@ def get_prev_events_for_room(self, room_id: str): events which refer to hundreds of prev_events. Args: - room_id (str): room_id + room_id: room_id Returns: - Deferred[List[str]]: the event ids of the forward extremites + The event ids of the forward extremities. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) @@ -328,17 +328,19 @@ def _get_prev_events_for_room_txn(self, txn, room_id: str): return [row[0] for row in txn] - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + async def get_rooms_with_many_extremities( + self, min_count: int, limit: int, room_id_filter: Iterable[str] + ) -> List[str]: """Get the top rooms with at least N extremities. Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results + min_count: The minimum number of extremities + limit: The maximum number of rooms to return. + room_id_filter: room_ids to exclude from the results Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. + At most `limit` room IDs that have at least `min_count` extremities, + sorted by extremity count. """ def _get_rooms_with_many_extremities_txn(txn): @@ -363,7 +365,7 @@ def _get_rooms_with_many_extremities_txn(txn): txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @@ -376,10 +378,10 @@ async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: desc="get_latest_event_ids_in_room", ) - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. + async def get_min_depth(self, room_id: str) -> int: + """For the given room, get the minimum depth we have seen for it. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) @@ -394,7 +396,9 @@ def _get_min_depth_interaction(self, txn, room_id): return int(min_depth) if min_depth is not None else None - def get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -402,11 +406,11 @@ def get_forward_extremeties_for_room(self, room_id, stream_ordering): stream_orderings from that point. Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: - deferred, which resolves to a list of event_ids + A list of event_ids """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. @@ -422,10 +426,10 @@ def get_forward_extremeties_for_room(self, room_id, stream_ordering): if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) - return self._get_forward_extremeties_for_room(room_id, stream_ordering) + return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -450,19 +454,18 @@ def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - async def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id: str, event_list: list, limit: int): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` Args: - txn - room_id (str) - event_list (list) - limit (int) + room_id + event_list + limit """ event_ids = await self.db_pool.runInteraction( "get_backfill_events", @@ -631,8 +634,8 @@ def _delete_old_forward_extrem_cache_txn(txn): _delete_old_forward_extrem_cache_txn, ) - def clean_room_for_join(self, room_id): - return self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id): + return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 6c6017188845..ccfbb2135eba 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -70,7 +70,9 @@ async def get_invited_users_in_group(self, group_id: str) -> List[str]: desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id: str, include_private: bool = False): + async def get_rooms_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Union[str, bool]]]: """Retrieve the rooms that belong to a given group. Does not return rooms that lack members. @@ -79,8 +81,7 @@ def get_rooms_in_group(self, group_id: str, include_private: bool = False): include_private: Whether to return private rooms in results Returns: - Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the - form of: + A list of dictionaries, each in the form of: { "room_id": "!a_room_id:example.com", # The ID of the room @@ -117,13 +118,13 @@ def _get_rooms_in_group_txn(txn): for room_id, is_public in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_in_group", _get_rooms_in_group_txn ) - def get_rooms_for_summary_by_category( + async def get_rooms_for_summary_by_category( self, group_id: str, include_private: bool = False, - ): + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request Args: @@ -131,7 +132,7 @@ def get_rooms_for_summary_by_category( include_private: Whether to return private rooms in results Returns: - Deferred[Tuple[List, Dict]]: A tuple containing: + A tuple containing: * A list of dictionaries with the keys: * "room_id": str, the room ID @@ -207,7 +208,7 @@ def _get_rooms_for_summary_txn(txn): return rooms, categories - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_summary", _get_rooms_for_summary_txn ) @@ -281,10 +282,11 @@ async def get_local_groups_for_room(self, room_id: str) -> List[str]: desc="get_local_groups_for_room", ) - def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role(self, group_id, include_private=False): """Get the users and roles that should be included in a summary request - Returns ([users], [roles]) + Returns: + ([users], [roles]) """ def _get_users_for_summary_txn(txn): @@ -338,7 +340,7 @@ def _get_users_for_summary_txn(txn): return users, roles - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) @@ -376,7 +378,7 @@ async def is_user_invited_to_local_group( allow_none=True, ) - def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group(self, group_id, user_id): """Get a dict describing the membership of a user in a group. Example if joined: @@ -387,7 +389,8 @@ def get_users_membership_info_in_group(self, group_id, user_id): "is_privileged": False, } - Returns an empty dict if the user is not join/invite/etc + Returns: + An empty dict if the user is not join/invite/etc """ def _get_users_membership_in_group_txn(txn): @@ -419,7 +422,7 @@ def _get_users_membership_in_group_txn(txn): return {} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -433,7 +436,7 @@ async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: desc="get_publicised_groups_for_user", ) - def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals(self, valid_until_ms): """Get all attestations that need to be renewed until givent time """ @@ -445,7 +448,7 @@ def _get_attestations_need_renewals_txn(txn): txn.execute(sql, (valid_until_ms,)) return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) @@ -475,7 +478,7 @@ async def get_joined_groups(self, user_id: str) -> List[str]: desc="get_joined_groups", ) - def get_all_groups_for_user(self, user_id, now_token): + async def get_all_groups_for_user(self, user_id, now_token): def _get_all_groups_for_user_txn(txn): sql = """ SELECT group_id, type, membership, u.content @@ -495,7 +498,7 @@ def _get_all_groups_for_user_txn(txn): for row in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -600,8 +603,27 @@ async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: desc="set_group_join_policy", ) - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db_pool.runInteraction( + async def add_room_to_summary( + self, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) room's entry in summary. + + Args: + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -612,18 +634,26 @@ def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): ) def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): + self, + txn, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: """Add (or update) room's entry in summary. Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. + txn + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public """ room_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -818,8 +848,27 @@ async def remove_group_role(self, group_id: str, role_id: str) -> int: desc="remove_group_role", ) - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db_pool.runInteraction( + async def add_user_to_summary( + self, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) user's entry in summary. + + Args: + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -830,18 +879,26 @@ def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): ) def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public + self, + txn, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], ): """Add (or update) user's entry in summary. Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. + txn + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public """ user_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -963,27 +1020,26 @@ async def add_group_invite(self, group_id: str, user_id: str) -> None: desc="add_group_invite", ) - def add_user_to_group( + async def add_user_to_group( self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): + group_id: str, + user_id: str, + is_admin: bool = False, + is_public: bool = True, + local_attestation: dict = None, + remote_attestation: dict = None, + ) -> None: """Add a user to the group server. Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote + group_id + user_id + is_admin + is_public + local_attestation: The attestation the GS created to give to the remote server. Optional if the user and group are on the same server + remote_attestation: The attestation given to GS by remote server. + Optional if the user and group are on the same server """ def _add_user_to_group_txn(txn): @@ -1026,9 +1082,9 @@ def _add_user_to_group_txn(txn): }, ) - return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) + await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - def remove_user_from_group(self, group_id, user_id): + async def remove_user_from_group(self, group_id: str, user_id: str) -> None: def _remove_user_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1056,7 +1112,7 @@ def _remove_user_from_group_txn(txn): keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) @@ -1079,7 +1135,7 @@ async def update_room_in_group_visibility( desc="update_room_in_group_visibility", ) - def remove_room_from_group(self, group_id, room_id): + async def remove_room_from_group(self, group_id: str, room_id: str) -> None: def _remove_room_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1093,7 +1149,7 @@ def _remove_room_from_group_txn(txn): keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) @@ -1286,14 +1342,11 @@ async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def delete_group(self, group_id): + async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. Args: - group_id (str) - - Returns: - Deferred + group_id: The group ID to delete. """ def _delete_group_txn(txn): @@ -1317,4 +1370,4 @@ def _delete_group_txn(txn): txn, table=table, keyvalues={"group_id": group_id} ) - return self.db_pool.runInteraction("delete_group", _delete_group_txn) + await self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 1c0a049c5548..ad43bb05abb5 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Iterable, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes @@ -42,16 +42,17 @@ def _get_server_verify_key(self, server_name_and_key_id): @cachedList( cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" ) - def get_server_verify_keys(self, server_name_and_key_ids): + async def get_server_verify_keys( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: """ Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): + server_name_and_key_ids: iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown """ keys = {} @@ -87,7 +88,7 @@ def _txn(txn): _get_keys(txn, batch) return keys - return self.db_pool.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_verify_keys", _txn) async def store_server_verify_keys( self, @@ -179,7 +180,9 @@ async def store_server_keys_json( desc="store_server_keys_json", ) - def get_server_keys_json(self, server_keys): + async def get_server_keys_json( + self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: """Retrive the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. @@ -188,8 +191,7 @@ def get_server_keys_json(self, server_keys): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts + A mapping from (server_name, key_id, source) triplets to a list of dicts """ def _get_server_keys_json_txn(txn): @@ -215,6 +217,6 @@ def _get_server_keys_json_txn(txn): results[(server_name, key_id, from_server)] = rows return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_server_keys_json", _get_server_keys_json_txn ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 2efcc0dc66f4..5b31aab700f9 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Optional, Tuple from canonicaljson import encode_canonical_json @@ -56,21 +57,23 @@ def __init__(self, database: DatabasePool, db_conn, hs): expiry_ms=5 * 60 * 1000, ) - def get_received_txn_response(self, transaction_id, origin): + async def get_received_txn_response( + self, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response body (as a dict). Args: - transaction_id (str) - origin(str) + transaction_id + origin Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) + None if we have not previously responded to this transaction or a + 2-tuple of (int, dict) """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -166,21 +169,25 @@ def _get_destination_retry_timings(self, txn, destination): else: return None - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): + async def set_destination_retry_timings( + self, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms + destination + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms """ self._destination_retry_cache.pop(destination, None) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -256,13 +263,13 @@ def _start_cleanup_transactions(self): "cleanup_transactions", self._cleanup_transactions ) - def _cleanup_transactions(self): + async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) From b055dc93220217fe55f8f4d28945f86353c2f3a8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 08:56:36 -0400 Subject: [PATCH 20/39] Ensure that the OpenID Connect remote ID is a string. (#8190) --- changelog.d/8190.bugfix | 1 + synapse/handlers/oidc_handler.py | 3 +++ tests/handlers/test_oidc.py | 41 ++++++++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8190.bugfix diff --git a/changelog.d/8190.bugfix b/changelog.d/8190.bugfix new file mode 100644 index 000000000000..bf6717ab28d9 --- /dev/null +++ b/changelog.d/8190.bugfix @@ -0,0 +1 @@ +Fix logging in via OpenID Connect with a provider that uses integer user IDs. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index c5bd2fea68ff..1b06f3173fa0 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -869,6 +869,9 @@ async def _map_userinfo_to_user( raise MappingException( "Failed to extract subject from OIDC response: %s" % (e,) ) + # Some OIDC providers use integer IDs, but Synapse expects external IDs + # to be strings. + remote_user_id = str(remote_user_id) logger.info( "Looking for existing mapping for user %s:%s", diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index f92f3b8c1527..89ec5fcb31bb 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -75,7 +75,17 @@ def deliverBody(self, protocol): COOKIE_NAME = b"oidc_session" COOKIE_PATH = "/_synapse/oidc" -MockedMappingProvider = Mock(OidcMappingProvider) + +class TestMappingProvider(OidcMappingProvider): + @staticmethod + def parse_config(config): + return + + def get_remote_user_id(self, userinfo): + return userinfo["sub"] + + async def map_user_attributes(self, userinfo, token): + return {"localpart": userinfo["username"], "display_name": None} def simple_async_mock(return_value=None, raises=None): @@ -123,7 +133,7 @@ def make_homeserver(self, reactor, clock): oidc_config["issuer"] = ISSUER oidc_config["scopes"] = SCOPES oidc_config["user_mapping_provider"] = { - "module": __name__ + ".MockedMappingProvider" + "module": __name__ + ".TestMappingProvider", } config["oidc_config"] = oidc_config @@ -580,3 +590,30 @@ def test_exchange_code(self): with self.assertRaises(OidcError) as exc: yield defer.ensureDeferred(self.handler._exchange_code(code)) self.assertEqual(exc.exception.error, "some_error") + + def test_map_userinfo_to_user(self): + """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" + userinfo = { + "sub": "test_user", + "username": "test_user", + } + # The token doesn't matter with the default user mapping provider. + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Some providers return an integer ID. + userinfo = { + "sub": 1234, + "username": "test_user_2", + } + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_2:test") From aec7085179464cc0a05177108ad2962d09ca4f4a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 09:37:55 -0400 Subject: [PATCH 21/39] Convert state and stream stores and related code to async (#8194) --- changelog.d/8194.misc | 1 + synapse/handlers/room.py | 2 +- synapse/storage/databases/main/state.py | 19 +++++++------- .../storage/databases/main/state_deltas.py | 21 ++++++++------- synapse/storage/databases/main/stream.py | 11 +++++--- synapse/storage/databases/state/store.py | 26 +++++++++---------- synapse/storage/state.py | 16 ++++++------ 7 files changed, 51 insertions(+), 45 deletions(-) create mode 100644 changelog.d/8194.misc diff --git a/changelog.d/8194.misc b/changelog.d/8194.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8194.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1419d72e9429..9d5b1828df00 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -451,7 +451,7 @@ async def clone_existing_room( old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) - for k, old_event in old_room_member_state_events.items(): + for old_event in old_room_member_state_events.values(): # Only transfer ban events if ( "membership" in old_event.content diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 458f169617e1..5c6168e30171 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -163,15 +164,15 @@ async def get_create_event_for_room(self, room_id: str) -> EventBase: return create_event @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): + async def get_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. Args: - room_id (str) + room_id: The room to get the state IDs of. Returns: - deferred: dict of (type, state_key) -> event_id + The current state of the room. """ def _get_current_state_ids_txn(txn): @@ -184,14 +185,14 @@ def _get_current_state_ids_txn(txn): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - def get_filtered_current_state_ids( + async def get_filtered_current_state_ids( self, room_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state @@ -202,14 +203,14 @@ def get_filtered_current_state_ids( from the database. Returns: - defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. + Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() if not where_clause: # We delegate to the cached version - return self.get_current_state_ids(room_id) + return await self.get_current_state_ids(room_id) def _get_filtered_current_state_ids_txn(txn): results = {} @@ -231,7 +232,7 @@ def _get_filtered_current_state_ids_txn(txn): return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 0d963c98ffa9..356623fc6e3d 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Any, Dict, List, Tuple from synapse.storage._base import SQLBaseStore @@ -23,7 +22,9 @@ class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id Each entry in the result contains the following fields: @@ -37,12 +38,12 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): if it's new state. Args: - prev_stream_id (int): point to get changes since (exclusive) - max_stream_id (int): the point that we know has been correctly persisted + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. Returns: - Deferred[tuple[int, list[dict]]: A tuple consisting of: + A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. @@ -58,7 +59,7 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. - return defer.succeed((max_stream_id, [])) + return (max_stream_id, []) def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than @@ -102,7 +103,7 @@ def get_current_state_deltas_txn(txn): txn.execute(sql, (prev_stream_id, clipped_stream_id)) return clipped_stream_id, self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) @@ -114,8 +115,8 @@ def _get_max_stream_id_in_current_state_deltas_txn(self, txn): retcol="COALESCE(MAX(stream_id), -1)", ) - def get_max_stream_id_in_current_state_deltas(self): - return self.db_pool.runInteraction( + async def get_max_stream_id_in_current_state_deltas(self): + return await self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 497f6077039b..24f44a7e3607 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -539,7 +539,9 @@ async def get_recent_event_ids_for_room( return rows, token - def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): + async def get_room_event_before_stream_ordering( + self, room_id: str, stream_ordering: int + ) -> Tuple[int, int, str]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -547,8 +549,7 @@ def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: i stream_ordering: Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) + A tuple of (stream ordering, topological ordering, event_id) """ def _f(txn): @@ -563,7 +564,9 @@ def _f(txn): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) + return await self.db_pool.runInteraction( + "get_room_event_before_stream_ordering", _f + ) async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7f104ad93640..e924f1ca3b31 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -17,8 +17,6 @@ from collections import namedtuple from typing import Dict, Iterable, List, Set, Tuple -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -103,7 +101,7 @@ def get_max_state_group_txn(txn: Cursor): ) @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): + async def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between the old and the new. @@ -135,7 +133,7 @@ def _get_state_group_delta_txn(txn): {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_state_group_delta", _get_state_group_delta_txn ) @@ -367,9 +365,9 @@ def _insert_into_cache( fetched_keys=non_member_types, ) - def store_state_group( + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: @@ -383,7 +381,7 @@ def store_state_group( to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ def _store_state_group_txn(txn): @@ -484,11 +482,13 @@ def _store_state_group_txn(txn): return state_group - return self.db_pool.runInteraction("store_state_group", _store_state_group_txn) + return await self.db_pool.runInteraction( + "store_state_group", _store_state_group_txn + ) - def purge_unreferenced_state_groups( + async def purge_unreferenced_state_groups( self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: + ) -> None: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -499,7 +499,7 @@ def purge_unreferenced_state_groups( to delete. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -594,7 +594,7 @@ async def get_previous_state_groups( return {row["state_group"]: row["prev_state_group"] for row in rows} - def purge_room_state(self, room_id, state_groups_to_delete): + async def purge_room_state(self, room_id, state_groups_to_delete): """Deletes all record of a room from state tables Args: @@ -602,7 +602,7 @@ def purge_room_state(self, room_id, state_groups_to_delete): state_groups_to_delete (list[int]): State groups to delete """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 534883361fd7..96a1b59d649b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -333,7 +333,7 @@ class StateGroupStorage(object): def __init__(self, hs, stores): self.stores = stores - def get_state_group_delta(self, state_group: int): + async def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. @@ -341,11 +341,11 @@ def get_state_group_delta(self, state_group: int): state_group: The state group used to retrieve state deltas. Returns: - Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: + Tuple[Optional[int], Optional[StateMap[str]]]: (prev_group, delta_ids) """ - return self.stores.state.get_state_group_delta(state_group) + return await self.stores.state.get_state_group_delta(state_group) async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] @@ -525,7 +525,7 @@ async def get_state_ids_for_event( state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] @@ -546,14 +546,14 @@ def _get_state_for_groups( """ return self.stores.state._get_state_for_groups(groups, state_filter) - def store_state_group( + async def store_state_group( self, event_id: str, room_id: str, prev_group: Optional[int], delta_ids: Optional[dict], current_state_ids: dict, - ): + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: @@ -567,8 +567,8 @@ def store_state_group( to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ - return self.stores.state.store_state_group( + return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) From 22b926c284f98b5507583a3a7866da12a9f4bb47 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 28 Aug 2020 15:59:28 +0100 Subject: [PATCH 22/39] Only return devices with keys from `/federation/v1/user/devices/` (#8198) There's not much point in returning all the others, and some people have a silly number of devices. --- changelog.d/8198.feature | 1 + synapse/storage/databases/main/devices.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8198.feature diff --git a/changelog.d/8198.feature b/changelog.d/8198.feature new file mode 100644 index 000000000000..c4401288bf36 --- /dev/null +++ b/changelog.d/8198.feature @@ -0,0 +1 @@ +Optimise `/federation/v1/user/devices/` API by only returning devices with encryption keys. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index ecd3f3b3108b..def96637a261 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -498,9 +498,7 @@ def _get_devices_with_keys_by_user_txn( ) -> Tuple[int, List[JsonDict]]: now_stream_id = self._device_list_id_gen.get_current_token() - devices = self._get_e2e_device_keys_txn( - txn, [(user_id, None)], include_all_devices=True - ) + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) if devices: user_devices = devices[user_id] From d58fda99ffd29c15254788476fb8775370eebec5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 11:34:50 -0400 Subject: [PATCH 23/39] Convert `event_push_actions`, `registration`, and `roommember` datastores to async (#8197) --- changelog.d/8197.misc | 1 + .../databases/main/event_push_actions.py | 38 ++- .../storage/databases/main/registration.py | 238 +++++++++--------- synapse/storage/databases/main/roommember.py | 52 ++-- 4 files changed, 169 insertions(+), 160 deletions(-) create mode 100644 changelog.d/8197.misc diff --git a/changelog.d/8197.misc b/changelog.d/8197.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8197.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index e8834b2162ba..384c39f4d718 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import List +from typing import Dict, List, Union from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json @@ -383,19 +383,20 @@ def get_no_receipt(txn): # Now return the first `limit` return notifs[:limit] - def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + async def get_if_maybe_push_in_range_for_user( + self, user_id: str, min_stream_ordering: int + ) -> bool: """A fast check to see if there might be something to push for the user since the given stream ordering. May return false positives. Useful to know whether to bother starting a pusher on start up or not. Args: - user_id (str) - min_stream_ordering (int) + user_id + min_stream_ordering Returns: - Deferred[bool]: True if there may be push to process, False if - there definitely isn't. + True if there may be push to process, False if there definitely isn't. """ def _get_if_maybe_push_in_range_for_user_txn(txn): @@ -408,22 +409,20 @@ def _get_if_maybe_push_in_range_for_user_txn(txn): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) - async def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging( + self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]] + ) -> None: """Add the push actions for the event to the push action staging area. Args: - event_id (str) - user_id_actions (dict[str, list[dict|str])]): A dictionary mapping - user_id to list of push actions, where an action can either be - a string or dict. - - Returns: - Deferred + event_id + user_id_actions: A mapping of user_id to list of push actions, where + an action can either be a string or dict. """ if not user_id_actions: @@ -507,7 +506,7 @@ def _find_stream_orderings_for_times_txn(self, txn): "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago ) - def find_first_stream_ordering_after_ts(self, ts): + async def find_first_stream_ordering_after_ts(self, ts: int) -> int: """Gets the stream ordering corresponding to a given timestamp. Specifically, finds the stream_ordering of the first event that was @@ -516,13 +515,12 @@ def find_first_stream_ordering_after_ts(self, ts): relatively slow. Args: - ts (int): timestamp in millis + ts: timestamp in millis Returns: - Deferred[int]: stream ordering of the first event received on/after - the timestamp + stream ordering of the first event received on/after the timestamp """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 12689f430868..000b8ccff436 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -84,17 +84,17 @@ async def is_trial_user(self, user_id: str) -> bool: return is_trial @cached() - def get_user_by_access_token(self, token): + async def get_user_by_access_token(self, token: str) -> Optional[dict]: """Get a user from the given access token. Args: - token (str): The access token of a user. + token: The access token of a user. Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -281,13 +281,12 @@ async def is_server_admin(self, user: UserID) -> bool: return bool(res) if res else False - def set_server_admin(self, user, admin): + async def set_server_admin(self, user: UserID, admin: bool) -> None: """Sets whether a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test - admin (bool): true iff the user is to be a server admin, - false otherwise. + user: user ID of the user to test + admin: true iff the user is to be a server admin, false otherwise. """ def set_server_admin_txn(txn): @@ -298,7 +297,7 @@ def set_server_admin_txn(txn): txn, self.get_user_by_id, (user.to_string(),) ) - return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( @@ -364,9 +363,11 @@ def is_support_user_txn(self, txn, user_id): ) return True if res == UserTypes.SUPPORT else False - def get_users_by_id_case_insensitive(self, user_id): + async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]: """Gets users that match user_id case insensitively. - Returns a mapping of user_id -> password_hash. + + Returns: + A mapping of user_id -> password_hash. """ def f(txn): @@ -374,7 +375,7 @@ def f(txn): txn.execute(sql, (user_id,)) return dict(txn) - return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -408,7 +409,7 @@ def _count_users(txn): return await self.db_pool.runInteraction("count_users", _count_users) - def count_daily_user_type(self): + async def count_daily_user_type(self) -> Dict[str, int]: """ Counts 1) native non guest users 2) native guests users @@ -437,7 +438,7 @@ def _count_daily_user_type(txn): results[row[0]] = row[1] return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_user_type", _count_daily_user_type ) @@ -663,24 +664,29 @@ async def get_user_deactivated_status(self, user_id: str) -> bool: # Convert the integer into a boolean. return res == 1 - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): + async def get_threepid_validation_session( + self, + medium: Optional[str], + client_secret: str, + address: Optional[str] = None, + sid: Optional[str] = None, + validated: Optional[bool] = True, + ) -> Optional[Dict[str, Any]]: """Gets a session_id and last_send_attempt (if available) for a combination of validation metadata Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str): A unique string provided by the client to help identify this + medium: The medium of the 3PID + client_secret: A unique string provided by the client to help identify this validation attempt - validated (bool|None): Whether sessions should be filtered by + address: The address of the 3PID + sid: The ID of the validation session + validated: Whether sessions should be filtered by whether they have been validated already or not. None to perform no filtering Returns: - Deferred[dict|None]: A dict containing the following: + A dict containing the following: * address - address of the 3pid * medium - medium of the 3pid * client_secret - a secret provided by the client for this validation session @@ -726,17 +732,17 @@ def get_threepid_validation_session_txn(txn): return rows[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) - def delete_threepid_session(self, session_id): + async def delete_threepid_session(self, session_id: str) -> None: """Removes a threepid validation session from the database. This can be done after validation has been performed and whatever action was waiting on it has been carried out Args: - session_id (str): The ID of the session to delete + session_id: The ID of the session to delete """ def delete_threepid_session_txn(txn): @@ -751,7 +757,7 @@ def delete_threepid_session_txn(txn): keyvalues={"session_id": session_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) @@ -941,43 +947,40 @@ async def add_access_token_to_user( desc="add_access_token_to_user", ) - def register_user( + async def register_user( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - shadow_banned=False, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Attempts to register an account. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being - upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, - false to add a regular user account. - appservice_id (str): The ID of the appservice registering the user. - create_profile_with_displayname (unicode): Optionally create a profile for + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Whether this is a guest account being upgraded to a + non-guest account. + make_guest: True if the the new user should be guest, false to add a + regular user account. + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from - api.constants.UserTypes, or None for a normal user. - shadow_banned (bool): Whether the user is shadow-banned, - i.e. they may be told their requests succeeded but we ignore them. + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, + or None for a normal user. + shadow_banned: Whether the user is shadow-banned, i.e. they may be + told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. - - Returns: - Deferred """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "register_user", self._register_user, user_id, @@ -1101,7 +1104,7 @@ async def record_user_external_id( desc="record_user_external_id", ) - def user_set_password_hash(self, user_id, password_hash): + async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be @@ -1114,17 +1117,18 @@ def user_set_password_hash_txn(txn): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "user_set_password_hash", user_set_password_hash_txn ) - def user_set_consent_version(self, user_id, consent_version): + async def user_set_consent_version( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy the user has consented - to + user_id: full mxid of the user to update + consent_version: version of the policy the user has consented to Raises: StoreError(404) if user not found @@ -1139,16 +1143,17 @@ def f(txn): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction("user_set_consent_version", f) + await self.db_pool.runInteraction("user_set_consent_version", f) - def user_set_consent_server_notice_sent(self, user_id, consent_version): + async def user_set_consent_server_notice_sent( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record that we have sent the user a server notice about privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy we have notified the - user about + user_id: full mxid of the user to update + consent_version: version of the policy we have notified the user about Raises: StoreError(404) if user not found @@ -1163,22 +1168,25 @@ def f(txn): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) + await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) - def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): + async def user_delete_access_tokens( + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, + ) -> List[Tuple[str, int, Optional[str]]]: """ Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str): list of access_tokens IDs which should - *not* be deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_tokens ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted Returns: - defer.Deferred[list[str, int, str|None, int]]: a list of - (token, token id, device id) for each of the deleted tokens + A tuple of (token, token id, device id) for each of the deleted tokens """ def f(txn): @@ -1209,9 +1217,9 @@ def f(txn): return tokens_and_devices - return self.db_pool.runInteraction("user_delete_access_tokens", f) + return await self.db_pool.runInteraction("user_delete_access_tokens", f) - def delete_access_token(self, access_token): + async def delete_access_token(self, access_token: str) -> None: def f(txn): self.db_pool.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} @@ -1221,7 +1229,7 @@ def f(txn): txn, self.get_user_by_access_token, (access_token,) ) - return self.db_pool.runInteraction("delete_access_token", f) + await self.db_pool.runInteraction("delete_access_token", f) @cached() async def is_guest(self, user_id: str) -> bool: @@ -1272,24 +1280,25 @@ async def get_user_pending_deactivation(self) -> Optional[str]: desc="get_users_pending_deactivation", ) - def validate_threepid_session(self, session_id, client_secret, token, current_ts): + async def validate_threepid_session( + self, session_id: str, client_secret: str, token: str, current_ts: int + ) -> Optional[str]: """Attempt to validate a threepid session using a token Args: - session_id (str): The id of a validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - token (str): A validation token - current_ts (int): The current unix time in milliseconds. Used for - checking token expiry status + session_id: The id of a validation session + client_secret: A unique string provided by the client to help identify + this validation attempt + token: A validation token + current_ts: The current unix time in milliseconds. Used for checking + token expiry status Raises: ThreepidValidationError: if a matching validation token was not found or has expired Returns: - deferred str|None: A str representing a link to redirect the user - to if there is one. + A str representing a link to redirect the user to if there is one. """ # Insert everything into a transaction in order to run atomically @@ -1359,36 +1368,35 @@ def validate_threepid_session_txn(txn): return next_link # Return next_link if it exists - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) - def start_or_continue_validation_session( + async def start_or_continue_validation_session( self, - medium, - address, - session_id, - client_secret, - send_attempt, - next_link, - token, - token_expires, - ): + medium: str, + address: str, + session_id: str, + client_secret: str, + send_attempt: int, + next_link: Optional[str], + token: str, + token_expires: int, + ) -> None: """Creates a new threepid validation session if it does not already exist and associates a new validation token with it Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - session_id (str): The id of this validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - next_link (str|None): The link to redirect the user to upon - successful validation - token (str): The validation token - token_expires (int): The timestamp for which after the token - will no longer be valid + medium: The medium of the 3PID + address: The address of the 3PID + session_id: The id of this validation session + client_secret: A unique string provided by the client to help + identify this validation attempt + send_attempt: The latest send_attempt on this session + next_link: The link to redirect the user to upon successful validation + token: The validation token + token_expires: The timestamp for which after the token will no + longer be valid """ def start_or_continue_validation_session_txn(txn): @@ -1417,12 +1425,12 @@ def start_or_continue_validation_session_txn(txn): }, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) - def cull_expired_threepid_validation_tokens(self): + async def cull_expired_threepid_validation_tokens(self) -> None: """Remove threepid validation tokens with expiry dates that have passed""" def cull_expired_threepid_validation_tokens_txn(txn, ts): @@ -1430,9 +1438,9 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts): DELETE FROM threepid_validation_token WHERE expires < ? """ - return txn.execute(sql, (ts,)) + txn.execute(sql, (ts,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 161edbeccb8a..c4ce5c17c215 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -152,8 +152,8 @@ def _check_safe_current_state_events_membership_updated_txn(self, txn): ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id: str): - return self.db_pool.runInteraction( + async def get_users_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -180,14 +180,13 @@ def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id: str): + async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: room_id: The room ID to query Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. + dict of membership states, pointing to a MemberSummary named tuple. """ def _get_room_summary_txn(txn): @@ -261,20 +260,22 @@ def _get_room_summary_txn(txn): return res - return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) + return await self.db_pool.runInteraction( + "get_room_summary", _get_room_summary_txn + ) @cached() - def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: + async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: """Get all the rooms the *local* user is invited to. Args: user_id: The user ID. Returns: - A awaitable list of RoomsForUser. + A list of RoomsForUser. """ - return self.get_rooms_for_local_user_where_membership_is( + return await self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) @@ -357,7 +358,9 @@ def _get_rooms_for_local_user_where_membership_is_txn( return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id: str): + async def get_rooms_for_user_with_stream_ordering( + self, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently @@ -367,17 +370,18 @@ def get_rooms_for_user_with_stream_ordering(self, user_id: str): user_id Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. + Returns the rooms the user is in currently, along with the stream + ordering of the most recent join for that user and room. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", self._get_rooms_for_user_with_stream_ordering_txn, user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): + def _get_rooms_for_user_with_stream_ordering_txn( + self, txn, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -404,9 +408,7 @@ def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): """ txn.execute(sql, (user_id, Membership.JOIN)) - results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) - - return results + return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] @@ -711,14 +713,14 @@ def f(txn): return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id: str): + async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: """Gets all rooms the user has forgotten. Args: - user_id + user_id: The user ID to query the rooms of. Returns: - Deferred[set[str]] + The forgotten rooms. """ def _get_forgotten_rooms_for_user_txn(txn): @@ -744,7 +746,7 @@ def _get_forgotten_rooms_for_user_txn(txn): txn.execute(sql, (user_id,)) return {row[0] for row in txn if row[1] == 0} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id: str, room_id: str): + async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -994,7 +996,7 @@ def f(txn): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.db_pool.runInteraction("forget_membership", f) + await self.db_pool.runInteraction("forget_membership", f) class _JoinedHostsCache(object): From 3b4556cf87666bb6f40d89c8c7fff42d336237b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Aug 2020 17:12:45 +0100 Subject: [PATCH 24/39] Fix `wait_for_stream_position` for multiple waiters. (#8196) This fixes a bug where having multiple callers waiting on the same stream and position will cause it to try and compare two deferreds, which fails (due to the sorted list having an entry of `Tuple[int, Deferred]`). --- changelog.d/8196.misc | 1 + synapse/python_dependencies.py | 4 +++- synapse/replication/tcp/client.py | 6 ++---- 3 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8196.misc diff --git a/changelog.d/8196.misc b/changelog.d/8196.misc new file mode 100644 index 000000000000..c42baf0e56c6 --- /dev/null +++ b/changelog.d/8196.misc @@ -0,0 +1 @@ +Fix `wait_for_stream_position` to allow multiple waiters on same stream ID. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index dd77a44b8db0..2d995ec456a5 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -66,7 +66,9 @@ "msgpack>=0.5.2", "phonenumbers>=8.2.0", "prometheus_client>=0.0.18,<0.9.0", - # we use attr.validators.deep_iterable, which arrived in 19.1.0 + # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note: + # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33 + # is out in November.) "attrs>=19.1.0", "netaddr>=0.7.18", "Jinja2>=2.9", diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index fcf8ebf1e74f..d6ecf5b32703 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,7 +14,6 @@ # limitations under the License. """A replication client for use by synapse workers. """ -import heapq import logging from typing import TYPE_CHECKING, Dict, List, Tuple @@ -219,9 +218,8 @@ async def wait_for_stream_position( waiting_list = self._streams_to_waiters.setdefault(stream_name, []) - # We insert into the list using heapq as it is more efficient than - # pushing then resorting each time. - heapq.heappush(waiting_list, (position, deferred)) + waiting_list.append((position, deferred)) + waiting_list.sort(key=lambda t: t[0]) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): From b4826d6eb18eeb05c973882efd4c0f5d68359449 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 28 Aug 2020 17:39:14 +0100 Subject: [PATCH 25/39] Fix incorrect return signature --- synapse/storage/databases/main/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 000b8ccff436..01f20c03c213 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -99,7 +99,7 @@ async def get_user_by_access_token(self, token: str) -> Optional[dict]: ) @cached() - async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]: + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]: """Get the expiration timestamp for the account bearing a given user ID. Args: From d2ac767de2ce895a965fb2fcdcb883636f19a5c5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 16:47:11 -0400 Subject: [PATCH 26/39] Convert ReadWriteLock to async/await. (#8202) --- changelog.d/8202.misc | 1 + synapse/handlers/pagination.py | 49 ++++++++++++++++++---------------- synapse/util/async_helpers.py | 16 +++++------ tests/util/test_rwlock.py | 6 +++-- 4 files changed, 39 insertions(+), 33 deletions(-) create mode 100644 changelog.d/8202.misc diff --git a/changelog.d/8202.misc b/changelog.d/8202.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8202.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index ac3418d69d9f..5a1aa7d83086 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Dict, Optional from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.api.filtering import Filter from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter -from synapse.types import RoomStreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Requester, RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -247,15 +250,16 @@ def start_purge_history(self, room_id, token, delete_local_events=False): ) return purge_id - async def _purge_history(self, purge_id, room_id, token, delete_local_events): + async def _purge_history( + self, purge_id: str, room_id: str, token: str, delete_local_events: bool + ) -> None: """Carry out a history purge on a room. Args: - purge_id (str): The id for this purge - room_id (str): The room to purge from - token (str): topological token to delete events before - delete_local_events (bool): True to delete local events as well as - remote ones + purge_id: The id for this purge + room_id: The room to purge from + token: topological token to delete events before + delete_local_events: True to delete local events as well as remote ones """ self._purges_in_progress_by_room.add(room_id) try: @@ -291,9 +295,9 @@ def get_purge_status(self, purge_id): """ return self._purges_by_id.get(purge_id) - async def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> None: """Purge the given room from the database""" - with (await self.pagination_lock.write(room_id)): + with await self.pagination_lock.write(room_id): # check we know about the room await self.store.get_room_version_id(room_id) @@ -307,23 +311,22 @@ async def purge_room(self, room_id): async def get_messages( self, - requester, - room_id=None, - pagin_config=None, - as_client_event=True, - event_filter=None, - ): + requester: Requester, + room_id: Optional[str] = None, + pagin_config: Optional[PaginationConfig] = None, + as_client_event: bool = True, + event_filter: Optional[Filter] = None, + ) -> Dict[str, Any]: """Get messages in a room. Args: - requester (Requester): The user requesting messages. - room_id (str): The room they want messages from. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. - as_client_event (bool): True to get events in client-server format. - event_filter (Filter): Filter to apply to results or None + requester: The user requesting messages. + room_id: The room they want messages from. + pagin_config: The pagination config rules to apply, if any. + as_client_event: True to get events in client-server format. + event_filter: Filter to apply to results or None Returns: - dict: Pagination API results + Pagination API results """ user_id = requester.user.to_string() @@ -343,7 +346,7 @@ async def get_messages( source_config = pagin_config.get_source_config("room") - with (await self.pagination_lock.read(room_id)): + with await self.pagination_lock.read(room_id): ( membership, member_event_id, diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f562770922d0..dfefbd996d51 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -20,6 +20,7 @@ from typing import Dict, Sequence, Set, Union import attr +from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -338,11 +339,11 @@ def eb(e): class ReadWriteLock(object): - """A deferred style read write lock. + """An async read write lock. Example: - with (yield read_write_lock.read("test_key")): + with await read_write_lock.read("test_key"): # do some work """ @@ -365,8 +366,7 @@ def __init__(self): # Latest writer queued self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] - @defer.inlineCallbacks - def read(self, key): + async def read(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.setdefault(key, set()) @@ -376,7 +376,8 @@ def read(self, key): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) + if curr_writer: + await make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -388,8 +389,7 @@ def _ctx_manager(): return _ctx_manager() - @defer.inlineCallbacks - def write(self, key): + async def write(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.get(key, set()) @@ -405,7 +405,7 @@ def write(self, key): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + await make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index bd32e2cee7d1..d3dea3b52a8a 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.util.async_helpers import ReadWriteLock @@ -43,6 +44,7 @@ def test_rwlock(self): rwlock.read(key), # 5 rwlock.write(key), # 6 ] + ds = [defer.ensureDeferred(d) for d in ds] self._assert_called_before_not_after(ds, 2) @@ -73,12 +75,12 @@ def test_rwlock(self): with ds[6].result: pass - d = rwlock.write(key) + d = defer.ensureDeferred(rwlock.write(key)) self.assertTrue(d.called) with d.result: pass - d = rwlock.read(key) + d = defer.ensureDeferred(rwlock.read(key)) self.assertTrue(d.called) with d.result: pass From 8027166dd546811316e7c8395deaf929110c2f3a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Sat, 29 Aug 2020 00:05:25 +0100 Subject: [PATCH 27/39] Add a comment about _LimitedHostnameResolver --- synapse/app/_base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2b2cd795e072..a43dc5b2c9da 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we can run out of file descriptors and infinite loop if we attempt to do too many DNS queries at once + + XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless + you explicitly install twisted.names as the resolver; rather it uses a GAIResolver + backed by the reactor's default threadpool (which is limited to 10 threads). So + (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't + understand why we would run out of FDs if we did too many lookups at once. + -- richvdh 2020/08/29 """ new_resolver = _LimitedHostnameResolver( reactor.nameResolver, max_dns_requests_in_flight From 45e8f7726f24d98a9d3fa06ea52ae960cc1d8689 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Sat, 29 Aug 2020 00:14:17 +0100 Subject: [PATCH 28/39] Rename `get_e2e_device_keys` to better reflect its purpose (#8205) ... and to show that it does something slightly different to `_get_e2e_device_keys_txn`. `include_all_devices` and `include_deleted_devices` were never used (and `include_deleted_devices` was broken, since that would cause `None`s in the result which were not handled in the loop below. Add some typing too. --- changelog.d/8205.misc | 1 + synapse/handlers/e2e_keys.py | 4 ++-- .../storage/databases/main/end_to_end_keys.py | 20 ++++++------------- tests/storage/test_end_to_end_keys.py | 8 +++++--- 4 files changed, 14 insertions(+), 19 deletions(-) create mode 100644 changelog.d/8205.misc diff --git a/changelog.d/8205.misc b/changelog.d/8205.misc new file mode 100644 index 000000000000..fb8fd832789a --- /dev/null +++ b/changelog.d/8205.misc @@ -0,0 +1 @@ + Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d8def45e388e..dfd1c785495a 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -353,7 +353,7 @@ async def query_local_devices( # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = await self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys_for_cs_api(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -734,7 +734,7 @@ async def _process_self_signatures(self, user_id, signatures): # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = await self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index af0b85e2c92c..50ecddf7fa70 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -23,6 +23,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -33,17 +34,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): @trace - async def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. + async def get_e2e_device_keys_for_cs_api( + self, query_list: List[Tuple[str, Optional[str]]] + ) -> Dict[str, Dict[str, JsonDict]]: + """Fetch a list of device keys, formatted suitably for the C/S API. Args: query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -54,11 +50,7 @@ async def get_e2e_device_keys( return {} results = await self.db_pool.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, ) # Build the result structure, un-jsonify the results, and add the diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 261bf5b08b1b..3fc4bb13b64c 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -37,7 +37,7 @@ def test_key_without_device_name(self): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -76,7 +76,7 @@ def test_get_key_with_device_name(self): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -108,7 +108,9 @@ def test_multiple_devices(self): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) + self.store.get_e2e_device_keys_for_cs_api( + (("user1", "device1"), ("user2", "device2")) + ) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) From aa07c37cf0a3b812e6aa1bb2d97d543e6925c8e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Sep 2020 12:41:21 +0100 Subject: [PATCH 29/39] Move and rename `get_devices_with_keys_by_user` (#8204) * Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. * Rename `get_devices_with_keys_by_user` to better reflect what it does. * get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`. --- changelog.d/8204.misc | 1 + synapse/handlers/device.py | 4 +- synapse/replication/slave/storage/devices.py | 3 ++ synapse/storage/databases/main/__init__.py | 3 ++ synapse/storage/databases/main/devices.py | 52 ++---------------- .../storage/databases/main/end_to_end_keys.py | 53 ++++++++++++++++++- 6 files changed, 67 insertions(+), 49 deletions(-) create mode 100644 changelog.d/8204.misc diff --git a/changelog.d/8204.misc b/changelog.d/8204.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8204.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60deb4..ee4666337a62 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -234,7 +234,9 @@ async def get_user_ids_changed(self, user_id, from_token): return result async def on_federation_query_user_devices(self, user_id): - stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( + user_id + ) master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 596c72eb92af..3b788c96250d 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,6 +48,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 70cf15dd7f46..e6536c8456de 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -264,6 +264,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index def96637a261..e8379c73c460 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple @@ -101,7 +102,7 @@ async def get_device_updates_by_remote( update included in the response), and the list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -412,8 +413,10 @@ def _add_user_signature_change_txn( }, ) + @abc.abstractmethod def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + """Get the current stream id from the _device_list_id_gen""" + ... @trace async def get_user_devices_from_cache( @@ -481,51 +484,6 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict] device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id: str): - """Get all devices (with any device keys) for a user - - Returns: - Deferred which resolves to (stream_id, devices) - """ - return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - async def get_users_whose_devices_changed( self, from_key: str, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 50ecddf7fa70..fb3b1f94de92 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -22,7 +23,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -33,6 +34,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore): + def get_e2e_device_keys_for_federation_query(self, user_id: str): + """Get all devices (with any device keys) for a user + + Returns: + Deferred which resolves to (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_e2e_device_keys_for_federation_query", + self._get_e2e_device_keys_for_federation_query_txn, + user_id, + ) + + def _get_e2e_device_keys_for_federation_query_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: + now_stream_id = self.get_device_stream_token() + + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace async def get_e2e_device_keys_for_cs_api( self, query_list: List[Tuple[str, Optional[str]]] @@ -533,6 +579,11 @@ def _get_all_user_signature_changes_for_remotes_txn(txn): _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): From 318245eaa6d37a27ca72168356198fdd90abfbb7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:16:58 -0400 Subject: [PATCH 30/39] Do not install setuptools 50.0. (#8212) This is due to compatibility issues with old Python versions. --- INSTALL.md | 2 +- changelog.d/8212.bugfix | 1 + synapse/python_dependencies.py | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8212.bugfix diff --git a/INSTALL.md b/INSTALL.md index 22f7b7c0293c..bdb7769fe929 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -73,7 +73,7 @@ mkdir -p ~/synapse virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate pip install --upgrade pip -pip install --upgrade setuptools +pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions pip install matrix-synapse ``` diff --git a/changelog.d/8212.bugfix b/changelog.d/8212.bugfix new file mode 100644 index 000000000000..0f8c0aed9204 --- /dev/null +++ b/changelog.d/8212.bugfix @@ -0,0 +1 @@ +Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2d995ec456a5..d666f2267442 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -74,6 +74,10 @@ "Jinja2>=2.9", "bleach>=1.4.3", "typing-extensions>=3.7.4", + # setuptools is required by a variety of dependencies, unfortunately version + # 50.0 is incompatible with older Python versions, see + # https://github.com/pypa/setuptools/issues/2352 + "setuptools!=50.0", ] CONDITIONAL_REQUIREMENTS = { From bbb3c8641ca1214702dbddd88acfe81cc4fc44ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Sep 2020 13:36:25 +0100 Subject: [PATCH 31/39] Make MultiWriterIDGenerator work for streams that use negative stream IDs (#8203) This is so that we can use it for the backfill events stream. --- changelog.d/8203.misc | 1 + synapse/storage/util/id_generators.py | 39 +++++++--- tests/storage/test_id_generators.py | 105 ++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8203.misc diff --git a/changelog.d/8203.misc b/changelog.d/8203.misc new file mode 100644 index 000000000000..9fe2224aaac2 --- /dev/null +++ b/changelog.d/8203.misc @@ -0,0 +1 @@ +Make `MultiWriterIDGenerator` work for streams that use negative values. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b27a4843d010..9f3d23f0a524 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -185,6 +185,8 @@ class MultiWriterIdGenerator: id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + positive: Whether the IDs are positive (true) or negative (false). + When using negative IDs we go backwards from -1 to -2, -3, etc. """ def __init__( @@ -196,13 +198,19 @@ def __init__( instance_column: str, id_column: str, sequence_name: str, + positive: bool = True, ): self._db = db self._instance_name = instance_name + self._positive = positive + self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. self._lock = threading.Lock() + # Note: If we are a negative stream then we still store all the IDs as + # positive to make life easier for us, and simply negate the IDs when we + # return them. self._current_positions = self._load_current_ids( db_conn, table, instance_column, id_column ) @@ -233,13 +241,16 @@ def __init__( def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: + # If positive stream aggregate via MAX. For negative stream use MIN + # *and* negate the result to get a positive number. sql = """ - SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s GROUP BY %(instance)s """ % { "instance": instance_column, "id": id_column, "table": table, + "agg": "MAX" if self._positive else "-MIN", } cur = db_conn.cursor() @@ -269,15 +280,16 @@ async def get_next(self): # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token_for_writer(self._instance_name) < next_id - with self._lock: + assert self._current_positions.get(self._instance_name, 0) < next_id + self._unfinished_ids.add(next_id) @contextlib.contextmanager def manager(): try: - yield next_id + # Multiply by the return factor so that the ID has correct sign. + yield self._return_factor * next_id finally: self._mark_id_as_finished(next_id) @@ -296,15 +308,15 @@ async def get_next_mult(self, n: int): # Assert the fetched ID is actually greater than any ID we've already # seen. If not, then the sequence and table have got out of sync # somehow. - assert max(self.get_positions().values(), default=0) < min(next_ids) - with self._lock: + assert max(self._current_positions.values(), default=0) < min(next_ids) + self._unfinished_ids.update(next_ids) @contextlib.contextmanager def manager(): try: - yield next_ids + yield [self._return_factor * i for i in next_ids] finally: for i in next_ids: self._mark_id_as_finished(i) @@ -327,7 +339,7 @@ def get_next_txn(self, txn: LoggingTransaction): txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) - return next_id + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the @@ -359,20 +371,25 @@ def get_current_token_for_writer(self, instance_name: str) -> int: """ with self._lock: - return self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get(instance_name, 0) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. """ with self._lock: - return dict(self._current_positions) + return { + name: self._return_factor * i + for name, i in self._current_positions.items() + } def advance(self, instance_name: str, new_id: int): """Advance the postion of the named writer to the given ID, if greater than existing entry. """ + new_id *= self._return_factor + with self._lock: self._current_positions[instance_name] = max( new_id, self._current_positions.get(instance_name, 0) @@ -390,7 +407,7 @@ def get_persisted_upto_position(self) -> int: """ with self._lock: - return self._persisted_upto_position + return self._return_factor * self._persisted_upto_position def _add_persisted_position(self, new_id: int): """Record that we have persisted a position. diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 14ce21c7869d..f0a8e32f1eaf 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -264,3 +264,108 @@ def test_get_persisted_upto_position_get_next(self): # We assume that so long as `get_next` does correctly advance the # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + + +class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): + """Tests MultiWriterIdGenerator that produce *negative* stream IDs. + """ + + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + positive=False, + ) + + return self.get_success(self.db_pool.runWithConnection(_create)) + + def _insert_row(self, instance_name: str, stream_id: int): + """Insert one row as the given instance with given stream_id. + """ + + def _insert(txn): + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + id_gen = self._create_id_generator() + + with self.get_success(id_gen.get_next()) as stream_id: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -1}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) + self.assertEqual(id_gen.get_persisted_upto_position(), -1) + + with self.get_success(id_gen.get_next_mult(3)) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -4}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(id_gen.get_persisted_upto_position(), -4) + + # Test loading from DB by creating a second ID gen + second_id_gen = self._create_id_generator() + + self.assertEqual(second_id_gen.get_positions(), {"master": -4}) + self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) + + def test_multiple_instance(self): + """Tests that having multiple instances that get advanced over + federation works corretly. + """ + id_gen_1 = self._create_id_generator("first") + id_gen_2 = self._create_id_generator("second") + + with self.get_success(id_gen_1.get_next()) as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) + + with self.get_success(id_gen_2.get_next()) as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) From da77520cd1c414c9341da287967feb1bab14cbec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:39:04 -0400 Subject: [PATCH 32/39] Convert additional databases to async/await part 2 (#8200) --- changelog.d/8200.misc | 1 + synapse/events/builder.py | 19 ++++--- synapse/handlers/message.py | 13 ++--- synapse/handlers/room_member.py | 12 +---- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/filtering.py | 5 +- synapse/storage/databases/main/openid.py | 8 ++- synapse/storage/databases/main/profile.py | 6 ++- synapse/storage/databases/main/push_rule.py | 10 ++-- synapse/storage/databases/main/room.py | 49 +++++++++++-------- synapse/storage/databases/main/signatures.py | 40 ++++++++++++--- synapse/storage/databases/main/ui_auth.py | 4 +- .../databases/main/user_erasure_store.py | 8 +-- tests/test_utils/event_injection.py | 7 ++- 15 files changed, 111 insertions(+), 81 deletions(-) create mode 100644 changelog.d/8200.misc diff --git a/changelog.d/8200.misc b/changelog.d/8200.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8200.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 9ed24380dd26..7878cd7044a1 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey @@ -97,14 +97,14 @@ def state_key(self): def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids): + async def build(self, prev_event_ids: List[str]) -> EventBase: """Transform into a fully signed and hashed event Args: - prev_event_ids (list[str]): The event IDs to use as the prev events + prev_event_ids: The event IDs to use as the prev events Returns: - FrozenEvent + The signed and hashed event. """ state_ids = await self._state.get_current_state_ids( @@ -114,8 +114,13 @@ async def build(self, prev_event_ids): format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = await self._store.add_event_hashes(auth_ids) - prev_events = await self._store.add_event_hashes(prev_event_ids) + # The types of auth/prev events changes between event versions. + auth_events = await self._store.add_event_hashes( + auth_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events = await self._store.add_event_hashes( + prev_event_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: auth_events = auth_ids prev_events = prev_event_ids @@ -138,7 +143,7 @@ async def build(self, prev_event_ids): "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } + } # type: Dict[str, Any] if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d0c38f4df02..72bb638167e2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -49,14 +49,7 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import ( - Collection, - Requester, - RoomAlias, - StreamToken, - UserID, - create_requester, -) +from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -446,7 +439,7 @@ async def create_event( event_dict: dict, token_id: Optional[str] = None, txn_id: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -786,7 +779,7 @@ async def create_new_client_event( self, builder: EventBuilder, requester: Optional[Requester] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cae4d013b8ae..a7962b0ada39 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -38,15 +38,7 @@ from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import ( - Collection, - JsonDict, - Requester, - RoomAlias, - RoomID, - StateMap, - UserID, -) +from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -184,7 +176,7 @@ async def _local_membership_update( target: UserID, room_id: str, membership: str, - prev_event_ids: Collection[str], + prev_event_ids: List[str], txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 216a5925fc37..c2fc847fbc78 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -396,7 +396,7 @@ async def insert_client_ip( self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -405,7 +405,7 @@ def _update_client_ips_batch(self): to_update = self._batch_row_update self._batch_row_update = {} - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 405b5eafa579..e5060d4c4685 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -159,9 +159,9 @@ def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -189,6 +189,6 @@ def _update_aliases_for_room_txn(txn): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 45a1760170bc..d2f5b9a50220 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,6 +17,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -40,7 +41,7 @@ async def get_user_filter(self, user_localpart, filter_id): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ def _do_txn(txn): return filter_id - return self.db_pool.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 4db8949da767..2aac64901bb4 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,3 +1,5 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore @@ -15,7 +17,9 @@ async def insert_open_id_token( desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -30,6 +34,6 @@ def get_user_id_for_token_txn(txn): else: return rows[0][0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 301875a6725a..d2e0685e9e66 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -138,7 +138,9 @@ async def maybe_delete_remote_profile_cache(self, user_id): desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -153,7 +155,7 @@ def _get_remote_profile_cache_entries_that_expire_txn(txn): return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2fb5b02d7d00..0de802a86bde 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import logging from typing import List, Tuple, Union -from twisted.internet import defer - from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -149,9 +147,11 @@ async def get_push_rules_enabled_for_user(self, user_id): ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -163,7 +163,7 @@ def have_push_rules_changed_txn(txn): (count,) = txn.fetchone() return bool(count) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a92641c33927..717df97301bb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -89,7 +89,7 @@ async def get_room(self, room_id: str) -> dict: allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -121,7 +121,7 @@ def get_room_with_stats_txn(txn, room_id): res["public"] = bool(res["public"]) return res - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) @@ -133,13 +133,17 @@ async def get_public_room_ids(self) -> List[str]: desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -183,7 +187,7 @@ def _count_public_rooms_txn(txn): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn ) @@ -586,15 +590,14 @@ def get_retention_policy_for_room_txn(txn): return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -610,11 +613,13 @@ def _get_media_mxcs_in_room_txn(txn): return local_media_mxcs, remote_media_mxcs - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -627,7 +632,7 @@ def _quarantine_media_in_room_txn(txn): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -690,9 +695,9 @@ def _get_media_mxcs_in_room_txn(self, txn, room_id): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -711,11 +716,13 @@ def _quarantine_media_by_id_txn(txn): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -727,7 +734,7 @@ def _quarantine_media_by_user_txn(txn): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -1284,8 +1291,8 @@ def set_room_is_public_appservice_txn(txn, next_id): ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1294,7 +1301,7 @@ def f(txn): row = txn.fetchone() return row[0] or 0 - return self.db_pool.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) async def add_event_report( self, diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index be191dd8708c..c8c67953e47b 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Tuple + from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -29,16 +32,37 @@ def get_event_reference_hash(self, event_id): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db_pool.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) - async def add_event_hashes(self, event_ids): + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ + + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} @@ -47,13 +71,15 @@ async def add_event_hashes(self, event_ids): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 9eef8e57c5b9..b89668d561b1 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -290,7 +290,7 @@ async def get_user_agents_ips_to_ui_auth_session( class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -299,7 +299,7 @@ def delete_old_ui_auth_sessions(self, expiration_time: int): This is an epoch time in milliseconds. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index e3547e53b3eb..2f7c95fc7431 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -66,7 +66,7 @@ async def are_users_erased(self, user_ids): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: @@ -84,9 +84,9 @@ def f(txn): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) - def mark_user_not_erased(self, user_id: str) -> None: + async def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. Args: @@ -106,4 +106,4 @@ def f(txn): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_not_erased", f) + await self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8522c6fc0910..fb1ca9033670 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Collection """ Utility functions for poking events into the storage of the server under test. @@ -58,7 +57,7 @@ async def inject_member_event( async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> EventBase: """Inject a generic event into a room @@ -80,7 +79,7 @@ async def inject_event( async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: if room_version is None: From 5bf8e5f55b49f9e46a7fe7d7872e6b16d38bffd3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:15:22 -0400 Subject: [PATCH 33/39] Convert the well known resolver to async (#8214) --- changelog.d/8214.misc | 1 + mypy.ini | 1 + .../federation/matrix_federation_agent.py | 4 +- .../http/federation/well_known_resolver.py | 57 ++++++++++--------- .../test_matrix_federation_agent.py | 24 ++++++-- 5 files changed, 53 insertions(+), 34 deletions(-) create mode 100644 changelog.d/8214.misc diff --git a/changelog.d/8214.misc b/changelog.d/8214.misc new file mode 100644 index 000000000000..e26764dea15a --- /dev/null +++ b/changelog.d/8214.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/mypy.ini b/mypy.ini index 4213e31b0320..21c6f523a0b9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,6 +28,7 @@ files = synapse/handlers/saml_handler.py, synapse/handlers/sync.py, synapse/handlers/ui_auth, + synapse/http/federation/well_known_resolver.py, synapse/http/server.py, synapse/http/site.py, synapse/logging/, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 369bf9c2fc37..782d39d4ca16 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -134,8 +134,8 @@ def request(self, method, uri, headers=None, bodyProducer=None): and not _is_ip_literal(parsed_uri.hostname) and not parsed_uri.port ): - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.hostname + well_known_result = yield defer.ensureDeferred( + self._well_known_resolver.get_well_known(parsed_uri.hostname) ) delegated_server = well_known_result.delegated_server diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index f794315debde..cdb6bec56e8e 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -16,6 +16,7 @@ import logging import random import time +from typing import Callable, Dict, Optional, Tuple import attr @@ -23,6 +24,7 @@ from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder @@ -99,15 +101,14 @@ def __init__( self._well_known_agent = RedirectAgent(agent) self.user_agent = user_agent - @defer.inlineCallbacks - def get_well_known(self, server_name): + async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult: """Attempt to fetch and parse a .well-known file for the given server Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Returns: - Deferred[WellKnownLookupResult]: The result of the lookup + The result of the lookup """ try: prev_result, expiry, ttl = self._well_known_cache.get_with_expiry( @@ -124,7 +125,9 @@ def get_well_known(self, server_name): # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._fetch_well_known(server_name) + result, cache_period = await self._fetch_well_known( + server_name + ) # type: Tuple[Optional[bytes], float] except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -153,18 +156,17 @@ def get_well_known(self, server_name): return WellKnownLookupResult(delegated_server=result) - @defer.inlineCallbacks - def _fetch_well_known(self, server_name): + async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]: """Actually fetch and parse a .well-known, without checking the cache Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Raises: _FetchWellKnownFailure if we fail to lookup a result Returns: - Deferred[Tuple[bytes,int]]: The lookup result and cache period. + The lookup result and cache period. """ had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) @@ -172,7 +174,7 @@ def _fetch_well_known(self, server_name): # We do this in two steps to differentiate between possibly transient # errors (e.g. can't connect to host, 503 response) and more permenant # errors (such as getting a 404 response). - response, body = yield self._make_well_known_request( + response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known ) @@ -215,20 +217,20 @@ def _fetch_well_known(self, server_name): return result, cache_period - @defer.inlineCallbacks - def _make_well_known_request(self, server_name, retry): + async def _make_well_known_request( + self, server_name: bytes, retry: bool + ) -> Tuple[IResponse, bytes]: """Make the well known request. This will retry the request if requested and it fails (with unable to connect or receives a 5xx error). Args: - server_name (bytes) - retry (bool): Whether to retry the request if it fails. + server_name: name of the server, from the requested url + retry: Whether to retry the request if it fails. Returns: - Deferred[tuple[IResponse, bytes]] Returns the response object and - body. Response may be a non-200 response. + Returns the response object and body. Response may be a non-200 response. """ uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") @@ -243,12 +245,12 @@ def _make_well_known_request(self, server_name, retry): logger.info("Fetching %s", uri_str) try: - response = yield make_deferred_yieldable( + response = await make_deferred_yieldable( self._well_known_agent.request( b"GET", uri, headers=Headers(headers) ) ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 500 <= response.code < 600: raise Exception("Non-200 response %s" % (response.code,)) @@ -265,21 +267,24 @@ def _make_well_known_request(self, server_name, retry): logger.info("Error fetching %s: %s. Retrying", uri_str, e) # Sleep briefly in the hopes that they come back up - yield self._clock.sleep(0.5) + await self._clock.sleep(0.5) -def _cache_period_from_headers(headers, time_now=time.time): +def _cache_period_from_headers( + headers: Headers, time_now: Callable[[], float] = time.time +) -> Optional[float]: cache_controls = _parse_cache_control(headers) if b"no-store" in cache_controls: return 0 if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass + max_age = cache_controls[b"max-age"] + if max_age: + try: + return int(max_age) + except ValueError: + pass expires = headers.getRawHeaders(b"expires") if expires is not None: @@ -295,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time): return None -def _parse_cache_control(headers): +def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: cache_controls = {} for hdr in headers.getRawHeaders(b"cache-control", []): for directive in hdr.split(b","): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 69945a8f98a3..eb78ab412a51 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -972,7 +972,9 @@ def test_idna_srv_target(self): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -995,7 +997,9 @@ def test_well_known_cache(self): well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") @@ -1003,7 +1007,9 @@ def test_well_known_cache(self): self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1026,7 +1032,9 @@ def test_well_known_cache_with_temp_failure(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -1052,7 +1060,9 @@ def test_well_known_cache_with_temp_failure(self): # another lookup. self.reactor.pump((900.0,)) - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # The resolver may retry a few times, so fonx all requests that come along attempts = 0 @@ -1082,7 +1092,9 @@ def test_well_known_cache_with_temp_failure(self): self.reactor.pump((10000.0,)) # Repated the request, this time it should fail if the lookup fails. - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) From 54f8d73c005cf0401d05fc90e857da253f9d1168 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:21:48 -0400 Subject: [PATCH 34/39] Convert additional databases to async/await (#8199) --- changelog.d/8199.misc | 1 + synapse/storage/databases/main/__init__.py | 50 +++++---- synapse/storage/databases/main/devices.py | 38 +++---- .../storage/databases/main/events_worker.py | 48 ++++---- .../storage/databases/main/purge_events.py | 30 ++--- synapse/storage/databases/main/receipts.py | 14 ++- synapse/storage/databases/main/relations.py | 103 ++++++++---------- 7 files changed, 147 insertions(+), 137 deletions(-) create mode 100644 changelog.d/8199.misc diff --git a/changelog.d/8199.misc b/changelog.d/8199.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8199.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index e6536c8456de..99890ffbf32f 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ import calendar import logging import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -294,16 +294,16 @@ def _get_active_presence(self, db_conn): return [UserPresenceState(**row) for row in rows] - def count_daily_users(self): + async def count_daily_users(self) -> int: """ Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_users", self._count_users, yesterday ) - def count_monthly_users(self): + async def count_monthly_users(self) -> int: """ Counts the number of users who used this homeserver in the last 30 days. Note this method is intended for phonehome metrics only and is different @@ -311,7 +311,7 @@ def count_monthly_users(self): amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -330,15 +330,15 @@ def _count_users(self, txn, time_from): (count,) = txn.fetchone() return count - def count_r30_users(self): + async def count_r30_users(self) -> Dict[str, int]: """ Counts the number of 30 day retained users, defined as:- * Users who have created their accounts more than 30 days ago * Where last seen at most 30 days ago * Where account creation and last_seen are > 30 days apart - Returns counts globaly for a given user as well as breaking - by platform + Returns: + A mapping of counts globally as well as broken out by platform. """ def _count_r30_users(txn): @@ -411,7 +411,7 @@ def _count_r30_users(txn): return results - return self.db_pool.runInteraction("count_r30_users", _count_r30_users) + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -421,7 +421,7 @@ def _get_start_of_day(self): today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 - def generate_user_daily_visits(self): + async def generate_user_daily_visits(self) -> None: """ Generates daily visit data for use in cohort/ retention analysis """ @@ -476,7 +476,7 @@ def _generate_user_daily_visits(txn): # frequently self._last_user_visit_update = now - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) @@ -500,22 +500,28 @@ async def get_users(self) -> List[Dict[str, Any]]: desc="get_users", ) - def get_users_paginate( - self, start, limit, user_id=None, name=None, guests=True, deactivated=False - ): + async def get_users_paginate( + self, + start: int, + limit: int, + user_id: Optional[str] = None, + name: Optional[str] = None, + guests: bool = True, + deactivated: bool = False, + ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: - start (int): start number to begin the query from - limit (int): number of rows to retrieve - user_id (string): search for user_id. ignored if name is not None - name (string): search for local part of user_id or display name - guests (bool): whether to in include guest users - deactivated (bool): whether to include deactivated users + start: start number to begin the query from + limit: number of rows to retrieve + user_id: search for user_id. ignored if name is not None + name: search for local part of user_id or display name + guests: whether to in include guest users + deactivated: whether to include deactivated users Returns: - defer.Deferred: resolves to list[dict[str, Any]], int + A tuple of a list of mappings from user to information and a count of total users. """ def get_users_paginate_txn(txn): @@ -558,7 +564,7 @@ def get_users_paginate_txn(txn): users = self.db_pool.cursor_to_dict(txn) return users, count - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_paginate_txn", get_users_paginate_txn ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e8379c73c460..a29157d979a3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -313,9 +313,9 @@ async def _get_device_update_edus_by_remote( return results - def _get_last_device_update_for_remote_user( + async def _get_last_device_update_for_remote_user( self, destination: str, user_id: str, from_stream_id: int - ): + ) -> int: def f(txn): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id @@ -326,12 +326,16 @@ def f(txn): rows = txn.fetchall() return rows[0][0] - return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) + return await self.db_pool.runInteraction( + "get_last_device_update_for_remote_user", f + ) - def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): + async def mark_as_sent_devices_by_remote( + self, destination: str, stream_id: int + ) -> None: """Mark that updates have successfully been sent to the destination. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -684,7 +688,7 @@ async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): + async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user. """ @@ -698,7 +702,7 @@ def _mark_remote_user_device_list_as_unsubscribed_txn(txn): txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_remote_user_device_list_as_unsubscribed", _mark_remote_user_device_list_as_unsubscribed_txn, ) @@ -959,9 +963,9 @@ async def update_device( desc="update_device", ) - def update_remote_device_list_cache_entry( + async def update_remote_device_list_cache_entry( self, user_id: str, device_id: str, content: JsonDict, stream_id: int - ): + ) -> None: """Updates a single device in the cache of a remote user's devicelist. Note: assumes that we are the only thread that can be updating this user's @@ -972,11 +976,8 @@ def update_remote_device_list_cache_entry( device_id: ID of decivice being updated content: new data on this device stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -1028,9 +1029,9 @@ def _update_remote_device_list_cache_entry_txn( lock=False, ) - def update_remote_device_list_cache( + async def update_remote_device_list_cache( self, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's @@ -1040,11 +1041,8 @@ def update_remote_device_list_cache( user_id: User to update device list for devices: list of device objects supplied over federation stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -1054,7 +1052,7 @@ def update_remote_device_list_cache( def _update_remote_device_list_cache_txn( self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e6247d682d8c..a7a73cc3d8bb 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -823,20 +823,24 @@ def _fetch_event_rows(self, txn, event_ids): return event_dict - def _maybe_redact_event_row(self, original_ev, redactions, event_map): + def _maybe_redact_event_row( + self, + original_ev: EventBase, + redactions: Iterable[str], + event_map: Dict[str, EventBase], + ) -> Optional[EventBase]: """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. Args: - original_ev (EventBase): - redactions (iterable[str]): list of event ids of potential redaction events - event_map (dict[str, EventBase]): other events which have been fetched, in - which we can look up the redaaction events. Map from event id to event. + original_ev: The original event. + redactions: list of event ids of potential redaction events + event_map: other events which have been fetched, in which we can + look up the redaaction events. Map from event id to event. Returns: - Deferred[EventBase|None]: if the event should be redacted, a pruned - event object. Otherwise, None. + If the event should be redacted, a pruned event object. Otherwise, None. """ if original_ev.type == "m.room.create": # we choose to ignore redactions of m.room.create events. @@ -946,17 +950,17 @@ def _get_current_state_event_counts_txn(self, txn, room_id): row = txn.fetchone() return row[0] if row else 0 - def get_current_state_event_counts(self, room_id): + async def get_current_state_event_counts(self, room_id: str) -> int: """ Gets the current number of state events in a room. Args: - room_id (str) + room_id: The room ID to query. Returns: - Deferred[int] + The current number of state events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, @@ -991,7 +995,9 @@ def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() - def get_all_new_forward_event_rows(self, last_id, current_id, limit): + async def get_all_new_forward_event_rows( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -999,7 +1005,7 @@ def get_all_new_forward_event_rows(self, last_id, current_id, limit): current_id: the maximum stream_id to return up to limit: the maximum number of rows to return - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1020,18 +1026,20 @@ def get_all_new_forward_event_rows(txn): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) - def get_ex_outlier_stream_rows(self, last_id, current_id): + async def get_ex_outlier_stream_rows( + self, last_id: int, current_id: int + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1054,7 +1062,7 @@ def get_ex_outlier_stream_rows_txn(txn): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) @@ -1226,11 +1234,11 @@ async def get_event_ordering(self, event_id): return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_next_event_to_expire(self): + async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry table, or None if there's no more event to expire. - Returns: Deferred[Optional[Tuple[str, int]]] + Returns: A tuple containing the event ID as its first element and an expiry timestamp as its second one, if there's at least one row in the event_expiry table. None otherwise. @@ -1246,6 +1254,6 @@ def get_next_event_to_expire_txn(txn): return txn.fetchone() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 3526b6fd6696..ea833829aeb5 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Tuple +from typing import Any, List, Set, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -25,25 +25,24 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> Set[int]: """Deletes room history before a certain point Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): + room_id: + token: A topological token to delete events before + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. + The set of state groups that are referenced by deleted events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -283,17 +282,18 @@ def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): return referenced_state_groups - def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> List[int]: """Deletes all record of a room Args: - room_id (str) + room_id Returns: - Deferred[List[int]]: The list of state groups to delete. + The list of state groups to delete. """ - - return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) + return await self.db_pool.runInteraction( + "purge_room", self._purge_room_txn, room_id + ) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 436f22ad2dbd..4a0d5a320efb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -276,12 +276,14 @@ def f(txn): } return results - def get_users_sent_receipts_between(self, last_id: int, current_id: int): + async def get_users_sent_receipts_between( + self, last_id: int, current_id: int + ) -> List[str]: """Get all users who sent receipts between `last_id` exclusive and `current_id` inclusive. Returns: - Deferred[List[str]] + The list of users. """ if last_id == current_id: @@ -296,7 +298,7 @@ def _get_users_sent_receipts_between_txn(txn): return [r[0] for r in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn ) @@ -553,8 +555,10 @@ def graph_to_linear(txn): return stream_id, max_persisted_id - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.db_pool.runInteraction( + async def insert_graph_receipt( + self, room_id, receipt_type, user_id, event_ids, data + ): + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index a9ceffc20e3c..5cd61547f733 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -34,38 +34,33 @@ class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) - def get_relations_for_event( + async def get_relations_for_event( self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[RelationPaginationToken] = None, + to_token: Optional[RelationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. + event_id: Fetch events that relate to this event ID. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. + List of event IDs that match relations requested. The rows are of + the form `{"event_id": "..."}`. """ where_clause = ["relates_to_id = ?"] @@ -131,20 +126,20 @@ def _get_recent_references_for_event_txn(txn): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @cached(tree=True) - def get_aggregation_groups_for_event( + async def get_aggregation_groups_for_event( self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + event_type: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[AggregationPaginationToken] = None, + to_token: Optional[AggregationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -152,21 +147,17 @@ def get_aggregation_groups_for_event( on an event. Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or + event_id: Fetch events that relate to this event ID. + event_type: Only fetch events with this event type, if given. + limit: Only fetch the `limit` groups. + direction: Whether to fetch the highest count first (`"b"`) or the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. + List of groups of annotations that match. Each row is a dict with + `type`, `key` and `count` fields. """ where_clause = ["relates_to_id = ?", "relation_type = ?"] @@ -225,7 +216,7 @@ def _get_aggregation_groups_for_event_txn(txn): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -279,18 +270,20 @@ def _get_applicable_edit_txn(txn): return await self.get_event(edit_id, allow_none=True) - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + async def has_user_annotated_event( + self, parent_id: str, event_type: str, aggregation_key: str, sender: str + ) -> bool: """Check if a user has already annotated an event with the same key (e.g. already liked an event). Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation + parent_id: The event being annotated + event_type: The event type of the annotation + aggregation_key: The aggregation key of the annotation + sender: The sender of the annotation Returns: - Deferred[bool] + True if the event is already annotated. """ sql = """ @@ -319,7 +312,7 @@ def _get_if_user_has_annotated_event(txn): return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) From 5615eb5cb48d63df15c391fe395c8740dc4af017 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Sep 2020 16:02:17 +0100 Subject: [PATCH 35/39] Rename `_get_e2e_device_keys_txn` (#8222) ... to `_get_e2e_device_keys_and_signatures_txn`, to better reflect what it does. --- changelog.d/8222.misc | 1 + synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/end_to_end_keys.py | 10 ++++++---- 3 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8222.misc diff --git a/changelog.d/8222.misc b/changelog.d/8222.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8222.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a29157d979a3..11956cc48e94 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -256,8 +256,8 @@ async def _get_device_update_edus_by_remote( """ devices = ( await self.db_pool.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, + "get_e2e_device_keys_and_signatures_txn", + self._get_e2e_device_keys_and_signatures_txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index fb3b1f94de92..1ee062e3c499 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -51,7 +51,7 @@ def _get_e2e_device_keys_for_federation_query_txn( ) -> Tuple[int, List[JsonDict]]: now_stream_id = self.get_device_stream_token() - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + devices = self._get_e2e_device_keys_and_signatures_txn(txn, [(user_id, None)]) if devices: user_devices = devices[user_id] @@ -96,7 +96,9 @@ async def get_e2e_device_keys_for_cs_api( return {} results = await self.db_pool.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, + "get_e2e_device_keys_and_signatures_txn", + self._get_e2e_device_keys_and_signatures_txn, + query_list, ) # Build the result structure, un-jsonify the results, and add the @@ -120,9 +122,9 @@ async def get_e2e_device_keys_for_cs_api( return rv @trace - def _get_e2e_device_keys_txn( + def _get_e2e_device_keys_and_signatures_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ): + ) -> Dict[str, Dict[str, Optional[Dict]]]: set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) From 7d103a594e10f4040096bf38716c9cdddf42eca0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 11:03:49 -0400 Subject: [PATCH 36/39] Convert appservice code to async/await. (#8207) --- changelog.d/8207.misc | 1 + synapse/appservice/api.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) create mode 100644 changelog.d/8207.misc diff --git a/changelog.d/8207.misc b/changelog.d/8207.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8207.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e72a0b9ac05b..bb6fa8299a8d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,18 +14,20 @@ # limitations under the License. import logging import urllib +from typing import TYPE_CHECKING, Optional from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache +if TYPE_CHECKING: + from synapse.appservice import ApplicationService + logger = logging.getLogger(__name__) sent_transactions_counter = Counter( @@ -163,19 +165,20 @@ async def query_3pe(self, service, kind, protocol, fields): logger.warning("query_3pe to %s threw exception %s", uri, ex) return [] - def get_3pe_protocol(self, service, protocol): + async def get_3pe_protocol( + self, service: "ApplicationService", protocol: str + ) -> Optional[JsonDict]: if service.url is None: return {} - @defer.inlineCallbacks - def _get(): + async def _get() -> Optional[JsonDict]: uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, urllib.parse.quote(protocol), ) try: - info = yield defer.ensureDeferred(self.get_json(uri, {})) + info = await self.get_json(uri, {}) if not _is_valid_3pe_metadata(info): logger.warning( @@ -196,7 +199,7 @@ def _get(): return None key = (service.id, protocol) - return self.protocol_meta_cache.wrap(key, _get) + return await self.protocol_meta_cache.wrap(key, _get) async def push_bulk(self, service, events, txn_id=None): if service.url is None: From 37db6252b7ee0f3e9798a561e2919a67299e08f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 11:04:17 -0400 Subject: [PATCH 37/39] Convert additional databases to async/await part 3 (#8201) --- changelog.d/8201.misc | 1 + synapse/storage/background_updates.py | 4 +- .../storage/databases/main/account_data.py | 59 ++++++++++--------- .../storage/databases/main/end_to_end_keys.py | 54 ++++++++++------- .../databases/main/media_repository.py | 31 ++++++---- synapse/storage/databases/main/search.py | 15 +++-- .../storage/databases/main/user_directory.py | 44 ++++++++------ 7 files changed, 121 insertions(+), 87 deletions(-) create mode 100644 changelog.d/8201.misc diff --git a/changelog.d/8201.misc b/changelog.d/8201.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8201.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 0db900fa0e0b..67a89cd51a30 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -433,7 +433,7 @@ async def _end_background_update(self, update_name: str) -> None: "background_updates", keyvalues={"update_name": update_name} ) - def _background_update_progress(self, update_name: str, progress: dict): + async def _background_update_progress(self, update_name: str, progress: dict): """Update the progress of a background update Args: @@ -441,7 +441,7 @@ def _background_update_progress(self, update_name: str, progress: dict): progress: The progress of the update. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 04042a2c981f..4436b1a83d97 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,9 +16,7 @@ import abc import logging -from typing import List, Optional, Tuple - -from twisted.internet import defer +from typing import Dict, List, Optional, Tuple from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool @@ -58,14 +56,16 @@ def get_max_account_data_stream_id(self): raise NotImplementedError() @cached() - def get_account_data_for_user(self, user_id): + async def get_account_data_for_user( + self, user_id: str + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a user. Args: - user_id(str): The user to get the account_data for. + user_id: The user to get the account_data for. Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. + A 2-tuple of a dict of global account_data and a dict mapping from + room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): @@ -94,7 +94,7 @@ def get_account_data_for_user_txn(txn): return global_account_data, by_room - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @@ -120,14 +120,16 @@ async def get_global_account_data_by_type_for_user( return None @cached(num_args=2) - def get_account_data_for_room(self, user_id, room_id): + async def get_account_data_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the client account_data for a user for a room. Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. + user_id: The user to get the account_data for. + room_id: The room to get the account_data for. Returns: - A deferred dict of the room account_data + A dict of the room account_data """ def get_account_data_for_room_txn(txn): @@ -142,21 +144,22 @@ def get_account_data_for_room_txn(txn): row["account_data_type"]: db_to_json(row["content"]) for row in rows } - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) - def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + async def get_account_data_for_room_and_type( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[JsonDict]: """Get the client account_data of given type for a user for a room. Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. - account_data_type (str): The account data type to get. + user_id: The user to get the account_data for. + room_id: The room to get the account_data for. + account_data_type: The account data type to get. Returns: - A deferred of the room account_data for that type, or None if - there isn't any set. + The room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): @@ -174,7 +177,7 @@ def get_account_data_for_room_and_type_txn(txn): return db_to_json(content_json) if content_json else None - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -238,12 +241,14 @@ def get_updated_room_account_data_txn(txn): "get_updated_room_account_data", get_updated_room_account_data_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id): + async def get_updated_account_data_for_user( + self, user_id: str, stream_id: int + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a that's changed for a user Args: - user_id(str): The user to get the account_data for. - stream_id(int): The point in the stream since which to get updates + user_id: The user to get the account_data for. + stream_id: The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. @@ -277,9 +282,9 @@ def get_updated_account_data_for_user_txn(txn): user_id, int(stream_id) ) if not changed: - return defer.succeed(({}, {})) + return ({}, {}) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -416,7 +421,7 @@ async def add_account_data_for_user( return self._account_data_id_gen.get_current_token() - def _update_max_stream_id(self, next_id: int): + async def _update_max_stream_id(self, next_id: int) -> None: """Update the max stream_id Args: @@ -435,4 +440,4 @@ def _update(txn): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.db_pool.runInteraction("update_account_data_max_stream_id", _update) + await self.db_pool.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1ee062e3c499..5a7de44b3344 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -34,13 +34,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore): - def get_e2e_device_keys_for_federation_query(self, user_id: str): + async def get_e2e_device_keys_for_federation_query( + self, user_id: str + ) -> Tuple[int, List[JsonDict]]: """Get all devices (with any device keys) for a user Returns: - Deferred which resolves to (stream_id, devices) + (stream_id, devices) """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_device_keys_for_federation_query", self._get_e2e_device_keys_for_federation_query_txn, user_id, @@ -292,10 +294,12 @@ def _add_e2e_one_time_keys(txn): ) @cached(max_entries=10000) - def count_e2e_one_time_keys(self, user_id, device_id): + async def count_e2e_one_time_keys( + self, user_id: str, device_id: str + ) -> Dict[str, int]: """ Count the number of one time keys the server has for a device Returns: - Dict mapping from algorithm to number of keys for that algorithm. + A mapping from algorithm to number of keys for that algorithm. """ def _count_e2e_one_time_keys(txn): @@ -310,7 +314,7 @@ def _count_e2e_one_time_keys(txn): result[algorithm] = key_count return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_e2e_one_time_keys", _count_e2e_one_time_keys ) @@ -348,7 +352,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id): list_name="user_ids", num_args=1, ) - def _get_bare_e2e_cross_signing_keys_bulk( + async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: List[str] ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this @@ -356,16 +360,15 @@ def _get_bare_e2e_cross_signing_keys_bulk( the signatures for the calling user need to be fetched. Args: - user_ids (list[str]): the users whose keys are being requested + user_ids: the users whose keys are being requested Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A mapping from user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, or + their user ID will map to None. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, @@ -588,7 +591,9 @@ def get_device_stream_token(self) -> int: class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + async def set_e2e_device_keys( + self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict + ) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. """ @@ -624,12 +629,21 @@ def _set_e2e_device_keys_txn(txn): log_kv({"message": "Device keys stored."}) return True - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_e2e_device_keys", _set_e2e_device_keys_txn ) - def claim_e2e_one_time_keys(self, query_list): - """Take a list of one time keys out of the database""" + async def claim_e2e_one_time_keys( + self, query_list: Iterable[Tuple[str, str, str]] + ) -> Dict[str, Dict[str, Dict[str, bytes]]]: + """Take a list of one time keys out of the database. + + Args: + query_list: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. + """ @trace def _claim_e2e_one_time_keys(txn): @@ -665,11 +679,11 @@ def _claim_e2e_one_time_keys(txn): ) return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - def delete_e2e_keys_by_device(self, user_id, device_id): + async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn): log_kv( { @@ -692,7 +706,7 @@ def delete_e2e_keys_by_device_txn(txn): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 3919ecad69a8..86557d551277 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -93,7 +93,7 @@ async def mark_local_media_as_safe(self, media_id: str) -> None: desc="mark_local_media_as_safe", ) - def get_url_cache(self, url, ts): + async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. @@ -139,7 +139,7 @@ def get_url_cache_txn(txn): ) ) - return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) + return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts @@ -237,12 +237,17 @@ async def store_cached_remote_media( desc="store_cached_remote_media", ) - def update_cached_last_access_time(self, local_media, remote_media, time_ms): + async def update_cached_last_access_time( + self, + local_media: Iterable[str], + remote_media: Iterable[Tuple[str, str]], + time_ms: int, + ): """Updates the last access time of the given media Args: - local_media (iterable[str]): Set of media_ids - remote_media (iterable[(str, str)]): Set of (server_name, media_id) + local_media: Set of media_ids + remote_media: Set of (server_name, media_id) time_ms: Current time in milliseconds """ @@ -267,7 +272,7 @@ def update_cache_txn(txn): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn ) @@ -325,7 +330,7 @@ async def get_remote_media_before(self, before_ts): "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts ) - def delete_remote_media(self, media_origin, media_id): + async def delete_remote_media(self, media_origin: str, media_id: str) -> None: def delete_remote_media_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -338,11 +343,11 @@ def delete_remote_media_txn(txn): keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_remote_media", delete_remote_media_txn ) - def get_expired_url_cache(self, now_ts): + async def get_expired_url_cache(self, now_ts: int) -> List[str]: sql = ( "SELECT media_id FROM local_media_repository_url_cache" " WHERE expires_ts < ?" @@ -354,7 +359,7 @@ def _get_expired_url_cache_txn(txn): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_expired_url_cache", _get_expired_url_cache_txn ) @@ -371,7 +376,7 @@ def _delete_url_cache_txn(txn): "delete_url_cache", _delete_url_cache_txn ) - def get_url_cache_media_before(self, before_ts): + async def get_url_cache_media_before(self, before_ts: int) -> List[str]: sql = ( "SELECT media_id FROM local_media_repository" " WHERE created_ts < ? AND url_cache IS NOT NULL" @@ -383,7 +388,7 @@ def _get_url_cache_media_before_txn(txn): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 7f8d1880e57e..f01cf2fd02e9 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -16,9 +16,10 @@ import logging import re from collections import namedtuple -from typing import List, Optional +from typing import List, Optional, Set from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -620,7 +621,9 @@ async def search_rooms( "count": count, } - def _find_highlights_in_postgres(self, search_query, events): + async def _find_highlights_in_postgres( + self, search_query: str, events: List[EventBase] + ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -628,11 +631,11 @@ def _find_highlights_in_postgres(self, search_query, events): highlight the matching parts. Args: - search_query (str) - events (list): A list of events + search_query + events: A list of events Returns: - deferred : A set of strings. + A set of strings. """ def f(txn): @@ -685,7 +688,7 @@ def f(txn): return highlight_words - return self.db_pool.runInteraction("_find_highlights", f) + return await self.db_pool.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index a9f2e9361482..1e96ae7828f8 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional, Tuple from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -365,7 +365,9 @@ async def is_room_world_readable_or_publicly_joinable(self, room_id): return False - def update_profile_in_user_dir(self, user_id, display_name, avatar_url): + async def update_profile_in_user_dir( + self, user_id: str, display_name: str, avatar_url: str + ) -> None: """ Update or add a user's profile in the user directory. """ @@ -458,17 +460,19 @@ def _update_profile_in_user_dir_txn(txn): txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) - def add_users_who_share_private_room(self, room_id, user_id_tuples): + async def add_users_who_share_private_room( + self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + room_id + user_id_tuples: iterable of 2-tuple of user IDs. """ def _add_users_who_share_room_txn(txn): @@ -484,17 +488,19 @@ def _add_users_who_share_room_txn(txn): value_values=None, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) - def add_users_in_public_rooms(self, room_id, user_ids): + async def add_users_in_public_rooms( + self, room_id: str, user_ids: Iterable[str] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_ids (list[str]) + room_id + user_ids """ def _add_users_in_public_rooms_txn(txn): @@ -508,11 +514,11 @@ def _add_users_in_public_rooms_txn(txn): value_values=None, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) - def delete_all_from_user_dir(self): + async def delete_all_from_user_dir(self) -> None: """Delete the entire user directory """ @@ -523,7 +529,7 @@ def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @@ -555,7 +561,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(UserDirectoryStore, self).__init__(database, db_conn, hs) - def remove_from_user_dir(self, user_id): + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} @@ -578,7 +584,7 @@ def _remove_from_user_dir_txn(txn): ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_from_user_dir", _remove_from_user_dir_txn ) @@ -605,14 +611,14 @@ async def get_users_in_dir_due_to_room(self, room_id): return user_ids - def remove_user_who_share_room(self, user_id, room_id): + async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None: """ Deletes entries in the users_who_share_*_rooms table. The first user should be a local user. Args: - user_id (str) - room_id (str) + user_id + room_id """ def _remove_user_who_share_room_txn(txn): @@ -632,7 +638,7 @@ def _remove_user_who_share_room_txn(txn): keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) From b5133dd97f693ca213b30f4f3e874e9ab3958ea7 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 1 Sep 2020 16:31:59 +0100 Subject: [PATCH 38/39] Explain better what GDPR-erased means (#8189) Fixes https://github.com/matrix-org/synapse/issues/8185 --- changelog.d/8189.doc | 1 + docs/admin_api/user_admin_api.rst | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8189.doc diff --git a/changelog.d/8189.doc b/changelog.d/8189.doc new file mode 100644 index 000000000000..800ff89dc58f --- /dev/null +++ b/changelog.d/8189.doc @@ -0,0 +1 @@ +Explain better what GDPR-erased means when deactivating a user. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index d6e3194cda5c..e21c78a9c62b 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -214,9 +214,11 @@ Deactivate Account This API deactivates an account. It removes active access tokens, resets the password, and deletes third-party IDs (to prevent the user requesting a -password reset). It can also mark the user as GDPR-erased (stopping their data -from distributed further, and deleting it entirely if there are no other -references to it). +password reset). + +It can also mark the user as GDPR-erased. This means messages sent by the +user will still be visible by anyone that was in the room when these messages +were sent, but hidden from users joining the room afterwards. The api is:: From b939251c37d748a4be6346eb27bd5fdfaff17738 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 13:02:41 -0400 Subject: [PATCH 39/39] Fix errors when updating the user directory with invalid data (#8223) --- changelog.d/8223.bugfix | 1 + synapse/handlers/profile.py | 6 ++++++ synapse/handlers/user_directory.py | 8 +++++++- synapse/storage/databases/main/user_directory.py | 5 +++++ 4 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8223.bugfix diff --git a/changelog.d/8223.bugfix b/changelog.d/8223.bugfix new file mode 100644 index 000000000000..60655ce3e11f --- /dev/null +++ b/changelog.d/8223.bugfix @@ -0,0 +1 @@ +Fixes a longstanding bug where user directory updates could break when unexpected profile data was included in events. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 96c9d6bab4f9..0cb8fad89a2c 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -161,6 +161,9 @@ async def set_displayname( Codes.FORBIDDEN, ) + if not isinstance(new_displayname, str): + raise SynapseError(400, "Invalid displayname") + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -235,6 +238,9 @@ async def set_avatar_url( 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN ) + if not isinstance(new_avatar_url, str): + raise SynapseError(400, "Invalid displayname") + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 521b6d620d3c..e21f8dbc58e3 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -234,7 +234,7 @@ async def _handle_deltas(self, deltas): async def _handle_room_publicity_change( self, room_id, prev_event_id, event_id, typ ): - """Handle a room having potentially changed from/to world_readable/publically + """Handle a room having potentially changed from/to world_readable/publicly joinable. Args: @@ -388,9 +388,15 @@ async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id prev_name = prev_event.content.get("displayname") new_name = event.content.get("displayname") + # If the new name is an unexpected form, do not update the directory. + if not isinstance(new_name, str): + new_name = prev_name prev_avatar = prev_event.content.get("avatar_url") new_avatar = event.content.get("avatar_url") + # If the new avatar is an unexpected form, do not update the directory. + if not isinstance(new_avatar, str): + new_avatar = prev_avatar if prev_name != new_name or prev_avatar != new_avatar: await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 1e96ae7828f8..c977db042e63 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -371,6 +371,11 @@ async def update_profile_in_user_dir( """ Update or add a user's profile in the user directory. """ + # If the display name or avatar URL are unexpected types, overwrite them. + if not isinstance(display_name, str): + display_name = None + if not isinstance(avatar_url, str): + avatar_url = None def _update_profile_in_user_dir_txn(txn): new_entry = self.db_pool.simple_upsert_txn(