From 319b8b6bef1a8666b9ef57dfea5030da0e1effc2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 14 Sep 2021 11:25:05 +0100 Subject: [PATCH 01/74] Name the type of token in "Invalid token" messages (#10815) I had one of these error messages yesterday and assumed it was an invalid auth token (because that was an HTTP query parameter in the test) I was working on. In fact, it was an invalid next batch token for syncing. --- changelog.d/10815.misc | 1 + synapse/handlers/auth.py | 2 +- synapse/storage/relations.py | 4 ++-- synapse/types.py | 6 +++--- 4 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10815.misc diff --git a/changelog.d/10815.misc b/changelog.d/10815.misc new file mode 100644 index 000000000..fc2534dc1 --- /dev/null +++ b/changelog.d/10815.misc @@ -0,0 +1 @@ +Specify the type of token in generic "Invalid token" error messages. \ No newline at end of file diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fbbf6fd83..3ea627008 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1347,7 +1347,7 @@ async def validate_short_term_login_token( try: res = self.macaroon_gen.verify_short_term_login_token(login_token) except Exception: - raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) + raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) await self.auth.check_auth_blocking(res.user_id) return res diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index c552dbf04..10a46b5e8 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -73,7 +73,7 @@ def from_string(string: str) -> "RelationPaginationToken": t, s = string.split("-") return RelationPaginationToken(int(t), int(s)) except ValueError: - raise SynapseError(400, "Invalid token") + raise SynapseError(400, "Invalid relation pagination token") def to_string(self) -> str: return "%d-%d" % (self.topological, self.stream) @@ -103,7 +103,7 @@ def from_string(string: str) -> "AggregationPaginationToken": c, s = string.split("-") return AggregationPaginationToken(int(c), int(s)) except ValueError: - raise SynapseError(400, "Invalid token") + raise SynapseError(400, "Invalid aggregation pagination token") def to_string(self) -> str: return "%d-%d" % (self.count, self.stream) diff --git a/synapse/types.py b/synapse/types.py index d4759b2df..90168ce8f 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -511,7 +511,7 @@ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": ) except Exception: pass - raise SynapseError(400, "Invalid token %r" % (string,)) + raise SynapseError(400, "Invalid room stream token %r" % (string,)) @classmethod def parse_stream_token(cls, string: str) -> "RoomStreamToken": @@ -520,7 +520,7 @@ def parse_stream_token(cls, string: str) -> "RoomStreamToken": return cls(topological=None, stream=int(string[1:])) except Exception: pass - raise SynapseError(400, "Invalid token %r" % (string,)) + raise SynapseError(400, "Invalid room stream token %r" % (string,)) def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": """Return a new token such that if an event is after both this token and @@ -619,7 +619,7 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) ) except Exception: - raise SynapseError(400, "Invalid Token") + raise SynapseError(400, "Invalid stream token") async def to_string(self, store: "DataStore") -> str: return self._SEPARATOR.join( From b996782df51eaa5dd30635a7c59c93994d3a735e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Sep 2021 07:09:38 -0400 Subject: [PATCH 02/74] Convert media repo's FileInfo to attrs. (#10785) This is mostly an internal change, but improves type hints in the media code. --- changelog.d/10785.misc | 1 + synapse/rest/media/v1/_base.py | 94 +++++++++++++-------- synapse/rest/media/v1/media_repository.py | 40 +++++---- synapse/rest/media/v1/media_storage.py | 36 ++++---- synapse/rest/media/v1/thumbnail_resource.py | 77 +++++++++-------- 5 files changed, 140 insertions(+), 108 deletions(-) create mode 100644 changelog.d/10785.misc diff --git a/changelog.d/10785.misc b/changelog.d/10785.misc new file mode 100644 index 000000000..3d7f91d51 --- /dev/null +++ b/changelog.d/10785.misc @@ -0,0 +1 @@ +Convert the internal `FileInfo` class to attrs and add type hints. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 90364ebcf..814f4309f 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -18,6 +18,8 @@ import urllib from typing import Awaitable, Dict, Generator, List, Optional, Tuple +import attr + from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from twisted.web.server import Request @@ -287,44 +289,62 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass -class FileInfo: - """Details about a requested/uploaded file. - - Attributes: - server_name (str): The server name where the media originated from, - or None if local. - file_id (str): The local ID of the file. For local files this is the - same as the media_id - url_cache (bool): If the file is for the url preview cache - thumbnail (bool): Whether the file is a thumbnail or not. - thumbnail_width (int) - thumbnail_height (int) - thumbnail_method (str) - thumbnail_type (str): Content type of thumbnail, e.g. image/png - thumbnail_length (int): The size of the media file, in bytes. - """ +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThumbnailInfo: + """Details about a generated thumbnail.""" - def __init__( - self, - server_name, - file_id, - url_cache=False, - thumbnail=False, - thumbnail_width=None, - thumbnail_height=None, - thumbnail_method=None, - thumbnail_type=None, - thumbnail_length=None, - ): - self.server_name = server_name - self.file_id = file_id - self.url_cache = url_cache - self.thumbnail = thumbnail - self.thumbnail_width = thumbnail_width - self.thumbnail_height = thumbnail_height - self.thumbnail_method = thumbnail_method - self.thumbnail_type = thumbnail_type - self.thumbnail_length = thumbnail_length + width: int + height: int + method: str + # Content type of thumbnail, e.g. image/png + type: str + # The size of the media file, in bytes. + length: Optional[int] = None + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FileInfo: + """Details about a requested/uploaded file.""" + + # The server name where the media originated from, or None if local. + server_name: Optional[str] + # The local ID of the file. For local files this is the same as the media_id + file_id: str + # If the file is for the url preview cache + url_cache: bool = False + # Whether the file is a thumbnail or not. + thumbnail: Optional[ThumbnailInfo] = None + + # The below properties exist to maintain compatibility with third-party modules. + @property + def thumbnail_width(self): + if not self.thumbnail: + return None + return self.thumbnail.width + + @property + def thumbnail_height(self): + if not self.thumbnail: + return None + return self.thumbnail.height + + @property + def thumbnail_method(self): + if not self.thumbnail: + return None + return self.thumbnail.method + + @property + def thumbnail_type(self): + if not self.thumbnail: + return None + return self.thumbnail.type + + @property + def thumbnail_length(self): + if not self.thumbnail: + return None + return self.thumbnail.length def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0f5ce41ff..40ce8d2bc 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -42,6 +42,7 @@ from ._base import ( FileInfo, Responder, + ThumbnailInfo, get_filename_from_headers, respond_404, respond_with_responder, @@ -210,7 +211,7 @@ async def get_local_media( upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] - file_info = FileInfo(None, media_id, url_cache=url_cache) + file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) responder = await self.media_storage.fetch_media(file_info) await respond_with_responder( @@ -514,7 +515,7 @@ async def generate_local_exact_thumbnail( t_height: int, t_method: str, t_type: str, - url_cache: Optional[str], + url_cache: bool, ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) @@ -548,11 +549,12 @@ async def generate_local_exact_thumbnail( server_name=None, file_id=media_id, url_cache=url_cache, - thumbnail=True, - thumbnail_width=t_width, - thumbnail_height=t_height, - thumbnail_method=t_method, - thumbnail_type=t_type, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), ) output_path = await self.media_storage.store_file( @@ -585,7 +587,7 @@ async def generate_remote_exact_thumbnail( t_type: str, ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(server_name, file_id, url_cache=False) + FileInfo(server_name, file_id) ) try: @@ -616,11 +618,12 @@ async def generate_remote_exact_thumbnail( file_info = FileInfo( server_name=server_name, file_id=file_id, - thumbnail=True, - thumbnail_width=t_width, - thumbnail_height=t_height, - thumbnail_method=t_method, - thumbnail_type=t_type, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), ) output_path = await self.media_storage.store_file( @@ -742,12 +745,13 @@ async def _generate_thumbnails( file_info = FileInfo( server_name=server_name, file_id=file_id, - thumbnail=True, - thumbnail_width=t_width, - thumbnail_height=t_height, - thumbnail_method=t_method, - thumbnail_type=t_type, url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), ) with self.media_storage.store_into_file(file_info) as (f, fname, finish): diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 56cdc1b4e..c0bb40c11 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -176,9 +176,9 @@ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: self.filepaths.remote_media_thumbnail_rel_legacy( server_name=file_info.server_name, file_id=file_info.file_id, - width=file_info.thumbnail_width, - height=file_info.thumbnail_height, - content_type=file_info.thumbnail_type, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, ) ) @@ -220,9 +220,9 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( server_name=file_info.server_name, file_id=file_info.file_id, - width=file_info.thumbnail_width, - height=file_info.thumbnail_height, - content_type=file_info.thumbnail_type, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, ) legacy_local_path = os.path.join(self.local_media_directory, legacy_path) if os.path.exists(legacy_local_path): @@ -255,10 +255,10 @@ def _file_info_to_path(self, file_info: FileInfo) -> str: if file_info.thumbnail: return self.filepaths.url_cache_thumbnail_rel( media_id=file_info.file_id, - width=file_info.thumbnail_width, - height=file_info.thumbnail_height, - content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, ) return self.filepaths.url_cache_filepath_rel(file_info.file_id) @@ -267,10 +267,10 @@ def _file_info_to_path(self, file_info: FileInfo) -> str: return self.filepaths.remote_media_thumbnail_rel( server_name=file_info.server_name, file_id=file_info.file_id, - width=file_info.thumbnail_width, - height=file_info.thumbnail_height, - content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, ) return self.filepaths.remote_media_filepath_rel( file_info.server_name, file_info.file_id @@ -279,10 +279,10 @@ def _file_info_to_path(self, file_info: FileInfo) -> str: if file_info.thumbnail: return self.filepaths.local_media_thumbnail_rel( media_id=file_info.file_id, - width=file_info.thumbnail_width, - height=file_info.thumbnail_height, - content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, ) return self.filepaths.local_media_filepath_rel(file_info.file_id) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 12bd745cb..22f43d853 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -26,6 +26,7 @@ from ._base import ( FileInfo, + ThumbnailInfo, parse_media_id, respond_404, respond_with_file, @@ -114,7 +115,7 @@ async def _respond_local_thumbnail( thumbnail_infos, media_id, media_id, - url_cache=media_info["url_cache"], + url_cache=bool(media_info["url_cache"]), server_name=None, ) @@ -149,11 +150,12 @@ async def _select_or_generate_local_thumbnail( server_name=None, file_id=media_id, url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=info["thumbnail_width"], - thumbnail_height=info["thumbnail_height"], - thumbnail_type=info["thumbnail_type"], - thumbnail_method=info["thumbnail_method"], + thumbnail=ThumbnailInfo( + width=info["thumbnail_width"], + height=info["thumbnail_height"], + type=info["thumbnail_type"], + method=info["thumbnail_method"], + ), ) t_type = file_info.thumbnail_type @@ -173,7 +175,7 @@ async def _select_or_generate_local_thumbnail( desired_height, desired_method, desired_type, - url_cache=media_info["url_cache"], + url_cache=bool(media_info["url_cache"]), ) if file_path: @@ -210,11 +212,12 @@ async def _select_or_generate_remote_thumbnail( file_info = FileInfo( server_name=server_name, file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=info["thumbnail_width"], - thumbnail_height=info["thumbnail_height"], - thumbnail_type=info["thumbnail_type"], - thumbnail_method=info["thumbnail_method"], + thumbnail=ThumbnailInfo( + width=info["thumbnail_width"], + height=info["thumbnail_height"], + type=info["thumbnail_type"], + method=info["thumbnail_method"], + ), ) t_type = file_info.thumbnail_type @@ -271,7 +274,7 @@ async def _respond_remote_thumbnail( thumbnail_infos, media_id, media_info["filesystem_id"], - url_cache=None, + url_cache=False, server_name=server_name, ) @@ -285,7 +288,7 @@ async def _select_and_respond_with_thumbnail( thumbnail_infos: List[Dict[str, Any]], media_id: str, file_id: str, - url_cache: Optional[str] = None, + url_cache: bool, server_name: Optional[str] = None, ) -> None: """ @@ -299,7 +302,7 @@ async def _select_and_respond_with_thumbnail( desired_type: The desired content-type of the thumbnail. thumbnail_infos: A list of dictionaries of candidate thumbnails. file_id: The ID of the media that a thumbnail is being requested for. - url_cache: The URL cache value. + url_cache: True if this is from a URL cache. server_name: The server name, if this is a remote thumbnail. """ if thumbnail_infos: @@ -318,13 +321,16 @@ async def _select_and_respond_with_thumbnail( respond_404(request) return + # The thumbnail property must exist. + assert file_info.thumbnail is not None + responder = await self.media_storage.fetch_media(file_info) if responder: await respond_with_responder( request, responder, - file_info.thumbnail_type, - file_info.thumbnail_length, + file_info.thumbnail.type, + file_info.thumbnail.length, ) return @@ -351,18 +357,18 @@ async def _select_and_respond_with_thumbnail( server_name, file_id=file_id, media_id=media_id, - t_width=file_info.thumbnail_width, - t_height=file_info.thumbnail_height, - t_method=file_info.thumbnail_method, - t_type=file_info.thumbnail_type, + t_width=file_info.thumbnail.width, + t_height=file_info.thumbnail.height, + t_method=file_info.thumbnail.method, + t_type=file_info.thumbnail.type, ) else: await self.media_repo.generate_local_exact_thumbnail( media_id=media_id, - t_width=file_info.thumbnail_width, - t_height=file_info.thumbnail_height, - t_method=file_info.thumbnail_method, - t_type=file_info.thumbnail_type, + t_width=file_info.thumbnail.width, + t_height=file_info.thumbnail.height, + t_method=file_info.thumbnail.method, + t_type=file_info.thumbnail.type, url_cache=url_cache, ) @@ -370,8 +376,8 @@ async def _select_and_respond_with_thumbnail( await respond_with_responder( request, responder, - file_info.thumbnail_type, - file_info.thumbnail_length, + file_info.thumbnail.type, + file_info.thumbnail.length, ) else: logger.info("Failed to find any generated thumbnails") @@ -385,7 +391,7 @@ def _select_thumbnail( desired_type: str, thumbnail_infos: List[Dict[str, Any]], file_id: str, - url_cache: Optional[str], + url_cache: bool, server_name: Optional[str], ) -> Optional[FileInfo]: """ @@ -398,7 +404,7 @@ def _select_thumbnail( desired_type: The desired content-type of the thumbnail. thumbnail_infos: A list of dictionaries of candidate thumbnails. file_id: The ID of the media that a thumbnail is being requested for. - url_cache: The URL cache value. + url_cache: True if this is from a URL cache. server_name: The server name, if this is a remote thumbnail. Returns: @@ -495,12 +501,13 @@ def _select_thumbnail( file_id=file_id, url_cache=url_cache, server_name=server_name, - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - thumbnail_length=thumbnail_info["thumbnail_length"], + thumbnail=ThumbnailInfo( + width=thumbnail_info["thumbnail_width"], + height=thumbnail_info["thumbnail_height"], + type=thumbnail_info["thumbnail_type"], + method=thumbnail_info["thumbnail_method"], + length=thumbnail_info["thumbnail_length"], + ), ) # No matching thumbnail was found. From 14b8c0476f93ea2ed3134e75733e45aa0ab6f5a5 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 14 Sep 2021 13:01:30 +0100 Subject: [PATCH 03/74] Prevent logging context going missing on federation request timeout (#10810) In `MatrixFederationHttpClient._send_request()`, we make a HTTP request using an `Agent`, wrap that request in a timeout and await the resulting `Deferred`. On its own, the `Agent` performing the HTTP request correctly stashes and restores the logging context while waiting. The addition of the timeout introduces a path where the logging context is not restored when execution resumes. To address this, we wrap the timeout `Deferred` in a `make_deferred_yieldable()` to stash the logging context and restore it on completion of the `await`. However this is not sufficient, since by the time we construct the timeout `Deferred`, the `Agent` has already stashed and cleared the logging context when using `make_deferred_yieldable()` to produce its `Deferred` for the request. Hence, we wrap the `Agent` request in a `run_in_background()` to "fork" and preserve the logging context so that we can stash and restore it when `await`ing the timeout `Deferred`. This approach is similar to the one used with `defer.gatherResults`. Note that the code is still not fully correct. When a timeout occurs, the request remains running in the background (existing behavior which is nothing to do with the new call to `run_in_background`) and may re-start the logging context after it has finished. --- changelog.d/10810.bugfix | 1 + synapse/http/matrixfederationclient.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10810.bugfix diff --git a/changelog.d/10810.bugfix b/changelog.d/10810.bugfix new file mode 100644 index 000000000..43e91f1f5 --- /dev/null +++ b/changelog.d/10810.bugfix @@ -0,0 +1 @@ +Fix a case where logging contexts would go missing when federation requests time out. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2e9898997..ef10ec093 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -66,7 +66,7 @@ ) from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.logging import opentracing -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.types import JsonDict from synapse.util import json_decoder @@ -553,20 +553,29 @@ async def _send_request( with Measure(self.clock, "outbound_request"): # we don't want all the fancy cookie and redirect handling # that treq.request gives: just use the raw Agent. - request_deferred = self.agent.request( + + # To preserve the logging context, the timeout is treated + # in a similar way to `defer.gatherResults`: + # * Each logging context-preserving fork is wrapped in + # `run_in_background`. In this case there is only one, + # since the timeout fork is not logging-context aware. + # * The `Deferred` that joins the forks back together is + # wrapped in `make_deferred_yieldable` to restore the + # logging context regardless of the path taken. + request_deferred = run_in_background( + self.agent.request, method_bytes, url_bytes, headers=Headers(headers_dict), bodyProducer=producer, ) - request_deferred = timeout_deferred( request_deferred, timeout=_sec_timeout, reactor=self.reactor, ) - response = await request_deferred + response = await make_deferred_yieldable(request_deferred) except DNSLookupError as e: raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: From 8eb7cb2e0dd66d2eb350c1822fb448e09148cd7e Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 14 Sep 2021 16:35:53 +0100 Subject: [PATCH 04/74] Make StateFilter frozen so we can hash it (#10816) Also enables Mypy for related tests. --- changelog.d/10816.misc | 1 + mypy.ini | 1 + synapse/storage/state.py | 45 +++++++++++++++++++++++++----------- tests/storage/test_state.py | 46 +++++++++++++++++++++++-------------- 4 files changed, 63 insertions(+), 30 deletions(-) create mode 100644 changelog.d/10816.misc diff --git a/changelog.d/10816.misc b/changelog.d/10816.misc new file mode 100644 index 000000000..2ca55b334 --- /dev/null +++ b/changelog.d/10816.misc @@ -0,0 +1 @@ +Make `StateFilter` frozen so it is hashable. diff --git a/mypy.ini b/mypy.ini index 09ffdda1b..60dadc478 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,6 +86,7 @@ files = tests/handlers/test_sync.py, tests/rest/client/test_login.py, tests/rest/client/test_auth.py, + tests/storage/test_state.py, tests/util/test_itertools.py, tests/util/test_stream_change_cache.py diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e5400d681..c76529cb5 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -25,12 +25,15 @@ ) import attr +from frozendict import frozendict from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: + from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad + from synapse.server import HomeServer from synapse.storage.databases import Databases @@ -40,7 +43,7 @@ T = TypeVar("T") -@attr.s(slots=True) +@attr.s(slots=True, frozen=True) class StateFilter: """A filter used when querying for state. @@ -53,14 +56,19 @@ class StateFilter: appear in `types`. """ - types = attr.ib(type=Dict[str, Optional[Set[str]]]) + types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]") include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: - self.types = {k: v for k, v in self.types.items() if v is not None} + # this is needed to work around the fact that StateFilter is frozen + object.__setattr__( + self, + "types", + frozendict({k: v for k, v in self.types.items() if v is not None}), + ) @staticmethod def all() -> "StateFilter": @@ -69,7 +77,7 @@ def all() -> "StateFilter": Returns: The new state filter. """ - return StateFilter(types={}, include_others=True) + return StateFilter(types=frozendict(), include_others=True) @staticmethod def none() -> "StateFilter": @@ -78,7 +86,7 @@ def none() -> "StateFilter": Returns: The new state filter. """ - return StateFilter(types={}, include_others=False) + return StateFilter(types=frozendict(), include_others=False) @staticmethod def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": @@ -103,7 +111,12 @@ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": type_dict.setdefault(typ, set()).add(s) # type: ignore - return StateFilter(types=type_dict) + return StateFilter( + types=frozendict( + (k, frozenset(v) if v is not None else None) + for k, v in type_dict.items() + ) + ) @staticmethod def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": @@ -116,7 +129,10 @@ def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": Returns: The new state filter """ - return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) + return StateFilter( + types=frozendict({EventTypes.Member: frozenset(members)}), + include_others=True, + ) def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed @@ -173,7 +189,7 @@ def return_expanded(self) -> "StateFilter": # We want to return all non-members, but only particular # memberships return StateFilter( - types={EventTypes.Member: self.types[EventTypes.Member]}, + types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) @@ -245,14 +261,15 @@ def max_entries_returned(self) -> Optional[int]: return len(self.concrete_types()) - def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]: - """Returns the state filtered with by this StateFilter + def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: + """Returns the state filtered with by this StateFilter. Args: state: The state map to filter Returns: - The filtered state map + The filtered state map. + This is a copy, so it's safe to mutate. """ if self.is_full(): return dict(state_dict) @@ -324,14 +341,16 @@ def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: if state_keys is None: member_filter = StateFilter.all() else: - member_filter = StateFilter({EventTypes.Member: state_keys}) + member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) elif self.include_others: member_filter = StateFilter.all() else: member_filter = StateFilter.none() non_member_filter = StateFilter( - types={k: v for k, v in self.types.items() if k != EventTypes.Member}, + types=frozendict( + {k: v for k, v in self.types.items() if k != EventTypes.Member} + ), include_others=self.include_others, ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 869526459..32060f2ab 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -14,6 +14,8 @@ import logging +from frozendict import frozendict + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter @@ -183,7 +185,9 @@ def test_get_state_for_event(self): self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, + types=frozendict( + {EventTypes.Member: frozenset({self.u_alice.to_string()})} + ), include_others=True, ), ) @@ -203,7 +207,8 @@ def test_get_state_for_event(self): self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: frozenset()}), + include_others=True, ), ) ) @@ -228,7 +233,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: frozenset()}), include_others=True ), ) @@ -245,7 +250,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: frozenset()}), include_others=True ), ) @@ -258,7 +263,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -275,7 +280,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -295,7 +300,8 @@ def test_get_state_for_event(self): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + include_others=True, ), ) @@ -312,7 +318,8 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + include_others=True, ), ) @@ -325,7 +332,8 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=False + types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + include_others=False, ), ) @@ -375,7 +383,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: frozenset()}), include_others=True ), ) @@ -387,7 +395,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: frozenset()}), include_others=True ), ) @@ -400,7 +408,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -411,7 +419,7 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -430,7 +438,8 @@ def test_get_state_for_event(self): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + include_others=True, ), ) @@ -441,7 +450,8 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + include_others=True, ), ) @@ -454,7 +464,8 @@ def test_get_state_for_event(self): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=False + types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + include_others=False, ), ) @@ -465,7 +476,8 @@ def test_get_state_for_event(self): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=False + types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + include_others=False, ), ) From 1c555527b351a8b0dcdf54ba7091141347af2a73 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 15 Sep 2021 03:30:58 -0500 Subject: [PATCH 05/74] Split out `/batch_send` meta events to their own fields (MSC2716) (#10777) --- changelog.d/10777.misc | 1 + synapse/rest/client/room_batch.py | 29 ++++++++++++++++++----------- 2 files changed, 19 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10777.misc diff --git a/changelog.d/10777.misc b/changelog.d/10777.misc new file mode 100644 index 000000000..aed78a16f --- /dev/null +++ b/changelog.d/10777.misc @@ -0,0 +1 @@ +Split out [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) meta events to their own fields in the `/batch_send` response. diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index ed9697844..783fecf19 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -14,6 +14,7 @@ import logging import re +from http import HTTPStatus from typing import TYPE_CHECKING, Awaitable, List, Tuple from twisted.web.server import Request @@ -179,7 +180,7 @@ async def on_POST( if not requester.app_service: raise AuthError( - 403, + HTTPStatus.FORBIDDEN, "Only application services can use the /batchsend endpoint", ) @@ -192,7 +193,7 @@ async def on_POST( if prev_events_from_query is None: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "prev_event query parameter is required when inserting historical messages back in time", errcode=Codes.MISSING_PARAM, ) @@ -213,7 +214,7 @@ async def on_POST( prev_state_ids = list(prev_state_map.values()) auth_event_ids = prev_state_ids - state_events_at_start = [] + state_event_ids_at_start = [] for state_event in body["state_events_at_start"]: assert_params_in_dict( state_event, ["type", "origin_server_ts", "content", "sender"] @@ -279,7 +280,7 @@ async def on_POST( ) event_id = event.event_id - state_events_at_start.append(event_id) + state_event_ids_at_start.append(event_id) auth_event_ids.append(event_id) events_to_create = body["events"] @@ -424,20 +425,26 @@ async def on_POST( context=context, ) - # Add the base_insertion_event to the bottom of the list we return - if base_insertion_event is not None: - event_ids.append(base_insertion_event.event_id) + insertion_event_id = event_ids[0] + chunk_event_id = event_ids[-1] + historical_event_ids = event_ids[1:-1] - return 200, { - "state_events": state_events_at_start, - "events": event_ids, + response_dict = { + "state_event_ids": state_event_ids_at_start, + "event_ids": historical_event_ids, "next_chunk_id": insertion_event["content"][ EventContentFields.MSC2716_NEXT_CHUNK_ID ], + "insertion_event_id": insertion_event_id, + "chunk_event_id": chunk_event_id, } + if base_insertion_event is not None: + response_dict["base_insertion_event_id"] = base_insertion_event.event_id + + return HTTPStatus.OK, response_dict def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: - return 501, "Not implemented" + return HTTPStatus.NOT_IMPLEMENTED, "Not implemented" def on_PUT( self, request: SynapseRequest, room_id: str From 145c006ef76ab3955fb8294203cb8e6e61372cd1 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 15 Sep 2021 03:34:30 -0500 Subject: [PATCH 06/74] Verify `?chunk_id` actually corresponds to an insertion event that exists (MSC2716) (#10776) --- changelog.d/10776.feature | 1 + synapse/rest/client/room_batch.py | 13 ++++++- synapse/storage/databases/main/__init__.py | 2 ++ synapse/storage/databases/main/room_batch.py | 36 ++++++++++++++++++++ 4 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10776.feature create mode 100644 synapse/storage/databases/main/room_batch.py diff --git a/changelog.d/10776.feature b/changelog.d/10776.feature new file mode 100644 index 000000000..aec0685a3 --- /dev/null +++ b/changelog.d/10776.feature @@ -0,0 +1 @@ +Only allow the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send?chunk_id=xxx` endpoint to connect to an already existing insertion event. diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 783fecf19..d466edeec 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -300,7 +300,18 @@ async def on_POST( # event, which causes the HS to ask for the state at the start of # the chunk later. prev_event_ids = [fake_prev_event_id] - # TODO: Verify the chunk_id_from_query corresponds to an insertion event + + # Verify the chunk_id_from_query corresponds to an actual insertion event + # and have the chunk connected. + corresponding_insertion_event_id = ( + await self.store.get_insertion_event_by_chunk_id(chunk_id_from_query) + ) + if corresponding_insertion_event_id is None: + raise SynapseError( + 400, + "No insertion event corresponds to the given ?chunk_id", + errcode=Codes.INVALID_PARAM, + ) pass # Otherwise, create an insertion event to act as a starting point. # diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 1dc347f0c..5c21402de 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -61,6 +61,7 @@ from .rejections import RejectionsStore from .relations import RelationsStore from .room import RoomStore +from .room_batch import RoomBatchStore from .roommember import RoomMemberStore from .search import SearchStore from .session import SessionStore @@ -81,6 +82,7 @@ class DataStore( EventsBackgroundUpdatesStore, RoomMemberStore, RoomStore, + RoomBatchStore, RegistrationStore, StreamStore, ProfileStore, diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py new file mode 100644 index 000000000..54fa361d3 --- /dev/null +++ b/synapse/storage/databases/main/room_batch.py @@ -0,0 +1,36 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from synapse.storage._base import SQLBaseStore + + +class RoomBatchStore(SQLBaseStore): + async def get_insertion_event_by_chunk_id(self, chunk_id: str) -> Optional[str]: + """Retrieve a insertion event ID. + + Args: + chunk_id: The chunk ID of the insertion event to retrieve. + + Returns: + The event_id of an insertion event, or None if there is no known + insertion event for the given insertion event. + """ + return await self.db_pool.simple_select_one_onecol( + table="insertion_events", + keyvalues={"next_chunk_id": chunk_id}, + retcol="event_id", + allow_none=True, + ) From 8c7a531e277f98ac6b7981b9738649f3a70feb94 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Sep 2021 08:34:52 -0400 Subject: [PATCH 07/74] Use direct references for some configuration variables (part 2) (#10812) --- changelog.d/10812.misc | 1 + synapse/api/auth.py | 4 ++-- synapse/api/auth_blocking.py | 16 +++++++++------- synapse/crypto/context_factory.py | 8 ++++---- synapse/crypto/keyring.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/sender/__init__.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/presence.py | 12 ++++++------ synapse/handlers/sync.py | 2 +- synapse/http/client.py | 7 +++++-- synapse/push/httppusher.py | 2 +- synapse/push/mailer.py | 10 +++++----- synapse/push/pusher.py | 8 ++++---- synapse/push/pusherpool.py | 2 +- synapse/server.py | 16 ++++++++-------- 16 files changed, 51 insertions(+), 45 deletions(-) create mode 100644 changelog.d/10812.misc diff --git a/changelog.d/10812.misc b/changelog.d/10812.misc new file mode 100644 index 000000000..586a0b3a9 --- /dev/null +++ b/changelog.d/10812.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 05699714e..e6ca9232e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -70,8 +70,8 @@ def __init__(self, hs: "HomeServer"): self._auth_blocking = AuthBlocking(self.hs) - self._track_appservice_user_ips = hs.config.track_appservice_user_ips - self._macaroon_secret_key = hs.config.macaroon_secret_key + self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips + self._macaroon_secret_key = hs.config.key.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users async def check_user_in_room( diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index e6bced93d..a3b95f4de 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -30,13 +30,15 @@ class AuthBlocking: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self._server_notices_mxid = hs.config.server_notices_mxid - self._hs_disabled = hs.config.hs_disabled - self._hs_disabled_message = hs.config.hs_disabled_message - self._admin_contact = hs.config.admin_contact - self._max_mau_value = hs.config.max_mau_value - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid + self._hs_disabled = hs.config.server.hs_disabled + self._hs_disabled_message = hs.config.server.hs_disabled_message + self._admin_contact = hs.config.server.admin_contact + self._max_mau_value = hs.config.server.max_mau_value + self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau + self._mau_limits_reserved_threepids = ( + hs.config.server.mau_limits_reserved_threepids + ) self._server_name = hs.hostname self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index c644b4dfc..d310976fe 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -102,7 +102,7 @@ def __init__(self, config): self._config = config # Check if we're using a custom list of a CA certificates - trust_root = config.federation_ca_trust_root + trust_root = config.tls.federation_ca_trust_root if trust_root is None: # Use CA root certs provided by OpenSSL trust_root = platformTrust() @@ -113,7 +113,7 @@ def __init__(self, config): # moving to TLS 1.2 by default, we want to respect the config option if # it is set to 1.0 (which the alternate option, raiseMinimumTo, will not # let us do). - minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version] + minTLS = _TLS_VERSION_MAP[config.tls.federation_client_minimum_tls_version] _verify_ssl = CertificateOptions( trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS @@ -125,10 +125,10 @@ def __init__(self, config): self._no_verify_ssl_context = _no_verify_ssl.getContext() self._no_verify_ssl_context.set_info_callback(_context_info_cb) - self._should_verify = self._config.federation_verify_certificates + self._should_verify = self._config.tls.federation_verify_certificates self._federation_certificate_verification_whitelist = ( - self._config.federation_certificate_verification_whitelist + self._config.tls.federation_certificate_verification_whitelist ) def get_options(self, host: bytes): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 9e9b1c1c8..e1e13a241 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -572,7 +572,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() - self.key_servers = self.config.key_servers + self.key_servers = self.config.key.key_servers async def _fetch_keys( self, keys_to_fetch: List[_FetchKeyRequest] diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 214ee948f..638959cbe 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1237,7 +1237,7 @@ def register_instances_for_edu( self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict) -> None: - if not self.config.use_presence and edu_type == EduTypes.Presence: + if not self.config.server.use_presence and edu_type == EduTypes.Presence: return # Check if we have a handler on this instance diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4671ac024..720d7bd74 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -594,7 +594,7 @@ def send_presence_to_destinations( destinations (list[str]) """ - if not states or not self.hs.config.use_presence: + if not states or not self.hs.config.server.use_presence: # No-op if presence is disabled. return diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 4e8f7f1d8..0b24b40eb 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -413,7 +413,7 @@ async def _room_initial_sync_joined( async def get_presence(): # If presence is disabled, return an empty list - if not self.hs.config.use_presence: + if not self.hs.config.server.use_presence: return [] states = await presence_handler.get_states( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 39b39cd3e..4ab962a84 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -374,7 +374,7 @@ def __init__(self, hs: "HomeServer"): self._presence_writer_instance = hs.config.worker.writers.presence[0] - self._presence_enabled = hs.config.use_presence + self._presence_enabled = hs.config.server.use_presence # Route presence EDUs to the right worker hs.get_federation_registry().register_instances_for_edu( @@ -584,7 +584,7 @@ async def set_state( user_id = target_user.to_string() # If presence is disabled, no-op - if not self.hs.config.use_presence: + if not self.hs.config.server.use_presence: return # Proxy request to instance that writes presence @@ -601,7 +601,7 @@ async def bump_presence_active_time(self, user: UserID) -> None: with the app. """ # If presence is disabled, no-op - if not self.hs.config.use_presence: + if not self.hs.config.server.use_presence: return # Proxy request to instance that writes presence @@ -618,7 +618,7 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() - self._presence_enabled = hs.config.use_presence + self._presence_enabled = hs.config.server.use_presence federation_registry = hs.get_federation_registry() @@ -916,7 +916,7 @@ async def bump_presence_active_time(self, user: UserID) -> None: with the app. """ # If presence is disabled, no-op - if not self.hs.config.use_presence: + if not self.hs.config.server.use_presence: return user_id = user.to_string() @@ -949,7 +949,7 @@ async def user_syncing( """ # Override if it should affect the user's presence, if presence is # disabled. - if not self.hs.config.use_presence: + if not self.hs.config.server.use_presence: affect_presence = False if affect_presence: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e017b28cd..91d24534e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1090,7 +1090,7 @@ async def generate_sync_result( block_all_presence_data = ( since_token is None and sync_config.filter_collection.blocks_all_presence() ) - if self.hs_config.use_presence and not block_all_presence_data: + if self.hs_config.server.use_presence and not block_all_presence_data: logger.debug("Fetching presence data") await self._generate_sync_entry_for_presence( sync_result_builder, diff --git a/synapse/http/client.py b/synapse/http/client.py index c2ea51ee1..5204c3d08 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -321,8 +321,11 @@ def __init__( self.user_agent = hs.version_string self.clock = hs.get_clock() - if hs.config.user_agent_suffix: - self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + if hs.config.server.user_agent_suffix: + self.user_agent = "%s %s" % ( + self.user_agent, + hs.config.server.user_agent_suffix, + ) # We use this for our body producers to ensure that they use the correct # reactor. diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 36aabd842..065948f98 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -365,7 +365,7 @@ async def _build_notification_dict( if event.type == "m.room.member" and event.is_state(): d["notification"]["membership"] = event.content["membership"] d["notification"]["user_is_target"] = event.state_key == self.user_id - if self.hs.config.push_include_content and event.content: + if self.hs.config.push.push_include_content and event.content: d["notification"]["content"] = event.content # We no longer send aliases separately, instead, we send the human diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b89c6e6f2..e38e3c5d4 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -110,7 +110,7 @@ def __init__( self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() self.app_name = app_name - self.email_subjects: EmailSubjectConfig = hs.config.email_subjects + self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects logger.info("Created Mailer for app_name %s" % app_name) @@ -796,8 +796,8 @@ def _make_room_link(self, room_id: str) -> str: Returns: A link to open a room in the web client. """ - if self.hs.config.email_riot_base_url: - base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) + if self.hs.config.email.email_riot_base_url: + base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS base_url = "https://vector.im/beta/#/room" @@ -815,9 +815,9 @@ def _make_notif_link(self, notif: Dict[str, str]) -> str: Returns: A link to open the notification in the web client. """ - if self.hs.config.email_riot_base_url: + if self.hs.config.email.email_riot_base_url: return "%s/#/room/%s/%s" % ( - self.hs.config.email_riot_base_url, + self.hs.config.email.email_riot_base_url, notif["room_id"], notif["event_id"], ) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 021275437..29ed346d3 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -35,12 +35,12 @@ def __init__(self, hs: "HomeServer"): "http": HttpPusher } - logger.info("email enable notifs: %r", hs.config.email_enable_notifs) - if hs.config.email_enable_notifs: + logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs) + if hs.config.email.email_enable_notifs: self.mailers: Dict[str, Mailer] = {} - self._notif_template_html = hs.config.email_notif_template_html - self._notif_template_text = hs.config.email_notif_template_text + self._notif_template_html = hs.config.email.email_notif_template_html + self._notif_template_text = hs.config.email.email_notif_template_text self.pusher_types["email"] = self._create_email_pusher diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index a1436f393..26735447a 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,7 +62,7 @@ def __init__(self, hs: "HomeServer"): self.clock = self.hs.get_clock() # We shard the handling of push notifications by user ID. - self._pusher_shard_config = hs.config.push.pusher_shard_config + self._pusher_shard_config = hs.config.worker.pusher_shard_config self._instance_name = hs.get_instance_name() self._should_start_pushers = ( self._instance_name in self._pusher_shard_config.instances diff --git a/synapse/server.py b/synapse/server.py index 4777ef585..637eb15b7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -392,7 +392,7 @@ def get_auth(self) -> Auth: @cache_in_self def get_http_client_context_factory(self) -> IPolicyForHTTPS: - if self.config.use_insecure_ssl_client_just_for_testing_do_not_use: + if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use: return InsecureInterceptableContextFactory() return RegularPolicyForHTTPS() @@ -418,8 +418,8 @@ def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: """ return SimpleHttpClient( self, - ip_whitelist=self.config.ip_range_whitelist, - ip_blacklist=self.config.ip_range_blacklist, + ip_whitelist=self.config.server.ip_range_whitelist, + ip_blacklist=self.config.server.ip_range_blacklist, use_proxy=True, ) @@ -801,18 +801,18 @@ def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: logger.info( "Connecting to redis (host=%r port=%r) for external cache", - self.config.redis_host, - self.config.redis_port, + self.config.redis.redis_host, + self.config.redis.redis_port, ) return lazyConnection( hs=self, - host=self.config.redis_host, - port=self.config.redis_port, + host=self.config.redis.redis_host, + port=self.config.redis.redis_port, password=self.config.redis.redis_password, reconnect=True, ) def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" - return self.config.send_federation + return self.config.worker.send_federation From b93259082c7d8d3fe8376a646e130213d90069dc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Sep 2021 08:45:32 -0400 Subject: [PATCH 08/74] Add missing type hints to non-client REST servlets. (#10817) Including admin, consent, key, synapse, and media. All REST servlets (the synapse.rest module) now require typed method definitions. --- changelog.d/10785.misc | 2 +- changelog.d/10817.misc | 1 + mypy.ini | 2 +- synapse/rest/__init__.py | 11 ++++-- synapse/rest/admin/devices.py | 2 +- synapse/rest/admin/server_notice_servlet.py | 2 +- synapse/rest/admin/users.py | 2 +- synapse/rest/consent/consent_resource.py | 39 ++++++++----------- synapse/rest/health.py | 3 +- synapse/rest/key/v2/__init__.py | 7 +++- synapse/rest/key/v2/local_key_resource.py | 15 ++++--- synapse/rest/key/v2/remote_key_resource.py | 30 +++++++++----- synapse/rest/media/v1/_base.py | 24 +++++++----- synapse/rest/media/v1/filepath.py | 6 +-- synapse/rest/media/v1/media_repository.py | 8 +++- synapse/rest/media/v1/media_storage.py | 32 ++++++++++++--- synapse/rest/media/v1/preview_url_resource.py | 5 ++- synapse/rest/media/v1/storage_provider.py | 4 +- synapse/rest/media/v1/thumbnailer.py | 2 +- .../rest/synapse/client/new_user_consent.py | 6 +-- synapse/rest/synapse/client/oidc/__init__.py | 6 ++- .../synapse/client/oidc/callback_resource.py | 5 ++- synapse/rest/synapse/client/pick_username.py | 9 +++-- synapse/rest/synapse/client/saml2/__init__.py | 6 ++- .../synapse/client/saml2/metadata_resource.py | 9 ++++- .../synapse/client/saml2/response_resource.py | 7 +++- synapse/rest/well_known.py | 20 +++++----- 27 files changed, 169 insertions(+), 96 deletions(-) create mode 100644 changelog.d/10817.misc diff --git a/changelog.d/10785.misc b/changelog.d/10785.misc index 3d7f91d51..39a37b90b 100644 --- a/changelog.d/10785.misc +++ b/changelog.d/10785.misc @@ -1 +1 @@ -Convert the internal `FileInfo` class to attrs and add type hints. +Add missing type hints to REST servlets. diff --git a/changelog.d/10817.misc b/changelog.d/10817.misc new file mode 100644 index 000000000..39a37b90b --- /dev/null +++ b/changelog.d/10817.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/mypy.ini b/mypy.ini index 60dadc478..e9052fa01 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,7 +90,7 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py -[mypy-synapse.rest.client.*] +[mypy-synapse.rest.*] disallow_untyped_defs = True [mypy-synapse.util.batching_queue] diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3adc57612..e04af705e 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import JsonResource +from typing import TYPE_CHECKING + +from synapse.http.server import HttpServer, JsonResource from synapse.rest import admin from synapse.rest.client import ( account, @@ -57,6 +59,9 @@ voip, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + class ClientRestResource(JsonResource): """Matrix Client API REST resource. @@ -68,12 +73,12 @@ class ClientRestResource(JsonResource): * etc """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): JsonResource.__init__(self, hs, canonical_json=False) self.register_servlets(self, hs) @staticmethod - def register_servlets(client_resource, hs): + def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None: versions.register_servlets(hs, client_resource) # Deprecated in r0 diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 5715190a7..a6fa03c90 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -47,7 +47,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() async def on_GET( - self, request: SynapseRequest, user_id, device_id: str + self, request: SynapseRequest, user_id: str, device_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index f5a38c267..19f84f33f 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -57,7 +57,7 @@ def __init__(self, hs: "HomeServer"): self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) - def register(self, json_resource: HttpServer): + def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" json_resource.register_paths( "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c1a1ba645..681e49182 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -419,7 +419,7 @@ def __init__(self, hs: "HomeServer"): self.nonces: Dict[str, int] = {} self.hs = hs - def _clear_old_nonces(self): + def _clear_old_nonces(self) -> None: """ Clear out old nonces that are older than NONCE_TIMEOUT. """ diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 11f732083..06e0fbde2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -17,17 +17,22 @@ from hashlib import sha256 from http import HTTPStatus from os import path -from typing import Dict, List +from typing import TYPE_CHECKING, Any, Dict, List import jinja2 from jinja2 import TemplateNotFound +from twisted.web.server import Request + from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_bytes_from_args, parse_string from synapse.types import UserID +if TYPE_CHECKING: + from synapse.server import HomeServer + # language to use for the templates. TODO: figure this out from Accept-Language TEMPLATE_LANGUAGE = "en" @@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource): against the user. """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): homeserver - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -106,18 +107,14 @@ def __init__(self, hs): self._hmac_secret = hs.config.form_secret.encode("utf-8") - async def _async_render_GET(self, request): - """ - Args: - request (twisted.web.http.Request): - """ + async def _async_render_GET(self, request: Request) -> None: version = parse_string(request, "v", default=self._default_consent_version) username = parse_string(request, "u", default="") userhmac = None has_consented = False public_version = username == "" if not public_version: - args: Dict[bytes, List[bytes]] = request.args + args: Dict[bytes, List[bytes]] = request.args # type: ignore userhmac_bytes = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac_bytes) @@ -147,14 +144,10 @@ async def _async_render_GET(self, request): except TemplateNotFound: raise NotFoundError("Unknown policy version") - async def _async_render_POST(self, request): - """ - Args: - request (twisted.web.http.Request): - """ + async def _async_render_POST(self, request: Request) -> None: version = parse_string(request, "v", required=True) username = parse_string(request, "u", required=True) - args: Dict[bytes, List[bytes]] = request.args + args: Dict[bytes, List[bytes]] = request.args # type: ignore userhmac = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac) @@ -177,7 +170,9 @@ async def _async_render_POST(self, request): except TemplateNotFound: raise NotFoundError("success.html not found") - def _render_template(self, request, template_name, **template_args): + def _render_template( + self, request: Request, template_name: str, **template_args: Any + ) -> None: # get_template checks for ".." so we don't need to worry too much # about path traversal here. template_html = self._jinja_env.get_template( @@ -186,11 +181,11 @@ def _render_template(self, request, template_name, **template_args): html = template_html.render(**template_args) respond_with_html(request, 200, html) - def _check_hash(self, userid, userhmac): + def _check_hash(self, userid: str, userhmac: bytes) -> None: """ Args: - userid (unicode): - userhmac (bytes): + userid: + userhmac: Raises: SynapseError if the hash doesn't match diff --git a/synapse/rest/health.py b/synapse/rest/health.py index 4487b54ab..78df7af2c 100644 --- a/synapse/rest/health.py +++ b/synapse/rest/health.py @@ -13,6 +13,7 @@ # limitations under the License. from twisted.web.resource import Resource +from twisted.web.server import Request class HealthResource(Resource): @@ -25,6 +26,6 @@ class HealthResource(Resource): isLeaf = 1 - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: request.setHeader(b"Content-Type", b"text/plain") return b"OK" diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index c6c63073e..7f8c1de1f 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from twisted.web.resource import Resource from .local_key_resource import LocalKey from .remote_key_resource import RemoteKey +if TYPE_CHECKING: + from synapse.server import HomeServer + class KeyApiV2Resource(Resource): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"server", LocalKey(hs)) self.putChild(b"query", RemoteKey(hs)) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 25f6eb842..ebe243bcf 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.http.server import respond_with_json_bytes +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -58,18 +63,18 @@ class LocalKey(Resource): isLeaf = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.config = hs.config self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) Resource.__init__(self) - def update_response_body(self, time_now_msec): + def update_response_body(self, time_now_msec: int) -> None: refresh_interval = self.config.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) self.response_body = encode_canonical_json(self.response_json_object()) - def response_json_object(self): + def response_json_object(self) -> JsonDict: verify_keys = {} for key in self.config.signing_key: verify_key_bytes = key.verify_key.encode() @@ -94,7 +99,7 @@ def response_json_object(self): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request): + def render_GET(self, request: Request) -> int: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 744360e5f..d8fd7938a 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,17 +13,23 @@ # limitations under the License. import logging -from typing import Dict +from typing import TYPE_CHECKING, Dict from signedjson.sign import sign_json +from twisted.web.server import Request + from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource): isLeaf = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.fetcher = ServerKeyFetcher(hs) @@ -94,7 +100,8 @@ def __init__(self, hs): self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.config = hs.config - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: + assert request.postpath is not None if len(request.postpath) == 1: (server,) = request.postpath query: dict = {server.decode("ascii"): {}} @@ -110,14 +117,19 @@ async def _async_render_GET(self, request): await self.query_keys(request, query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: Request) -> None: content = parse_json_object_from_request(request) query = content["server_keys"] await self.query_keys(request, query, query_remote_on_cache_miss=True) - async def query_keys(self, request, query, query_remote_on_cache_miss=False): + async def query_keys( + self, + request: Request, + query: JsonDict, + query_remote_on_cache_miss: bool = False, + ) -> None: logger.info("Handling query for keys %r", query) store_queries = [] @@ -142,8 +154,8 @@ async def query_keys(self, request, query, query_remote_on_cache_miss=False): # Note that the value is unused. cache_misses: Dict[str, Dict[str, int]] = {} - for (server_name, key_id, _), results in cached.items(): - results = [(result["ts_added_ms"], result) for result in results] + for (server_name, key_id, _), key_results in cached.items(): + results = [(result["ts_added_ms"], result) for result in key_results] if not results and key_id is not None: cache_misses.setdefault(server_name, {})[key_id] = 0 @@ -230,6 +242,6 @@ async def query_keys(self, request, query, query_remote_on_cache_miss=False): signed_keys.append(key_json) - results = {"server_keys": signed_keys} + response = {"server_keys": signed_keys} - respond_with_json(request, 200, results, canonical_json=True) + respond_with_json(request, 200, response, canonical_json=True) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 814f4309f..7c881f2bd 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -16,7 +16,8 @@ import logging import os import urllib -from typing import Awaitable, Dict, Generator, List, Optional, Tuple +from types import TracebackType +from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type import attr @@ -122,7 +123,7 @@ def add_file_headers( upload_name: The name of the requested file, if any. """ - def _quote(x): + def _quote(x: str) -> str: return urllib.parse.quote(x.encode("utf-8")) # Default to a UTF-8 charset for text content types. @@ -282,10 +283,15 @@ def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """ pass - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: pass @@ -317,31 +323,31 @@ class FileInfo: # The below properties exist to maintain compatibility with third-party modules. @property - def thumbnail_width(self): + def thumbnail_width(self) -> Optional[int]: if not self.thumbnail: return None return self.thumbnail.width @property - def thumbnail_height(self): + def thumbnail_height(self) -> Optional[int]: if not self.thumbnail: return None return self.thumbnail.height @property - def thumbnail_method(self): + def thumbnail_method(self) -> Optional[str]: if not self.thumbnail: return None return self.thumbnail.method @property - def thumbnail_type(self): + def thumbnail_type(self) -> Optional[str]: if not self.thumbnail: return None return self.thumbnail.type @property - def thumbnail_length(self): + def thumbnail_length(self) -> Optional[int]: if not self.thumbnail: return None return self.thumbnail.length diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 09531ebf5..39bbe4e87 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -16,7 +16,7 @@ import functools import os import re -from typing import Callable, List +from typing import Any, Callable, List NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") @@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]: """ @functools.wraps(func) - def _wrapped(self, *args, **kwargs): + def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str: path = func(self, *args, **kwargs) return os.path.join(self.base_path, path) @@ -129,7 +129,7 @@ def remote_media_thumbnail_rel( # using the new path. def remote_media_thumbnail_rel_legacy( self, server_name: str, file_id: str, width: int, height: int, content_type: str - ): + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 40ce8d2bc..50e4c9e29 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -21,6 +21,7 @@ import twisted.internet.error import twisted.web.http +from twisted.internet.defer import Deferred from twisted.web.resource import Resource from twisted.web.server import Request @@ -32,6 +33,7 @@ SynapseError, ) from synapse.config._base import ConfigError +from synapse.config.repository import ThumbnailRequirement from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID @@ -114,7 +116,7 @@ def __init__(self, hs: "HomeServer"): self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS ) - def _start_update_recently_accessed(self): + def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( "update_recently_accessed_media", self._update_recently_accessed ) @@ -469,7 +471,9 @@ async def _download_remote_file( return media_info - def _get_thumbnail_requirements(self, media_type): + def _get_thumbnail_requirements( + self, media_type: str + ) -> Tuple[ThumbnailRequirement, ...]: scpos = media_type.find(";") if scpos > 0: media_type = media_type[:scpos] diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index c0bb40c11..01fada8fb 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -15,7 +15,20 @@ import logging import os import shutil -from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + Awaitable, + BinaryIO, + Callable, + Generator, + Optional, + Sequence, + Tuple, + Type, +) import attr @@ -83,12 +96,14 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str: return fname - async def write_to_file(self, source: IO, output: IO): + async def write_to_file(self, source: IO, output: IO) -> None: """Asynchronously write the `source` to `output`.""" await defer_to_thread(self.reactor, _write_file_synchronously, source, output) @contextlib.contextmanager - def store_into_file(self, file_info: FileInfo): + def store_into_file( + self, file_info: FileInfo + ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: """Context manager used to get a file like object to write into, as described by file_info. @@ -125,7 +140,7 @@ def store_into_file(self, file_info: FileInfo): try: with open(fname, "wb") as f: - async def finish(): + async def finish() -> None: # Ensure that all writes have been flushed and close the # file. f.flush() @@ -315,7 +330,12 @@ def write_to_consumer(self, consumer: IConsumer) -> Deferred: FileSender().beginFileTransfer(self.open_file, consumer) ) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.open_file.close() @@ -339,7 +359,7 @@ class ReadableFileWrapper: clock = attr.ib(type=Clock) path = attr.ib(type=str) - async def write_chunks_to(self, callback: Callable[[bytes], None]): + async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None: """Reads the file in chunks and calls the callback with each chunk.""" with open(self.path, "rb") as file: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index f108da05d..fe0627d9b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -27,6 +27,7 @@ import attr +from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError from twisted.web.server import Request @@ -473,7 +474,7 @@ async def _download_url(self, url: str, user: str) -> MediaInfo: etag=etag, ) - def _start_expire_url_cache_data(self): + def _start_expire_url_cache_data(self) -> Deferred: return run_as_background_process( "expire_url_cache_data", self._expire_url_cache_data ) @@ -782,7 +783,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: def _iterate_over_text( - tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] + tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] ) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 0ff6ad3c0..6c9969e55 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -99,7 +99,7 @@ async def store_file(self, path: str, file_info: FileInfo) -> None: await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore else: # TODO: Handle errors. - async def store(): + async def store() -> None: try: return await maybe_awaitable( self.backend.store_file(path, file_info) @@ -128,7 +128,7 @@ def __init__(self, hs: "HomeServer", config: str): self.cache_directory = hs.config.media_store_path self.base_directory = config - def __str__(self): + def __str__(self) -> str: return "FileStorageProviderBackend[%s]" % (self.base_directory,) async def store_file(self, path: str, file_info: FileInfo) -> None: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index a65e9e180..df54a4064 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -41,7 +41,7 @@ class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} @staticmethod - def set_limits(max_image_pixels: int): + def set_limits(max_image_pixels: int) -> None: Image.MAX_IMAGE_PIXELS = max_image_pixels def __init__(self, input_path: str): diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 67c1ed1f5..1c1c7b361 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generator from twisted.web.server import Request @@ -45,7 +45,7 @@ def __init__(self, hs: "HomeServer"): self._server_name = hs.hostname self._consent_version = hs.config.consent.user_consent_version - def template_search_dirs(): + def template_search_dirs() -> Generator[str, None, None]: if hs.config.server.custom_template_directory: yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: @@ -88,7 +88,7 @@ async def _async_render_GET(self, request: Request) -> None: html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: Request): + async def _async_render_POST(self, request: Request) -> None: try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py index 36ba40165..81fec3965 100644 --- a/synapse/rest/synapse/client/oidc/__init__.py +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -13,16 +13,20 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from twisted.web.resource import Resource from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class OIDCResource(Resource): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"callback", OIDCCallbackResource(hs)) diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py index 7785f17e9..4f375cb74 100644 --- a/synapse/rest/synapse/client/oidc/callback_resource.py +++ b/synapse/rest/synapse/client/oidc/callback_resource.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING from synapse.http.server import DirectServeHtmlResource +from synapse.http.site import SynapseRequest if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,10 +31,10 @@ def __init__(self, hs: "HomeServer"): super().__init__() self._oidc_handler = hs.get_oidc_handler() - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: SynapseRequest) -> None: await self._oidc_handler.handle_oidc_callback(request) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: SynapseRequest) -> None: # the auth response can be returned via an x-www-form-urlencoded form instead # of GET params, as per # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html. diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index d30b478b9..28ae08349 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Generator, List, Tuple from twisted.web.resource import Resource from twisted.web.server import Request @@ -27,6 +27,7 @@ ) from synapse.http.servlet import parse_boolean, parse_string from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from synapse.util.templates import build_jinja_env if TYPE_CHECKING: @@ -57,7 +58,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self._sso_handler = hs.get_sso_handler() - async def _async_render_GET(self, request: Request): + async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]: localpart = parse_string(request, "username", required=True) session_id = get_username_mapping_session_cookie_from_request(request) @@ -73,7 +74,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self._sso_handler = hs.get_sso_handler() - def template_search_dirs(): + def template_search_dirs() -> Generator[str, None, None]: if hs.config.server.custom_template_directory: yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: @@ -104,7 +105,7 @@ async def _async_render_GET(self, request: Request) -> None: html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: SynapseRequest): + async def _async_render_POST(self, request: SynapseRequest) -> None: # This will always be set by the time Twisted calls us. assert request.args is not None diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py index 781ccb237..3f247e6a2 100644 --- a/synapse/rest/synapse/client/saml2/__init__.py +++ b/synapse/rest/synapse/client/saml2/__init__.py @@ -13,17 +13,21 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from twisted.web.resource import Resource from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class SAML2Resource(Resource): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) self.putChild(b"authn_response", SAML2ResponseResource(hs)) diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py index b37c7083d..64378ed57 100644 --- a/synapse/rest/synapse/client/saml2/metadata_resource.py +++ b/synapse/rest/synapse/client/saml2/metadata_resource.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING import saml2.metadata from twisted.web.resource import Resource +from twisted.web.server import Request + +if TYPE_CHECKING: + from synapse.server import HomeServer class SAML2MetadataResource(Resource): @@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource): isLeaf = 1 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.sp_config = hs.config.saml2_sp_config - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: metadata_xml = saml2.metadata.create_metadata_string( configfile=None, config=self.sp_config ) diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py index 774ccd870..47d2a6a22 100644 --- a/synapse/rest/synapse/client/saml2/response_resource.py +++ b/synapse/rest/synapse/client/saml2/response_resource.py @@ -15,7 +15,10 @@ from typing import TYPE_CHECKING +from twisted.web.server import Request + from synapse.http.server import DirectServeHtmlResource +from synapse.http.site import SynapseRequest if TYPE_CHECKING: from synapse.server import HomeServer @@ -31,7 +34,7 @@ def __init__(self, hs: "HomeServer"): self._saml_handler = hs.get_saml_handler() self._sso_handler = hs.get_sso_handler() - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should @@ -40,5 +43,5 @@ async def _async_render_GET(self, request): request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" ) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: SynapseRequest) -> None: await self._saml_handler.handle_saml_response(request) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 6a66a88c5..c80a3a99a 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -13,26 +13,26 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.http.server import set_cors_headers +from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class WellKnownBuilder: - """Utility to construct the well-known response - - Args: - hs (synapse.server.HomeServer): - """ - - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._config = hs.config - def get_well_known(self): + def get_well_known(self) -> Optional[JsonDict]: # if we don't have a public_baseurl, we can't help much here. if self._config.server.public_baseurl is None: return None @@ -52,11 +52,11 @@ class WellKnownResource(Resource): isLeaf = 1 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): Resource.__init__(self) self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: set_cors_headers(request) r = self._well_known_builder.get_well_known() if not r: From 3eba047d388fd0d798229a0779f343dbda8a2887 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Sep 2021 09:54:13 -0400 Subject: [PATCH 09/74] Add type hints to state database module. (#10823) --- changelog.d/10823.misc | 1 + mypy.ini | 1 + synapse/storage/databases/state/bg_updates.py | 60 +++++--- synapse/storage/databases/state/store.py | 136 +++++++++++------- synapse/storage/state.py | 3 +- synapse/util/caches/dictionary_cache.py | 4 +- 6 files changed, 133 insertions(+), 72 deletions(-) create mode 100644 changelog.d/10823.misc diff --git a/changelog.d/10823.misc b/changelog.d/10823.misc new file mode 100644 index 000000000..053296990 --- /dev/null +++ b/changelog.d/10823.misc @@ -0,0 +1 @@ +Add type hints to the state database. diff --git a/mypy.ini b/mypy.ini index e9052fa01..b21e1555a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -60,6 +60,7 @@ files = synapse/storage/databases/main/session.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, + synapse/storage/databases/state, synapse/storage/database.py, synapse/storage/engines, synapse/storage/keys.py, diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index c2891cb07..eb1118d2c 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -13,12 +13,20 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter +from synapse.types import MutableStateMap, StateMap + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): updates. """ - def _count_state_group_hops_txn(self, txn, state_group): + def _count_state_group_hops_txn( + self, txn: LoggingTransaction, state_group: int + ) -> int: """Given a state group, count how many hops there are in the tree. This is used to ensure the delta chains don't get too long. @@ -56,7 +66,7 @@ def _count_state_group_hops_txn(self, txn, state_group): else: # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group + next_group: Optional[int] = state_group count = 0 while next_group: @@ -73,11 +83,14 @@ def _count_state_group_hops_txn(self, txn, state_group): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter: Optional[StateFilter] = None - ): + self, + txn: LoggingTransaction, + groups: List[int], + state_filter: Optional[StateFilter] = None, + ) -> Mapping[int, StateMap[str]]: state_filter = state_filter or StateFilter.all() - results = {group: {} for group in groups} + results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() @@ -117,7 +130,7 @@ def _get_state_groups_from_groups_txn( """ for group in groups: - args = [group] + args: List[Union[int, str]] = [group] args.extend(where_args) txn.execute(sql % (where_clause,), args) @@ -131,7 +144,7 @@ def _get_state_groups_from_groups_txn( # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: - next_group = group + next_group: Optional[int] = group while next_group: # We did this before by getting the list of group ids, and @@ -173,6 +186,7 @@ def _get_state_groups_from_groups_txn( allow_none=True, ) + # The results shouldn't be considered mutable. return results @@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, @@ -198,7 +217,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): columns=["room_id"], ) - async def _background_deduplicate_state(self, progress, batch_size): + async def _background_deduplicate_state( + self, progress: dict, batch_size: int + ) -> int: """This background update will slowly deduplicate state by reencoding them as deltas. """ @@ -218,7 +239,7 @@ async def _background_deduplicate_state(self, progress, batch_size): ) max_group = rows[0][0] - def reindex_txn(txn): + def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]: new_last_state_group = last_state_group for count in range(batch_size): txn.execute( @@ -251,7 +272,8 @@ def reindex_txn(txn): " WHERE id < ? AND room_id = ?", (state_group, room_id), ) - (prev_group,) = txn.fetchone() + # There will be a result due to the coalesce. + (prev_group,) = txn.fetchone() # type: ignore new_last_state_group = state_group if prev_group: @@ -261,15 +283,15 @@ def reindex_txn(txn): # otherwise read performance degrades. continue - prev_state = self._get_state_groups_from_groups_txn( + prev_state_by_group = self._get_state_groups_from_groups_txn( txn, [prev_group] ) - prev_state = prev_state[prev_group] + prev_state = prev_state_by_group[prev_group] - curr_state = self._get_state_groups_from_groups_txn( + curr_state_by_group = self._get_state_groups_from_groups_txn( txn, [state_group] ) - curr_state = curr_state[state_group] + curr_state = curr_state_by_group[state_group] if not set(prev_state.keys()) - set(curr_state.keys()): # We can only do a delta if the current has a strict super set @@ -340,8 +362,8 @@ def reindex_txn(txn): return result * BATCH_SIZE_SCALE_FACTOR - async def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): + async def _background_index_state(self, progress: dict, batch_size: int) -> int: + def reindex_txn(conn: LoggingDatabaseConnection) -> None: conn.rollback() if isinstance(self.database_engine, PostgresEngine): # postgres insists on autocommit for the index diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f839c0c24..f1e3a27e6 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,43 +13,56 @@ # limitations under the License. import logging -from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple + +import attr from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateKey, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _GetStateGroupDelta: """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching + us use the iterable flag when caching """ - __slots__ = [] + prev_group: Optional[int] + delta_ids: Optional[StateMap[str]] - def __len__(self): + def __len__(self) -> int: return len(self.delta_ids) if self.delta_ids else 0 class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups.""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the @@ -81,19 +94,21 @@ def __init__(self, database: DatabasePool, db_conn, hs): # We size the non-members cache to be smaller than the members cache as the # vast majority of state in Matrix (today) is member events. - self._state_group_cache = DictionaryCache( + self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet 50000, ) - self._state_group_members_cache = DictionaryCache( + self._state_group_members_cache: DictionaryCache[ + int, StateKey, str + ] = DictionaryCache( "*stateGroupMembersCache*", 500000, ) - def get_max_state_group_txn(txn: Cursor): + def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - return txn.fetchone()[0] + return txn.fetchone()[0] # type: ignore self._state_group_seq_gen = build_sequence_generator( db_conn, @@ -105,15 +120,15 @@ def get_max_state_group_txn(txn: Cursor): ) @cached(max_entries=10000, iterable=True) - async def get_state_group_delta(self, state_group): + async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta: """Given a state group try to return a previous group and a delta between the old and the new. Returns: - (prev_group, delta_ids), where both may be None. + _GetStateGroupDelta containing prev_group and delta_ids, where both may be None. """ - def _get_state_group_delta_txn(txn): + def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta: prev_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", @@ -154,7 +169,7 @@ async def _get_state_groups_from_groups( Returns: Dict of state group to state map. """ - results = {} + results: Dict[int, StateMap[str]] = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: @@ -168,19 +183,24 @@ async def _get_state_groups_from_groups( return results - def _get_state_for_group_using_cache(self, cache, group, state_filter): + def _get_state_for_group_using_cache( + self, + cache: DictionaryCache[int, StateKey, str], + group: int, + state_filter: StateFilter, + ) -> Tuple[MutableStateMap[str], bool]: """Checks if group is in cache. See `_get_state_for_groups` Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. + cache: the state group cache to use + group: The state group to lookup + state_filter: The state filter used to fetch state from the database. - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. + Returns: + 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. """ cache_entry = cache.get(group) state_dict_ids = cache_entry.value @@ -277,8 +297,11 @@ async def _get_state_for_groups( return state def _get_state_for_groups_using_cache( - self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter - ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: + self, + groups: Iterable[int], + cache: DictionaryCache[int, StateKey, str], + state_filter: StateFilter, + ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. @@ -310,21 +333,21 @@ def _get_state_for_groups_using_cache( def _insert_into_cache( self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): + group_to_state_dict: Dict[int, StateMap[str]], + state_filter: StateFilter, + cache_seq_num_members: int, + cache_seq_num_non_members: int, + ) -> None: """Inserts results from querying the database into the relevant cache. Args: - group_to_state_dict (dict): The new entries pulled from database. + group_to_state_dict: The new entries pulled from database. Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state + state_filter: The state filter used to fetch state from the database. - cache_seq_num_members (int): Sequence number of member cache since + cache_seq_num_members: Sequence number of member cache since last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since + cache_seq_num_non_members: Sequence number of member cache since last lookup in cache """ @@ -395,7 +418,7 @@ async def store_state_group( The state group ID """ - def _store_state_group_txn(txn): + def _store_state_group_txn(txn: LoggingTransaction) -> int: if current_state_ids is None: # AFAIK, this can never happen raise Exception("current_state_ids cannot be None") @@ -426,6 +449,8 @@ def _store_state_group_txn(txn): potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + assert delta_ids is not None + self.db_pool.simple_insert_txn( txn, table="state_group_edges", @@ -498,7 +523,7 @@ def _store_state_group_txn(txn): ) async def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete + self, room_id: str, state_groups_to_delete: Collection[int] ) -> None: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -506,8 +531,7 @@ async def purge_unreferenced_state_groups( Args: room_id: The room the state groups belong to (must all be in the same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. + state_groups_to_delete: Set of all state groups to delete. """ await self.db_pool.runInteraction( @@ -517,7 +541,12 @@ async def purge_unreferenced_state_groups( state_groups_to_delete, ) - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + def _purge_unreferenced_state_groups( + self, + txn: LoggingTransaction, + room_id: str, + state_groups_to_delete: Collection[int], + ) -> None: logger.info( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) @@ -546,8 +575,8 @@ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete) # groups to non delta versions. for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] + curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state_by_group[sg] self.db_pool.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} @@ -605,12 +634,14 @@ async def get_previous_state_groups( return {row["state_group"]: row["prev_state_group"] for row in rows} - async def purge_room_state(self, room_id, state_groups_to_delete): + async def purge_room_state( + self, room_id: str, state_groups_to_delete: Collection[int] + ) -> None: """Deletes all record of a room from state tables Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete + room_id: + state_groups_to_delete: State groups to delete """ await self.db_pool.runInteraction( @@ -620,7 +651,12 @@ async def purge_room_state(self, room_id, state_groups_to_delete): state_groups_to_delete, ) - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + def _purge_room_state_txn( + self, + txn: LoggingTransaction, + room_id: str, + state_groups_to_delete: Collection[int], + ) -> None: # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c76529cb5..5e86befde 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -377,7 +377,8 @@ async def get_state_group_delta( make up the delta between the old and new state groups. """ - return await self.stores.state.get_state_group_delta(state_group) + state_group_delta = await self.stores.state.get_state_group_delta(state_group) + return state_group_delta.prev_group, state_group_delta.delta_ids async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index ade088aae..485ddb189 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -130,7 +130,7 @@ def update( sequence: int, key: KT, value: Dict[DKT, DV], - fetched_keys: Optional[Set[DKT]] = None, + fetched_keys: Optional[Iterable[DKT]] = None, ) -> None: """Updates the entry in the cache @@ -155,7 +155,7 @@ def update( self._update_or_insert(key, value, fetched_keys) def _update_or_insert( - self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] + self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT] ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed From bfb4b858a999684ba2459ee4c3aa20270d13062d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Sep 2021 12:01:14 -0400 Subject: [PATCH 10/74] Create a constant for a small png image in tests. (#10834) To avoid duplicating it between a few tests. --- changelog.d/10834.misc | 1 + tests/replication/test_multi_media_repo.py | 18 ++++-------- tests/rest/admin/test_admin.py | 23 ++++++--------- tests/rest/admin/test_media.py | 34 ++++------------------ tests/rest/admin/test_statistics.py | 12 ++------ tests/rest/admin/test_user.py | 19 +++--------- tests/rest/media/v1/test_media_storage.py | 18 +++--------- tests/test_utils/__init__.py | 14 +++++++-- 8 files changed, 45 insertions(+), 94 deletions(-) create mode 100644 changelog.d/10834.misc diff --git a/changelog.d/10834.misc b/changelog.d/10834.misc new file mode 100644 index 000000000..037695e6e --- /dev/null +++ b/changelog.d/10834.misc @@ -0,0 +1 @@ +Factor out PNG image data to a constant to be used in several tests. diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index ac419f0db..01b1b0d4a 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. import logging import os -from binascii import unhexlify from typing import Optional, Tuple from twisted.internet.protocol import Factory @@ -28,6 +27,7 @@ from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, FakeSite, FakeTransport, make_request +from tests.test_utils import SMALL_PNG logger = logging.getLogger(__name__) @@ -190,31 +190,25 @@ def test_download_image_race(self): channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") - png_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) - request1.setResponseCode(200) request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) - request1.write(png_data) + request1.write(SMALL_PNG) request1.finish() self.pump(0.1) self.assertEqual(channel1.code, 200, channel1.result["body"]) - self.assertEqual(channel1.result["body"], png_data) + self.assertEqual(channel1.result["body"], SMALL_PNG) request2.setResponseCode(200) request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) - request2.write(png_data) + request2.write(SMALL_PNG) request2.finish() self.pump(0.1) self.assertEqual(channel2.code, 200, channel2.result["body"]) - self.assertEqual(channel2.result["body"], png_data) + self.assertEqual(channel2.result["body"], SMALL_PNG) # We expect only three new thumbnails to have been persisted. self.assertEqual(start_count + 3, self._count_remote_thumbnails()) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index bfa638fb4..febd40b65 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -1,4 +1,4 @@ -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ import json import os import urllib.parse -from binascii import unhexlify from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -28,6 +27,7 @@ from tests import unittest from tests.server import FakeSite, make_request +from tests.test_utils import SMALL_PNG class VersionTestCase(unittest.HomeserverTestCase): @@ -150,11 +150,6 @@ def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] self.upload_resource = self.media_repo.children[b"upload"] - self.image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) def make_homeserver(self, reactor, clock): @@ -266,7 +261,7 @@ def test_quarantine_media_by_id(self): # Upload some media into the room response = self.helper.upload_media( - self.upload_resource, self.image_data, tok=admin_user_tok + self.upload_resource, SMALL_PNG, tok=admin_user_tok ) # Extract media ID from the response @@ -314,10 +309,10 @@ def test_quarantine_all_media_in_room(self, override_url_template=None): # Upload some media response_1 = self.helper.upload_media( - self.upload_resource, self.image_data, tok=non_admin_user_tok + self.upload_resource, SMALL_PNG, tok=non_admin_user_tok ) response_2 = self.helper.upload_media( - self.upload_resource, self.image_data, tok=non_admin_user_tok + self.upload_resource, SMALL_PNG, tok=non_admin_user_tok ) # Extract mxcs @@ -381,10 +376,10 @@ def test_quarantine_all_media_by_user(self): # Upload some media response_1 = self.helper.upload_media( - self.upload_resource, self.image_data, tok=non_admin_user_tok + self.upload_resource, SMALL_PNG, tok=non_admin_user_tok ) response_2 = self.helper.upload_media( - self.upload_resource, self.image_data, tok=non_admin_user_tok + self.upload_resource, SMALL_PNG, tok=non_admin_user_tok ) # Extract media IDs @@ -421,10 +416,10 @@ def test_cannot_quarantine_safe_media(self): # Upload some media response_1 = self.helper.upload_media( - self.upload_resource, self.image_data, tok=non_admin_user_tok + self.upload_resource, SMALL_PNG, tok=non_admin_user_tok ) response_2 = self.helper.upload_media( - self.upload_resource, self.image_data, tok=non_admin_user_tok + self.upload_resource, SMALL_PNG, tok=non_admin_user_tok ) # Extract media IDs diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 972d60570..2f02934e7 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -1,4 +1,5 @@ # Copyright 2020 Dirk Klimpel +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +15,6 @@ import json import os -from binascii import unhexlify from parameterized import parameterized @@ -25,6 +25,7 @@ from tests import unittest from tests.server import FakeSite, make_request +from tests.test_utils import SMALL_PNG class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): @@ -110,15 +111,10 @@ def test_delete_media(self): download_resource = self.media_repo.children[b"download"] upload_resource = self.media_repo.children[b"upload"] - image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -504,16 +500,10 @@ def _create_media(self): Create a media and return media_id and server_and_media_id """ upload_resource = self.media_repo.children[b"upload"] - # file size is 67 Byte - image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -584,16 +574,10 @@ def prepare(self, reactor, clock, hs): # Create media upload_resource = media_repo.children[b"upload"] - # file size is 67 Byte - image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -711,16 +695,10 @@ def prepare(self, reactor, clock, hs): # Create media upload_resource = media_repo.children[b"upload"] - # file size is 67 Byte - image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 5cd82209c..ece89a65a 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -1,4 +1,5 @@ # Copyright 2020 Dirk Klimpel +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,6 @@ # limitations under the License. import json -from binascii import unhexlify from typing import Any, Dict, List, Optional import synapse.rest.admin @@ -21,6 +21,7 @@ from synapse.rest.client import login from tests import unittest +from tests.test_utils import SMALL_PNG class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): @@ -468,16 +469,9 @@ def _create_media(self, user_token: str, number_media: int): """ upload_resource = self.media_repo.children[b"upload"] for _ in range(number_media): - # file size is 67 Byte - image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) - # Upload some media into the room self.helper.upload_media( - upload_resource, image_data, tok=user_token, expect_code=200 + upload_resource, SMALL_PNG, tok=user_token, expect_code=200 ) def _check_fields(self, content: List[Dict[str, Any]]): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ee204c404..cc3f16c62 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1,4 +1,4 @@ -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ from tests import unittest from tests.server import FakeSite, make_request -from tests.test_utils import make_awaitable +from tests.test_utils import SMALL_PNG, make_awaitable from tests.unittest import override_config @@ -2835,11 +2835,7 @@ def test_order_by(self): other_user_tok = self.login("user", "pass") # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B - image_data1 = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) + image_data1 = SMALL_PNG # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B image_data2 = unhexlify( b"47494638376101000100800100000000" @@ -2943,14 +2939,7 @@ def _create_media_for_user(self, user_token: str, number_media: int) -> List[str """ media_ids = [] for _ in range(number_media): - # file size is 67 Byte - image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) - - media_ids.append(self._create_media_and_access(user_token, image_data)) + media_ids.append(self._create_media_and_access(user_token, SMALL_PNG)) return media_ids diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2f7eebfe6..9ea1c2bf2 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -1,4 +1,4 @@ -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,6 +38,7 @@ from tests import unittest from tests.server import FakeSite, make_request +from tests.test_utils import SMALL_PNG from tests.utils import default_config @@ -134,11 +135,7 @@ class _TestImage: # smoll png ( _TestImage( - unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ), + SMALL_PNG, b"image/png", b".png", unhexlify( @@ -593,15 +590,8 @@ def default_config(self): def test_upload_innocent(self): """Attempt to upload some innocent data that should be allowed.""" - - image_data = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) - self.helper.upload_media( - self.upload_resource, image_data, tok=self.tok, expect_code=200 + self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 ) def test_upload_ban(self): diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index be6302d17..15ac2bfeb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,4 @@ -# Copyright 2019 New Vector Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +18,7 @@ import sys import warnings from asyncio import Future +from binascii import unhexlify from typing import Any, Awaitable, Callable, TypeVar from unittest.mock import Mock @@ -117,3 +117,13 @@ class FakeResponse: def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + + +# A small image used in some tests. +# +# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B +SMALL_PNG = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" +) From 6b6bb81b23425cf4f8e0c739946783b98ad056b8 Mon Sep 17 00:00:00 2001 From: Charles Wright Date: Fri, 17 Sep 2021 12:04:37 -0500 Subject: [PATCH 11/74] =?UTF-8?q?Fix=20#10837=20by=20adding=20JSON=20encod?= =?UTF-8?q?ing/decoding=20to=20the=20Module=20API=20example=E2=80=A6=20(#1?= =?UTF-8?q?0845)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changelog.d/10845.doc | 1 + docs/modules/spam_checker_callbacks.md | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10845.doc diff --git a/changelog.d/10845.doc b/changelog.d/10845.doc new file mode 100644 index 000000000..a13c845ae --- /dev/null +++ b/changelog.d/10845.doc @@ -0,0 +1 @@ +Fix some crashes in the Module API example code, by adding JSON encoding/decoding. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index c45eafcc4..81574a015 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -136,9 +136,9 @@ class IsUserEvilResource(Resource): self.evil_users = config.get("evil_users") or [] def render_GET(self, request: Request): - user = request.args.get(b"user")[0] + user = request.args.get(b"user")[0].decode() request.setHeader(b"Content-Type", b"application/json") - return json.dumps({"evil": user in self.evil_users}) + return json.dumps({"evil": user in self.evil_users}).encode() class ListSpamChecker: From 437961744c6c8761e6483bb215e5e779123ffd97 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 20 Sep 2021 10:26:13 +0100 Subject: [PATCH 12/74] Fix remove_stale_pushers job on SQLite. (#10843) --- changelog.d/10843.bugfix | 1 + synapse/storage/database.py | 21 +++++++++++-------- .../storage/databases/main/account_data.py | 2 +- synapse/storage/databases/main/events.py | 2 +- .../databases/main/events_bg_updates.py | 4 ++-- synapse/storage/databases/main/pusher.py | 4 ++-- synapse/storage/databases/main/state.py | 4 ++-- synapse/storage/databases/main/ui_auth.py | 6 +++--- synapse/storage/databases/state/store.py | 6 +++--- 9 files changed, 27 insertions(+), 23 deletions(-) create mode 100644 changelog.d/10843.bugfix diff --git a/changelog.d/10843.bugfix b/changelog.d/10843.bugfix new file mode 100644 index 000000000..5027a1dbe --- /dev/null +++ b/changelog.d/10843.bugfix @@ -0,0 +1 @@ +Fix a bug causing the `remove_stale_pushers` background job to repeatedly fail and log errors. This bug affected Synapse servers that had been upgraded from version 1.28 or older and are using SQLite. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0084d9f96..f5a8f90a0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1632,7 +1632,7 @@ def simple_select_many_txn( txn: LoggingTransaction, table: str, column: str, - iterable: Iterable[Any], + iterable: Collection[Any], keyvalues: Dict[str, Any], retcols: Iterable[str], ) -> List[Dict[str, Any]]: @@ -1891,29 +1891,32 @@ def simple_delete_many_txn( txn: LoggingTransaction, table: str, column: str, - iterable: Iterable[Any], + values: Collection[Any], keyvalues: Dict[str, Any], ) -> int: """Executes a DELETE query on the named table. - Filters rows by if value of `column` is in `iterable`. + Deletes the rows: + - whose value of `column` is in `values`; AND + - that match extra column-value pairs specified in `keyvalues`. Args: txn: Transaction object table: string giving the table name - column: column name to test for inclusion against `iterable` - iterable: list - keyvalues: dict of column names and values to select the rows with + column: column name to test for inclusion against `values` + values: values of `column` which choose rows to delete + keyvalues: dict of extra column names and values to select the rows + with. They will be ANDed together with the main predicate. Returns: Number rows deleted """ - if not iterable: + if not values: return 0 sql = "DELETE FROM %s" % table - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clause, values = make_in_list_sql_clause(txn.database_engine, column, values) clauses = [clause] for key, value in keyvalues.items(): @@ -2098,7 +2101,7 @@ def simple_search_list_txn( def make_in_list_sql_clause( - database_engine: BaseDatabaseEngine, column: str, iterable: Iterable + database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any] ) -> Tuple[str, list]: """Returns an SQL clause that checks the given column is in the iterable. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 1d02795f4..d0cf3460d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -494,7 +494,7 @@ def _add_account_data_for_user( txn, table="ignored_users", column="ignored_user_id", - iterable=previously_ignored_users - currently_ignored_users, + values=previously_ignored_users - currently_ignored_users, keyvalues={"ignorer_user_id": user_id}, ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 8e691678e..dec7e8594 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -667,7 +667,7 @@ def _add_chain_cover_index( table="event_auth_chain_to_calculate", keyvalues={}, column="event_id", - iterable=new_chain_tuples, + values=new_chain_tuples, ) # Now we need to calculate any new links between chains caused by diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 6fcb2b835..1afc59faf 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -490,7 +490,7 @@ def _cleanup_extremities_bg_update_txn(txn): txn=txn, table="event_forward_extremities", column="event_id", - iterable=to_delete, + values=to_delete, keyvalues={}, ) @@ -520,7 +520,7 @@ def _cleanup_extremities_bg_update_txn(txn): txn=txn, table="_extremities_to_check", column="event_id", - iterable=original_set, + values=original_set, keyvalues={}, ) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 63ac09c61..a93caae8d 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -324,7 +324,7 @@ def _delete_pushers(txn) -> int: txn, table="pushers", column="user_name", - iterable=users, + values=users, keyvalues={}, ) @@ -373,7 +373,7 @@ def _delete_pushers(txn) -> int: txn, table="pushers", column="id", - iterable=(pusher_id for pusher_id, token in pushers if token is None), + values=[pusher_id for pusher_id, token in pushers if token is None], keyvalues={}, ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 8e22da99a..a8e8dd457 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -473,7 +473,7 @@ def _background_remove_left_rooms_txn(txn): txn, table="current_state_events", column="room_id", - iterable=to_delete, + values=to_delete, keyvalues={}, ) @@ -481,7 +481,7 @@ def _background_remove_left_rooms_txn(txn): txn, table="event_forward_extremities", column="room_id", - iterable=to_delete, + values=to_delete, keyvalues={}, ) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 4d6bbc94c..340ca9e47 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -326,7 +326,7 @@ def _delete_old_ui_auth_sessions_txn( txn, table="ui_auth_sessions_ips", column="session_id", - iterable=session_ids, + values=session_ids, keyvalues={}, ) @@ -377,7 +377,7 @@ def _delete_old_ui_auth_sessions_txn( txn, table="ui_auth_sessions_credentials", column="session_id", - iterable=session_ids, + values=session_ids, keyvalues={}, ) @@ -386,7 +386,7 @@ def _delete_old_ui_auth_sessions_txn( txn, table="ui_auth_sessions", column="session_id", - iterable=session_ids, + values=session_ids, keyvalues={}, ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f1e3a27e6..c4c8c0021 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -664,7 +664,7 @@ def _purge_room_state_txn( txn, table="state_groups_state", column="state_group", - iterable=state_groups_to_delete, + values=state_groups_to_delete, keyvalues={}, ) @@ -675,7 +675,7 @@ def _purge_room_state_txn( txn, table="state_group_edges", column="state_group", - iterable=state_groups_to_delete, + values=state_groups_to_delete, keyvalues={}, ) @@ -686,6 +686,6 @@ def _purge_room_state_txn( txn, table="state_groups", column="id", - iterable=state_groups_to_delete, + values=state_groups_to_delete, keyvalues={}, ) From b3590614da7e3e17e75530a9d4808df17be9b127 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 20 Sep 2021 08:56:23 -0400 Subject: [PATCH 13/74] Require type hints in the handlers module. (#10831) Adds missing type hints to methods in the synapse.handlers module and requires all methods to have type hints there. This also removes the unused construct_auth_difference method from the FederationHandler. --- changelog.d/10831.misc | 1 + mypy.ini | 3 + synapse/config/password_auth_providers.py | 4 +- synapse/handlers/_base.py | 14 ++- synapse/handlers/account_data.py | 4 +- synapse/handlers/account_validity.py | 4 +- synapse/handlers/appservice.py | 18 +-- synapse/handlers/auth.py | 45 ++++---- synapse/handlers/cas.py | 18 ++- synapse/handlers/device.py | 2 +- synapse/handlers/e2e_keys.py | 4 +- synapse/handlers/event_auth.py | 7 +- synapse/handlers/federation.py | 130 ---------------------- synapse/handlers/federation_event.py | 8 +- synapse/handlers/groups_local.py | 8 +- synapse/handlers/initial_sync.py | 8 +- synapse/handlers/message.py | 20 ++-- synapse/handlers/oidc.py | 34 +++--- synapse/handlers/pagination.py | 19 ++-- synapse/handlers/presence.py | 45 +++++--- synapse/handlers/profile.py | 4 +- synapse/handlers/receipts.py | 4 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 16 ++- synapse/handlers/room_list.py | 12 +- synapse/handlers/room_member.py | 8 +- synapse/handlers/room_summary.py | 2 +- synapse/handlers/saml.py | 14 +-- synapse/handlers/send_email.py | 4 +- synapse/handlers/sso.py | 6 +- synapse/handlers/stats.py | 2 +- synapse/handlers/sync.py | 11 +- synapse/handlers/typing.py | 4 +- synapse/handlers/ui_auth/checkers.py | 2 +- synapse/handlers/user_directory.py | 2 +- 35 files changed, 194 insertions(+), 295 deletions(-) create mode 100644 changelog.d/10831.misc diff --git a/changelog.d/10831.misc b/changelog.d/10831.misc new file mode 100644 index 000000000..f09af2e00 --- /dev/null +++ b/changelog.d/10831.misc @@ -0,0 +1 @@ +Add missing type hints to handlers. diff --git a/mypy.ini b/mypy.ini index b21e1555a..3cb6cecd7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,6 +91,9 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py +[mypy-synapse.handlers.*] +disallow_untyped_defs = True + [mypy-synapse.rest.*] disallow_untyped_defs = True diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 0f5b2b397..83994df79 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Tuple, Type from synapse.util.module_loader import load_module @@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config): section = "authproviders" def read_config(self, config, **kwargs): - self.password_providers: List[Any] = [] + self.password_providers: List[Tuple[Type, Any]] = [] providers = [] # We want to be backwards compatible with the old `ldap_config` diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c23ccd6dd..0ccef884e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.ratelimiting import Ratelimiter +from synapse.types import Requester if TYPE_CHECKING: from synapse.server import HomeServer @@ -63,16 +64,21 @@ def __init__(self, hs: "HomeServer"): self.event_builder_factory = hs.get_event_builder_factory() - async def ratelimit(self, requester, update=True, is_admin_redaction=False): + async def ratelimit( + self, + requester: Requester, + update: bool = True, + is_admin_redaction: bool = False, + ) -> None: """Ratelimits requests. Args: - requester (Requester) - update (bool): Whether to record that a request is being processed. + requester + update: Whether to record that a request is being processed. Set to False when doing multiple checks for one request (e.g. to check up front if we would reject the request), and set to True for the last call for a given request. - is_admin_redaction (bool): Whether this is a room admin/moderator + is_admin_redaction: Whether this is a room admin/moderator redacting an event. If so then we may apply different ratelimits depending on config. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index affb54e0e..e9e7a7854 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import random -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, Any, List, Tuple from synapse.replication.http.account_data import ( ReplicationAddTagRestServlet, @@ -171,7 +171,7 @@ def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_account_data_stream_id() async def get_new_events( - self, user: UserID, from_key: int, **kwargs + self, user: UserID, from_key: int, **kwargs: Any ) -> Tuple[List[JsonDict], int]: user_id = user.to_string() last_stream_id = from_key diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index a9c2222f4..4724565ba 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -99,7 +99,7 @@ def register_account_validity_callbacks( on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, - ): + ) -> None: """Register callbacks from module for each hook.""" if is_user_expired is not None: self._is_user_expired_callbacks.append(is_user_expired) @@ -165,7 +165,7 @@ async def is_user_expired(self, user_id: str) -> bool: return False - async def on_user_registration(self, user_id: str): + async def on_user_registration(self, user_id: str) -> None: """Tell third-party modules about a user's registration. Args: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index a7b5a4e9c..8bde9ed66 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union from prometheus_client import Counter @@ -58,7 +58,7 @@ def __init__(self, hs: "HomeServer"): self.current_max = 0 self.is_processing = False - def notify_interested_services(self, max_token: RoomStreamToken): + def notify_interested_services(self, max_token: RoomStreamToken) -> None: """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any @@ -82,7 +82,7 @@ def notify_interested_services(self, max_token: RoomStreamToken): self._notify_interested_services(max_token) @wrap_as_background_process("notify_interested_services") - async def _notify_interested_services(self, max_token: RoomStreamToken): + async def _notify_interested_services(self, max_token: RoomStreamToken) -> None: with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: @@ -100,7 +100,7 @@ async def _notify_interested_services(self, max_token: RoomStreamToken): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - async def handle_event(event): + async def handle_event(event: EventBase) -> None: # Gather interested services services = await self._get_services_for_event(event) if len(services) == 0: @@ -116,9 +116,9 @@ async def handle_event(event): if not self.started_scheduler: - async def start_scheduler(): + async def start_scheduler() -> None: try: - return await self.scheduler.start() + await self.scheduler.start() except Exception: logger.error("Application Services Failure") @@ -137,7 +137,7 @@ async def start_scheduler(): "appservice_sender" ).observe((now - ts) / 1000) - async def handle_room_events(events): + async def handle_room_events(events: Iterable[EventBase]) -> None: for event in events: await handle_event(event) @@ -184,7 +184,7 @@ def notify_interested_services_ephemeral( stream_key: str, new_token: Optional[int], users: Optional[Collection[Union[str, UserID]]] = None, - ): + ) -> None: """This is called by the notifier in the background when a ephemeral event handled by the homeserver. @@ -226,7 +226,7 @@ async def _notify_interested_services_ephemeral( stream_key: str, new_token: Optional[int], users: Collection[Union[str, UserID]], - ): + ) -> None: logger.debug("Checking interested services for %s" % (stream_key)) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3ea627008..bcd4249e0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -29,6 +29,7 @@ Mapping, Optional, Tuple, + Type, Union, cast, ) @@ -439,7 +440,7 @@ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: return ui_auth_types - def get_enabled_auth_types(self): + def get_enabled_auth_types(self) -> Iterable[str]: """Return the enabled user-interactive authentication types Returns the UI-Auth types which are supported by the homeserver's current @@ -702,7 +703,7 @@ async def get_session_data( except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - async def _expire_old_sessions(self): + async def _expire_old_sessions(self) -> None: """ Invalidate any user interactive authentication sessions that have expired. """ @@ -1352,7 +1353,7 @@ async def validate_short_term_login_token( await self.auth.check_auth_blocking(res.user_id) return res - async def delete_access_token(self, access_token: str): + async def delete_access_token(self, access_token: str) -> None: """Invalidate a single access token Args: @@ -1381,7 +1382,7 @@ async def delete_access_tokens_for_user( user_id: str, except_token_id: Optional[int] = None, device_id: Optional[str] = None, - ): + ) -> None: """Invalidate access tokens belonging to a user Args: @@ -1409,7 +1410,7 @@ async def delete_access_tokens_for_user( async def add_threepid( self, user_id: str, medium: str, address: str, validated_at: int - ): + ) -> None: # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -1480,7 +1481,7 @@ async def hash(self, password: str) -> str: Hashed password. """ - def _do_hash(): + def _do_hash() -> str: # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) @@ -1504,7 +1505,7 @@ async def validate_hash( Whether self.hash(password) == stored_hash. """ - def _do_validate_hash(checked_hash: bytes): + def _do_validate_hash(checked_hash: bytes) -> bool: # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) @@ -1581,7 +1582,7 @@ async def complete_sso_login( client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, - ): + ) -> None: """Having figured out a mxid for this user, complete the HTTP request Args: @@ -1627,7 +1628,7 @@ def _complete_sso_login( extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, - ): + ) -> None: """ The synchronous portion of complete_sso_login. @@ -1726,7 +1727,7 @@ def _expire_sso_extra_attributes(self) -> None: del self._extra_attributes[user_id] @staticmethod - def add_query_param_to_url(url: str, param_name: str, param: Any): + def add_query_param_to_url(url: str, param_name: str, param: Any) -> str: url_parts = list(urllib.parse.urlparse(url)) query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) query.append((param_name, param)) @@ -1734,9 +1735,9 @@ def add_query_param_to_url(url: str, param_name: str, param: Any): return urllib.parse.urlunparse(url_parts) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class MacaroonGenerator: - hs = attr.ib() + hs: "HomeServer" def generate_guest_access_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1816,7 +1817,9 @@ class PasswordProvider: """ @classmethod - def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": + def load( + cls, module: Type, config: JsonDict, module_api: ModuleApi + ) -> "PasswordProvider": try: pp = module(config=config, account_handler=module_api) except Exception as e: @@ -1824,7 +1827,7 @@ def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": raise return cls(pp, module_api) - def __init__(self, pp, module_api: ModuleApi): + def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): self._pp = pp self._module_api = module_api @@ -1838,7 +1841,7 @@ def __init__(self, pp, module_api: ModuleApi): if g: self._supported_login_types.update(g()) - def __str__(self): + def __str__(self) -> str: return str(self._pp) def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: @@ -1876,19 +1879,19 @@ async def check_auth( """ # first grandfather in a call to check_password if login_type == LoginType.PASSWORD: - g = getattr(self._pp, "check_password", None) - if g: + check_password = getattr(self._pp, "check_password", None) + if check_password: qualified_user_id = self._module_api.get_qualified_user_id(username) - is_valid = await self._pp.check_password( + is_valid = await check_password( qualified_user_id, login_dict["password"] ) if is_valid: return qualified_user_id, None - g = getattr(self._pp, "check_auth", None) - if not g: + check_auth = getattr(self._pp, "check_auth", None) + if not check_auth: return None - result = await g(username, login_type, login_dict) + result = await check_auth(username, login_type, login_dict) # Check if the return value is a str or a tuple if isinstance(result, str): diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 47ddabbe4..b0b188dc7 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -34,20 +34,20 @@ class CasError(Exception): """Used to catch errors when validating the CAS ticket.""" - def __init__(self, error, error_description=None): + def __init__(self, error: str, error_description: Optional[str] = None): self.error = error self.error_description = error_description - def __str__(self): + def __str__(self) -> str: if self.error_description: return f"{self.error}: {self.error_description}" return self.error -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class CasResponse: - username = attr.ib(type=str) - attributes = attr.ib(type=Dict[str, List[Optional[str]]]) + username: str + attributes: Dict[str, List[Optional[str]]] class CasHandler: @@ -133,11 +133,9 @@ async def _validate_ticket( body = pde.response except HttpResponseException as e: description = ( - ( - 'Authorization server responded with a "{status}" error ' - "while exchanging the authorization code." - ).format(status=e.code), - ) + 'Authorization server responded with a "{status}" error ' + "while exchanging the authorization code." + ).format(status=e.code) raise CasError("server_error", description) from e return self._parse_cas_response(body) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 46ee83440..35334725d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -267,7 +267,7 @@ def __init__(self, hs: "HomeServer"): hs.get_distributor().observe("user_left_room", self.user_left_room) - def _check_device_name_length(self, name: Optional[str]): + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 08a137561..d0fb2fc7d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -202,7 +202,7 @@ async def query_devices( # Now fetch any devices that we don't have in our cache @trace - async def do_remote_query(destination): + async def do_remote_query(destination: str) -> None: """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for @@ -447,7 +447,7 @@ async def claim_one_time_keys( } @trace - async def claim_client_keys(destination): + async def claim_client_keys(destination: str) -> None: set_tag("destination", destination) device_keys = remote_queries[destination] try: diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 4288ffff0..cb81fa098 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -25,6 +25,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.builder import EventBuilder +from synapse.events.snapshot import EventContext from synapse.types import StateMap, get_domain_from_id from synapse.util.metrics import Measure @@ -45,7 +46,11 @@ def __init__(self, hs: "HomeServer"): self._server_name = hs.hostname async def check_from_context( - self, room_version: str, event, context, do_sig_check=True + self, + room_version: str, + event: EventBase, + context: EventContext, + do_sig_check: bool = True, ) -> None: auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6754c64c3..8e2cf3387 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1221,136 +1221,6 @@ async def on_get_missing_events( return missing_events - async def construct_auth_difference( - self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] - ) -> Dict: - """Given a local and remote auth chain, find the differences. This - assumes that we have already processed all events in remote_auth - - Params: - local_auth - remote_auth - - Returns: - dict - """ - - logger.debug("construct_auth_difference Start!") - - # TODO: Make sure we are OK with local_auth or remote_auth having more - # auth events in them than strictly necessary. - - def sort_fun(ev): - return ev.depth, ev.event_id - - logger.debug("construct_auth_difference after sort_fun!") - - # We find the differences by starting at the "bottom" of each list - # and iterating up on both lists. The lists are ordered by depth and - # then event_id, we iterate up both lists until we find the event ids - # don't match. Then we look at depth/event_id to see which side is - # missing that event, and iterate only up that list. Repeat. - - remote_list = list(remote_auth) - remote_list.sort(key=sort_fun) - - local_list = list(local_auth) - local_list.sort(key=sort_fun) - - local_iter = iter(local_list) - remote_iter = iter(remote_list) - - logger.debug("construct_auth_difference before get_next!") - - def get_next(it, opt=None): - try: - return next(it) - except Exception: - return opt - - current_local = get_next(local_iter) - current_remote = get_next(remote_iter) - - logger.debug("construct_auth_difference before while") - - missing_remotes = [] - missing_locals = [] - while current_local or current_remote: - if current_remote is None: - missing_locals.append(current_local) - current_local = get_next(local_iter) - continue - - if current_local is None: - missing_remotes.append(current_remote) - current_remote = get_next(remote_iter) - continue - - if current_local.event_id == current_remote.event_id: - current_local = get_next(local_iter) - current_remote = get_next(remote_iter) - continue - - if current_local.depth < current_remote.depth: - missing_locals.append(current_local) - current_local = get_next(local_iter) - continue - - if current_local.depth > current_remote.depth: - missing_remotes.append(current_remote) - current_remote = get_next(remote_iter) - continue - - # They have the same depth, so we fall back to the event_id order - if current_local.event_id < current_remote.event_id: - missing_locals.append(current_local) - current_local = get_next(local_iter) - - if current_local.event_id > current_remote.event_id: - missing_remotes.append(current_remote) - current_remote = get_next(remote_iter) - continue - - logger.debug("construct_auth_difference after while") - - # missing locals should be sent to the server - # We should find why we are missing remotes, as they will have been - # rejected. - - # Remove events from missing_remotes if they are referencing a missing - # remote. We only care about the "root" rejected ones. - missing_remote_ids = [e.event_id for e in missing_remotes] - base_remote_rejected = list(missing_remotes) - for e in missing_remotes: - for e_id in e.auth_event_ids(): - if e_id in missing_remote_ids: - try: - base_remote_rejected.remove(e) - except ValueError: - pass - - reason_map = {} - - for e in base_remote_rejected: - reason = await self.store.get_rejection_reason(e.event_id) - if reason is None: - # TODO: e is not in the current state, so we should - # construct some proof of that. - continue - - reason_map[e.event_id] = reason - - logger.debug("construct_auth_difference returning") - - return { - "auth_chain": local_auth, - "rejects": { - e.event_id: {"reason": reason_map[e.event_id], "proof": None} - for e in base_remote_rejected - }, - "missing": [e.event_id for e in missing_locals], - } - @log_function async def exchange_third_party_invite( self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 946343fa2..3b95beeb0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1016,7 +1016,7 @@ async def _resync_device(self, sender: str) -> None: except Exception: logger.exception("Failed to resync device for %s", sender) - async def _handle_marker_event(self, origin: str, marker_event: EventBase): + async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None: """Handles backfilling the insertion event when we receive a marker event that points to one. @@ -1109,7 +1109,7 @@ async def _get_events_and_persist( event_map: Dict[str, EventBase] = {} - async def get_event(event_id: str): + async def get_event(event_id: str) -> None: with nested_logging_context(event_id): try: event = await self._federation_client.get_pdu( @@ -1218,7 +1218,7 @@ async def _auth_and_persist_fetched_events( if not event_infos: return - async def prep(ev_info: _NewEventInfo): + async def prep(ev_info: _NewEventInfo) -> EventContext: event = ev_info.event with nested_logging_context(suffix=event.event_id): res = await self._state_handler.compute_event_context(event) @@ -1692,7 +1692,7 @@ async def _update_context_for_auth_events( async def _run_push_actions_and_persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ): + ) -> None: """Run the push actions for a received event, and persist it. Args: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 1a6c5c64a..9e270d461 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Set +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.types import GroupID, JsonDict, get_domain_from_id @@ -25,12 +25,14 @@ logger = logging.getLogger(__name__) -def _create_rerouter(func_name): +def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]: """Returns an async function that looks at the group id and calls the function on federation or the local group server if the group is local """ - async def f(self, group_id, *args, **kwargs): + async def f( + self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any + ) -> JsonDict: if not GroupID.is_valid(group_id): raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 0b24b40eb..c942086e7 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from twisted.internet import defer @@ -150,7 +150,7 @@ async def _snapshot_all_rooms( if limit is None: limit = 10 - async def handle_room(event: RoomsForUser): + async def handle_room(event: RoomsForUser) -> None: d: JsonDict = { "room_id": event.room_id, "membership": event.membership, @@ -411,7 +411,7 @@ async def _room_initial_sync_joined( presence_handler = self.hs.get_presence_handler() - async def get_presence(): + async def get_presence() -> List[JsonDict]: # If presence is disabled, return an empty list if not self.hs.config.server.use_presence: return [] @@ -428,7 +428,7 @@ async def get_presence(): for s in states ] - async def get_receipts(): + async def get_receipts() -> List[JsonDict]: receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 10f1584a0..bf4853630 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -46,6 +46,7 @@ from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.handlers.directory import DirectoryHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet @@ -298,7 +299,7 @@ async def get_joined_members(self, requester: Requester, room_id: str) -> dict: for user_id, profile in users_with_profile.items() } - def maybe_schedule_expiry(self, event: EventBase): + def maybe_schedule_expiry(self, event: EventBase) -> None: """Schedule the expiry of an event if there's not already one scheduled, or if the one running is for an event that will expire after the provided timestamp. @@ -318,7 +319,7 @@ def maybe_schedule_expiry(self, event: EventBase): # a task scheduled for a timestamp that's sooner than the provided one. self._schedule_expiry_for_event(event.event_id, expiry_ts) - async def _schedule_next_expiry(self): + async def _schedule_next_expiry(self) -> None: """Retrieve the ID and the expiry timestamp of the next event to be expired, and schedule an expiry task for it. @@ -331,7 +332,7 @@ async def _schedule_next_expiry(self): event_id, expiry_ts = res self._schedule_expiry_for_event(event_id, expiry_ts) - def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int): + def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None: """Schedule an expiry task for the provided event if there's not already one scheduled at a timestamp that's sooner than the provided one. @@ -367,7 +368,7 @@ def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int): event_id, ) - async def _expire_event(self, event_id: str): + async def _expire_event(self, event_id: str) -> None: """Retrieve and expire an event that needs to be expired from the database. If the event doesn't exist in the database, log it and delete the expiry date @@ -1229,7 +1230,10 @@ async def cache_joined_hosts_for_event( self._external_cache_joined_hosts_updates[state_entry.state_group] = None async def _validate_canonical_alias( - self, directory_handler, room_alias_str: str, expected_room_id: str + self, + directory_handler: DirectoryHandler, + room_alias_str: str, + expected_room_id: str, ) -> None: """ Ensure that the given room alias points to the expected room ID. @@ -1477,7 +1481,7 @@ async def persist_and_notify_client_event( # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - def _notify(): + def _notify() -> None: try: self.notifier.on_new_room_event( event, event_pos, max_stream_token, extra_users=extra_users @@ -1523,7 +1527,7 @@ async def _bump_active_time(self, user: UserID) -> None: except Exception: logger.exception("Error bumping presence active time") - async def _send_dummy_events_to_fill_extremities(self): + async def _send_dummy_events_to_fill_extremities(self) -> None: """Background task to send dummy events into rooms that have a large number of extremities """ @@ -1600,7 +1604,7 @@ async def _send_dummy_event_for_room(self, room_id: str) -> bool: ) return False - def _expire_rooms_to_exclude_from_dummy_event_insertion(self): + def _expire_rooms_to_exclude_from_dummy_event_insertion(self) -> None: expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY to_expire = set() for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items(): diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index dfc251b2a..aed5a40a7 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -14,7 +14,7 @@ # limitations under the License. import inspect import logging -from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union from urllib.parse import urlencode, urlparse import attr @@ -249,11 +249,11 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None: class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" - def __init__(self, error, error_description=None): + def __init__(self, error: str, error_description: Optional[str] = None): self.error = error self.error_description = error_description - def __str__(self): + def __str__(self) -> str: if self.error_description: return f"{self.error}: {self.error_description}" return self.error @@ -1057,13 +1057,13 @@ def __init__( self._cached_secret = b"" self._cached_secret_replacement_time = 0 - def __str__(self): + def __str__(self) -> str: # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls # encode_client_secret_basic, which calls "{}".format(secret), which ends up # here. return self._get_secret().decode("ascii") - def __bytes__(self): + def __bytes__(self) -> bytes: # if client_auth_method is client_secret_post, then ClientAuth.prepare calls # encode_client_secret_post, which ends up here. return self._get_secret() @@ -1197,21 +1197,21 @@ def verify_oidc_session_token( ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class OidcSessionData: """The attributes which are stored in a OIDC session cookie""" # the Identity Provider being used - idp_id = attr.ib(type=str) + idp_id: str # The `nonce` parameter passed to the OIDC provider. - nonce = attr.ib(type=str) + nonce: str # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) - client_redirect_url = attr.ib(type=str) + client_redirect_url: str # The session ID of the ongoing UI Auth ("" if this is a login) - ui_auth_session_id = attr.ib(type=str) + ui_auth_session_id: str class UserAttributeDict(TypedDict): @@ -1290,20 +1290,20 @@ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDi # Used to clear out "None" values in templates -def jinja_finalize(thing): +def jinja_finalize(thing: Any) -> Any: return thing if thing is not None else "" env = Environment(finalize=jinja_finalize) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: - subject_claim = attr.ib(type=str) - localpart_template = attr.ib(type=Optional[Template]) - display_name_template = attr.ib(type=Optional[Template]) - email_template = attr.ib(type=Optional[Template]) - extra_attributes = attr.ib(type=Dict[str, Template]) + subject_claim: str + localpart_template: Optional[Template] + display_name_template: Optional[Template] + email_template: Optional[Template] + extra_attributes: Dict[str, Template] class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7dc0ee4be..08b93b3ec 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,6 +15,8 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Set +import attr + from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership @@ -24,7 +26,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import Requester +from synapse.types import JsonDict, Requester from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -36,15 +38,12 @@ logger = logging.getLogger(__name__) +@attr.s(slots=True, auto_attribs=True) class PurgeStatus: """Object tracking the status of a purge request This class contains information on the progress of a purge request, for return by get_purge_status. - - Attributes: - status (int): Tracks whether this request has completed. One of - STATUS_{ACTIVE,COMPLETE,FAILED} """ STATUS_ACTIVE = 0 @@ -57,10 +56,10 @@ class PurgeStatus: STATUS_FAILED: "failed", } - def __init__(self): - self.status = PurgeStatus.STATUS_ACTIVE + # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}. + status: int = STATUS_ACTIVE - def asdict(self): + def asdict(self) -> JsonDict: return {"status": PurgeStatus.STATUS_TEXT[self.status]} @@ -107,7 +106,7 @@ def __init__(self, hs: "HomeServer"): async def purge_history_for_rooms_in_range( self, min_ms: Optional[int], max_ms: Optional[int] - ): + ) -> None: """Purge outdated events from rooms within the given retention range. If a default retention policy is defined in the server's configuration and its @@ -291,7 +290,7 @@ async def _purge_history( self._purges_in_progress_by_room.discard(room_id) # remove the purge from the list 24 hours after it completes - def clear_purge(): + def clear_purge() -> None: del self._purges_by_id[purge_id] self.hs.get_reactor().callLater(24 * 3600, clear_purge) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4ab962a84..841c8815b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -26,18 +26,22 @@ import logging from bisect import bisect from contextlib import contextmanager +from types import TracebackType from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Collection, Dict, FrozenSet, + Generator, Iterable, List, Optional, Set, Tuple, + Type, Union, ) @@ -240,7 +244,7 @@ async def set_state( """ @abc.abstractmethod - async def bump_presence_active_time(self, user: UserID): + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -274,7 +278,7 @@ async def update_external_syncs_clear(self, process_id: str) -> None: async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: """Process streams received over replication.""" await self._federation_queue.process_replication_rows( stream_name, instance_name, token, rows @@ -286,7 +290,7 @@ def get_federation_queue(self) -> "PresenceFederationQueue": async def maybe_send_presence_to_interested_destinations( self, states: List[UserPresenceState] - ): + ) -> None: """If this instance is a federation sender, send the states to all destinations that are interested. Filters out any states for remote users. @@ -309,7 +313,7 @@ async def maybe_send_presence_to_interested_destinations( for destination, host_states in hosts_to_states.items(): self._federation.send_presence_to_destinations(host_states, [destination]) - async def send_full_presence_to_users(self, user_ids: Collection[str]): + async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None: """ Adds to the list of users who should receive a full snapshot of presence upon their next sync. Note that this only works for local users. @@ -363,7 +367,12 @@ async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool class _NullContextManager(ContextManager[None]): """A context manager which does nothing.""" - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: pass @@ -468,7 +477,7 @@ async def user_syncing( if self._user_to_num_current_syncs[user_id] == 1: self.mark_as_coming_online(user_id) - def _end(): + def _end() -> None: # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. @@ -480,7 +489,7 @@ def _end(): self.mark_as_going_offline(user_id) @contextlib.contextmanager - def _user_syncing(): + def _user_syncing() -> Generator[None, None, None]: try: yield finally: @@ -503,7 +512,7 @@ async def notify_from_replication( async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: await super().process_replication_rows(stream_name, instance_name, token, rows) if stream_name != PresenceStream.NAME: @@ -689,7 +698,7 @@ def __init__(self, hs: "HomeServer"): # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. - def run_timeout_handler(): + def run_timeout_handler() -> Awaitable[None]: return run_as_background_process( "handle_presence_timeouts", self._handle_timeouts ) @@ -698,7 +707,7 @@ def run_timeout_handler(): 30, self.clock.looping_call, run_timeout_handler, 5000 ) - def run_persister(): + def run_persister() -> Awaitable[None]: return run_as_background_process( "persist_presence_changes", self._persist_unpersisted_changes ) @@ -942,8 +951,8 @@ async def user_syncing( when users disconnect/reconnect. Args: - user_id (str) - affect_presence (bool): If false this function will be a no-op. + user_id + affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. """ @@ -978,7 +987,7 @@ async def user_syncing( ] ) - async def _end(): + async def _end() -> None: try: self.user_to_num_current_syncs[user_id] -= 1 @@ -994,7 +1003,7 @@ async def _end(): logger.exception("Error updating presence after sync") @contextmanager - def _user_syncing(): + def _user_syncing() -> Generator[None, None, None]: try: yield finally: @@ -1264,7 +1273,7 @@ def notify_new_event(self) -> None: if self._event_processing: return - async def _process_presence(): + async def _process_presence() -> None: assert not self._event_processing self._event_processing = True @@ -1513,7 +1522,7 @@ async def get_new_events( room_ids: Optional[List[str]] = None, include_offline: bool = True, explicit_room_id: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. @@ -2074,7 +2083,7 @@ def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler): if self._queue_presence_updates: self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) - def _clear_queue(self): + def _clear_queue(self) -> None: """Clear out older entries from the queue.""" clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS @@ -2205,7 +2214,7 @@ async def get_replication_rows( async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: if stream_name != PresenceFederationStream.NAME: return diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 51adf8762..246eb9828 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -254,7 +254,7 @@ async def set_avatar_url( requester: Requester, new_avatar_url: str, by_admin: bool = False, - ): + ) -> None: """Set a new avatar URL for a user. Args: @@ -425,7 +425,7 @@ async def check_profile_query_allowed( raise @wrap_as_background_process("Update remote profile") - async def _update_remote_profile_cache(self): + async def _update_remote_profile_cache(self) -> None: """Called periodically to check profiles of remote users we haven't checked in a while. """ diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a49b8ee4b..c7567ac05 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.appservice import ApplicationService @@ -216,7 +216,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: return visible_events async def get_new_events( - self, from_key: int, room_ids: List[str], user: UserID, **kwargs + self, from_key: int, room_ids: List[str], user: UserID, **kwargs: Any ) -> Tuple[List[JsonDict], int]: from_key = int(from_key) to_key = self.get_current_key() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 38c4993da..efb7d2676 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -125,7 +125,7 @@ async def check_username( localpart: str, guest_access_token: Optional[str] = None, assigned_user_id: Optional[str] = None, - ): + ) -> None: if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9345ae02e..abdd50616 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,6 +1,4 @@ -# Copyright 2014 - 2016 OpenMarket Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -186,7 +184,7 @@ async def upgrade_room( async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion - ): + ) -> str: """ Args: requester: the user requesting the upgrade @@ -512,7 +510,7 @@ async def _move_aliases_to_new_room( old_room_id: str, new_room_id: str, old_room_state: StateMap[str], - ): + ) -> None: # check to see if we have a canonical alias. canonical_alias_event = None canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) @@ -902,7 +900,7 @@ async def _send_events_for_new_room( event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} - def create(etype: str, content: JsonDict, **kwargs) -> JsonDict: + def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} e.update(event_keys) @@ -910,7 +908,7 @@ def create(etype: str, content: JsonDict, **kwargs) -> JsonDict: return e - async def send(etype: str, content: JsonDict, **kwargs) -> int: + async def send(etype: str, content: JsonDict, **kwargs: Any) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) # Allow these events to be sent even if the user is shadow-banned to @@ -1033,7 +1031,7 @@ async def _generate_room_id( creator_id: str, is_public: bool, room_version: RoomVersion, - ): + ) -> str: # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. attempts = 0 @@ -1097,7 +1095,7 @@ async def get_event_context( users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users - async def filter_evts(events): + async def filter_evts(events: List[EventBase]) -> List[EventBase]: if use_admin_priviledge: return events return await filter_events_for_client( diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 81680b8df..c83ff585e 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -14,7 +14,7 @@ import logging from collections import namedtuple -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional, Tuple import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -33,7 +33,7 @@ SynapseError, ) from synapse.types import JsonDict, ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler @@ -169,7 +169,7 @@ async def _get_public_room_list( ignore_non_federatable=from_federation, ) - def build_room_entry(room): + def build_room_entry(room: JsonDict) -> JsonDict: entry = { "room_id": room["room_id"], "name": room["name"], @@ -249,10 +249,10 @@ async def generate_room_entry( self, room_id: str, num_joined_users: int, - cache_context, + cache_context: _CacheContext, with_alias: bool = True, allow_private: bool = False, - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """Returns the entry for a room Args: @@ -507,7 +507,7 @@ def to_token(self) -> str: ) ) - def copy_and_replace(self, **kwds) -> "RoomListNextBatch": + def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch": return self._replace(**kwds) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 439020164..a3e13c227 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -225,7 +225,7 @@ async def ratelimit_multiple_invites( room_id: Optional[str], n_invites: int, update: bool = True, - ): + ) -> None: """Ratelimit more than one invite sent by the given requester in the given room. Args: @@ -249,7 +249,7 @@ async def ratelimit_invite( requester: Optional[Requester], room_id: Optional[str], invitee_user_id: str, - ): + ) -> None: """Ratelimit invites by room and by target user. If room ID is missing then we just rate limit by target user. @@ -386,7 +386,7 @@ async def _local_membership_update( return result_event.event_id, result_event.internal_metadata.stream_ordering async def copy_room_tags_and_direct_to_room( - self, old_room_id, new_room_id, user_id + self, old_room_id: str, new_room_id: str, user_id: str ) -> None: """Copies the tags and direct room state from one room to another. @@ -1030,7 +1030,7 @@ async def send_membership_event( event: EventBase, context: EventContext, ratelimit: bool = True, - ): + ) -> None: """ Change the membership status of a user in a room. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 781da9e81..4e28fb968 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -541,7 +541,7 @@ async def get_federation_hierarchy( origin: str, requested_room_id: str, suggested_only: bool, - ): + ) -> JsonDict: """ Implementation of the room hierarchy Federation API. diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 0066d570c..185befbe9 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -40,15 +40,15 @@ logger = logging.getLogger(__name__) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class Saml2SessionData: """Data we track about SAML2 sessions""" # time the session was created, in milliseconds - creation_time = attr.ib() + creation_time: int # The user interactive authentication session ID associated with this SAML # session (or None if this SAML session is for an initial login). - ui_auth_session_id = attr.ib(type=Optional[str], default=None) + ui_auth_session_id: Optional[str] = None class SamlHandler(BaseHandler): @@ -359,7 +359,7 @@ def _remote_id_from_saml_response( return remote_user_id - def expire_sessions(self): + def expire_sessions(self) -> None: expire_before = self.clock.time_msec() - self._saml2_session_lifetime to_expire = set() for reqid, data in self._outstanding_requests_dict.items(): @@ -391,10 +391,10 @@ def dot_replace_for_mxid(username: str) -> str: } -@attr.s +@attr.s(auto_attribs=True) class SamlConfig: - mxid_source_attribute = attr.ib() - mxid_mapper = attr.ib() + mxid_source_attribute: str + mxid_mapper: Callable[[str], str] class DefaultSamlMappingProvider: diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index a31fe3e3c..25e6b012b 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -17,7 +17,7 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from io import BytesIO -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from pkg_resources import parse_version @@ -79,7 +79,7 @@ async def _sendmail( msg = BytesIO(msg_bytes) d: "Deferred[object]" = Deferred() - def build_sender_factory(**kwargs) -> ESMTPSenderFactory: + def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: return ESMTPSenderFactory( username, password, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 05aa76d6a..e044251a1 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -205,7 +205,7 @@ def __init__(self, hs: "HomeServer"): self._consent_at_registration = hs.config.consent.user_consent_at_registration - def register_identity_provider(self, p: SsoIdentityProvider): + def register_identity_provider(self, p: SsoIdentityProvider) -> None: p_id = p.idp_id assert p_id not in self._identity_providers self._identity_providers[p_id] = p @@ -856,7 +856,7 @@ async def handle_submit_username_request( async def handle_terms_accepted( self, request: Request, session_id: str, terms_version: str - ): + ) -> None: """Handle a request to the new-user 'consent' endpoint Will serve an HTTP response to the request. @@ -959,7 +959,7 @@ async def register_sso_user(self, request: Request, session_id: str) -> None: new_user=True, ) - def _expire_old_sessions(self): + def _expire_old_sessions(self) -> None: to_expire = [] now = int(self._clock.time_msec()) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index b64ce8cab..9fc53333f 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -68,7 +68,7 @@ def notify_new_event(self) -> None: self._is_processing = True - async def process(): + async def process() -> None: try: await self._unsafe_process() finally: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7523d8e83..e93db4bdc 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -364,7 +364,9 @@ async def _wait_for_sync_for_user( ) else: - async def current_sync_callback(before_token, after_token) -> SyncResult: + async def current_sync_callback( + before_token: StreamToken, after_token: StreamToken + ) -> SyncResult: return await self.current_sync_for_user(sync_config, since_token) result = await self.notifier.wait_for_events( @@ -1532,9 +1534,9 @@ async def _generate_sync_entry_for_rooms( newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms - async def handle_room_entries(room_entry: "RoomSyncResultBuilder"): + async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) - res = await self._generate_room_entry( + await self._generate_room_entry( sync_result_builder, ignored_users, room_entry, @@ -1544,7 +1546,6 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder"): always_include=sync_result_builder.full_state, ) logger.debug("Generated room entry for %s", room_entry.room_id) - return res await concurrently_execute(handle_room_entries, room_entries, 10) @@ -1925,7 +1926,7 @@ async def _generate_room_entry( tags: Optional[Dict[str, Dict[str, Any]]], account_data: Dict[str, JsonDict], always_include: bool = False, - ): + ) -> None: """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 9cea011e6..4492c8567 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -14,7 +14,7 @@ import logging import random from collections import namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService @@ -485,7 +485,7 @@ async def get_new_events_as( return (events, handler._latest_room_serial) async def get_new_events( - self, from_key: int, room_ids: Iterable[str], **kwargs + self, from_key: int, room_ids: Iterable[str], **kwargs: Any ) -> Tuple[List[JsonDict], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index d3828dec6..ea9325e96 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -70,7 +70,7 @@ async def check_auth(self, authdict: dict, clientip: str) -> Any: class TermsAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.TERMS - def is_enabled(self): + def is_enabled(self) -> bool: return True async def check_auth(self, authdict: dict, clientip: str) -> Any: diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 6faa1d84b..8dc46d767 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -114,7 +114,7 @@ def notify_new_event(self) -> None: if self._is_processing: return - async def process(): + async def process() -> None: try: await self._unsafe_process() finally: From f455b0e4209cd581ea7b75436b178d2843cc6e0d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 20 Sep 2021 17:35:16 +0100 Subject: [PATCH 14/74] GHA: reintroduce an env var for `$GITHUB_HEAD_REF` (#10659) This should ensure GHA runs synapse against the same-named sytest branch --- .github/workflows/tests.yml | 1 + changelog.d/10659.misc | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/10659.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8736699ad..fa9c5e036 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -192,6 +192,7 @@ jobs: volumes: - ${{ github.workspace }}:/src env: + SYTEST_BRANCH: ${{ github.head_ref }} POSTGRES: ${{ matrix.postgres && 1}} MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}} WORKERS: ${{ matrix.workers && 1 }} diff --git a/changelog.d/10659.misc b/changelog.d/10659.misc new file mode 100644 index 000000000..d677a521c --- /dev/null +++ b/changelog.d/10659.misc @@ -0,0 +1 @@ +Fix GitHub Actions config so we can run sytest on synapse from parallel branches. \ No newline at end of file From 6a751ff5e064bbb1fae2915e533031531c9d74e7 Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Tue, 21 Sep 2021 05:23:34 -0500 Subject: [PATCH 15/74] Allow sending a membership event to unban a user (#10807) * Allow membership event to unban user Signed-off-by: Aaron Raimist --- changelog.d/10807.bugfix | 1 + synapse/handlers/room_member.py | 2 +- tests/rest/client/test_rooms.py | 87 ++++++++++++++++++++++++++++++++- tests/rest/client/utils.py | 11 +++++ 4 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10807.bugfix diff --git a/changelog.d/10807.bugfix b/changelog.d/10807.bugfix new file mode 100644 index 000000000..be03f5c73 --- /dev/null +++ b/changelog.d/10807.bugfix @@ -0,0 +1 @@ +Allow sending a membership event to unban a user. Contributed by @aaronraimist. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a3e13c227..7bb3f0bc4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -668,7 +668,7 @@ async def update_membership_locked( " (membership=%s)" % old_membership, errcode=Codes.BAD_STATE, ) - if old_membership == "ban" and action != "unban": + if old_membership == "ban" and action not in ["ban", "unban", "leave"]: raise SynapseError( 403, "Cannot %s user who was banned" % (action,), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 50100a5ae..5a01765f4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -26,7 +26,7 @@ import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership -from synapse.api.errors import HttpResponseException +from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync @@ -377,6 +377,91 @@ def test_leave_permissions(self): expect_code=403, ) + # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember + def test_member_event_from_ban(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + + other = "@burgundy:red" + + # User cannot ban other since they do not have required power level + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.BAN, + expect_code=403, # expect failure + expect_errcode=Codes.FORBIDDEN, + ) + + # Admin bans other + self.helper.change_membership( + room=room, + src=self.rmcreator_id, + targ=other, + membership=Membership.BAN, + expect_code=200, + ) + + # from ban to invite: Must never happen. + self.helper.change_membership( + room=room, + src=self.rmcreator_id, + targ=other, + membership=Membership.INVITE, + expect_code=403, # expect failure + expect_errcode=Codes.BAD_STATE, + ) + + # from ban to join: Must never happen. + self.helper.change_membership( + room=room, + src=other, + targ=other, + membership=Membership.JOIN, + expect_code=403, # expect failure + expect_errcode=Codes.BAD_STATE, + ) + + # from ban to ban: No change. + self.helper.change_membership( + room=room, + src=self.rmcreator_id, + targ=other, + membership=Membership.BAN, + expect_code=200, + ) + + # from ban to knock: Must never happen. + self.helper.change_membership( + room=room, + src=self.rmcreator_id, + targ=other, + membership=Membership.KNOCK, + expect_code=403, # expect failure + expect_errcode=Codes.BAD_STATE, + ) + + # User cannot unban other since they do not have required power level + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.LEAVE, + expect_code=403, # expect failure + expect_errcode=Codes.FORBIDDEN, + ) + + # from ban to leave: User was unbanned. + self.helper.change_membership( + room=room, + src=self.rmcreator_id, + targ=other, + membership=Membership.LEAVE, + expect_code=200, + ) + class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 954ad1a1f..c56e45fc1 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -138,6 +138,7 @@ def change_membership( extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, + expect_errcode: str = None, ) -> None: """ Send a membership state event into a room. @@ -150,6 +151,7 @@ def change_membership( extra_data: Extra information to include in the content of the event tok: The user access token to use expect_code: The expected HTTP response code + expect_errcode: The expected Matrix error code """ temp_id = self.auth_user_id self.auth_user_id = src @@ -177,6 +179,15 @@ def change_membership( channel.result["body"], ) + if expect_errcode: + assert ( + str(channel.json_body["errcode"]) == expect_errcode + ), "Expected: %r, got: %r, resp: %r" % ( + expect_errcode, + channel.json_body["errcode"], + channel.result["body"], + ) + self.auth_user_id = temp_id def send( From 60453315bdbbbd364f13ca386de965e015f1062f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 21 Sep 2021 13:02:34 +0100 Subject: [PATCH 16/74] Always add local users to the user directory (#10796) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It's a simplification, but one that'll help make the user directory logic easier to follow with the other changes upcoming. It's not strictly required for those changes, but this will help simplify the resulting logic that listens for `m.room.member` events and generally make the logic easier to follow. This means the config option `search_all_users` ends up controlling the search query only, and not the data we store. The cost of doing so is an extra row in the `user_directory` and `user_directory_search` tables for each local user which - belongs to no public rooms - belongs to no private rooms of size ≥ 2 I think the cost of this will be marginal (since they'll already have entries in `users` and `profiles` anyway). As a small upside, a homeserver whose directory was built with this change can toggle `search_all_users` without having to rebuild their directory. Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/10796.misc | 1 + docs/sample_config.yaml | 14 ++++++---- synapse/config/user_directory.py | 14 ++++++---- synapse/handlers/deactivate_account.py | 7 ++--- synapse/handlers/profile.py | 18 ++++++------- synapse/handlers/register.py | 9 +++---- .../storage/databases/main/user_directory.py | 27 +++++++------------ tests/handlers/test_profile.py | 7 +++-- tests/rest/client/test_rooms.py | 12 ++++----- 9 files changed, 54 insertions(+), 55 deletions(-) create mode 100644 changelog.d/10796.misc diff --git a/changelog.d/10796.misc b/changelog.d/10796.misc new file mode 100644 index 000000000..1873b2386 --- /dev/null +++ b/changelog.d/10796.misc @@ -0,0 +1 @@ +Simplify the internal logic which maintains the user directory database tables. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 95cca1655..166cec38d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2362,12 +2362,16 @@ user_directory: #enabled: false # Defines whether to search all users visible to your HS when searching - # the user directory, rather than limiting to users visible in public - # rooms. Defaults to false. + # the user directory. If false, search results will only contain users + # visible in public rooms and users sharing a room with the requester. + # Defaults to false. # - # If you set it true, you'll have to rebuild the user_directory search - # indexes, see: - # https://matrix-org.github.io/synapse/latest/user_directory.html + # NB. If you set this to true, and the last time the user_directory search + # indexes were (re)built was before Synapse 1.44, you'll have to + # rebuild the indexes in order to search through all known users. + # These indexes are built the first time Synapse starts; admins can + # manually trigger a rebuild following the instructions at + # https://matrix-org.github.io/synapse/latest/user_directory.html # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index b10df8a23..2552f688d 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -45,12 +45,16 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): #enabled: false # Defines whether to search all users visible to your HS when searching - # the user directory, rather than limiting to users visible in public - # rooms. Defaults to false. + # the user directory. If false, search results will only contain users + # visible in public rooms and users sharing a room with the requester. + # Defaults to false. # - # If you set it true, you'll have to rebuild the user_directory search - # indexes, see: - # https://matrix-org.github.io/synapse/latest/user_directory.html + # NB. If you set this to true, and the last time the user_directory search + # indexes were (re)built was before Synapse 1.44, you'll have to + # rebuild the indexes in order to search through all known users. + # These indexes are built the first time Synapse starts; admins can + # manually trigger a rebuild following the instructions at + # https://matrix-org.github.io/synapse/latest/user_directory.html # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index dcd320c55..a03ff9842 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -257,11 +257,8 @@ async def activate_account(self, user_id: str) -> None: """ # Add the user to the directory, if necessary. user = UserID.from_string(user_id) - if self.hs.config.user_directory_search_all_users: - profile = await self.store.get_profileinfo(user.localpart) - await self.user_directory_handler.handle_local_profile_change( - user_id, profile - ) + profile = await self.store.get_profileinfo(user.localpart) + await self.user_directory_handler.handle_local_profile_change(user_id, profile) # Ensure the user is not marked as erased. await self.store.mark_user_not_erased(user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 246eb9828..f06070bfc 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -214,11 +214,10 @@ async def set_displayname( target_user.localpart, displayname_to_set ) - if self.hs.config.user_directory_search_all_users: - profile = await self.store.get_profileinfo(target_user.localpart) - await self.user_directory_handler.handle_local_profile_change( - target_user.to_string(), profile - ) + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) await self._update_join_states(requester, target_user) @@ -300,11 +299,10 @@ async def set_avatar_url( target_user.localpart, avatar_url_to_set ) - if self.hs.config.user_directory_search_all_users: - profile = await self.store.get_profileinfo(target_user.localpart) - await self.user_directory_handler.handle_local_profile_change( - target_user.to_string(), profile - ) + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) await self._update_join_states(requester, target_user) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index efb7d2676..1c195c65d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -295,11 +295,10 @@ async def register_user( shadow_banned=shadow_banned, ) - if self.hs.config.user_directory_search_all_users: - profile = await self.store.get_profileinfo(localpart) - await self.user_directory_handler.handle_local_profile_change( - user_id, profile - ) + profile = await self.store.get_profileinfo(localpart) + await self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) else: # autogen a sequential user ID diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 8aebdc281..718f3e997 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -85,19 +85,17 @@ def _make_staging_area(txn): self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms - # If search all users is on, get all the users we want to add. - if self.hs.config.user_directory_search_all_users: - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_users(user_id TEXT NOT NULL)" - ) - txn.execute(sql) + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_users(user_id TEXT NOT NULL)" + ) + txn.execute(sql) - txn.execute("SELECT name FROM users") - users = [{"user_id": x[0]} for x in txn.fetchall()] + txn.execute("SELECT name FROM users") + users = [{"user_id": x[0]} for x in txn.fetchall()] - self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = await self.get_max_stream_id_in_current_state_deltas() await self.db_pool.runInteraction( @@ -265,13 +263,8 @@ def _get_next_batch(txn): async def _populate_user_directory_process_users(self, progress, batch_size): """ - If search_all_users is enabled, add all of the users to the user directory. + Add all local users to the user directory. """ - if not self.hs.config.user_directory_search_all_users: - await self.db_pool.updates._end_background_update( - "populate_user_directory_process_users" - ) - return 1 def _get_next_batch(txn): sql = "SELECT user_id FROM %s LIMIT %s" % ( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 2928c4f48..57cc3e264 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -16,6 +16,7 @@ import synapse.types from synapse.api.errors import AuthError, SynapseError +from synapse.rest import admin from synapse.types import UserID from tests import unittest @@ -25,6 +26,8 @@ class ProfileTestCase(unittest.HomeserverTestCase): """Tests profile management.""" + servlets = [admin.register_servlets] + def make_homeserver(self, reactor, clock): self.mock_federation = Mock() self.mock_registry = Mock() @@ -46,11 +49,11 @@ def register_query_handler(query_type, handler): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.frank = UserID.from_string("@1234ABCD:test") + self.frank = UserID.from_string("@1234abcd:test") self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - self.get_success(self.store.create_profile(self.frank.localpart)) + self.get_success(self.register_user(self.frank.localpart, "frankpassword")) self.handler = hs.get_profile_handler() diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 5a01765f4..ef847f0f5 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -869,6 +869,12 @@ class RoomJoinRatelimitTestCase(RoomBase): room.register_servlets, ] + def prepare(self, reactor, clock, homeserver): + super().prepare(reactor, clock, homeserver) + # profile changes expect that the user is actually registered + user = UserID.from_string(self.user_id) + self.get_success(self.register_user(user.localpart, "supersecretpassword")) + @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) @@ -898,12 +904,6 @@ def test_join_local_ratelimit_profile_change(self): # join in a second. room_ids.append(self.helper.create_room_as(self.user_id)) - # Create a profile for the user, since it hasn't been done on registration. - store = self.hs.get_datastore() - self.get_success( - store.create_profile(UserID.from_string(self.user_id).localpart) - ) - # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) From ee557b5375e376e5664f6a3e372c946f7a754f75 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 21 Sep 2021 08:10:01 -0500 Subject: [PATCH 17/74] Rename `/batch_send` query parameter from `?prev_event` to more obvious usage with `?prev_event_id` (MSC2716) (#10839) As mentioned in https://github.com/matrix-org/matrix-doc/pull/2716#discussion_r705872887 and https://github.com/matrix-org/synapse/issues/10737 --- changelog.d/10839.misc | 1 + synapse/rest/client/room_batch.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 changelog.d/10839.misc diff --git a/changelog.d/10839.misc b/changelog.d/10839.misc new file mode 100644 index 000000000..d0e10f31d --- /dev/null +++ b/changelog.d/10839.misc @@ -0,0 +1 @@ +Rename [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` query parameter from `?prev_event` to more obvious usage with `?prev_event_id`. diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index d466edeec..f73ccc7f6 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -61,7 +61,7 @@ class RoomBatchSendEventRestServlet(RestServlet): some messages, you can only insert older ones after that. tldr; Insert chunks from your most recent history -> oldest history. - POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= + POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event_id=&chunk_id= { "events": [ ... ], "state_events_at_start": [ ... ] @@ -188,24 +188,26 @@ async def on_POST( assert_params_in_dict(body, ["state_events_at_start", "events"]) assert request.args is not None - prev_events_from_query = parse_strings_from_args(request.args, "prev_event") + prev_event_ids_from_query = parse_strings_from_args( + request.args, "prev_event_id" + ) chunk_id_from_query = parse_string(request, "chunk_id") - if prev_events_from_query is None: + if prev_event_ids_from_query is None: raise SynapseError( HTTPStatus.BAD_REQUEST, "prev_event query parameter is required when inserting historical messages back in time", errcode=Codes.MISSING_PARAM, ) - # For the event we are inserting next to (`prev_events_from_query`), + # For the event we are inserting next to (`prev_event_ids_from_query`), # find the most recent auth events (derived from state events) that # allowed that message to be sent. We will use that as a base # to auth our historical messages against. ( most_recent_prev_event_id, _, - ) = await self.store.get_max_depth_of(prev_events_from_query) + ) = await self.store.get_max_depth_of(prev_event_ids_from_query) # mapping from (type, state_key) -> state_event_id prev_state_map = await self.state_store.get_state_ids_for_event( most_recent_prev_event_id @@ -286,7 +288,7 @@ async def on_POST( events_to_create = body["events"] inherited_depth = await self._inherit_depth_from_prev_ids( - prev_events_from_query + prev_event_ids_from_query ) # Figure out which chunk to connect to. If they passed in @@ -321,7 +323,7 @@ async def on_POST( # an insertion event), in which case we just create a new insertion event # that can then get pointed to by a "marker" event later. else: - prev_event_ids = prev_events_from_query + prev_event_ids = prev_event_ids_from_query base_insertion_event_dict = self._create_insertion_event_dict( sender=requester.user.to_string(), From 5fca3c8ae62c66a1777bcb85c98679669369b061 Mon Sep 17 00:00:00 2001 From: Hillery Shay Date: Tue, 21 Sep 2021 08:04:35 -0700 Subject: [PATCH 18/74] Allow Synapse Admin API's Room Search to accept non-ASCII characters (#10859) * add tests for checking if room search works with non-ascii char * change encoding on parse_string to UTF-8 * lints * properly encode search term * lints * add changelog file * update changelog number * set changelog entry filetype to .bugfix * Revert "set changelog entry filetype to .bugfix" This reverts commit be8e5a314251438ec4ec7dbc59ba32162c93e550. * update changelog message and file type * change parse_string default encoding back to ascii and update room search admin api calll to parse string * refactor tests * Update tests/rest/admin/test_room.py Co-authored-by: Patrick Cloke Co-authored-by: Patrick Cloke --- changelog.d/10859.bugfix | 1 + synapse/rest/admin/rooms.py | 2 +- tests/rest/admin/test_room.py | 27 +++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10859.bugfix diff --git a/changelog.d/10859.bugfix b/changelog.d/10859.bugfix new file mode 100644 index 000000000..c1bfe22d5 --- /dev/null +++ b/changelog.d/10859.bugfix @@ -0,0 +1 @@ +Fix a bug in Unicode support of the room search admin API. It is now possible to search for rooms with non-ASCII characters. \ No newline at end of file diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index ad83d4b54..8f781f745 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -125,7 +125,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: errcode=Codes.INVALID_PARAM, ) - search_term = parse_string(request, "search_term") + search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": raise SynapseError( 400, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 40e032df7..e798513ac 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -941,6 +941,33 @@ def _search_test( _search_test(None, "bar") _search_test(None, "", expected_http_code=400) + def test_search_term_non_ascii(self): + """Test that searching for a room with non-ASCII characters works correctly""" + + # Create test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_name = "ж" + + # Set the name for the room + self.helper.send_state( + room_id, + "m.room.name", + {"name": room_name}, + tok=self.admin_user_tok, + ) + + # make the request and test that the response is what we wanted + search_term = urllib.parse.quote("ж", "utf-8") + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) + self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) + def test_single_room(self): """Test that a single room can be requested correctly""" # Create two test rooms From 2843058a8b5456dd63b83ad39a992d5d1a285eb6 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 21 Sep 2021 17:40:20 +0200 Subject: [PATCH 19/74] Test that state events sent by modules correctly end up in the room's state (#10835) Test for #10830 Ideally the test would also make sure the new state event comes down sync, but this is probably good enough. --- changelog.d/10835.misc | 1 + tests/rest/client/test_third_party_rules.py | 84 +++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 changelog.d/10835.misc diff --git a/changelog.d/10835.misc b/changelog.d/10835.misc new file mode 100644 index 000000000..0c3d13477 --- /dev/null +++ b/changelog.d/10835.misc @@ -0,0 +1 @@ +Add a test to ensure state events sent by modules get persisted correctly. diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 0ae402964..38ac9be11 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -15,6 +15,7 @@ from typing import Dict from unittest.mock import Mock +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi @@ -327,3 +328,86 @@ def test_legacy_on_create_room(self): correctly. """ self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) + + def test_sent_event_end_up_in_room_state(self): + """Tests that a state event sent by a module while processing another state event + doesn't get dropped from the state of the room. This is to guard against a bug + where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830 + """ + event_type = "org.matrix.test_state" + + # This content will be updated later on, and since we actually use a reference on + # the dict it does the right thing. It's a bit hacky but a handy way of making + # sure the state actually gets updated. + event_content = {"i": -1} + + api = self.hs.get_module_api() + + # Define a callback that sends a custom event on power levels update. + async def test_fn(event: EventBase, state_events): + if event.is_state and event.type == EventTypes.PowerLevels: + await api.create_and_send_event_into_room( + { + "room_id": event.room_id, + "sender": event.sender, + "type": event_type, + "content": event_content, + "state_key": "", + } + ) + return True, None + + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [test_fn] + + # Sometimes the bug might not happen the first time the event type is added + # to the state but might happen when an event updates the state of the room for + # that type, so we test updating the state several times. + for i in range(5): + # Update the content of the custom state event to be sent by the callback. + event_content["i"] = i + + # Update the room's power levels with a different value each time so Synapse + # doesn't consider an update redundant. + self._update_power_levels(event_default=i) + + # Check that the new event made it to the room's state. + channel = self.make_request( + method="GET", + path="/rooms/" + self.room_id + "/state/" + event_type, + access_token=self.tok, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["i"], i) + + def _update_power_levels(self, event_default: int = 0): + """Updates the room's power levels. + + Args: + event_default: Value to use for 'events_default'. + """ + self.helper.send_state( + room_id=self.room_id, + event_type=EventTypes.PowerLevels, + body={ + "ban": 50, + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.encryption": 100, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.server_acl": 100, + "m.room.tombstone": 100, + }, + "events_default": event_default, + "invite": 0, + "kick": 50, + "redact": 50, + "state_default": 50, + "users": {self.user_id: 100}, + "users_default": 0, + }, + tok=self.tok, + ) From ba7a91aea5fd624bf048f0fda0dca80da7a1945e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 21 Sep 2021 12:09:57 -0400 Subject: [PATCH 20/74] Refactor oEmbed previews (#10814) The major change is moving the decision of whether to use oEmbed further up the call-stack. This reverts the _download_url method to being a "dumb" functionwhich takes a single URL and downloads it (as it was before #7920). This also makes more minor refactorings: * Renames internal variables for clarity. * Factors out shared code between the HTML and rich oEmbed previews. * Fixes tests to preview an oEmbed image. --- changelog.d/10814.feature | 1 + docs/development/url_previews.md | 21 +- synapse/rest/media/v1/oembed.py | 145 +++++--- synapse/rest/media/v1/preview_url_resource.py | 326 ++++++++++-------- tests/rest/media/v1/test_url_preview.py | 26 +- 5 files changed, 299 insertions(+), 220 deletions(-) create mode 100644 changelog.d/10814.feature diff --git a/changelog.d/10814.feature b/changelog.d/10814.feature new file mode 100644 index 000000000..4fa95a6cc --- /dev/null +++ b/changelog.d/10814.feature @@ -0,0 +1 @@ +Improve oEmbed previews by processing the author name, photo, and video information. diff --git a/docs/development/url_previews.md b/docs/development/url_previews.md index bbe05e281..aff381360 100644 --- a/docs/development/url_previews.md +++ b/docs/development/url_previews.md @@ -25,16 +25,14 @@ When Synapse is asked to preview a URL it does the following: 3. Kicks off a background process to generate a preview: 1. Checks the database cache by URL and timestamp and returns the result if it has not expired and was successful (a 2xx return code). - 2. Checks if the URL matches an oEmbed pattern. If it does, fetch the oEmbed - response. If this is an image, replace the URL to fetch and continue. If - if it is HTML content, use the HTML as the document and continue. - 3. If it doesn't match an oEmbed pattern, downloads the URL and stores it - into a file via the media storage provider and saves the local media - metadata. - 5. If the media is an image: + 2. Checks if the URL matches an [oEmbed](https://oembed.com/) pattern. If it + does, update the URL to download. + 3. Downloads the URL and stores it into a file via the media storage provider + and saves the local media metadata. + 4. If the media is an image: 1. Generates thumbnails. 2. Generates an Open Graph response based on image properties. - 6. If the media is HTML: + 5. If the media is HTML: 1. Decodes the HTML via the stored file. 2. Generates an Open Graph response from the HTML. 3. If an image exists in the Open Graph response: @@ -42,6 +40,13 @@ When Synapse is asked to preview a URL it does the following: provider and saves the local media metadata. 2. Generates thumbnails. 3. Updates the Open Graph response based on image properties. + 6. If the media is JSON and an oEmbed URL was found: + 1. Convert the oEmbed response to an Open Graph response. + 2. If a thumbnail or image is in the oEmbed response: + 1. Downloads the URL and stores it into a file via the media storage + provider and saves the local media metadata. + 2. Generates thumbnails. + 3. Updates the Open Graph response based on image properties. 7. Stores the result in the database cache. 4. Returns the result. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 2e6706dbf..8b74e7265 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import urllib.parse from typing import TYPE_CHECKING, Optional import attr from synapse.http.client import SimpleHttpClient +from synapse.types import JsonDict +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -24,18 +27,15 @@ logger = logging.getLogger(__name__) -@attr.s(slots=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class OEmbedResult: - # Either HTML content or URL must be provided. - html: Optional[str] - url: Optional[str] - title: Optional[str] - # Number of seconds to cache the content. - cache_age: int - - -class OEmbedError(Exception): - """An error occurred processing the oEmbed object.""" + # The Open Graph result (converted from the oEmbed result). + open_graph_result: JsonDict + # Number of seconds to cache the content, according to the oEmbed response. + # + # This will be None if no cache-age is provided in the oEmbed response (or + # if the oEmbed response cannot be turned into an Open Graph response). + cache_age: Optional[int] class OEmbedProvider: @@ -81,75 +81,106 @@ def get_oembed_url(self, url: str) -> Optional[str]: """ for url_pattern, endpoint in self._oembed_patterns.items(): if url_pattern.fullmatch(url): - return endpoint + # TODO Specify max height / width. + + # Note that only the JSON format is supported, some endpoints want + # this in the URL, others want it as an argument. + endpoint = endpoint.replace("{format}", "json") + + args = {"url": url, "format": "json"} + query_str = urllib.parse.urlencode(args, True) + return f"{endpoint}?{query_str}" # No match. return None - async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: + def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: """ - Request content from an oEmbed endpoint. + Parse the oEmbed response into an Open Graph response. Args: - endpoint: The oEmbed API endpoint. - url: The URL to pass to the API. + url: The URL which is being previewed (not the one which was + requested). + raw_body: The oEmbed response as JSON encoded as bytes. Returns: - An object representing the metadata returned. - - Raises: - OEmbedError if fetching or parsing of the oEmbed information fails. + json-encoded Open Graph data """ - try: - logger.debug("Trying to get oEmbed content for url '%s'", url) - # Note that only the JSON format is supported, some endpoints want - # this in the URL, others want it as an argument. - endpoint = endpoint.replace("{format}", "json") - - result = await self._client.get_json( - endpoint, - # TODO Specify max height / width. - args={"url": url, "format": "json"}, - ) + try: + # oEmbed responses *must* be UTF-8 according to the spec. + oembed = json_decoder.decode(raw_body.decode("utf-8")) # Ensure there's a version of 1.0. - if result.get("version") != "1.0": - raise OEmbedError("Invalid version: %s" % (result.get("version"),)) - - oembed_type = result.get("type") + oembed_version = oembed["version"] + if oembed_version != "1.0": + raise RuntimeError(f"Invalid version: {oembed_version}") # Ensure the cache age is None or an int. - cache_age = result.get("cache_age") + cache_age = oembed.get("cache_age") if cache_age: cache_age = int(cache_age) - oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) + # The results. + open_graph_response = {"og:title": oembed.get("title")} - # HTML content. + # If a thumbnail exists, use it. Note that dimensions will be calculated later. + if "thumbnail_url" in oembed: + open_graph_response["og:image"] = oembed["thumbnail_url"] + + # Process each type separately. + oembed_type = oembed["type"] if oembed_type == "rich": - oembed_result.html = result.get("html") - return oembed_result + calc_description_and_urls(open_graph_response, oembed["html"]) - if oembed_type == "photo": - oembed_result.url = result.get("url") - return oembed_result + elif oembed_type == "photo": + # If this is a photo, use the full image, not the thumbnail. + open_graph_response["og:image"] = oembed["url"] - # TODO Handle link and video types. + else: + raise RuntimeError(f"Unknown oEmbed type: {oembed_type}") - if "thumbnail_url" in result: - oembed_result.url = result.get("thumbnail_url") - return oembed_result + except Exception as e: + # Trap any exception and let the code follow as usual. + logger.warning(f"Error parsing oEmbed metadata from {url}: {e:r}") + open_graph_response = {} + cache_age = None - raise OEmbedError("Incompatible oEmbed information.") + return OEmbedResult(open_graph_response, cache_age) - except OEmbedError as e: - # Trap OEmbedErrors first so we can directly re-raise them. - logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) - raise - except Exception as e: - # Trap any exception and let the code follow as usual. - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) - raise OEmbedError() from e +def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None: + """ + Calculate description for an HTML document. + + This uses lxml to convert the HTML document into plaintext. If errors + occur during processing of the document, an empty response is returned. + + Args: + open_graph_response: The current Open Graph summary. This is updated with additional fields. + html_body: The HTML document, as bytes. + + Returns: + The summary + """ + # If there's no body, nothing useful is going to be found. + if not html_body: + return + + from lxml import etree + + # Create an HTML parser. If this fails, log and return no metadata. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(html_body, parser) + + # The data was successfully parsed, but no tree was found. + if tree is None: + return + + from synapse.rest.media.v1.preview_url_resource import _calc_description + + description = _calc_description(tree) + if description: + open_graph_response["og:description"] = description diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fe0627d9b..0a0b476d2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -44,7 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage -from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider +from synapse.rest.media.v1.oembed import OEmbedProvider from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred @@ -73,6 +73,7 @@ OG_TAG_VALUE_MAXLEN = 1000 ONE_HOUR = 60 * 60 * 1000 +ONE_DAY = 24 * ONE_HOUR @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -255,10 +256,19 @@ async def _do_preview(self, url: str, user: str, ts: int) -> bytes: og = og.encode("utf8") return og - media_info = await self._download_url(url, user) + # If this URL can be accessed via oEmbed, use that instead. + url_to_download = url + oembed_url = self._oembed.get_oembed_url(url) + if oembed_url: + url_to_download = oembed_url + + media_info = await self._download_url(url_to_download, user) logger.debug("got media_info of '%s'", media_info) + # The number of milliseconds that the response should be considered valid. + expiration_ms = media_info.expires + if _is_media(media_info.media_type): file_id = media_info.filesystem_id dims = await self.media_repo._generate_thumbnails( @@ -288,34 +298,22 @@ async def _do_preview(self, url: str, user: str, ts: int) -> bytes: encoding = get_html_media_encoding(body, media_info.media_type) og = decode_and_calc_og(body, media_info.uri, encoding) - # pre-cache the image for posterity - # FIXME: it might be cleaner to use the same flow as the main /preview_url - # request itself and benefit from the same caching etc. But for now we - # just rely on the caching on the master request to speed things up. - if "og:image" in og and og["og:image"]: - image_info = await self._download_url( - _rebase_url(og["og:image"], media_info.uri), user - ) + await self._precache_image_url(user, media_info, og) + + elif oembed_url and _is_json(media_info.media_type): + # Handle an oEmbed response. + with open(media_info.filename, "rb") as file: + body = file.read() + + oembed_response = self._oembed.parse_oembed_response(media_info.uri, body) + og = oembed_response.open_graph_result + + # Use the cache age from the oEmbed result, instead of the HTTP response. + if oembed_response.cache_age is not None: + expiration_ms = oembed_response.cache_age + + await self._precache_image_url(user, media_info, og) - if _is_media(image_info.media_type): - # TODO: make sure we don't choke on white-on-transparent images - file_id = image_info.filesystem_id - dims = await self.media_repo._generate_thumbnails( - None, file_id, file_id, image_info.media_type, url_cache=True - ) - if dims: - og["og:image:width"] = dims["width"] - og["og:image:height"] = dims["height"] - else: - logger.warning("Couldn't get dims for %s", og["og:image"]) - - og[ - "og:image" - ] = f"mxc://{self.server_name}/{image_info.filesystem_id}" - og["og:image:type"] = image_info.media_type - og["matrix:image:size"] = image_info.media_length - else: - del og["og:image"] else: logger.warning("Failed to find any OG data in %s", url) og = {} @@ -336,12 +334,15 @@ async def _do_preview(self, url: str, user: str, ts: int) -> bytes: jsonog = json_encoder.encode(og) + # Cap the amount of time to consider a response valid. + expiration_ms = min(expiration_ms, ONE_DAY) + # store OG in history-aware DB cache await self.store.store_url_cache( url, media_info.response_code, media_info.etag, - media_info.expires + media_info.created_ts_ms, + media_info.created_ts_ms + expiration_ms, jsonog, media_info.filesystem_id, media_info.created_ts_ms, @@ -358,88 +359,52 @@ async def _download_url(self, url: str, user: str) -> MediaInfo: file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) - # If this URL can be accessed via oEmbed, use that instead. - url_to_download: Optional[str] = url - oembed_url = self._oembed.get_oembed_url(url) - if oembed_url: - # The result might be a new URL to download, or it might be HTML content. + with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: - oembed_result = await self._oembed.get_oembed_content(oembed_url, url) - if oembed_result.url: - url_to_download = oembed_result.url - elif oembed_result.html: - url_to_download = None - except OEmbedError: - # If an error occurs, try doing a normal preview. - pass + logger.debug("Trying to get preview for url '%s'", url) + length, headers, uri, code = await self.client.get_file( + url, + output_stream=f, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, + ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, + ) + except Exception as e: + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading %s: %r", url, e) - if url_to_download: - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - try: - logger.debug("Trying to get preview for url '%s'", url_to_download) - length, headers, uri, code = await self.client.get_file( - url_to_download, - output_stream=f, - max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, - ) - except SynapseError: - # Pass SynapseErrors through directly, so that the servlet - # handler will return a SynapseError to the client instead of - # blank data or a 500. - raise - except DNSLookupError: - # DNS lookup returned no results - # Note: This will also be the case if one of the resolved IP - # addresses is blacklisted - raise SynapseError( - 502, - "DNS resolution failure during URL preview generation", - Codes.UNKNOWN, - ) - except Exception as e: - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading %s: %r", url_to_download, e) - - raise SynapseError( - 500, - "Failed to download content: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - await finish() - - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") - else: - media_type = "application/octet-stream" + raise SynapseError( + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + await finish() - download_name = get_filename_from_headers(headers) + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - expires = ONE_HOUR - etag = ( - headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None - ) - else: - # we can only get here if we did an oembed request and have an oembed_result.html - assert oembed_result.html is not None - assert oembed_url is not None - - html_bytes = oembed_result.html.encode("utf-8") - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - f.write(html_bytes) - await finish() - - media_type = "text/html" - download_name = oembed_result.title - length = len(html_bytes) - # If a specific cache age was not given, assume 1 hour. - expires = oembed_result.cache_age or ONE_HOUR - uri = oembed_url - code = 200 - etag = None + download_name = get_filename_from_headers(headers) + + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + expires = ONE_HOUR + etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None try: time_now_ms = self.clock.time_msec() @@ -474,6 +439,46 @@ async def _download_url(self, url: str, user: str) -> MediaInfo: etag=etag, ) + async def _precache_image_url( + self, user: str, media_info: MediaInfo, og: JsonDict + ) -> None: + """ + Pre-cache the image (if one exists) for posterity + + Args: + user: The user requesting the preview. + media_info: The media being previewed. + og: The Open Graph dictionary. This is modified with image information. + """ + # If there's no image or it is blank, there's nothing to do. + if "og:image" not in og or not og["og:image"]: + return + + # FIXME: it might be cleaner to use the same flow as the main /preview_url + # request itself and benefit from the same caching etc. But for now we + # just rely on the caching on the master request to speed things up. + image_info = await self._download_url( + _rebase_url(og["og:image"], media_info.uri), user + ) + + if _is_media(image_info.media_type): + # TODO: make sure we don't choke on white-on-transparent images + file_id = image_info.filesystem_id + dims = await self.media_repo._generate_thumbnails( + None, file_id, file_id, image_info.media_type, url_cache=True + ) + if dims: + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] + else: + logger.warning("Couldn't get dims for %s", og["og:image"]) + + og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}" + og["og:image:type"] = image_info.media_type + og["matrix:image:size"] = image_info.media_length + else: + del og["og:image"] + def _start_expire_url_cache_data(self) -> Deferred: return run_as_background_process( "expire_url_cache_data", self._expire_url_cache_data @@ -527,7 +532,7 @@ async def _expire_url_cache_data(self) -> None: # These may be cached for a bit on the client (i.e., they # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. - expire_before = now - 2 * 24 * ONE_HOUR + expire_before = now - 2 * ONE_DAY media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] @@ -669,7 +674,18 @@ def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str] def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: - # suck our tree into lxml and define our OG response. + """ + Calculate metadata for an HTML document. + + This uses lxml to search the HTML document for Open Graph data. + + Args: + tree: The parsed HTML document. + media_url: The URI used to download the body. + + Returns: + The Open Graph response as a dictionary. + """ # if we see any image URLs in the OG response, then spider them # (although the client could choose to do this by asking for previews of those @@ -743,35 +759,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: if meta_description: og["og:description"] = meta_description[0] else: - # grab any text nodes which are inside the tag... - # unless they are within an HTML5 semantic markup tag... - #
,