diff --git a/CHANGES.md b/CHANGES.md index d859baa9ff56..92e29983b99c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,17 @@ +For the next release +==================== + +Removal warning +--------------- + +Some older clients used a +[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken) +(`:`) in the `client_secret` parameter of various endpoints. The incorrect +behaviour was allowed for backwards compatibility, but is now being removed +from Synapse as most users have updated their client. Further context can be +found at [\#6766](https://github.com/matrix-org/synapse/issues/6766). + + Synapse 1.19.1 (2020-08-27) =========================== diff --git a/changelog.d/7757.misc b/changelog.d/7757.misc new file mode 100644 index 000000000000..091f40382e68 --- /dev/null +++ b/changelog.d/7757.misc @@ -0,0 +1 @@ +Reduce run times of some unit tests by advancing the reactor a fewer number of times. \ No newline at end of file diff --git a/changelog.d/8130.misc b/changelog.d/8130.misc new file mode 100644 index 000000000000..7944c09adee0 --- /dev/null +++ b/changelog.d/8130.misc @@ -0,0 +1 @@ +Update the test federation client to handle streaming responses. diff --git a/changelog.d/8144.docker b/changelog.d/8144.docker new file mode 100644 index 000000000000..9bb5881fa8f3 --- /dev/null +++ b/changelog.d/8144.docker @@ -0,0 +1 @@ +Fix builds of the Docker image on non-x86 platforms. diff --git a/changelog.d/8157.feature b/changelog.d/8157.feature new file mode 100644 index 000000000000..813e6d0903d9 --- /dev/null +++ b/changelog.d/8157.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/changelog.d/8162.misc b/changelog.d/8162.misc new file mode 100644 index 000000000000..e26764dea15a --- /dev/null +++ b/changelog.d/8162.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/changelog.d/8171.misc b/changelog.d/8171.misc new file mode 100644 index 000000000000..cafbf23d836f --- /dev/null +++ b/changelog.d/8171.misc @@ -0,0 +1 @@ +Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`. diff --git a/changelog.d/8174.misc b/changelog.d/8174.misc new file mode 100644 index 000000000000..a39e9eab46cd --- /dev/null +++ b/changelog.d/8174.misc @@ -0,0 +1 @@ +Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`. \ No newline at end of file diff --git a/changelog.d/8175.misc b/changelog.d/8175.misc new file mode 100644 index 000000000000..28af294dcf6c --- /dev/null +++ b/changelog.d/8175.misc @@ -0,0 +1 @@ +Standardize the mypy configuration. diff --git a/changelog.d/8176.feature b/changelog.d/8176.feature new file mode 100644 index 000000000000..813e6d0903d9 --- /dev/null +++ b/changelog.d/8176.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/changelog.d/8181.misc b/changelog.d/8181.misc new file mode 100644 index 000000000000..a39e9eab46cd --- /dev/null +++ b/changelog.d/8181.misc @@ -0,0 +1 @@ +Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`. \ No newline at end of file diff --git a/debian/changelog b/debian/changelog index 6676706dea12..bde3b636eeb3 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.19.0ubuntu1) UNRELEASED; urgency=medium + + * Use Type=notify in systemd service + + -- Dexter Chua Wed, 26 Aug 2020 12:41:36 +0000 + matrix-synapse-py3 (1.19.1) stable; urgency=medium * New synapse release 1.19.1. diff --git a/debian/matrix-synapse.service b/debian/matrix-synapse.service index b0a8d72e6d25..553babf5492d 100644 --- a/debian/matrix-synapse.service +++ b/debian/matrix-synapse.service @@ -2,7 +2,7 @@ Description=Synapse Matrix homeserver [Service] -Type=simple +Type=notify User=matrix-synapse WorkingDirectory=/var/lib/matrix-synapse EnvironmentFile=/etc/default/matrix-synapse diff --git a/docker/Dockerfile b/docker/Dockerfile index 432d56a8ee11..27512f860092 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -19,11 +19,16 @@ ARG PYTHON_VERSION=3.7 FROM docker.io/python:${PYTHON_VERSION}-slim as builder # install the OS build deps - - RUN apt-get update && apt-get install -y \ build-essential \ + libffi-dev \ + libjpeg-dev \ libpq-dev \ + libssl-dev \ + libwebp-dev \ + libxml++2.6-dev \ + libxslt1-dev \ + zlib1g-dev \ && rm -rf /var/lib/apt/lists/* # Build dependencies that are not available as wheels, to speed up rebuilds @@ -56,9 +61,11 @@ FROM docker.io/python:${PYTHON_VERSION}-slim RUN apt-get update && apt-get install -y \ curl \ + gosu \ + libjpeg62-turbo \ libpq5 \ + libwebp6 \ xmlsec1 \ - gosu \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local diff --git a/mypy.ini b/mypy.ini index c69cb5dc4064..4213e31b0320 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,6 +6,55 @@ check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs +files = + synapse/api, + synapse/appservice, + synapse/config, + synapse/event_auth.py, + synapse/events/builder.py, + synapse/events/spamcheck.py, + synapse/federation, + synapse/handlers/auth.py, + synapse/handlers/cas_handler.py, + synapse/handlers/directory.py, + synapse/handlers/federation.py, + synapse/handlers/identity.py, + synapse/handlers/message.py, + synapse/handlers/oidc_handler.py, + synapse/handlers/presence.py, + synapse/handlers/room.py, + synapse/handlers/room_member.py, + synapse/handlers/room_member_worker.py, + synapse/handlers/saml_handler.py, + synapse/handlers/sync.py, + synapse/handlers/ui_auth, + synapse/http/server.py, + synapse/http/site.py, + synapse/logging/, + synapse/metrics, + synapse/module_api, + synapse/notifier.py, + synapse/push/pusherpool.py, + synapse/push/push_rule_evaluator.py, + synapse/replication, + synapse/rest, + synapse/server.py, + synapse/server_notices, + synapse/spam_checker_api, + synapse/state, + synapse/storage/databases/main/ui_auth.py, + synapse/storage/database.py, + synapse/storage/engines, + synapse/storage/state.py, + synapse/storage/util, + synapse/streams, + synapse/types.py, + synapse/util/caches/stream_change_cache.py, + synapse/util/metrics.py, + tests/replication, + tests/test_utils, + tests/rest/client/v2_alpha/test_auth.py, + tests/util/test_stream_change_cache.py [mypy-pymacaroons.*] ignore_missing_imports = True diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 531010185d8f..ad12523c4d62 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -21,10 +21,12 @@ import base64 import json import sys +from typing import Any, Optional from urllib import parse as urlparse import nacl.signing import requests +import signedjson.types import srvlookup import yaml from requests.adapters import HTTPAdapter @@ -69,7 +71,9 @@ def encode_canonical_json(value): ).encode("UTF-8") -def sign_json(json_object, signing_key, signing_name): +def sign_json( + json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str +) -> Any: signatures = json_object.pop("signatures", {}) unsigned = json_object.pop("unsigned", None) @@ -122,7 +126,14 @@ def read_signing_keys(stream): return keys -def request_json(method, origin_name, origin_key, destination, path, content): +def request( + method: Optional[str], + origin_name: str, + origin_key: signedjson.types.SigningKey, + destination: str, + path: str, + content: Optional[str], +) -> requests.Response: if method is None: if content is None: method = "GET" @@ -159,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content): if method == "POST": headers["Content-Type"] = "application/json" - result = s.request( - method=method, url=dest, headers=headers, verify=False, data=content + return s.request( + method=method, + url=dest, + headers=headers, + verify=False, + data=content, + stream=True, ) - sys.stderr.write("Status Code: %d\n" % (result.status_code,)) - return result.json() def main(): @@ -222,7 +236,7 @@ def main(): with open(args.signing_key_path) as f: key = read_signing_keys(f)[0] - result = request_json( + result = request( args.method, args.server_name, key, @@ -231,7 +245,12 @@ def main(): content=args.body, ) - json.dump(result, sys.stdout) + sys.stderr.write("Status Code: %d\n" % (result.status_code,)) + + for chunk in result.iter_content(): + # we write raw utf8 to stdout. + sys.stdout.buffer.write(chunk) + print("") diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 755a52a50dea..14b03490fad4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -96,12 +96,7 @@ def __init__(self, hs): ) async def get_room_data( - self, - user_id: str, - room_id: str, - event_type: str, - state_key: str, - is_guest: bool, + self, user_id: str, room_id: str, event_type: str, state_key: str, ) -> dict: """ Get data from a room. @@ -110,11 +105,10 @@ async def get_room_data( room_id event_type state_key - is_guest Returns: The path data content. Raises: - SynapseError if something went wrong. + SynapseError or AuthError if the user is not in the room """ ( membership, @@ -131,6 +125,16 @@ async def get_room_data( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) + else: + # check_user_in_room_or_world_readable, if it doesn't raise an AuthError, should + # only ever return a Membership.JOIN/LEAVE object + # + # Safeguard in case it returned something else + logger.error( + "Attempted to retrieve data from a room for a user that has never been in it. " + "This should not have happened." + ) + raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN) return data diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 4f3198896e83..adb9dc7c4289 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -15,11 +15,12 @@ # limitations under the License. import logging +import random from typing import List from signedjson.sign import sign_json -from twisted.internet import defer, reactor +from twisted.internet import reactor from synapse.api.errors import ( AuthError, @@ -73,36 +74,45 @@ def __init__(self, hs): ) if len(self.hs.config.replicate_user_profiles_to) > 0: - reactor.callWhenRunning(self._assign_profile_replication_batches) - reactor.callWhenRunning(self._replicate_profiles) + reactor.callWhenRunning(self._do_assign_profile_replication_batches) + reactor.callWhenRunning(self._start_replicate_profiles) # Add a looping call to replicate_profiles: this handles retries # if the replication is unsuccessful when the user updated their # profile. self.clock.looping_call( - self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL + self._start_replicate_profiles, self.PROFILE_REPLICATE_INTERVAL ) - @defer.inlineCallbacks - def _assign_profile_replication_batches(self): + def _do_assign_profile_replication_batches(self): + return run_as_background_process( + "_assign_profile_replication_batches", + self._assign_profile_replication_batches, + ) + + def _start_replicate_profiles(self): + return run_as_background_process( + "_replicate_profiles", self._replicate_profiles + ) + + async def _assign_profile_replication_batches(self): """If no profile replication has been done yet, allocate replication batch numbers to each profile to start the replication process. """ logger.info("Assigning profile batch numbers...") total = 0 while True: - assigned = yield self.store.assign_profile_batch() + assigned = await self.store.assign_profile_batch() total += assigned if assigned == 0: break logger.info("Assigned %d profile batch numbers", total) - @defer.inlineCallbacks - def _replicate_profiles(self): + async def _replicate_profiles(self): """If any profile data has been updated and not pushed to the replication targets, replicate it. """ - host_batches = yield self.store.get_replication_hosts() - latest_batch = yield self.store.get_latest_profile_replication_batch_number() + host_batches = await self.store.get_replication_hosts() + latest_batch = await self.store.get_latest_profile_replication_batch_number() if latest_batch is None: latest_batch = -1 for repl_host in self.hs.config.replicate_user_profiles_to: @@ -110,16 +120,15 @@ def _replicate_profiles(self): host_batches[repl_host] = -1 try: for i in range(host_batches[repl_host] + 1, latest_batch + 1): - yield self._replicate_host_profile_batch(repl_host, i) + await self._replicate_host_profile_batch(repl_host, i) except Exception: logger.exception( "Exception while replicating to %s: aborting for now", repl_host ) - @defer.inlineCallbacks - def _replicate_host_profile_batch(self, host, batchnum): + async def _replicate_host_profile_batch(self, host, batchnum): logger.info("Replicating profile batch %d to %s", batchnum, host) - batch_rows = yield self.store.get_profile_batch(batchnum) + batch_rows = await self.store.get_profile_batch(batchnum) batch = { UserID(r["user_id"], self.hs.hostname).to_string(): ( {"display_name": r["displayname"], "avatar_url": r["avatar_url"]} @@ -133,13 +142,11 @@ def _replicate_host_profile_batch(self, host, batchnum): body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname} signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0]) try: - yield defer.ensureDeferred( - self.http_client.post_json_get_json(url, signed_body) - ) - yield defer.ensureDeferred( - self.store.update_replication_batch_for_host(host, batchnum) + await self.http_client.post_json_get_json(url, signed_body) + await self.store.update_replication_batch_for_host(host, batchnum) + logger.info( + "Successfully replicated profile batch %d to %s", batchnum, host ) - logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host) except Exception: # This will get retried when the looping call next comes around logger.exception( @@ -292,8 +299,7 @@ async def set_displayname( # start a profile replication push run_in_background(self._replicate_profiles) - @defer.inlineCallbacks - def set_active( + async def set_active( self, users: List[UserID], active: bool, hide: bool, ): """ @@ -316,19 +322,16 @@ def set_active( hide: Whether to hide the user (withold from replication). If False and active is False, user will have their profile erased - - Returns: - Deferred """ if len(self.replicate_user_profiles_to) > 0: cur_batchnum = ( - yield self.store.get_latest_profile_replication_batch_number() + await self.store.get_latest_profile_replication_batch_number() ) new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 else: new_batchnum = None - yield self.store.set_profiles_active(users, active, hide, new_batchnum) + await self.store.set_profiles_active(users, active, hide, new_batchnum) # start a profile replication push run_in_background(self._replicate_profiles) @@ -362,8 +365,14 @@ async def get_avatar_url(self, target_user): async def set_avatar_url( self, target_user, requester, new_avatar_url, by_admin=False ): - """target_user is the user whose avatar_url is to be changed; - auth_user is the user attempting to make this change.""" + """Set a new avatar URL for a user. + + Args: + target_user (UserID): the user whose avatar URL is to be changed. + requester (Requester): The user attempting to make this change. + new_avatar_url (str): The avatar URL to give this user. + by_admin (bool): Whether this change was made by an administrator. + """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -484,6 +493,12 @@ async def _update_join_states(self, requester, target_user): await self.ratelimit(requester) + # Do not actually update the room state for shadow-banned users. + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + return + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8ee9b2063dac..50f375644613 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -384,7 +384,7 @@ async def _update_membership( # later on. content = dict(content) - if not self.allow_per_room_profiles: + if not self.allow_per_room_profiles or requester.shadow_banned: # Strip profile data, knowing that new profile data will be added to the # event's content in event_creation_handler.create_event() using the target's # global profile. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a86ac0150e05..1d828bd7be16 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -14,10 +14,11 @@ # limitations under the License. import logging +import random from collections import namedtuple from typing import TYPE_CHECKING, List, Set, Tuple -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream from synapse.types import UserID, get_domain_from_id @@ -227,9 +228,9 @@ def _handle_timeout_for_member(self, now: int, member: RoomMember): self._stopped_typing(member) return - async def started_typing(self, target_user, auth_user, room_id, timeout): + async def started_typing(self, target_user, requester, room_id, timeout): target_user_id = target_user.to_string() - auth_user_id = auth_user.to_string() + auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") @@ -237,6 +238,11 @@ async def started_typing(self, target_user, auth_user, room_id, timeout): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -256,9 +262,9 @@ async def started_typing(self, target_user, auth_user, room_id, timeout): self._push_update(member=member, typing=True) - async def stopped_typing(self, target_user, auth_user, room_id): + async def stopped_typing(self, target_user, requester, room_id): target_user_id = target_user.to_string() - auth_user_id = auth_user.to_string() + auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") @@ -266,6 +272,11 @@ async def stopped_typing(self, target_user, auth_user, room_id): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index d43eaf3a2994..047f2c50f78a 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -21,9 +21,9 @@ def __init__(self, db_conn, table, column, extra_tables=[], step=1): self.step = step self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self.advance(_load_current_id(db_conn, table, column)) + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, new_id): + def advance(self, instance_name, new_id): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self): diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 154f0e687c58..bb66ba9b80f8 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -41,12 +41,12 @@ def get_max_account_data_stream_id(self): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index ee7f69a91816..533d927701d3 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -46,7 +46,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(token) + self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 722f3745e9bc..596c72eb92af 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -50,10 +50,10 @@ def __init__(self, database: DatabasePool, db_conn, hs): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 3291558c7a76..567b4a5cc1cc 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -40,7 +40,7 @@ def get_group_stream_token(self): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == GroupServerStream.NAME: - self._group_updates_id_gen.advance(token) + self._group_updates_id_gen.advance(instance_name, token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index a912c04360e1..025f6f6be8e6 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -44,7 +44,7 @@ def get_current_presence_token(self): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(token) + self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 90d90833f989..de904c943cc0 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -30,7 +30,7 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) if stream_name == PushRulesStream.NAME: - self._push_rules_stream_id_gen.advance(token) + self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 63300e5da608..9da218bfe855 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -34,5 +34,5 @@ def get_pushers_stream_token(self): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(token) + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 17ba1f22ac47..5c2986e05017 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -46,7 +46,7 @@ def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(token) + self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 427c81772b51..80ae803ad9ab 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -33,6 +33,6 @@ def get_current_public_room_stream_id(self): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PublicRoomsStream.NAME: - self._public_room_id_gen.advance(token) + self._public_room_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 7ed1ccb5a059..3929d5519d75 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -171,7 +171,6 @@ async def on_GET(self, request, room_id, event_type, state_key): room_id=room_id, event_type=event_type, state_key=state_key, - is_guest=requester.is_guest, ) if not data: @@ -870,17 +869,21 @@ async def on_PUT(self, request, room_id, user_id): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) - if content["typing"]: - await self.typing_handler.started_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, - timeout=timeout, - ) - else: - await self.typing_handler.stopped_typing( - target_user=target_user, auth_user=requester.user, room_id=room_id - ) + try: + if content["typing"]: + await self.typing_handler.started_typing( + target_user=target_user, + requester=requester, + room_id=room_id, + timeout=timeout, + ) + else: + await self.typing_handler.stopped_typing( + target_user=target_user, requester=requester, room_id=room_id + ) + except ShadowBanError: + # Pretend this worked without error. + pass return 200, {} diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bc327e344e63..181c3ec24994 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,9 +29,11 @@ Tuple, TypeVar, Union, + overload, ) from prometheus_client import Histogram +from typing_extensions import Literal from twisted.enterprise import adbapi from twisted.internet import defer @@ -1020,14 +1022,36 @@ def simple_upsert_many_txn_native_upsert( return txn.execute_batch(sql, args) - def simple_select_one( + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one", + ) -> Dict[str, Any]: + ... + + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: + ... + + async def simple_select_one( self, table: str, keyvalues: Dict[str, Any], retcols: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> defer.Deferred: + ) -> Optional[Dict[str, Any]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1038,18 +1062,18 @@ def simple_select_one( allow_none: If true, return None instead of failing if the SELECT statement returns no rows """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def simple_select_one_onecol( + async def simple_select_one_onecol( self, table: str, keyvalues: Dict[str, Any], retcol: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one_onecol", - ) -> defer.Deferred: + ) -> Optional[Any]: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -1061,7 +1085,7 @@ def simple_select_one_onecol( statement returns no rows desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_onecol_txn, table, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 03b45dbc4d51..a811a39eb524 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -47,7 +47,7 @@ class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id: str, device_id: str): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -55,11 +55,11 @@ def get_device(self, user_id: str, device_id: str): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - defer.Deferred for a dict containing the device information + A dict containing the device information Raises: StoreError: if the device is not found """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -656,11 +656,13 @@ def _get_all_device_list_changes_for_remotes(txn): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id: str): + async def get_device_list_last_stream_id_for_remote( + self, user_id: str + ) -> Optional[Any]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 037e02603c7b..301d5d845ac8 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -59,8 +59,8 @@ async def get_association_from_room_alias( return RoomAliasMapping(room_id, room_alias.to_string(), servers) - def get_room_alias_creator(self, room_alias): - return self.db_pool.simple_select_one_onecol( + async def get_room_alias_creator(self, room_alias: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 2eeb9f97dc14..46c3e33cc667 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -223,15 +223,15 @@ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): return ret - def count_e2e_room_keys(self, user_id, version): + async def count_e2e_room_keys(self, user_id: str, version: str) -> int: """Get the number of keys in a backup version. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e1241a724b67..e6247d682d8c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -113,25 +113,25 @@ def __init__(self, database: DatabasePool, db_conn, hs): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(token) + self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(-token) + self._backfill_id_gen.advance(instance_name, -token) super().process_replication_rows(stream_name, instance_name, token, rows) - def get_received_ts(self, event_id): + async def get_received_ts(self, event_id: str) -> Optional[int]: """Get received_ts (when it was persisted) for the event. Raises an exception for unknown events. Args: - event_id (str) + event_id: The event ID to query. Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. + Timestamp in milliseconds, or None for events that were persisted + before received_ts was implemented. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a488e0924b66..c39864f59f8d 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -28,8 +28,8 @@ class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db_pool.simple_select_one( + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -351,8 +351,10 @@ async def is_user_in_group(self, user_id: str, group_id: str) -> bool: ) return bool(result) - def is_user_admin_in_group(self, group_id, user_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_admin_in_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -360,10 +362,12 @@ def is_user_admin_in_group(self, group_id, user_id): desc="is_user_admin_in_group", ) - def is_user_invited_to_local_group(self, group_id, user_id): + async def is_user_invited_to_local_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: """Has the group server invited a user? """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 80fc1cd0092a..4ae255ebd8f5 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - def get_local_media(self, media_id): + async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media + Returns: None if the media_id doesn't exist. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -191,8 +194,10 @@ def store_local_thumbnail( desc="store_local_thumbnail", ) - def get_cached_remote_media(self, origin, media_id): - return self.db_pool.simple_select_one( + async def get_cached_remote_media( + self, origin, media_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e71cdd2cb4e2..fe30552c08ef 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -99,17 +99,18 @@ async def get_registered_reserved_users(self) -> List[str]: return users @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): + async def user_last_seen_monthly_active(self, user_id: str) -> int: """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen + Checks if a given user is part of the monthly active user group + Arguments: + user_id: user to add/update + + Return: + Timestamp since last seen, None if never seen """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 086cfbeed4cb..f607b823f8f8 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,8 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -26,7 +25,7 @@ class ProfileWorkerStore(SQLBaseStore): - async def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", @@ -46,8 +45,8 @@ async def get_profileinfo(self, user_localpart): ) @cached(max_entries=5000) - def get_profile_displayname(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_displayname(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -55,33 +54,33 @@ def get_profile_displayname(self, user_localpart): ) @cached(max_entries=5000) - def get_profile_avatar_url(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_avatar_url(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", desc="get_profile_avatar_url", ) - def get_latest_profile_replication_batch_number(self): + async def get_latest_profile_replication_batch_number(self): def f(txn): txn.execute("SELECT MAX(batch) as maxbatch FROM profiles") rows = self.db_pool.cursor_to_dict(txn) return rows[0]["maxbatch"] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_latest_profile_replication_batch_number", f ) - def get_profile_batch(self, batchnum): - return self.db_pool.simple_select_list( + async def get_profile_batch(self, batchnum): + return await self.db_pool.simple_select_list( table="profiles", keyvalues={"batch": batchnum}, retcols=("user_id", "displayname", "avatar_url", "active"), desc="get_profile_batch", ) - def assign_profile_batch(self): + async def assign_profile_batch(self): def f(txn): sql = ( "UPDATE profiles SET batch = " @@ -93,9 +92,9 @@ def f(txn): txn.execute(sql, (BATCH_SIZE,)) return txn.rowcount - return self.db_pool.runInteraction("assign_profile_batch", f) + return await self.db_pool.runInteraction("assign_profile_batch", f) - def get_replication_hosts(self): + async def get_replication_hosts(self): def f(txn): txn.execute( "SELECT host, last_synced_batch FROM profile_replication_status" @@ -103,18 +102,22 @@ def f(txn): rows = self.db_pool.cursor_to_dict(txn) return {r["host"]: r["last_synced_batch"] for r in rows} - return self.db_pool.runInteraction("get_replication_hosts", f) + return await self.db_pool.runInteraction("get_replication_hosts", f) - def update_replication_batch_for_host(self, host, last_synced_batch): - return self.db_pool.simple_upsert( + async def update_replication_batch_for_host( + self, host: str, last_synced_batch: int + ): + return await self.db_pool.simple_upsert( table="profile_replication_status", keyvalues={"host": host}, values={"last_synced_batch": last_synced_batch}, desc="update_replication_batch_for_host", ) - def get_from_remote_profile_cache(self, user_id): - return self.db_pool.simple_select_one( + async def get_from_remote_profile_cache( + self, user_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -151,9 +154,9 @@ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum): lock=False, # we can do this because user_id has a unique index ) - def set_profiles_active( + async def set_profiles_active( self, users: List[UserID], active: bool, hide: bool, batchnum: int, - ): + ) -> None: """Given a set of users, set active and hidden flags on them. Args: @@ -163,9 +166,6 @@ def set_profiles_active( False and active is False, users will have their profiles erased batchnum: The batch number, used for profile replication - - Returns: - Deferred """ # Convert list of localparts to list of tuples containing localparts user_localparts = [(user.localpart,) for user in users] @@ -180,7 +180,7 @@ def set_profiles_active( value_names += ("avatar_url", "displayname") values = [v + (None, None) for v in values] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_profiles_active", self.db_pool.simple_upsert_many_txn, table="profiles", @@ -225,7 +225,7 @@ def update_remote_profile_cache(self, user_id, displayname, avatar_url): return self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, - updatevalues={ + values={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6821476ee084..cea5ac9a6862 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -71,8 +71,10 @@ def get_receipts_for_room(self, room_id, receipt_type): ) @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db_pool.simple_select_one_onecol( + async def get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 336b578e23cf..48c979aeea1e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -46,8 +46,8 @@ def __init__(self, database: DatabasePool, db_conn, hs): ) @cached() - def get_user_by_id(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -1338,12 +1338,12 @@ def del_user_pending_deactivation(self, user_id): desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py index cf9ba5120594..1e361aaa9a73 100644 --- a/synapse/storage/databases/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore @@ -21,8 +22,8 @@ class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def get_rejection_reason(self, event_id: str) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 99a8a9fab0fd..dc97f70c66cb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,15 +73,15 @@ def __init__(self, database: DatabasePool, db_conn, hs): self.config = hs.config - def get_room(self, room_id): + async def get_room(self, room_id: str) -> dict: """Retrieve a room. Args: - room_id (str): The ID of the room to retrieve. + room_id: The ID of the room to retrieve. Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -330,8 +330,8 @@ def _get_largest_public_rooms_txn(txn): return ret_val @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db_pool.simple_select_one_onecol( + async def is_room_blocked(self, room_id: str) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 991233a9bcae..458f169617e1 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -260,8 +260,8 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 802c9019b9f4..9fe97af56adb 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -211,11 +211,11 @@ def _get_next_batch(txn): return len(rooms_to_work_on) - def get_stats_positions(self): + async def get_stats_positions(self) -> int: """ Returns the stats processor positions. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -300,7 +300,7 @@ def _get_statistics_for_subject_txn( return slice_list @cached() - def get_earliest_token_for_stats(self, stats_type, id): + async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -308,11 +308,11 @@ def get_earliest_token_for_stats(self, stats_type, id): being calculated. Returns: - Deferred[int] + The earliest token. """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af21fe457adb..20cbcd851c04 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,6 +15,7 @@ import logging import re +from typing import Any, Dict, Optional from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -527,8 +528,8 @@ def _delete_all_from_user_dir_txn(txn): ) @cached() - def get_user_in_directory(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -663,8 +664,8 @@ async def get_user_dir_rooms_user_is_in(self, user_id): users.update(rows) return list(users) - def get_user_directory_stream_pos(self): - return self.db_pool.simple_select_one_onecol( + async def get_user_directory_stream_pos(self) -> int: + return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d7f0c19c4cb9..3cd4f71d40e5 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -106,7 +106,12 @@ def test_set_my_name(self): ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) @defer.inlineCallbacks @@ -201,7 +206,11 @@ def test_set_my_avatar(self): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/pic.gif", ) @@ -215,7 +224,11 @@ def test_set_my_avatar(self): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) @@ -231,7 +244,11 @@ def test_set_my_avatar_if_disabled(self): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 4b627dac00b1..2b7eeef12947 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -262,7 +262,7 @@ def test_initial_earliest_token(self): # self.handler.notify_new_event() # We need to let the delta processor advanceā€¦ - self.pump(10 * 60) + self.reactor.advance(10 * 60) # Get the slices! There should be two -- day 1, and day 2. r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index e01de158e5f1..81c1839637eb 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import make_awaitable @@ -144,9 +144,9 @@ def get_users_in_room(room_id): self.datastore.get_users_in_room = get_users_in_room - self.datastore.get_user_directory_stream_pos.return_value = ( + self.datastore.get_user_directory_stream_pos.side_effect = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam - defer.succeed(1) + lambda: make_awaitable(1) ) self.datastore.get_current_state_deltas.return_value = (0, None) @@ -167,7 +167,10 @@ def test_started_typing_local(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -194,7 +197,10 @@ def test_started_typing_remote_send(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -269,7 +275,9 @@ def test_stopped_typing(self): self.get_success( self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, ) ) @@ -309,7 +317,10 @@ def test_typing_timeout(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) @@ -348,7 +359,10 @@ def test_typing_timeout(self): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ac598249e405..7f1dfb2128d7 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -508,6 +508,6 @@ def test_closes_connection(self): self.assertFalse(conn.disconnecting) # wait for a while - self.pump(120) + self.reactor.advance(120) self.assertTrue(conn.disconnecting) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9c778a0e4561..ccbb82f6a362 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -47,7 +47,7 @@ def test_can_register_user(self): # Check that the new user exists with all provided attributes self.assertEqual(user_id, "@bob:test") self.assertTrue(access_token) - self.assertTrue(self.store.get_user_by_id(user_id)) + self.assertTrue(self.get_success(self.store.get_user_by_id(user_id))) # Check that the email was assigned emails = self.get_success(self.store.user_get_threepids(user_id)) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 83032cc9eab6..227b0d32d047 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -170,7 +170,7 @@ def _check_for_mail(self): last_stream_ordering = pushers[0]["last_stream_ordering"] # Advance time a bit, so the pusher will register something has happened - self.pump(100) + self.pump(10) # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 83f9aa291c67..8b4982ecb160 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -20,7 +20,7 @@ from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.test_utils import make_awaitable @@ -175,7 +175,7 @@ def test_send_typing_sharded(self): self.get_success( typing_handler.started_typing( target_user=UserID.from_string(user), - auth_user=UserID.from_string(user), + requester=create_requester(user), room_id=room, timeout=20000, ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index c7e287c61e04..47c0d5634cd2 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -179,7 +179,7 @@ def _test_retention_event_purged(self, room_id: str, increment: float): message_handler = self.hs.get_message_handler() create_event = self.get_success( message_handler.get_room_data( - self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False + self.user_id, room_id, EventTypes.Create, state_key="" ) ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py new file mode 100644 index 000000000000..dfe4bf7762e3 --- /dev/null +++ b/tests/rest/client/test_shadow_banned.py @@ -0,0 +1,312 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock, patch + +import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import directory, login, profile, room +from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet + +from tests import unittest + + +class _ShadowBannedBase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + # Create two users, one of which is shadow-banned. + self.banned_user_id = self.register_user("banned", "test") + self.banned_access_token = self.login("banned", "test") + + self.store = self.hs.get_datastore() + + self.get_success( + self.store.db_pool.simple_update( + table="users", + keyvalues={"name": self.banned_user_id}, + updatevalues={"shadow_banned": True}, + desc="shadow_ban", + ) + ) + + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class RoomTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def test_invite(self): + """Invites from shadow-banned users don't actually get sent.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + self.helper.invite( + room=room_id, + src=self.banned_user_id, + tok=self.banned_access_token, + targ=self.other_user_id, + ) + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + def test_invite_3pid(self): + """Ensure that a 3PID invite does not attempt to contact the identity server.""" + identity_handler = self.hs.get_handlers().identity_handler + identity_handler.lookup_3pid = Mock( + side_effect=AssertionError("This should not get called") + ) + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + request, channel = self.make_request( + "POST", + "/rooms/%s/invite" % (room_id,), + {"id_server": "test", "medium": "email", "address": "test@test.test"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # This should have raised an error earlier, but double check this wasn't called. + identity_handler.lookup_3pid.assert_not_called() + + def test_create_room(self): + """Invitations during a room creation should be discarded, but the room still gets created.""" + # The room creation is successful. + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + {"visibility": "public", "invite": [self.other_user_id]}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + room_id = channel.json_body["room_id"] + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + # Since a real room was created, the other user should be able to join it. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + # Both users should be in the room. + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) + + def test_message(self): + """Messages from shadow-banned users don't actually get sent.""" + + room_id = self.helper.create_room_as( + self.other_user_id, tok=self.other_access_token + ) + + # The user should be in the room. + self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) + + # Sending a message should complete successfully. + result = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "with right label"}, + tok=self.banned_access_token, + ) + self.assertIn("event_id", result) + event_id = result["event_id"] + + latest_events = self.get_success( + self.store.get_latest_event_ids_in_room(room_id) + ) + self.assertNotIn(event_id, latest_events) + + def test_upgrade(self): + """A room upgrade should fail, but look like it succeeded.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), + {"new_version": "6"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + # A new room_id should be returned. + self.assertIn("replacement_room", channel.json_body) + + new_room_id = channel.json_body["replacement_room"] + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(new_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + def test_typing(self): + """Typing notifications should not be propagated into the room.""" + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.banned_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # There should be no typing events. + event_source = self.hs.get_event_sources().sources["typing"] + self.assertEquals(event_source.get_current_key(), 0) + + # The other user can join and send typing events. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.other_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.other_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # These appear in the room. + self.assertEquals(event_source.get_current_key(), 1) + events = self.get_success( + event_source.get_new_events(from_key=0, room_ids=[room_id]) + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": room_id, + "content": {"user_ids": [self.other_user_id]}, + } + ], + ) + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class ProfileTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def test_displayname(self): + """Profile changes should succeed, but don't end up in a room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,), + {"displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertEqual(channel.json_body, {}) + + # The user's display name should be updated. + request, channel = self.make_request( + "GET", "/profile/%s/displayname" % (self.banned_user_id,) + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["displayname"], new_display_name) + + # But the display name in the room should not be. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) + + def test_room_displayname(self): + """Changes to state events for a room should be processed, but not end up in the room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (room_id, self.banned_user_id), + {"membership": "join", "displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertIn("event_id", channel.json_body) + + # The display name in the room should not be changed. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 68c4a6a8f7ea..0a567b032f45 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -21,13 +21,13 @@ import json from urllib import parse as urlparse -from mock import Mock, patch +from mock import Mock import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet +from synapse.rest.client.v2_alpha import account from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string @@ -2060,158 +2060,3 @@ def test_bad_alias(self): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) - - -# To avoid the tests timing out don't add a delay to "annoy the requester". -@patch("random.randint", new=lambda a, b: 0) -class ShadowBannedTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - room_upgrade_rest_servlet.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.banned_user_id = self.register_user("banned", "test") - self.banned_access_token = self.login("banned", "test") - - self.store = self.hs.get_datastore() - - self.get_success( - self.store.db_pool.simple_update( - table="users", - keyvalues={"name": self.banned_user_id}, - updatevalues={"shadow_banned": True}, - desc="shadow_ban", - ) - ) - - self.other_user_id = self.register_user("otheruser", "pass") - self.other_access_token = self.login("otheruser", "pass") - - def test_invite(self): - """Invites from shadow-banned users don't actually get sent.""" - - # The create works fine. - room_id = self.helper.create_room_as( - self.banned_user_id, tok=self.banned_access_token - ) - - # Inviting the user completes successfully. - self.helper.invite( - room=room_id, - src=self.banned_user_id, - tok=self.banned_access_token, - targ=self.other_user_id, - ) - - # But the user wasn't actually invited. - invited_rooms = self.get_success( - self.store.get_invited_rooms_for_local_user(self.other_user_id) - ) - self.assertEqual(invited_rooms, []) - - def test_invite_3pid(self): - """Ensure that a 3PID invite does not attempt to contact the identity server.""" - identity_handler = self.hs.get_handlers().identity_handler - identity_handler.lookup_3pid = Mock( - side_effect=AssertionError("This should not get called") - ) - - # The create works fine. - room_id = self.helper.create_room_as( - self.banned_user_id, tok=self.banned_access_token - ) - - # Inviting the user completes successfully. - request, channel = self.make_request( - "POST", - "/rooms/%s/invite" % (room_id,), - {"id_server": "test", "medium": "email", "address": "test@test.test"}, - access_token=self.banned_access_token, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - # This should have raised an error earlier, but double check this wasn't called. - identity_handler.lookup_3pid.assert_not_called() - - def test_create_room(self): - """Invitations during a room creation should be discarded, but the room still gets created.""" - # The room creation is successful. - request, channel = self.make_request( - "POST", - "/_matrix/client/r0/createRoom", - {"visibility": "public", "invite": [self.other_user_id]}, - access_token=self.banned_access_token, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - room_id = channel.json_body["room_id"] - - # But the user wasn't actually invited. - invited_rooms = self.get_success( - self.store.get_invited_rooms_for_local_user(self.other_user_id) - ) - self.assertEqual(invited_rooms, []) - - # Since a real room was created, the other user should be able to join it. - self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) - - # Both users should be in the room. - users = self.get_success(self.store.get_users_in_room(room_id)) - self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) - - def test_message(self): - """Messages from shadow-banned users don't actually get sent.""" - - room_id = self.helper.create_room_as( - self.other_user_id, tok=self.other_access_token - ) - - # The user should be in the room. - self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) - - # Sending a message should complete successfully. - result = self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "with right label"}, - tok=self.banned_access_token, - ) - self.assertIn("event_id", result) - event_id = result["event_id"] - - latest_events = self.get_success( - self.store.get_latest_event_ids_in_room(room_id) - ) - self.assertNotIn(event_id, latest_events) - - def test_upgrade(self): - """A room upgrade should fail, but look like it succeeded.""" - - # The create works fine. - room_id = self.helper.create_room_as( - self.banned_user_id, tok=self.banned_access_token - ) - - request, channel = self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), - {"new_version": "6"}, - access_token=self.banned_access_token, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - # A new room_id should be returned. - self.assertIn("replacement_room", channel.json_body) - - new_room_id = channel.json_body["replacement_room"] - - # It doesn't really matter what API we use here, we just want to assert - # that the room doesn't exist. - summary = self.get_success(self.store.get_room_summary(new_room_id)) - # The summary should be empty since the room doesn't exist. - self.assertEqual(summary, {}) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 13bcac743acf..bf22540d9905 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -97,8 +97,10 @@ def test_select_one_1col(self): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db_pool.simple_select_one_onecol( - table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + value = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one_onecol( + table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + ) ) self.assertEquals("Value", value) @@ -111,10 +113,12 @@ def test_select_one_3col(self): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"], + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"], + ) ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) @@ -127,11 +131,13 @@ def test_select_one_missing(self): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True, + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True, + ) ) self.assertFalse(ret) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 43639ca28615..080761d1d2dc 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -271,7 +271,7 @@ def test_send_dummy_event(self): # Pump the reactor repeatedly so that the background updates have a # chance to run. - self.pump(10 * 60) + self.pump(20) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 87ed8f8cd1b4..34ae8c9da7fc 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -38,7 +38,7 @@ def test_store_new_device(self): self.store.store_device("user_id", "device_id", "display_name") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -111,12 +111,12 @@ def test_update_device(self): self.store.store_device("user_id", "device_id", "display_name 1") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -127,7 +127,7 @@ def test_update_device(self): ) # check it worked - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) @defer.inlineCallbacks diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 16a32cb819f4..7a38022e7189 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -40,7 +40,12 @@ def test_displayname(self): ) self.assertEquals( - "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ), ) @defer.inlineCallbacks @@ -55,5 +60,9 @@ def test_avatar_url(self): self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ), ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 58f827d8d329..70c55cd65040 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,7 +53,7 @@ def test_register(self): "user_type": None, "deactivated": 0, }, - (yield self.store.get_user_by_id(self.user_id)), + (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index d07b985a8e04..bc8400f24072 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -54,12 +54,14 @@ def test_get_room(self): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield self.store.get_room(self.room.to_string())), + (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), ) @defer.inlineCallbacks def test_get_room_unknown_room(self): - self.assertIsNone((yield self.store.get_room("!uknown:test")),) + self.assertIsNone( + (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) + ) @defer.inlineCallbacks def test_get_room_with_stats(self): @@ -69,12 +71,22 @@ def test_get_room_with_stats(self): "creator": self.u_creator.to_string(), "public": True, }, - (yield self.store.get_room_with_stats(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks def test_get_room_with_stats_unknown_room(self): - self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats("!uknown:test") + ) + ), + ) class RoomEventsStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index d98fe8754dab..12ccc1f53e99 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -87,7 +87,7 @@ def test_count_known_servers_stat_counter_disabled(self): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump() self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) @@ -101,7 +101,7 @@ def test_count_known_servers_stat_counter_enabled(self): # Initialises to 1 -- itself self.assertEqual(self.store._known_servers_count, 1) - self.pump(20) + self.pump() # No rooms have been joined, so technically the SQL returns 0, but it # will still say it knows about itself. @@ -111,7 +111,7 @@ def test_count_known_servers_stat_counter_enabled(self): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump(1) # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index bc42ffce880c..5f46ed0cefd9 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -91,7 +91,7 @@ def test_limiter(self): # # one more go, with success # - self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) + self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) diff --git a/tox.ini b/tox.ini index f8ecd1aa6974..6329ba286a29 100644 --- a/tox.ini +++ b/tox.ini @@ -172,58 +172,8 @@ deps = {[base]deps} mypy==0.782 mypy-zope -env = - MYPYPATH = stubs/ extras = all -commands = mypy \ - synapse/api \ - synapse/appservice \ - synapse/config \ - synapse/event_auth.py \ - synapse/events/builder.py \ - synapse/events/spamcheck.py \ - synapse/federation \ - synapse/handlers/auth.py \ - synapse/handlers/cas_handler.py \ - synapse/handlers/directory.py \ - synapse/handlers/federation.py \ - synapse/handlers/identity.py \ - synapse/handlers/message.py \ - synapse/handlers/oidc_handler.py \ - synapse/handlers/presence.py \ - synapse/handlers/room.py \ - synapse/handlers/room_member.py \ - synapse/handlers/room_member_worker.py \ - synapse/handlers/saml_handler.py \ - synapse/handlers/sync.py \ - synapse/handlers/ui_auth \ - synapse/http/server.py \ - synapse/http/site.py \ - synapse/logging/ \ - synapse/metrics \ - synapse/module_api \ - synapse/notifier.py \ - synapse/push/pusherpool.py \ - synapse/push/push_rule_evaluator.py \ - synapse/replication \ - synapse/rest \ - synapse/server.py \ - synapse/server_notices \ - synapse/spam_checker_api \ - synapse/state \ - synapse/storage/databases/main/ui_auth.py \ - synapse/storage/database.py \ - synapse/storage/engines \ - synapse/storage/state.py \ - synapse/storage/util \ - synapse/streams \ - synapse/types.py \ - synapse/util/caches/stream_change_cache.py \ - synapse/util/metrics.py \ - tests/replication \ - tests/test_utils \ - tests/rest/client/v2_alpha/test_auth.py \ - tests/util/test_stream_change_cache.py +commands = mypy # To find all folders that pass mypy you run: #