From 6adb89ff007500ea9c41fb5bd1a9e644cc6397cd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Mar 2022 06:56:16 -0500 Subject: [PATCH 001/230] Improve and refactor the tests for relations. (#12113) * Modernizes code (f-strings, etc.) * Fixes incorrect comments. * Splits the test case into two. * Factors out some duplicated code. --- changelog.d/12113.misc | 1 + tests/rest/client/test_relations.py | 386 +++++++++++++--------------- 2 files changed, 179 insertions(+), 208 deletions(-) create mode 100644 changelog.d/12113.misc diff --git a/changelog.d/12113.misc b/changelog.d/12113.misc new file mode 100644 index 000000000000..102e064053c2 --- /dev/null +++ b/changelog.d/12113.misc @@ -0,0 +1 @@ +Refactor the tests for event relations. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c8db45719e07..a087cd7b2149 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -34,7 +34,7 @@ from tests.test_utils.event_injection import inject_event -class RelationsTestCase(unittest.HomeserverTestCase): +class BaseRelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, @@ -48,7 +48,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): def default_config(self) -> dict: # We need to enable msc1849 support for aggregations config = super().default_config() - config["experimental_msc1849_support_enabled"] = True # We enable frozen dicts as relations/edits change event contents, so we # want to test that we don't modify the events in the caches. @@ -67,10 +66,62 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] - def test_send_relation(self) -> None: - """Tests that sending a relation using the new /send_relation works - creates the right shape of event. + def _create_user(self, localpart: str) -> Tuple[str, str]: + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token + + def _send_relation( + self, + relation_type: str, + event_type: str, + key: Optional[str] = None, + content: Optional[dict] = None, + access_token: Optional[str] = None, + parent_id: Optional[str] = None, + ) -> FakeChannel: + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type: One of `RelationTypes` + event_type: The type of the event to create + key: The aggregation key used for m.annotation relation type. + content: The content of the created event. Will be modified to configure + the m.relates_to key based on the other provided parameters. + access_token: The access token used to send the relation, defaults + to `self.user_token` + parent_id: The event_id this relation relates to. If None, then self.parent_id + + Returns: + FakeChannel """ + if not access_token: + access_token = self.user_token + + original_id = parent_id if parent_id else self.parent_id + + if content is None: + content = {} + content["m.relates_to"] = { + "event_id": original_id, + "rel_type": relation_type, + } + if key is not None: + content["m.relates_to"]["key"] = key + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", + content, + access_token=access_token, + ) + return channel + + +class RelationsTestCase(BaseRelationsTestCase): + def test_send_relation(self) -> None: + """Tests that sending a relation works.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) @@ -79,7 +130,7 @@ def test_send_relation(self) -> None: channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, event_id), + f"/rooms/{self.room}/event/{event_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -317,9 +368,7 @@ def test_pagination_from_sync_and_messages(self) -> None: # Request /sync, limiting it such that only the latest event is returned # (and not the relation). - filter = urllib.parse.quote_plus( - '{"room": {"timeline": {"limit": 1}}}'.encode() - ) + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) @@ -404,8 +453,7 @@ def test_aggregation_pagination_groups(self) -> None: channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -544,8 +592,7 @@ def test_aggregation(self) -> None: channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -560,47 +607,13 @@ def test_aggregation(self) -> None: }, ) - def test_aggregation_redactions(self) -> None: - """Test that annotations get correctly aggregated after a redaction.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Now lets redact one of the 'a' reactions - channel = self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), - access_token=self.user_token, - content={}, - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - def test_aggregation_must_be_annotation(self) -> None: """Test that aggregations must be annotations.""" channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" - % (self.room, self.parent_id, RelationTypes.REPLACE), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations" + f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", access_token=self.user_token, ) self.assertEqual(400, channel.code, channel.json_body) @@ -986,9 +999,7 @@ def assert_bundle(event_json: JsonDict) -> None: # Request sync, but limit the timeline so it becomes limited (and includes # bundled aggregations). - filter = urllib.parse.quote_plus( - '{"room": {"timeline": {"limit": 2}}}'.encode() - ) + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 2}}}') channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) @@ -1053,7 +1064,7 @@ def test_multi_edit(self) -> None: channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1096,7 +1107,7 @@ def test_edit_reply(self) -> None: channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, reply), + f"/rooms/{self.room}/event/{reply}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1198,7 +1209,7 @@ def test_edit_edit(self) -> None: # Request the original event. channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1217,102 +1228,6 @@ def test_edit_edit(self) -> None: {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_relations_redaction_redacts_edits(self) -> None: - """Test that edits of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - parent_id=original_event_id, - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check the relation is returned - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - # Redact the original event - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), - access_token=self.user_token, - content="{}", - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - - def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None: - """Test that annotations of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Hello!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Redact the original - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % ( - self.room, - original_event_id, - "test_aggregations_redaction_prevents_access_to_aggregations", - ), - access_token=self.user_token, - content="{}", - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that aggregations returns zero - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") @@ -1321,8 +1236,7 @@ def test_unknown_relations(self) -> None: channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1343,7 +1257,7 @@ def test_unknown_relations(self) -> None: # When bundling the unknown relation is not included. channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1352,8 +1266,7 @@ def test_unknown_relations(self) -> None: # But unknown relations can be directly queried. channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1369,58 +1282,6 @@ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: raise AssertionError(f"Event {self.parent_id} not found in chunk") - def _send_relation( - self, - relation_type: str, - event_type: str, - key: Optional[str] = None, - content: Optional[dict] = None, - access_token: Optional[str] = None, - parent_id: Optional[str] = None, - ) -> FakeChannel: - """Helper function to send a relation pointing at `self.parent_id` - - Args: - relation_type: One of `RelationTypes` - event_type: The type of the event to create - key: The aggregation key used for m.annotation relation type. - content: The content of the created event. Will be modified to configure - the m.relates_to key based on the other provided parameters. - access_token: The access token used to send the relation, defaults - to `self.user_token` - parent_id: The event_id this relation relates to. If None, then self.parent_id - - Returns: - FakeChannel - """ - if not access_token: - access_token = self.user_token - - original_id = parent_id if parent_id else self.parent_id - - if content is None: - content = {} - content["m.relates_to"] = { - "event_id": original_id, - "rel_type": relation_type, - } - if key is not None: - content["m.relates_to"]["key"] = key - - channel = self.make_request( - "POST", - f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", - content, - access_token=access_token, - ) - return channel - - def _create_user(self, localpart: str) -> Tuple[str, str]: - user_id = self.register_user(localpart, "abc123") - access_token = self.login(localpart, "abc123") - - return user_id, access_token - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -1482,3 +1343,112 @@ def test_background_update(self) -> None: [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good, thread_event_id], ) + + +class RelationRedactionTestCase(BaseRelationsTestCase): + """Test the behaviour of relations when the parent or child event is redacted.""" + + def _redact(self, event_id: str) -> None: + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}", + access_token=self.user_token, + content={}, + ) + self.assertEqual(200, channel.code, channel.json_body) + + def test_redact_relation_annotation(self) -> None: + """Test that annotations of an event are properly handled after the + annotation is redacted. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEqual(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Redact one of the reactions. + self._redact(to_redact_event_id) + + # Ensure that the aggregations are correct. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + def test_redact_relation_edit(self) -> None: + """Test that edits of an event are redacted when the original event + is redacted. + """ + # Add a relation + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + parent_id=self.parent_id, + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Check the relation is returned + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}/m.replace/m.room.message", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + # Redact the original event + self._redact(self.parent_id) + + # Try to check for remaining m.replace relations + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}/m.replace/m.room.message", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Check that no relations are returned + self.assertIn("chunk", channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + def test_redact_parent(self) -> None: + """Test that annotations of an event are redacted when the original event + is redacted. + """ + # Add a relation + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + self.assertEqual(200, channel.code, channel.json_body) + + # Redact the original event. + self._redact(self.parent_id) + + # Check that aggregations returns zero + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) From f3f0ab10fe766c766dedf9d80e4ef198e3e45c09 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 13:00:16 +0000 Subject: [PATCH 002/230] Move scripts directory inside synapse, exposing as setuptools entry_points (#12118) * Two scripts are basically entry_points already * Move and rename scripts/* to synapse/_scripts/*.py * Delete sync_room_to_group.pl * Expose entry points in setup.py * Update linter script and config * Fixup scripts & docs mentioning scripts that moved Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- .ci/scripts/test_export_data_command.sh | 4 +- .ci/scripts/test_synapse_port_db.sh | 12 ++--- .dockerignore | 1 - MANIFEST.in | 1 - changelog.d/12118.misc | 1 + docker/Dockerfile | 1 - docs/development/database_schema.md | 6 +-- docs/usage/administration/admin_api/README.md | 2 +- mypy.ini | 4 ++ scripts-dev/generate_sample_config | 10 ++--- scripts-dev/lint.sh | 7 --- scripts-dev/make_full_schema.sh | 6 +-- scripts/register_new_matrix_user | 19 -------- scripts/synapse_review_recent_signups | 19 -------- scripts/sync_room_to_group.pl | 45 ------------------- setup.py | 14 +++++- snap/snapcraft.yaml | 2 +- .../_scripts/export_signing_key.py | 7 ++- .../_scripts/generate_config.py | 7 ++- .../_scripts/generate_log_config.py | 7 ++- .../_scripts}/generate_signing_key.py | 7 ++- .../_scripts/hash_password.py | 12 +++-- .../move_remote_media_to_new_store.py | 2 +- .../_scripts/synapse_port_db.py | 6 ++- .../_scripts/update_synapse_database.py | 0 synapse/config/_base.py | 2 +- tox.ini | 8 ---- 27 files changed, 77 insertions(+), 135 deletions(-) create mode 100644 changelog.d/12118.misc delete mode 100755 scripts/register_new_matrix_user delete mode 100755 scripts/synapse_review_recent_signups delete mode 100755 scripts/sync_room_to_group.pl rename scripts/export_signing_key => synapse/_scripts/export_signing_key.py (99%) rename scripts/generate_config => synapse/_scripts/generate_config.py (98%) rename scripts/generate_log_config => synapse/_scripts/generate_log_config.py (98%) rename {scripts => synapse/_scripts}/generate_signing_key.py (97%) rename scripts/hash_password => synapse/_scripts/hash_password.py (96%) rename {scripts => synapse/_scripts}/move_remote_media_to_new_store.py (97%) rename scripts/synapse_port_db => synapse/_scripts/synapse_port_db.py (99%) rename scripts/update_synapse_database => synapse/_scripts/update_synapse_database.py (100%) diff --git a/.ci/scripts/test_export_data_command.sh b/.ci/scripts/test_export_data_command.sh index ab96387a0aef..224cae921658 100755 --- a/.ci/scripts/test_export_data_command.sh +++ b/.ci/scripts/test_export_data_command.sh @@ -21,7 +21,7 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml echo "--- Prepare test database" # Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates +update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates # Run the export-data command on the sqlite test database python -m synapse.app.admin_cmd -c .ci/sqlite-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \ @@ -41,7 +41,7 @@ fi # Port the SQLite databse to postgres so we can check command works against postgres echo "+++ Port SQLite3 databse to postgres" -scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml # Run the export-data command on postgres database python -m synapse.app.admin_cmd -c .ci/postgres-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \ diff --git a/.ci/scripts/test_synapse_port_db.sh b/.ci/scripts/test_synapse_port_db.sh index 797904e64ca5..91bd966f32bd 100755 --- a/.ci/scripts/test_synapse_port_db.sh +++ b/.ci/scripts/test_synapse_port_db.sh @@ -25,17 +25,19 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml echo "--- Prepare test database" # Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates +update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates # Create the PostgreSQL database. .ci/scripts/postgres_exec.py "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against test database" -coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +# TODO: this invocation of synapse_port_db (and others below) used to be prepended with `coverage run`, +# but coverage seems unable to find the entrypoints installed by `pip install -e .`. +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml # We should be able to run twice against the same database. echo "+++ Run synapse_port_db a second time" -coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml ##### @@ -46,7 +48,7 @@ echo "--- Prepare empty SQLite database" # we do this by deleting the sqlite db, and then doing the same again. rm .ci/test_db.db -scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates +update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates # re-create the PostgreSQL database. .ci/scripts/postgres_exec.py \ @@ -54,4 +56,4 @@ scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-b "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against empty database" -coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml diff --git a/.dockerignore b/.dockerignore index f6c638b0a221..617f7015971b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,7 +3,6 @@ # things to include !docker -!scripts !synapse !MANIFEST.in !README.rst diff --git a/MANIFEST.in b/MANIFEST.in index 76d14eb642b2..7e903518e152 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -17,7 +17,6 @@ recursive-include synapse/storage *.txt recursive-include synapse/storage *.md recursive-include docs * -recursive-include scripts * recursive-include scripts-dev * recursive-include synapse *.pyi recursive-include tests *.py diff --git a/changelog.d/12118.misc b/changelog.d/12118.misc new file mode 100644 index 000000000000..a2c397d90755 --- /dev/null +++ b/changelog.d/12118.misc @@ -0,0 +1 @@ +Move scripts to Synapse package and expose as setuptools entry points. diff --git a/docker/Dockerfile b/docker/Dockerfile index a8bb9b0e7f7f..327275a9cae6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -46,7 +46,6 @@ RUN \ && rm -rf /var/lib/apt/lists/* # Copy just what we need to pip install -COPY scripts /synapse/scripts/ COPY MANIFEST.in README.rst setup.py synctl /synapse/ COPY synapse/__init__.py /synapse/synapse/__init__.py COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py diff --git a/docs/development/database_schema.md b/docs/development/database_schema.md index a767d3af9fd3..d996a7caa2c6 100644 --- a/docs/development/database_schema.md +++ b/docs/development/database_schema.md @@ -158,9 +158,9 @@ same as integers. There are three separate aspects to this: * Any new boolean column must be added to the `BOOLEAN_COLUMNS` list in - `scripts/synapse_port_db`. This tells the port script to cast the integer - value from SQLite to a boolean before writing the value to the postgres - database. + `synapse/_scripts/synapse_port_db.py`. This tells the port script to cast + the integer value from SQLite to a boolean before writing the value to the + postgres database. * Before SQLite 3.23, `TRUE` and `FALSE` were not recognised as constants by SQLite, and the `IS [NOT] TRUE`/`IS [NOT] FALSE` operators were not diff --git a/docs/usage/administration/admin_api/README.md b/docs/usage/administration/admin_api/README.md index 2fca96f8be4e..3cbedc5dfa30 100644 --- a/docs/usage/administration/admin_api/README.md +++ b/docs/usage/administration/admin_api/README.md @@ -12,7 +12,7 @@ UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; ``` A new server admin user can also be created using the `register_new_matrix_user` -command. This is a script that is located in the `scripts/` directory, or possibly +command. This is a script that is distributed as part of synapse. It is possibly already on your `$PATH` depending on how Synapse was installed. Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings. diff --git a/mypy.ini b/mypy.ini index 38ff78760931..6b1e995e64e8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,6 +23,10 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( + |synapse/_scripts/export_signing_key.py + |synapse/_scripts/move_remote_media_to_new_store.py + |synapse/_scripts/synapse_port_db.py + |synapse/_scripts/update_synapse_database.py |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/cache.py diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config index 4cd1d1d5b829..185e277933e3 100755 --- a/scripts-dev/generate_sample_config +++ b/scripts-dev/generate_sample_config @@ -10,19 +10,19 @@ SAMPLE_CONFIG="docs/sample_config.yaml" SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml" check() { - diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || return 1 + diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || return 1 } if [ "$1" == "--check" ]; then - diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || { + diff -u "$SAMPLE_CONFIG" <(synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml) >/dev/null || { echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 exit 1 } - diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || { + diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || { echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 exit 1 } else - ./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" - ./scripts/generate_log_config -o "$SAMPLE_LOG_CONFIG" + synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" + synapse/_scripts/generate_log_config.py -o "$SAMPLE_LOG_CONFIG" fi diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index b6554a73c115..df4d4934d06f 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -84,13 +84,6 @@ else files=( "synapse" "docker" "tests" # annoyingly, black doesn't find these so we have to list them - "scripts/export_signing_key" - "scripts/generate_config" - "scripts/generate_log_config" - "scripts/hash_password" - "scripts/register_new_matrix_user" - "scripts/synapse_port_db" - "scripts/update_synapse_database" "scripts-dev" "scripts-dev/build_debian_packages" "scripts-dev/sign_json" diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh index c3c90f4ec637..f0e22d4ca25b 100755 --- a/scripts-dev/make_full_schema.sh +++ b/scripts-dev/make_full_schema.sh @@ -147,7 +147,7 @@ python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG" # Make sure the SQLite3 database is using the latest schema and has no pending background update. echo "Running db background jobs..." -scripts/update_synapse_database --database-config --run-background-updates "$SQLITE_CONFIG" +synapse/_scripts/update_synapse_database.py --database-config --run-background-updates "$SQLITE_CONFIG" # Create the PostgreSQL database. echo "Creating postgres database..." @@ -156,10 +156,10 @@ createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_DB_NAME" echo "Copying data from SQLite3 to Postgres with synapse_port_db..." if [ -z "$COVERAGE" ]; then # No coverage needed - scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" + synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" else # Coverage desired - coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" + coverage run synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" fi # Delete schema_version, applied_schema_deltas and applied_module_schemas tables diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user deleted file mode 100755 index 00104b9d62cb..000000000000 --- a/scripts/register_new_matrix_user +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse._scripts.register_new_matrix_user import main - -if __name__ == "__main__": - main() diff --git a/scripts/synapse_review_recent_signups b/scripts/synapse_review_recent_signups deleted file mode 100755 index a36d46e14cde..000000000000 --- a/scripts/synapse_review_recent_signups +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse._scripts.review_recent_signups import main - -if __name__ == "__main__": - main() diff --git a/scripts/sync_room_to_group.pl b/scripts/sync_room_to_group.pl deleted file mode 100755 index f0c2dfadfa11..000000000000 --- a/scripts/sync_room_to_group.pl +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env perl - -use strict; -use warnings; - -use JSON::XS; -use LWP::UserAgent; -use URI::Escape; - -if (@ARGV < 4) { - die "usage: $0 \n"; -} - -my ($hs, $access_token, $room_id, $group_id) = @ARGV; -my $ua = LWP::UserAgent->new(); -$ua->timeout(10); - -if ($room_id =~ /^#/) { - $room_id = uri_escape($room_id); - $room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id}; -} - -my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ]; -my $group_users = [ - (map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}), - (map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}), -]; - -die "refusing to sync from empty room" unless (@$room_users); -die "refusing to sync to empty group" unless (@$group_users); - -my $diff = {}; -foreach my $user (@$room_users) { $diff->{$user}++ } -foreach my $user (@$group_users) { $diff->{$user}-- } - -foreach my $user (keys %$diff) { - if ($diff->{$user} == 1) { - warn "inviting $user"; - print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n"; - } - elsif ($diff->{$user} == -1) { - warn "removing $user"; - print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n"; - } -} diff --git a/setup.py b/setup.py index 26f4650348d5..318df16766ec 100755 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import glob import os from typing import Any, Dict @@ -153,8 +152,19 @@ def exec_file(path_segments): python_requires="~=3.7", entry_points={ "console_scripts": [ + # Application "synapse_homeserver = synapse.app.homeserver:main", "synapse_worker = synapse.app.generic_worker:main", + # Scripts + "export_signing_key = synapse._scripts.export_signing_key:main", + "generate_config = synapse._scripts.generate_config:main", + "generate_log_config = synapse._scripts.generate_log_config:main", + "generate_signing_key = synapse._scripts.generate_signing_key:main", + "hash_password = synapse._scripts.hash_password:main", + "register_new_matrix_user = synapse._scripts.register_new_matrix_user:main", + "synapse_port_db = synapse._scripts.synapse_port_db:main", + "synapse_review_recent_signups = synapse._scripts.review_recent_signups:main", + "update_synapse_database = synapse._scripts.update_synapse_database:main", ] }, classifiers=[ @@ -167,6 +177,6 @@ def exec_file(path_segments): "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", ], - scripts=["synctl"] + glob.glob("scripts/*"), + scripts=["synctl"], cmdclass={"test": TestCommand}, ) diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml index 9a01152c156b..dd4c8478d59d 100644 --- a/snap/snapcraft.yaml +++ b/snap/snapcraft.yaml @@ -20,7 +20,7 @@ apps: generate-config: command: generate_config generate-signing-key: - command: generate_signing_key.py + command: generate_signing_key register-new-matrix-user: command: register_new_matrix_user plugs: [network] diff --git a/scripts/export_signing_key b/synapse/_scripts/export_signing_key.py similarity index 99% rename from scripts/export_signing_key rename to synapse/_scripts/export_signing_key.py index bf0139bd64b8..3d254348f165 100755 --- a/scripts/export_signing_key +++ b/synapse/_scripts/export_signing_key.py @@ -50,7 +50,7 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): ) -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -85,7 +85,6 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): else format_plain ) - keys = [] for file in args.key_file: try: res = read_signing_keys(file) @@ -98,3 +97,7 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): res = [] for key in res: formatter(get_verify_key(key)) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_config b/synapse/_scripts/generate_config.py similarity index 98% rename from scripts/generate_config rename to synapse/_scripts/generate_config.py index 931b40c045f3..75fce20b1234 100755 --- a/scripts/generate_config +++ b/synapse/_scripts/generate_config.py @@ -6,7 +6,8 @@ from synapse.config.homeserver import HomeServerConfig -if __name__ == "__main__": + +def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config-dir", @@ -76,3 +77,7 @@ shutil.copyfileobj(args.header_file, args.output_file) args.output_file.write(conf) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_log_config b/synapse/_scripts/generate_log_config.py similarity index 98% rename from scripts/generate_log_config rename to synapse/_scripts/generate_log_config.py index e72a0dafb769..82fc7631401a 100755 --- a/scripts/generate_log_config +++ b/synapse/_scripts/generate_log_config.py @@ -19,7 +19,8 @@ from synapse.config.logger import DEFAULT_LOG_CONFIG -if __name__ == "__main__": + +def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -42,3 +43,7 @@ out = args.output_file out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) out.flush() + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_signing_key.py b/synapse/_scripts/generate_signing_key.py similarity index 97% rename from scripts/generate_signing_key.py rename to synapse/_scripts/generate_signing_key.py index 07df25a8099d..bc26d25bfd36 100755 --- a/scripts/generate_signing_key.py +++ b/synapse/_scripts/generate_signing_key.py @@ -19,7 +19,8 @@ from synapse.util.stringutils import random_string -if __name__ == "__main__": + +def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -34,3 +35,7 @@ key_id = "a_" + random_string(4) key = (generate_signing_key(key_id),) write_signing_keys(args.output_file, key) + + +if __name__ == "__main__": + main() diff --git a/scripts/hash_password b/synapse/_scripts/hash_password.py similarity index 96% rename from scripts/hash_password rename to synapse/_scripts/hash_password.py index 1d6fb0d70022..708640c7de8d 100755 --- a/scripts/hash_password +++ b/synapse/_scripts/hash_password.py @@ -8,9 +8,6 @@ import bcrypt import yaml -bcrypt_rounds = 12 -password_pepper = "" - def prompt_for_pass(): password = getpass.getpass("Password: ") @@ -26,7 +23,10 @@ def prompt_for_pass(): return password -if __name__ == "__main__": +def main(): + bcrypt_rounds = 12 + password_pepper = "" + parser = argparse.ArgumentParser( description=( "Calculate the hash of a new password, so that passwords can be reset" @@ -77,3 +77,7 @@ def prompt_for_pass(): ).decode("ascii") print(hashed) + + +if __name__ == "__main__": + main() diff --git a/scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py similarity index 97% rename from scripts/move_remote_media_to_new_store.py rename to synapse/_scripts/move_remote_media_to_new_store.py index 875aa4781f49..9667d95dfe44 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -28,7 +28,7 @@ To use, pipe the above into:: - PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py + PYTHON_PATH=. synapse/_scripts/move_remote_media_to_new_store.py """ import argparse diff --git a/scripts/synapse_port_db b/synapse/_scripts/synapse_port_db.py similarity index 99% rename from scripts/synapse_port_db rename to synapse/_scripts/synapse_port_db.py index db354b3c8c5c..c38666da18e6 100755 --- a/scripts/synapse_port_db +++ b/synapse/_scripts/synapse_port_db.py @@ -1146,7 +1146,7 @@ def set_state(self, state): ############################################## -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser( description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." @@ -1251,3 +1251,7 @@ def run(): sys.stderr.write(end_error) sys.exit(5) + + +if __name__ == "__main__": + main() diff --git a/scripts/update_synapse_database b/synapse/_scripts/update_synapse_database.py similarity index 100% rename from scripts/update_synapse_database rename to synapse/_scripts/update_synapse_database.py diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1265738dc12e..8e19e2fc2668 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -383,7 +383,7 @@ def generate_config( Build a default configuration file This is used when the user explicitly asks us to generate a config file - (eg with --generate_config). + (eg with --generate-config). Args: config_dir_path: The path where the config files are kept. Used to diff --git a/tox.ini b/tox.ini index 04b972e2c58c..8d6aa7580bb8 100644 --- a/tox.ini +++ b/tox.ini @@ -38,15 +38,7 @@ lint_targets = setup.py synapse tests - scripts # annoyingly, black doesn't find these so we have to list them - scripts/export_signing_key - scripts/generate_config - scripts/generate_log_config - scripts/hash_password - scripts/register_new_matrix_user - scripts/synapse_port_db - scripts/update_synapse_database scripts-dev scripts-dev/build_debian_packages scripts-dev/sign_json From 1103c5fe8a795eafc4aeedc547faa1b68d5a12f5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Mar 2022 08:18:51 -0500 Subject: [PATCH 003/230] Check if instances are lists, not sequences. (#12128) As a str is a sequence, the checks were not granular enough and would allow lists or strings, when only lists were valid. --- changelog.d/12128.misc | 1 + synapse/federation/federation_client.py | 8 ++++---- synapse/handlers/room_summary.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/12128.misc diff --git a/changelog.d/12128.misc b/changelog.d/12128.misc new file mode 100644 index 000000000000..0570a8e3272f --- /dev/null +++ b/changelog.d/12128.misc @@ -0,0 +1 @@ +Fix data validation to compare to lists, not sequences. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 64e595e748f4..467275b98c3c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1428,7 +1428,7 @@ async def send_request( # Validate children_state of the room. children_state = room.pop("children_state", []) - if not isinstance(children_state, Sequence): + if not isinstance(children_state, list): raise InvalidResponseError("'room.children_state' must be a list") if any(not isinstance(e, dict) for e in children_state): raise InvalidResponseError("Invalid event in 'children_state' list") @@ -1440,14 +1440,14 @@ async def send_request( # Validate the children rooms. children = res.get("children", []) - if not isinstance(children, Sequence): + if not isinstance(children, list): raise InvalidResponseError("'children' must be a list") if any(not isinstance(r, dict) for r in children): raise InvalidResponseError("Invalid room in 'children' list") # Validate the inaccessible children. inaccessible_children = res.get("inaccessible_children", []) - if not isinstance(inaccessible_children, Sequence): + if not isinstance(inaccessible_children, list): raise InvalidResponseError("'inaccessible_children' must be a list") if any(not isinstance(r, str) for r in inaccessible_children): raise InvalidResponseError( @@ -1630,7 +1630,7 @@ def _validate_hierarchy_event(d: JsonDict) -> None: raise ValueError("Invalid event: 'content' must be a dict") via = content.get("via") - if not isinstance(via, Sequence): + if not isinstance(via, list): raise ValueError("Invalid event: 'via' must be a list") if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 55c2cbdba8bc..3979cbba71bd 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -857,7 +857,7 @@ def as_json(self) -> JsonDict: def _has_valid_via(e: EventBase) -> bool: via = e.content.get("via") - if not via or not isinstance(via, Sequence): + if not via or not isinstance(via, list): return False for v in via: if not isinstance(v, str): From b4461e7d8ab6cfe150f39f62aa68f7f13ef97a24 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 2 Mar 2022 16:11:16 +0000 Subject: [PATCH 004/230] Enable complexity checking in complexity checking docs example (#11998) --- changelog.d/11998.doc | 1 + ...nning_synapse_on_single_board_computers.md | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) create mode 100644 changelog.d/11998.doc diff --git a/changelog.d/11998.doc b/changelog.d/11998.doc new file mode 100644 index 000000000000..33ab7b7880be --- /dev/null +++ b/changelog.d/11998.doc @@ -0,0 +1 @@ +Fix complexity checking config example in [Resource Constrained Devices](https://matrix-org.github.io/synapse/v1.54/other/running_synapse_on_single_board_computers.html) docs page. \ No newline at end of file diff --git a/docs/other/running_synapse_on_single_board_computers.md b/docs/other/running_synapse_on_single_board_computers.md index ea14afa8b2df..dcf96f0056ba 100644 --- a/docs/other/running_synapse_on_single_board_computers.md +++ b/docs/other/running_synapse_on_single_board_computers.md @@ -31,28 +31,29 @@ Anything that requires modifying the device list [#7721](https://github.com/matr Put the below in a new file at /etc/matrix-synapse/conf.d/sbc.yaml to override the defaults in homeserver.yaml. ``` -# Set to false to disable presence tracking on this homeserver. +# Disable presence tracking, which is currently fairly resource intensive +# More info: https://github.com/matrix-org/synapse/issues/9478 use_presence: false -# When this is enabled, the room "complexity" will be checked before a user -# joins a new remote room. If it is above the complexity limit, the server will -# disallow joining, or will instantly leave. +# Set a small complexity limit, preventing users from joining large rooms +# which may be resource-intensive to remain a part of. +# +# Note that this will not prevent users from joining smaller rooms that +# eventually become complex. limit_remote_rooms: - # Uncomment to enable room complexity checking. - #enabled: true + enabled: true complexity: 3.0 # Database configuration database: + # Use postgres for the best performance name: psycopg2 args: user: matrix-synapse - # Generate a long, secure one with a password manager + # Generate a long, secure password using a password manager password: hunter2 database: matrix-synapse host: localhost - cp_min: 5 - cp_max: 10 ``` Currently the complexity is measured by [current_state_events / 500](https://github.com/matrix-org/synapse/blob/v1.20.1/synapse/storage/databases/main/events_worker.py#L986). You can find join times and your most complex rooms like this: From 2ffaf30803f93273a4d8a65c9e6c3110c8433488 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 2 Mar 2022 17:34:14 +0100 Subject: [PATCH 005/230] Add type hints to `tests/rest/client` (#12108) * Add type hints to `tests/rest/client` * newsfile * fix imports * add `test_account.py` * Remove one type hint in `test_report_event.py` * change `on_create_room` to `async` * update new functions in `test_third_party_rules.py` * Add `test_filter.py` * add `test_rooms.py` * change to `assertEquals` to `assertEqual` * lint --- changelog.d/12108.misc | 1 + mypy.ini | 6 - tests/rest/client/test_account.py | 290 +++++++++++--------- tests/rest/client/test_filter.py | 29 +- tests/rest/client/test_relations.py | 4 +- tests/rest/client/test_report_event.py | 25 +- tests/rest/client/test_rooms.py | 271 +++++++++--------- tests/rest/client/test_third_party_rules.py | 108 +++++--- tests/rest/client/test_typing.py | 41 +-- 9 files changed, 423 insertions(+), 352 deletions(-) create mode 100644 changelog.d/12108.misc diff --git a/changelog.d/12108.misc b/changelog.d/12108.misc new file mode 100644 index 000000000000..0360dbd61edc --- /dev/null +++ b/changelog.d/12108.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/mypy.ini b/mypy.ini index 6b1e995e64e8..23ca4eaa5a8b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -78,13 +78,7 @@ exclude = (?x) |tests/push/test_http.py |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py - |tests/rest/client/test_account.py - |tests/rest/client/test_filter.py - |tests/rest/client/test_report_event.py - |tests/rest/client/test_rooms.py - |tests/rest/client/test_third_party_rules.py |tests/rest/client/test_transactions.py - |tests/rest/client/test_typing.py |tests/rest/key/v2/test_remote_key_resource.py |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 6c4462e74a53..def836054db1 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -15,11 +15,12 @@ import os import re from email.parser import Parser -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock import pkg_resources +from twisted.internet.interfaces import IReactorTCP from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +31,7 @@ from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -67,20 +69,27 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) + + self.email_attempts: List[bytes] = [] hs.get_send_email_handler()._sendmail = sendmail return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) - def test_basic_password_reset(self): + def test_basic_password_reset(self) -> None: """Test basic password reset flow""" old_password = "monkey" new_password = "kangeroo" @@ -118,7 +127,7 @@ def test_basic_password_reset(self): self.attempt_wrong_password_login("kermit", old_password) @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_email(self): + def test_ratelimit_by_email(self) -> None: """Test that we ratelimit /requestToken for the same email.""" old_password = "monkey" new_password = "kangeroo" @@ -139,7 +148,7 @@ def test_ratelimit_by_email(self): ) ) - def reset(ip): + def reset(ip: str) -> None: client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) @@ -166,7 +175,7 @@ def reset(ip): self.assertEqual(cm.exception.code, 429) - def test_basic_password_reset_canonicalise_email(self): + def test_basic_password_reset_canonicalise_email(self) -> None: """Test basic password reset flow Request password reset with different spelling """ @@ -206,7 +215,7 @@ def test_basic_password_reset_canonicalise_email(self): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) - def test_cant_reset_password_without_clicking_link(self): + def test_cant_reset_password_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" old_password = "monkey" new_password = "kangeroo" @@ -241,7 +250,7 @@ def test_cant_reset_password_without_clicking_link(self): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -277,7 +286,7 @@ def test_no_valid_token(self): self.attempt_wrong_password_login("kermit", new_password) @unittest.override_config({"request_token_inhibit_3pid_errors": True}) - def test_password_reset_bad_email_inhibit_error(self): + def test_password_reset_bad_email_inhibit_error(self) -> None: """Test that triggering a password reset with an email address that isn't bound to an account doesn't leak the lack of binding for that address if configured that way. @@ -292,7 +301,12 @@ def test_password_reset_bad_email_inhibit_error(self): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret, ip="127.0.0.1"): + def _request_token( + self, + email: str, + client_secret: str, + ip: str = "127.0.0.1", + ) -> str: channel = self.make_request( "POST", b"account/password/email/requestToken", @@ -309,7 +323,7 @@ def _request_token(self, email, client_secret, ip="127.0.0.1"): return channel.json_body["sid"] - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") @@ -339,7 +353,7 @@ def _validate_token(self, link): ) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -354,14 +368,19 @@ def _get_link_from_email(self): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) def _reset_password( - self, new_password, session_id, client_secret, expected_code=200 - ): + self, + new_password: str, + session_id: str, + client_secret: str, + expected_code: int = 200, + ) -> None: channel = self.make_request( "POST", b"account/password", @@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def test_deactivate_account(self): + def test_deactivate_account(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -407,7 +426,7 @@ def test_deactivate_account(self): channel = self.make_request("GET", "account/whoami", access_token=tok) self.assertEqual(channel.code, 401) - def test_pending_invites(self): + def test_pending_invites(self) -> None: """Tests that deactivating a user rejects every pending invite for them.""" store = self.hs.get_datastores().main @@ -448,7 +467,7 @@ def test_pending_invites(self): self.assertEqual(len(memberships), 1, memberships) self.assertEqual(memberships[0].room_id, room_id, memberships) - def deactivate(self, user_id, tok): + def deactivate(self, user_id: str, tok: str) -> None: request_data = json.dumps( { "auth": { @@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_GET_whoami(self): + def test_GET_whoami(self) -> None: device_id = "wouldgohere" user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test", device_id=device_id) @@ -496,7 +515,7 @@ def test_GET_whoami(self): }, ) - def test_GET_whoami_guests(self): + def test_GET_whoami_guests(self) -> None: channel = self.make_request( b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" ) @@ -516,7 +535,7 @@ def test_GET_whoami_guests(self): }, ) - def test_GET_whoami_appservices(self): + def test_GET_whoami_appservices(self) -> None: user_id = "@as:test" as_token = "i_am_an_app_service" @@ -541,7 +560,7 @@ def test_GET_whoami_appservices(self): ) self.assertFalse(hasattr(whoami, "device_id")) - def _whoami(self, tok): + def _whoami(self, tok: str) -> JsonDict: channel = self.make_request("GET", "account/whoami", {}, access_token=tok) self.assertEqual(channel.code, 200) return channel.json_body @@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -576,16 +595,23 @@ def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) + + self.email_attempts: List[bytes] = [] self.hs.get_send_email_handler()._sendmail = sendmail return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") @@ -593,83 +619,73 @@ def prepare(self, reactor, clock, hs): self.email = "test@example.com" self.url_3pid = b"account/3pid" - def test_add_valid_email(self): - self.get_success(self._add_email(self.email, self.email)) + def test_add_valid_email(self) -> None: + self._add_email(self.email, self.email) - def test_add_valid_email_second_time(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - self.email, - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_valid_email_second_time_canonicalise(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - "TEST@EXAMPLE.COM", - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time_canonicalise(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_email_no_at(self): - self.get_success( - self._request_token_invalid_email( - "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_no_at(self) -> None: + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_two_at(self): - self.get_success( - self._request_token_invalid_email( - "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_two_at(self) -> None: + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_bad_format(self): - self.get_success( - self._request_token_invalid_email( - "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_bad_format(self) -> None: + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_domain_to_lower(self): - self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + def test_add_email_domain_to_lower(self) -> None: + self._add_email("foo@TEST.BAR", "foo@test.bar") - def test_add_email_domain_with_umlaut(self): - self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + def test_add_email_domain_with_umlaut(self) -> None: + self._add_email("foo@Öumlaut.com", "foo@öumlaut.com") - def test_add_email_address_casefold(self): - self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + def test_add_email_address_casefold(self) -> None: + self._add_email("Strauß@Example.com", "strauss@example.com") - def test_address_trim(self): - self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + def test_address_trim(self) -> None: + self._add_email(" foo@test.bar ", "foo@test.bar") @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_ip(self): + def test_ratelimit_by_ip(self) -> None: """Tests that adding emails is ratelimited by IP""" # We expect to be able to set three emails before getting ratelimited. - self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) - self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) - self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + self._add_email("foo1@test.bar", "foo1@test.bar") + self._add_email("foo2@test.bar", "foo2@test.bar") + self._add_email("foo3@test.bar", "foo3@test.bar") with self.assertRaises(HttpResponseException) as cm: - self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + self._add_email("foo4@test.bar", "foo4@test.bar") self.assertEqual(cm.exception.code, 429) - def test_add_email_if_disabled(self): + def test_add_email_if_disabled(self) -> None: """Test adding email to profile when doing so is disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -695,7 +711,7 @@ def test_add_email_if_disabled(self): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -705,10 +721,10 @@ def test_add_email_if_disabled(self): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email(self): + def test_delete_email(self) -> None: """Test deleting an email from profile""" # Add a threepid self.get_success( @@ -727,7 +743,7 @@ def test_delete_email(self): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -736,10 +752,10 @@ def test_delete_email(self): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email_if_disabled(self): + def test_delete_email_if_disabled(self) -> None: """Test deleting an email from profile when disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -761,7 +777,7 @@ def test_delete_email_if_disabled(self): access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -771,11 +787,11 @@ def test_delete_email_if_disabled(self): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - def test_cant_add_email_without_clicking_link(self): + def test_cant_add_email_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -797,7 +813,7 @@ def test_cant_add_email_without_clicking_link(self): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -807,10 +823,10 @@ def test_cant_add_email_without_clicking_link(self): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -832,7 +848,7 @@ def test_no_valid_token(self): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -842,11 +858,11 @@ def test_no_valid_token(self): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) - def test_next_link(self): + def test_next_link(self) -> None: """Tests a valid next_link parameter value with no whitelist (good case)""" self._request_token( "something@example.com", @@ -856,7 +872,7 @@ def test_next_link(self): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_exotic_protocol(self): + def test_next_link_exotic_protocol(self) -> None: """Tests using a esoteric protocol as a next_link parameter value. Someone may be hosting a client on IPFS etc. """ @@ -868,7 +884,7 @@ def test_next_link_exotic_protocol(self): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_file_uri(self): + def test_next_link_file_uri(self) -> None: """Tests next_link parameters cannot be file URI""" # Attempt to use a next_link value that points to the local disk self._request_token( @@ -879,7 +895,7 @@ def test_next_link_file_uri(self): ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) - def test_next_link_domain_whitelist(self): + def test_next_link_domain_whitelist(self) -> None: """Tests next_link parameters must fit the whitelist if provided""" # Ensure not providing a next_link parameter still works @@ -912,7 +928,7 @@ def test_next_link_domain_whitelist(self): ) @override_config({"next_link_domain_whitelist": []}) - def test_empty_next_link_domain_whitelist(self): + def test_empty_next_link_domain_whitelist(self) -> None: """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially disallowed """ @@ -962,28 +978,28 @@ def _request_token( def _request_token_invalid_email( self, - email, - expected_errcode, - expected_error, - client_secret="foobar", - ): + email: str, + expected_errcode: str, + expected_error: str, + client_secret: str = "foobar", + ) -> None: channel = self.make_request( "POST", b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -998,12 +1014,13 @@ def _get_link_from_email(self): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) - def _add_email(self, request_email, expected_email): + def _add_email(self, request_email: str, expected_email: str) -> None: """Test adding an email to profile""" previous_email_attempts = len(self.email_attempts) @@ -1030,7 +1047,7 @@ def _add_email(self, request_email, expected_email): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1039,7 +1056,7 @@ def _add_email(self, request_email, expected_email): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["experimental_features"] = {"msc3720_enabled": True} return self.setup_test_homeserver(config=config) - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.requester = self.register_user("requester", "password") self.requester_tok = self.login("requester", "password") - self.server_name = homeserver.config.server.server_name + self.server_name = hs.config.server.server_name - def test_missing_mxid(self): + def test_missing_mxid(self) -> None: """Tests that not providing any MXID raises an error.""" self._test_status( users=None, @@ -1074,7 +1091,7 @@ def test_missing_mxid(self): expected_errcode=Codes.MISSING_PARAM, ) - def test_invalid_mxid(self): + def test_invalid_mxid(self) -> None: """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], @@ -1082,7 +1099,7 @@ def test_invalid_mxid(self): expected_errcode=Codes.INVALID_PARAM, ) - def test_local_user_not_exists(self): + def test_local_user_not_exists(self) -> None: """Tests that the account status endpoints correctly reports that a user doesn't exist. """ @@ -1098,7 +1115,7 @@ def test_local_user_not_exists(self): expected_failures=[], ) - def test_local_user_exists(self): + def test_local_user_exists(self) -> None: """Tests that the account status endpoint correctly reports that a user doesn't exist. """ @@ -1115,7 +1132,7 @@ def test_local_user_exists(self): expected_failures=[], ) - def test_local_user_deactivated(self): + def test_local_user_deactivated(self) -> None: """Tests that the account status endpoint correctly reports a deactivated user.""" user = self.register_user("someuser", "password") self.get_success( @@ -1135,7 +1152,7 @@ def test_local_user_deactivated(self): expected_failures=[], ) - def test_mixed_local_and_remote_users(self): + def test_mixed_local_and_remote_users(self) -> None: """Tests that if some users are remote the account status endpoint correctly merges the remote responses with the local result. """ @@ -1150,7 +1167,13 @@ def test_mixed_local_and_remote_users(self): "@bad:badremote", ] - async def post_json(destination, path, data, *a, **kwa): + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + *a: Any, + **kwa: Any, + ) -> Union[JsonDict, list]: if destination == "remote": return { "account_statuses": { @@ -1160,9 +1183,7 @@ async def post_json(destination, path, data, *a, **kwa): }, } } - if destination == "otherremote": - return {} - if destination == "badremote": + elif destination == "badremote": # badremote tries to overwrite the status of a user that doesn't belong # to it (i.e. users[1]) with false data, which Synapse is expected to # ignore. @@ -1176,6 +1197,9 @@ async def post_json(destination, path, data, *a, **kwa): }, } } + # if destination == "otherremote" + else: + return {} # Register a mock that will return the expected result depending on the remote. self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) @@ -1205,7 +1229,7 @@ def _test_status( expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, - ): + ) -> None: """Send a request to the account status endpoint and check that the response matches with what's expected. diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 5c31a54421df..823e8ab8c474 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest.client import filter +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase): EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' servlets = [filter.register_servlets] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filtering = hs.get_filtering() self.store = hs.get_datastores().main - def test_add_filter(self): + def test_add_filter(self) -> None: channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), @@ -43,11 +45,13 @@ def test_add_filter(self): self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) - filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + filter = self.get_success( + self.store.get_user_filter(user_localpart="apple", filter_id=0) + ) self.pump() - self.assertEqual(filter.result, self.EXAMPLE_FILTER) + self.assertEqual(filter, self.EXAMPLE_FILTER) - def test_add_filter_for_other_user(self): + def test_add_filter_for_other_user(self) -> None: channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), @@ -57,7 +61,7 @@ def test_add_filter_for_other_user(self): self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) - def test_add_filter_non_local_user(self): + def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False channel = self.make_request( @@ -70,14 +74,13 @@ def test_add_filter_non_local_user(self): self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) - def test_get_filter(self): - filter_id = defer.ensureDeferred( + def test_get_filter(self) -> None: + filter_id = self.get_success( self.filtering.add_user_filter( user_localpart="apple", user_filter=self.EXAMPLE_FILTER ) ) self.reactor.advance(1) - filter_id = filter_id.result channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) @@ -85,7 +88,7 @@ def test_get_filter(self): self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) - def test_get_filter_non_existant(self): + def test_get_filter_non_existant(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) @@ -95,7 +98,7 @@ def test_get_filter_non_existant(self): # Currently invalid params do not have an appropriate errcode # in errors.py - def test_get_filter_invalid_id(self): + def test_get_filter_invalid_id(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) @@ -103,7 +106,7 @@ def test_get_filter_invalid_id(self): self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error - def test_get_filter_no_id(self): + def test_get_filter_no_id(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a087cd7b2149..709f851a3874 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -45,7 +45,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): ] hijack_auth = False - def default_config(self) -> dict: + def default_config(self) -> Dict[str, Any]: # We need to enable msc1849 support for aggregations config = super().default_config() diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index ee6b0b9ebfb0..20a259fc4388 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -14,8 +14,13 @@ import json +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, report_event, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") @@ -42,35 +47,35 @@ def prepare(self, reactor, clock, hs): self.event_id = resp["event_id"] self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" - def test_reason_str_and_score_int(self): + def test_reason_str_and_score_int(self) -> None: data = {"reason": "this makes me sad", "score": -100} self._assert_status(200, data) - def test_no_reason(self): + def test_no_reason(self) -> None: data = {"score": 0} self._assert_status(200, data) - def test_no_score(self): + def test_no_score(self) -> None: data = {"reason": "this makes me sad"} self._assert_status(200, data) - def test_no_reason_and_no_score(self): - data = {} + def test_no_reason_and_no_score(self) -> None: + data: JsonDict = {} self._assert_status(200, data) - def test_reason_int_and_score_str(self): + def test_reason_int_and_score_str(self) -> None: data = {"reason": 10, "score": "string"} self._assert_status(400, data) - def test_reason_zero_and_score_blank(self): + def test_reason_zero_and_score_blank(self) -> None: data = {"reason": 0, "score": ""} self._assert_status(400, data) - def test_reason_and_score_null(self): + def test_reason_and_score_null(self) -> None: data = {"reason": None, "score": None} self._assert_status(400, data) - def _assert_status(self, response_status, data): + def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( "POST", self.report_path, diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e0b11e726433..37866ee330f3 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,11 +18,12 @@ """Tests REST events for /rooms paths.""" import json -from typing import Iterable, List +from typing import Any, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( @@ -35,7 +36,9 @@ from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync +from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -45,11 +48,11 @@ class RoomBase(unittest.HomeserverTestCase): - rmcreator_id = None + rmcreator_id: Optional[str] = None servlets = [room.register_servlets, room.register_deprecated_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver( "red", @@ -57,15 +60,15 @@ def make_homeserver(self, reactor, clock): federation_client=Mock(), ) - self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler = Mock() # type: ignore[assignment] self.hs.get_federation_handler.return_value.maybe_backfill = Mock( return_value=make_awaitable(None) ) - async def _insert_client_ip(*args, **kwargs): + async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - self.hs.get_datastores().main.insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] return self.hs @@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase): user_id = "@sid1:red" rmcreator_id = "@notme:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id @@ -108,12 +111,12 @@ def prepare(self, reactor, clock, hs): # auth as user_id now self.helper.auth_user_id = self.user_id - def test_can_do_action(self): + def test_can_do_action(self) -> None: msg_content = b'{"msgtype":"m.text","body":"hello"}' seq = iter(range(100)) - def send_msg_path(): + def send_msg_path() -> str: return "/rooms/%s/send/m.room.message/mid%s" % ( self.created_rmid, str(next(seq)), @@ -148,7 +151,7 @@ def send_msg_path(): channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_topic_perms(self): + def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid @@ -214,14 +217,14 @@ def test_topic_perms(self): self.assertEqual(403, channel.code, msg=channel.result["body"]) def _test_get_membership( - self, room=None, members: Iterable = frozenset(), expect_code=None - ): + self, room: str, members: Iterable = frozenset(), expect_code: int = 200 + ) -> None: for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) self.assertEqual(expect_code, channel.code) - def test_membership_basic_room_perms(self): + def test_membership_basic_room_perms(self) -> None: # === room does not exist === room = self.uncreated_rmid # get membership of self, get membership of other, uncreated room @@ -241,7 +244,7 @@ def test_membership_basic_room_perms(self): self.helper.join(room=room, user=usr, expect_code=404) self.helper.leave(room=room, user=usr, expect_code=404) - def test_membership_private_room_perms(self): + def test_membership_private_room_perms(self) -> None: room = self.created_rmid # get membership of self, get membership of other, private room + invite # expect all 403s @@ -264,7 +267,7 @@ def test_membership_private_room_perms(self): members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 ) - def test_membership_public_room_perms(self): + def test_membership_public_room_perms(self) -> None: room = self.created_public_rmid # get membership of self, get membership of other, public room + invite # expect 403 @@ -287,7 +290,7 @@ def test_membership_public_room_perms(self): members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 ) - def test_invited_permissions(self): + def test_invited_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) @@ -310,7 +313,7 @@ def test_invited_permissions(self): expect_code=403, ) - def test_joined_permissions(self): + def test_joined_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -348,7 +351,7 @@ def test_joined_permissions(self): # set left of self, expect 200 self.helper.leave(room=room, user=self.user_id) - def test_leave_permissions(self): + def test_leave_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -383,7 +386,7 @@ def test_leave_permissions(self): ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember - def test_member_event_from_ban(self): + def test_member_event_from_ban(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase): user_id = "@sid1:red" - def test_get_member_list(self): + def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(200, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_room(self): + def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission(self): + def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_with_at_token(self): + def test_get_member_list_no_permission_with_at_token(self) -> None: """ Tests that a stranger to the room cannot get the member list (in the case that they use an at token). @@ -509,7 +512,7 @@ def test_get_member_list_no_permission_with_at_token(self): ) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_former_member(self): + def test_get_member_list_no_permission_former_member(self) -> None: """ Tests that a former member of the room can not get the member list. """ @@ -529,7 +532,7 @@ def test_get_member_list_no_permission_former_member(self): channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_former_member_with_at_token(self): + def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ Tests that a former member of the room can not get the member list (in the case that they use an at token). @@ -569,7 +572,7 @@ def test_get_member_list_no_permission_former_member_with_at_token(self): ) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_mixed_memberships(self): + def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) room_path = "/rooms/%s/members" % room_id @@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase): user_id = "@sid1:red" - def test_post_room_no_keys(self): + def test_post_room_no_keys(self) -> None: # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") self.assertEqual(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) - def test_post_room_visibility_key(self): + def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_custom_key(self): + def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_known_and_unknown_keys(self): + def test_post_room_known_and_unknown_keys(self) -> None: # POST with custom + known config keys, expect new room id channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' @@ -621,7 +624,7 @@ def test_post_room_known_and_unknown_keys(self): self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_invalid_content(self): + def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') self.assertEqual(400, channel.code) @@ -629,7 +632,7 @@ def test_post_room_invalid_content(self): channel = self.make_request("POST", "/createRoom", b'["hello"]') self.assertEqual(400, channel.code) - def test_post_room_invitees_invalid_mxid(self): + def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 # Note the trailing space in the MXID here! channel = self.make_request( @@ -638,7 +641,7 @@ def test_post_room_invitees_invalid_mxid(self): self.assertEqual(400, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) - def test_post_room_invitees_ratelimit(self): + def test_post_room_invitees_ratelimit(self) -> None: """Test that invites sent when creating a room are ratelimited by a RateLimiter, which ratelimits them correctly, including by not limiting when the requester is exempt from ratelimiting. @@ -674,7 +677,7 @@ def test_post_room_invitees_ratelimit(self): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) - def test_spam_checker_may_join_room(self): + def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. """ @@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # create the room self.room_id = self.helper.create_room_as(self.user_id) self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") self.assertEqual(400, channel.code, msg=channel.result["body"]) @@ -736,7 +739,7 @@ def test_invalid_puts(self): channel = self.make_request("PUT", self.path, content) self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_topic(self): + def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) self.assertEqual(404, channel.code, msg=channel.result["body"]) @@ -751,7 +754,7 @@ def test_rooms_topic(self): self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) - def test_rooms_topic_with_extra_keys(self): + def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) @@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") @@ -801,7 +804,7 @@ def test_invalid_puts(self): channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_members_self(self): + def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), self.user_id, @@ -812,13 +815,13 @@ def test_rooms_members_self(self): channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) - def test_rooms_members_other(self): + def test_rooms_members_other(self) -> None: self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), @@ -830,11 +833,11 @@ def test_rooms_members_other(self): channel = self.make_request("PUT", path, content) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) - def test_rooms_members_other_custom_keys(self): + def test_rooms_members_other_custom_keys(self) -> None: self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), @@ -849,7 +852,7 @@ def test_rooms_members_other_custom_keys(self): channel = self.make_request("PUT", path, content) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invites_by_rooms_ratelimit(self): + def test_invites_by_rooms_ratelimit(self) -> None: """Tests that invites in a room are actually rate-limited.""" room_id = self.helper.create_room_as(self.user_id) @@ -878,7 +881,7 @@ def test_invites_by_rooms_ratelimit(self): @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invites_by_users_ratelimit(self): + def test_invites_by_users_ratelimit(self) -> None: """Tests that invites to a specific user are actually rate-limited.""" for _ in range(3): @@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user1 = self.register_user("thomas", "hackme") self.tok1 = self.login("thomas", "hackme") @@ -908,7 +911,7 @@ def prepare(self, reactor, clock, homeserver): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self): + def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. """ @@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): - super().prepare(reactor, clock, homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) # profile changes expect that the user is actually registered user = UserID.from_string(self.user_id) self.get_success(self.register_user(user.localpart, "supersecretpassword")) @@ -984,7 +987,7 @@ def prepare(self, reactor, clock, homeserver): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit(self): + def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" for _ in range(3): self.helper.create_room_as(self.user_id) @@ -994,7 +997,7 @@ def test_join_local_ratelimit(self): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit_profile_change(self): + def test_join_local_ratelimit_profile_change(self) -> None: """Tests that sending a profile update into all of the user's joined rooms isn't rate-limited by the rate-limiter on joins.""" @@ -1031,7 +1034,7 @@ def test_join_local_ratelimit_profile_change(self): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit_idempotent(self): + def test_join_local_ratelimit_idempotent(self) -> None: """Tests that the room join endpoints remain idempotent despite rate-limiting on room joins.""" room_id = self.helper.create_room_as(self.user_id) @@ -1056,7 +1059,7 @@ def test_join_local_ratelimit_idempotent(self): "autocreate_auto_join_rooms": True, }, ) - def test_autojoin_rooms(self): + def test_autojoin_rooms(self) -> None: user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms @@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") @@ -1095,7 +1098,7 @@ def test_invalid_puts(self): channel = self.make_request("PUT", path, b"") self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_messages_sent(self): + def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' @@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # create the room self.room_id = self.helper.create_room_as(self.user_id) - def test_initial_sync(self): + def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) self.assertEqual(200, channel.code) @@ -1131,7 +1134,7 @@ def test_initial_sync(self): self.assertEqual("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict - state = {} + state: JsonDict = {} for event in channel.json_body["state"]: if "state_key" not in event: continue @@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_topo_token_is_accepted(self): + def test_topo_token_is_accepted(self) -> None: token = "t1-0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) @@ -1174,7 +1177,7 @@ def test_topo_token_is_accepted(self): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_stream_token_is_accepted_for_fwd_pagianation(self): + def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: token = "s0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) @@ -1185,7 +1188,7 @@ def test_stream_token_is_accepted_for_fwd_pagianation(self): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_room_messages_purge(self): + def test_room_messages_purge(self) -> None: store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() @@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Register the user who does the searching - self.user_id = self.register_user("user", "pass") + self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") # Register the user who sends the message @@ -1289,12 +1292,12 @@ def prepare(self, reactor, clock, hs): self.other_access_token = self.login("otheruser", "pass") # Create a room - self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token) # Invite the other person self.helper.invite( room=self.room, - src=self.user_id, + src=self.user_id2, tok=self.access_token, targ=self.other_user_id, ) @@ -1304,7 +1307,7 @@ def prepare(self, reactor, clock, hs): room=self.room, user=self.other_user_id, tok=self.other_access_token ) - def test_finds_message(self): + def test_finds_message(self) -> None: """ The search functionality will search for content in messages if asked to do so. @@ -1333,7 +1336,7 @@ def test_finds_message(self): # No context was requested, so we should get none. self.assertEqual(results["results"][0]["context"], {}) - def test_include_context(self): + def test_include_context(self) -> None: """ When event_context includes include_profile, profile information will be included in the search response. @@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = b"/_matrix/client/r0/publicRooms" @@ -1389,11 +1392,11 @@ def make_homeserver(self, reactor, clock): return self.hs - def test_restricted_no_auth(self): + def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401, channel.result) - def test_restricted_auth(self): + def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") @@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(federation_client=Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.register_user("user", "pass") self.token = self.login("user", "pass") self.federation_client = hs.get_federation_client() - def test_simple(self): + def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = ( - lambda *a, **k: defer.succeed({}) + self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] + {} ) search_filter = {"generic_search_term": "foobar"} @@ -1437,7 +1440,7 @@ def test_simple(self): ) self.assertEqual(channel.code, 200, channel.result) - self.federation_client.get_public_rooms.assert_called_once_with( + self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", limit=100, since_token=None, @@ -1446,12 +1449,12 @@ def test_simple(self): third_party_instance_id=None, ) - def test_fallback(self): + def test_fallback(self) -> None: "Test that searching public rooms over federation falls back if it gets a 404" # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. - self.federation_client.get_public_rooms.side_effect = ( + self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(404, "Not Found", b""), defer.succeed({}), ) @@ -1466,7 +1469,7 @@ def test_fallback(self): ) self.assertEqual(channel.code, 200, channel.result) - self.federation_client.get_public_rooms.assert_has_calls( + self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ call( "testserv", @@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["allow_per_room_profiles"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") @@ -1522,7 +1525,7 @@ def prepare(self, reactor, clock, homeserver): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_per_room_profile_forbidden(self): + def test_per_room_profile_forbidden(self) -> None: data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) channel = self.make_request( @@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.creator = self.register_user("creator", "test") self.creator_tok = self.login("creator", "test") @@ -1566,7 +1569,7 @@ def prepare(self, reactor, clock, homeserver): self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) - def test_join_reason(self): + def test_join_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1578,7 +1581,7 @@ def test_join_reason(self): self._check_for_reason(reason) - def test_leave_reason(self): + def test_leave_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1592,7 +1595,7 @@ def test_leave_reason(self): self._check_for_reason(reason) - def test_kick_reason(self): + def test_kick_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1606,7 +1609,7 @@ def test_kick_reason(self): self._check_for_reason(reason) - def test_ban_reason(self): + def test_ban_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1620,7 +1623,7 @@ def test_ban_reason(self): self._check_for_reason(reason) - def test_unban_reason(self): + def test_unban_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1632,7 +1635,7 @@ def test_unban_reason(self): self._check_for_reason(reason) - def test_invite_reason(self): + def test_invite_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1644,7 +1647,7 @@ def test_invite_reason(self): self._check_for_reason(reason) - def test_reject_invite_reason(self): + def test_reject_invite_reason(self) -> None: self.helper.invite( self.room_id, src=self.creator, @@ -1663,7 +1666,7 @@ def test_reject_invite_reason(self): self._check_for_reason(reason) - def _check_for_reason(self, reason): + def _check_for_reason(self, reason: str) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( @@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): "org.matrix.not_labels": ["#notfun"], } - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_context_filter_labels(self): + def test_context_filter_labels(self) -> None: """Test that we can filter by a label on a /context request.""" event_id = self._send_labelled_messages_in_room() @@ -1739,7 +1742,7 @@ def test_context_filter_labels(self): events_after[0]["content"]["body"], "with right label", events_after[0] ) - def test_context_filter_not_labels(self): + def test_context_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /context request.""" event_id = self._send_labelled_messages_in_room() @@ -1772,7 +1775,7 @@ def test_context_filter_not_labels(self): events_after[1]["content"]["body"], "with two wrong labels", events_after[1] ) - def test_context_filter_labels_not_labels(self): + def test_context_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /context request. """ @@ -1801,7 +1804,7 @@ def test_context_filter_labels_not_labels(self): events_after[0]["content"]["body"], "with wrong label", events_after[0] ) - def test_messages_filter_labels(self): + def test_messages_filter_labels(self) -> None: """Test that we can filter by a label on a /messages request.""" self._send_labelled_messages_in_room() @@ -1818,7 +1821,7 @@ def test_messages_filter_labels(self): self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - def test_messages_filter_not_labels(self): + def test_messages_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /messages request.""" self._send_labelled_messages_in_room() @@ -1839,7 +1842,7 @@ def test_messages_filter_not_labels(self): events[3]["content"]["body"], "with two wrong labels", events[3] ) - def test_messages_filter_labels_not_labels(self): + def test_messages_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /messages request. """ @@ -1862,7 +1865,7 @@ def test_messages_filter_labels_not_labels(self): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - def test_search_filter_labels(self): + def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" request_data = json.dumps( { @@ -1899,7 +1902,7 @@ def test_search_filter_labels(self): results[1]["result"]["content"]["body"], ) - def test_search_filter_not_labels(self): + def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" request_data = json.dumps( { @@ -1946,7 +1949,7 @@ def test_search_filter_not_labels(self): results[3]["result"]["content"]["body"], ) - def test_search_filter_labels_not_labels(self): + def test_search_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /search request. """ @@ -1980,7 +1983,7 @@ def test_search_filter_labels_not_labels(self): results[0]["result"]["content"]["body"], ) - def _send_labelled_messages_in_room(self): + def _send_labelled_messages_in_room(self) -> str: """Sends several messages to a room with different labels (or without any) to test filtering by label. Returns: @@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["experimental_features"] = {"msc3440_enabled": True} return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) @@ -2136,7 +2139,7 @@ def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: return channel.json_body["chunk"] - def test_filter_relation_senders(self): + def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. filter = {"io.element.relation_senders": [self.second_user_id]} chunk = self._filter_messages(filter) @@ -2159,7 +2162,7 @@ def test_filter_relation_senders(self): [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] ) - def test_filter_relation_type(self): + def test_filter_relation_type(self) -> None: # Messages which have annotations. filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) @@ -2185,7 +2188,7 @@ def test_filter_relation_type(self): [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] ) - def test_filter_relation_senders_and_type(self): + def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { "io.element.relation_senders": [self.second_user_id], @@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.tok = self.login("user", "password") self.room_id = self.helper.create_room_as( @@ -2218,7 +2221,7 @@ def prepare(self, reactor, clock, homeserver): self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) - def test_erased_sender(self): + def test_erased_sender(self) -> None: """Test that an erasure request results in the requester's events being hidden from any new member of the room. """ @@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -2340,17 +2343,17 @@ def prepare(self, reactor, clock, homeserver): self.room_owner, tok=self.room_owner_tok ) - def test_no_aliases(self): + def test_no_aliases(self) -> None: res = self._get_aliases(self.room_owner_tok) self.assertEqual(res["aliases"], []) - def test_not_in_room(self): + def test_not_in_room(self) -> None: self.register_user("user", "test") user_tok = self.login("user", "test") res = self._get_aliases(user_tok, expected_code=403) self.assertEqual(res["errcode"], "M_FORBIDDEN") - def test_admin_user(self): + def test_admin_user(self) -> None: alias1 = self._random_alias() self._set_alias_via_directory(alias1) @@ -2360,7 +2363,7 @@ def test_admin_user(self): res = self._get_aliases(user_tok) self.assertEqual(res["aliases"], [alias1]) - def test_with_aliases(self): + def test_with_aliases(self) -> None: alias1 = self._random_alias() alias2 = self._random_alias() @@ -2370,7 +2373,7 @@ def test_with_aliases(self): res = self._get_aliases(self.room_owner_tok) self.assertEqual(set(res["aliases"]), {alias1, alias2}) - def test_peekable_room(self): + def test_peekable_room(self) -> None: alias1 = self._random_alias() self._set_alias_via_directory(alias1) @@ -2404,7 +2407,7 @@ def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: def _random_alias(self) -> str: return RoomAlias(random_string(5), self.hs.hostname).to_string() - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -2434,7 +2437,7 @@ def prepare(self, reactor, clock, homeserver): self.alias = "#alias:test" self._set_alias_via_directory(self.alias) - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -2456,7 +2459,9 @@ def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: self.assertIsInstance(res, dict) return res - def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + def _set_canonical_alias( + self, content: JsonDict, expected_code: int = 200 + ) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" channel = self.make_request( "PUT", @@ -2469,7 +2474,7 @@ def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDi self.assertIsInstance(res, dict) return res - def test_canonical_alias(self): + def test_canonical_alias(self) -> None: """Test a basic alias message.""" # There is no canonical alias to start with. self._get_canonical_alias(expected_code=404) @@ -2488,7 +2493,7 @@ def test_canonical_alias(self): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_alt_aliases(self): + def test_alt_aliases(self) -> None: """Test a canonical alias message with alt_aliases.""" # Create an alias. self._set_canonical_alias({"alt_aliases": [self.alias]}) @@ -2504,7 +2509,7 @@ def test_alt_aliases(self): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_alias_alt_aliases(self): + def test_alias_alt_aliases(self) -> None: """Test a canonical alias message with an alias and alt_aliases.""" # Create an alias. self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) @@ -2520,7 +2525,7 @@ def test_alias_alt_aliases(self): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_partial_modify(self): + def test_partial_modify(self) -> None: """Test removing only the alt_aliases.""" # Create an alias. self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) @@ -2536,7 +2541,7 @@ def test_partial_modify(self): res = self._get_canonical_alias() self.assertEqual(res, {"alias": self.alias}) - def test_add_alias(self): + def test_add_alias(self) -> None: """Test removing only the alt_aliases.""" # Create an additional alias. second_alias = "#second:test" @@ -2556,7 +2561,7 @@ def test_add_alias(self): res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} ) - def test_bad_data(self): + def test_bad_data(self) -> None: """Invalid data for alt_aliases should cause errors.""" self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": None}, expected_code=400) @@ -2566,7 +2571,7 @@ def test_bad_data(self): self._set_canonical_alias({"alt_aliases": True}, expected_code=400) self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) - def test_bad_alias(self): + def test_bad_alias(self) -> None: """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) @@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("thomas", "hackme") self.tok = self.login("thomas", "hackme") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self): + def test_threepid_invite_spamcheck(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we # can check its call_count later on during the test. diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index bfc04785b7b2..58f1ea11b7da 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError +from synapse.api.room_versions import RoomVersion from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room +from synapse.server import HomeServer from synapse.types import JsonDict, Requester, StateMap +from synapse.util import Clock from synapse.util.frozenutils import unfreeze from tests import unittest @@ -34,7 +40,7 @@ class LegacyThirdPartyRulesTestModule: - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: # keep a record of the "current" rules module, so that the test can patch # it if desired. thread_local.rules_module = self @@ -42,32 +48,36 @@ def __init__(self, config: Dict, module_api: "ModuleApi"): async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ): + ) -> bool: return True - async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + async def check_event_allowed( + self, event: EventBase, state: StateMap[EventBase] + ) -> Union[bool, dict]: return True @staticmethod - def parse_config(config): + def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: super().__init__(config, module_api) - def on_create_room( + async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ): + ) -> bool: return False class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: super().__init__(config, module_api) - async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + async def check_event_allowed( + self, event: EventBase, state: StateMap[EventBase] + ) -> JsonDict: d = event.get_dict() content = unfreeze(event.content) content["foo"] = "bar" @@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() load_legacy_third_party_event_rules(hs) @@ -94,22 +104,30 @@ def make_homeserver(self, reactor, clock): # Note that these checks are not relevant to this test case. # Have this homeserver auto-approve all event signature checking. - async def approve_all_signature_checking(_, pdu): + async def approve_all_signature_checking( + _: RoomVersion, pdu: EventBase + ) -> EventBase: return pdu - hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking + hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment] # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth(origin, event, context, *args, **kwargs): + async def _check_event_auth( + origin: str, + event: EventBase, + context: EventContext, + *args: Any, + **kwargs: Any, + ) -> EventContext: return context - hs.get_federation_event_handler()._check_event_auth = _check_event_auth + hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] return hs - def prepare(self, reactor, clock, homeserver): - super().prepare(reactor, clock, homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) # Create some users and a room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.invitee = self.register_user("invitee", "hackme") @@ -121,13 +139,15 @@ def prepare(self, reactor, clock, homeserver): except Exception: pass - def test_third_party_rules(self): + def test_third_party_rules(self) -> None: """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ # patch the rules module with a Mock which will return False for some event # types - async def check(ev, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) @@ -161,7 +181,7 @@ async def check(ev, state): ) self.assertEqual(channel.result["code"], b"403", channel.result) - def test_third_party_rules_workaround_synapse_errors_pass_through(self): + def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: """ Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042 is functional: that SynapseErrors are passed through from check_event_allowed @@ -172,7 +192,7 @@ def test_third_party_rules_workaround_synapse_errors_pass_through(self): """ class NastyHackException(SynapseError): - def error_dict(self): + def error_dict(self) -> JsonDict: """ This overrides SynapseError's `error_dict` to nastily inject JSON into the error response. @@ -182,7 +202,9 @@ def error_dict(self): return result # add a callback that will raise our hacky exception - async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]: + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: raise NastyHackException(429, "message") self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] @@ -202,11 +224,13 @@ async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]: {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"}, ) - def test_cannot_modify_event(self): + def test_cannot_modify_event(self) -> None: """cannot accidentally modify an event before it is persisted""" # first patch the event checker so that it will try to modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: ev.content = {"x": "y"} return True, None @@ -223,10 +247,12 @@ async def check(ev: EventBase, state): # 500 Internal Server Error self.assertEqual(channel.code, 500, channel.result) - def test_modify_event(self): + def test_modify_event(self) -> None: """The module can return a modified version of the event""" # first patch the event checker so that it will modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: d = ev.get_dict() d["content"] = {"x": "y"} return True, d @@ -253,10 +279,12 @@ async def check(ev: EventBase, state): ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") - def test_message_edit(self): + def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" # first patch the event checker so that it will modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: d = ev.get_dict() d["content"] = { "msgtype": "m.text", @@ -315,7 +343,7 @@ async def check(ev: EventBase, state): ev = channel.json_body self.assertEqual(ev["content"]["body"], "EDITED BODY") - def test_send_event(self): + def test_send_event(self) -> None: """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", @@ -344,7 +372,7 @@ def test_send_event(self): } } ) - def test_legacy_check_event_allowed(self): + def test_legacy_check_event_allowed(self) -> None: """Tests that the wrapper for legacy check_event_allowed callbacks works correctly. """ @@ -379,13 +407,13 @@ def test_legacy_check_event_allowed(self): } } ) - def test_legacy_on_create_room(self): + def test_legacy_on_create_room(self) -> None: """Tests that the wrapper for legacy on_create_room callbacks works correctly. """ self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) - def test_sent_event_end_up_in_room_state(self): + def test_sent_event_end_up_in_room_state(self) -> None: """Tests that a state event sent by a module while processing another state event doesn't get dropped from the state of the room. This is to guard against a bug where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830 @@ -400,7 +428,9 @@ def test_sent_event_end_up_in_room_state(self): api = self.hs.get_module_api() # Define a callback that sends a custom event on power levels update. - async def test_fn(event: EventBase, state_events): + async def test_fn( + event: EventBase, state_events: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: if event.is_state and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { @@ -436,7 +466,7 @@ async def test_fn(event: EventBase, state_events): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["i"], i) - def test_on_new_event(self): + def test_on_new_event(self) -> None: """Test that the on_new_event callback is called on new events""" on_new_event = Mock(make_awaitable(None)) self.hs.get_third_party_event_rules()._on_new_event_callbacks.append( @@ -501,7 +531,7 @@ def _send_event_over_federation(self) -> None: self.assertEqual(channel.code, 200, channel.result) - def _update_power_levels(self, event_default: int = 0): + def _update_power_levels(self, event_default: int = 0) -> None: """Updates the room's power levels. Args: @@ -533,7 +563,7 @@ def _update_power_levels(self, event_default: int = 0): tok=self.tok, ) - def test_on_profile_update(self): + def test_on_profile_update(self) -> None: """Tests that the on_profile_update module callback is correctly called on profile updates. """ @@ -592,7 +622,7 @@ def test_on_profile_update(self): self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.avatar_url, avatar_url) - def test_on_profile_update_admin(self): + def test_on_profile_update_admin(self) -> None: """Tests that the on_profile_update module callback is correctly called on profile updates triggered by a server admin. """ @@ -634,7 +664,7 @@ def test_on_profile_update_admin(self): self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.avatar_url, avatar_url) - def test_on_user_deactivation_status_changed(self): + def test_on_user_deactivation_status_changed(self) -> None: """Tests that the on_user_deactivation_status_changed module callback is called correctly when processing a user's deactivation. """ @@ -691,7 +721,7 @@ def test_on_user_deactivation_status_changed(self): args = profile_mock.call_args[0] self.assertTrue(args[3]) - def test_on_user_deactivation_status_changed_admin(self): + def test_on_user_deactivation_status_changed_admin(self) -> None: """Tests that the on_user_deactivation_status_changed module callback is called correctly when processing a user's deactivation triggered by a server admin as well as a reactivation. diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 8b2da88e8a58..43be711a64f2 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -14,11 +14,16 @@ # limitations under the License. """Tests REST events for /rooms paths.""" - +from typing import Any from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -33,7 +38,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [room.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( "red", @@ -43,30 +48,34 @@ def make_homeserver(self, reactor, clock): self.event_source = hs.get_event_sources().sources.typing - hs.get_federation_handler = Mock() + hs.get_federation_handler = Mock() # type: ignore[assignment] - async def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } + async def get_user_by_access_token( + token: str, + rights: str = "access", + allow_expired: bool = False, + ) -> TokenLookupResult: + return TokenLookupResult( + user_id=self.user_id, + is_guest=False, + token_id=1, + ) - hs.get_auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - async def _insert_client_ip(*args, **kwargs): + async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - hs.get_datastores().main.insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) # Need another user to make notifications actually work self.helper.join(self.room_id, user="@jim:red") - def test_set_typing(self): + def test_set_typing(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), @@ -95,7 +104,7 @@ def test_set_typing(self): ], ) - def test_set_not_typing(self): + def test_set_not_typing(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), @@ -103,7 +112,7 @@ def test_set_not_typing(self): ) self.assertEqual(200, channel.code) - def test_typing_timeout(self): + def test_typing_timeout(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), From 106959b3cf1a59ab5469db639223b6a5b84fb7d7 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 17:24:52 +0000 Subject: [PATCH 006/230] Remove unused mocks from `test_typing` (#12136) * Remove unused mocks from `test_typing` It's not clear what these do. `get_user_by_access_token` has the wrong signature, including the return type. Tests all pass without these. I think we should nuke them. * Changelog * Fixup imports --- changelog.d/12136.misc | 1 + tests/rest/client/test_typing.py | 32 +------------------------------- 2 files changed, 2 insertions(+), 31 deletions(-) create mode 100644 changelog.d/12136.misc diff --git a/changelog.d/12136.misc b/changelog.d/12136.misc new file mode 100644 index 000000000000..98b1c1c9d8ac --- /dev/null +++ b/changelog.d/12136.misc @@ -0,0 +1 @@ +Remove unused mocks from `test_typing`. \ No newline at end of file diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 43be711a64f2..d6da510773af 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -14,14 +14,11 @@ # limitations under the License. """Tests REST events for /rooms paths.""" -from typing import Any -from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor from synapse.rest.client import room from synapse.server import HomeServer -from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from synapse.util import Clock @@ -39,35 +36,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - + hs = self.setup_test_homeserver("red") self.event_source = hs.get_event_sources().sources.typing - - hs.get_federation_handler = Mock() # type: ignore[assignment] - - async def get_user_by_access_token( - token: str, - rights: str = "access", - allow_expired: bool = False, - ) -> TokenLookupResult: - return TokenLookupResult( - user_id=self.user_id, - is_guest=False, - token_id=1, - ) - - hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - - async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: - return None - - hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] - return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: From 1fbe0316a991e77289d4577b16ff3fcd27c26dc8 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 18:00:26 +0000 Subject: [PATCH 007/230] Add suffices to scripts in scripts-dev (#12137) * Rename scripts-dev to have suffices * Update references to `scripts-dev` * Changelog * These scripts don't pass mypy --- .github/workflows/release-artifacts.yml | 4 ++-- .github/workflows/tests.yml | 4 ++-- changelog.d/12137.misc | 1 + docs/code_style.md | 2 +- mypy.ini | 12 +++++++++++- ...uild_debian_packages => build_debian_packages.py} | 0 .../{check-newsfragment => check-newsfragment.sh} | 0 ...erate_sample_config => generate_sample_config.sh} | 4 ++-- scripts-dev/lint.sh | 2 -- scripts-dev/{sign_json => sign_json.py} | 0 tox.ini | 2 -- 11 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12137.misc rename scripts-dev/{build_debian_packages => build_debian_packages.py} (100%) rename scripts-dev/{check-newsfragment => check-newsfragment.sh} (100%) rename scripts-dev/{generate_sample_config => generate_sample_config.sh} (86%) rename scripts-dev/{sign_json => sign_json.py} (100%) diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index eee3633d5043..65ea761ad713 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -31,7 +31,7 @@ jobs: # if we're running from a tag, get the full list of distros; otherwise just use debian:sid dists='["debian:sid"]' if [[ $GITHUB_REF == refs/tags/* ]]; then - dists=$(scripts-dev/build_debian_packages --show-dists-json) + dists=$(scripts-dev/build_debian_packages.py --show-dists-json) fi echo "::set-output name=distros::$dists" # map the step outputs to job outputs @@ -74,7 +74,7 @@ jobs: # see https://github.com/docker/build-push-action/issues/252 # for the cache magic here run: | - ./src/scripts-dev/build_debian_packages \ + ./src/scripts-dev/build_debian_packages.py \ --docker-build-arg=--cache-from=type=local,src=/tmp/.buildx-cache \ --docker-build-arg=--cache-to=type=local,mode=max,dest=/tmp/.buildx-cache-new \ --docker-build-arg=--progress=plain \ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e9e427732239..3f4e44ca592d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - run: pip install -e . - - run: scripts-dev/generate_sample_config --check + - run: scripts-dev/generate_sample_config.sh --check lint: runs-on: ubuntu-latest @@ -51,7 +51,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v2 - run: "pip install 'towncrier>=18.6.0rc1'" - - run: scripts-dev/check-newsfragment + - run: scripts-dev/check-newsfragment.sh env: PULL_REQUEST_NUMBER: ${{ github.event.number }} diff --git a/changelog.d/12137.misc b/changelog.d/12137.misc new file mode 100644 index 000000000000..118ff77a91c6 --- /dev/null +++ b/changelog.d/12137.misc @@ -0,0 +1 @@ +Give `scripts-dev` scripts suffixes for neater CI config. \ No newline at end of file diff --git a/docs/code_style.md b/docs/code_style.md index 4d8e7c973d05..e7c9cd1a5e4f 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -172,6 +172,6 @@ frobber: ``` Note that the sample configuration is generated from the synapse code -and is maintained by a script, `scripts-dev/generate_sample_config`. +and is maintained by a script, `scripts-dev/generate_sample_config.sh`. Making sure that the output from this script matches the desired format is left as an exercise for the reader! diff --git a/mypy.ini b/mypy.ini index 23ca4eaa5a8b..10971b722514 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,7 +11,7 @@ local_partial_types = True no_implicit_optional = True files = - scripts-dev/sign_json, + scripts-dev/, setup.py, synapse/, tests/ @@ -23,10 +23,20 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( + |scripts-dev/build_debian_packages.py + |scripts-dev/check_signature.py + |scripts-dev/definitions.py + |scripts-dev/federation_client.py + |scripts-dev/hash_history.py + |scripts-dev/list_url_patterns.py + |scripts-dev/release.py + |scripts-dev/tail-synapse.py + |synapse/_scripts/export_signing_key.py |synapse/_scripts/move_remote_media_to_new_store.py |synapse/_scripts/synapse_port_db.py |synapse/_scripts/update_synapse_database.py + |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/cache.py diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages.py similarity index 100% rename from scripts-dev/build_debian_packages rename to scripts-dev/build_debian_packages.py diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment.sh similarity index 100% rename from scripts-dev/check-newsfragment rename to scripts-dev/check-newsfragment.sh diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config.sh similarity index 86% rename from scripts-dev/generate_sample_config rename to scripts-dev/generate_sample_config.sh index 185e277933e3..375897eacb67 100755 --- a/scripts-dev/generate_sample_config +++ b/scripts-dev/generate_sample_config.sh @@ -15,11 +15,11 @@ check() { if [ "$1" == "--check" ]; then diff -u "$SAMPLE_CONFIG" <(synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml) >/dev/null || { - echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 + echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2 exit 1 } diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || { - echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 + echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2 exit 1 } else diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index df4d4934d06f..2f5f2c356674 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -85,8 +85,6 @@ else "synapse" "docker" "tests" # annoyingly, black doesn't find these so we have to list them "scripts-dev" - "scripts-dev/build_debian_packages" - "scripts-dev/sign_json" "contrib" "synctl" "setup.py" "synmark" "stubs" ".ci" ) fi diff --git a/scripts-dev/sign_json b/scripts-dev/sign_json.py similarity index 100% rename from scripts-dev/sign_json rename to scripts-dev/sign_json.py diff --git a/tox.ini b/tox.ini index 8d6aa7580bb8..f4829200cca6 100644 --- a/tox.ini +++ b/tox.ini @@ -40,8 +40,6 @@ lint_targets = tests # annoyingly, black doesn't find these so we have to list them scripts-dev - scripts-dev/build_debian_packages - scripts-dev/sign_json stubs contrib synctl From 11282ade1d8deeafa042a639e2685472d6347e69 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 19:22:44 +0000 Subject: [PATCH 008/230] Move the `snapcraft` configuration to `contrib`. (#12142) * Move the `snapcraft` configuration to `contrib`. We're happy for people to package this as a snap image if it's useful, but we don't support or maintain it. I'd like to move the config to `contrib` to reflect this state of affairs. * Changelog --- MANIFEST.in | 1 - changelog.d/12142.misc | 1 + {snap => contrib/snap}/snapcraft.yaml | 0 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 changelog.d/12142.misc rename {snap => contrib/snap}/snapcraft.yaml (100%) diff --git a/MANIFEST.in b/MANIFEST.in index 7e903518e152..f1e295e5837f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -52,5 +52,4 @@ prune contrib prune debian prune demo/etc prune docker -prune snap prune stubs diff --git a/changelog.d/12142.misc b/changelog.d/12142.misc new file mode 100644 index 000000000000..5d09f90b5244 --- /dev/null +++ b/changelog.d/12142.misc @@ -0,0 +1 @@ +Move the snapcraft configuration file to `contrib`. \ No newline at end of file diff --git a/snap/snapcraft.yaml b/contrib/snap/snapcraft.yaml similarity index 100% rename from snap/snapcraft.yaml rename to contrib/snap/snapcraft.yaml From 31b125ccec75e708b09f40205c8cfe692edfa6b4 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 3 Mar 2022 04:45:23 -0600 Subject: [PATCH 009/230] Enable MSC3030 Complement tests in Synapse (#12144) The Complement tests for MSC3030 are now merged, https://github.com/matrix-org/complement/pull/178 Synapse implmentation: https://github.com/matrix-org/synapse/pull/9445 --- .github/workflows/tests.yml | 2 +- changelog.d/12144.misc | 1 + scripts-dev/complement.sh | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12144.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3f4e44ca592d..fa9611d42b30 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -376,7 +376,7 @@ jobs: # Run Complement - run: | set -o pipefail - go test -v -json -p 1 -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt + go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc3030 ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: diff --git a/changelog.d/12144.misc b/changelog.d/12144.misc new file mode 100644 index 000000000000..d8f71bb203eb --- /dev/null +++ b/changelog.d/12144.misc @@ -0,0 +1 @@ +Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 0aecb3daf158..e3d3e0f293ac 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -71,4 +71,4 @@ fi # Run the tests! echo "Images built; running complement" -go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2403,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... From 61fd2a8f591f20fe9d1cffe659336664bf44e742 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Mar 2022 10:52:35 +0000 Subject: [PATCH 010/230] Limit the size of the aggregation_key (#12101) There's no reason to let people use long keys. --- changelog.d/12101.misc | 1 + synapse/handlers/message.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/12101.misc diff --git a/changelog.d/12101.misc b/changelog.d/12101.misc new file mode 100644 index 000000000000..d165f73d13e8 --- /dev/null +++ b/changelog.d/12101.misc @@ -0,0 +1 @@ +Limit the size of `aggregation_key` on annotations. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 61cb133ef265..0799ec9a84df 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1069,6 +1069,9 @@ async def _validate_event_relation(self, event: EventBase) -> None: if relation_type == RelationTypes.ANNOTATION: aggregation_key = relation["key"] + if len(aggregation_key) > 500: + raise SynapseError(400, "Aggregation key is too long") + already_exists = await self.store.has_user_annotated_event( relates_to, event.type, aggregation_key, event.sender ) From a511a890d7c556ad357d27443e5665e6cc25e0b5 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 3 Mar 2022 05:19:20 -0600 Subject: [PATCH 011/230] Enable MSC2716 Complement tests in Synapse (#12145) Co-authored-by: Brendan Abolivier --- .github/workflows/tests.yml | 2 +- changelog.d/12145.misc | 1 + scripts-dev/complement.sh | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12145.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fa9611d42b30..c89c50cd07e2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -376,7 +376,7 @@ jobs: # Run Complement - run: | set -o pipefail - go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc3030 ./tests/... 2>&1 | gotestfmt + go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: diff --git a/changelog.d/12145.misc b/changelog.d/12145.misc new file mode 100644 index 000000000000..4092a2d66e45 --- /dev/null +++ b/changelog.d/12145.misc @@ -0,0 +1 @@ +Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index e3d3e0f293ac..0a79a4063f5a 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -71,4 +71,4 @@ fi # Run the tests! echo "Images built; running complement" -go test -v -tags synapse_blacklist,msc2403,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2403,msc2716,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... From 1d11b452b70c768e4919bd9cf6bcaeda2050a3d4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 3 Mar 2022 10:43:06 -0500 Subject: [PATCH 012/230] Use the proper serialization format when bundling aggregations. (#12090) This ensures that the `latest_event` field of the bundled aggregation for threads uses the same format as the other events in the response. --- changelog.d/12090.bugfix | 1 + synapse/appservice/api.py | 24 ++--- synapse/events/utils.py | 81 ++++++++++------ synapse/handlers/events.py | 3 +- synapse/handlers/initial_sync.py | 9 +- synapse/handlers/pagination.py | 7 +- synapse/rest/client/notifications.py | 9 +- synapse/rest/client/sync.py | 132 ++++++++------------------- tests/events/test_utils.py | 5 +- tests/rest/client/test_relations.py | 2 - 10 files changed, 130 insertions(+), 143 deletions(-) create mode 100644 changelog.d/12090.bugfix diff --git a/changelog.d/12090.bugfix b/changelog.d/12090.bugfix new file mode 100644 index 000000000000..087065dcb1cd --- /dev/null +++ b/changelog.d/12090.bugfix @@ -0,0 +1 @@ +Use the proper serialization format for bundled thread aggregations. The bug has existed since Synapse v1.48.0. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index a0ea958af62a..98fe354014c4 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -25,7 +25,7 @@ TransactionUnusedFallbackKeys, ) from synapse.events import EventBase -from synapse.events.utils import serialize_event +from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache @@ -321,16 +321,18 @@ def _serialize( serialize_event( e, time_now, - as_client_event=True, - # If this is an invite or a knock membership event, and we're interested - # in this user, then include any stripped state alongside the event. - include_stripped_room_state=( - e.type == EventTypes.Member - and ( - e.membership == Membership.INVITE - or e.membership == Membership.KNOCK - ) - and service.is_interested_in_user(e.state_key) + config=SerializeEventConfig( + as_client_event=True, + # If this is an invite or a knock membership event, and we're interested + # in this user, then include any stripped state alongside the event. + include_stripped_room_state=( + e.type == EventTypes.Member + and ( + e.membership == Membership.INVITE + or e.membership == Membership.KNOCK + ) + and service.is_interested_in_user(e.state_key) + ), ), ) for e in events diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9386fa29ddd3..ee34cb46e437 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -26,6 +26,7 @@ Union, ) +import attr from frozendict import frozendict from synapse.api.constants import EventContentFields, EventTypes, RelationTypes @@ -303,29 +304,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: return d +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SerializeEventConfig: + as_client_event: bool = True + # Function to convert from federation format to client format + event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 + # ID of the user's auth token - used for namespacing of transaction IDs + token_id: Optional[int] = None + # List of event fields to include. If empty, all fields will be returned. + only_event_fields: Optional[List[str]] = None + # Some events can have stripped room state stored in the `unsigned` field. + # This is required for invite and knock functionality. If this option is + # False, that state will be removed from the event before it is returned. + # Otherwise, it will be kept. + include_stripped_room_state: bool = False + + +_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig() + + def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, *, - as_client_event: bool = True, - event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, - token_id: Optional[str] = None, - only_event_fields: Optional[List[str]] = None, - include_stripped_room_state: bool = False, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, ) -> JsonDict: """Serialize event for clients Args: e time_now_ms - as_client_event - event_format - token_id - only_event_fields - include_stripped_room_state: Some events can have stripped room state - stored in the `unsigned` field. This is required for invite and knock - functionality. If this option is False, that state will be removed from the - event before it is returned. Otherwise, it will be kept. + config: Event serialization config Returns: The serialized event dictionary. @@ -348,11 +357,11 @@ def serialize_event( if "redacted_because" in e.unsigned: d["unsigned"]["redacted_because"] = serialize_event( - e.unsigned["redacted_because"], time_now_ms, event_format=event_format + e.unsigned["redacted_because"], time_now_ms, config=config ) - if token_id is not None: - if token_id == getattr(e.internal_metadata, "token_id", None): + if config.token_id is not None: + if config.token_id == getattr(e.internal_metadata, "token_id", None): txn_id = getattr(e.internal_metadata, "txn_id", None) if txn_id is not None: d["unsigned"]["transaction_id"] = txn_id @@ -361,13 +370,14 @@ def serialize_event( # that are meant to provide metadata about a room to an invitee/knocker. They are # intended to only be included in specific circumstances, such as down sync, and # should not be included in any other case. - if not include_stripped_room_state: + if not config.include_stripped_room_state: d["unsigned"].pop("invite_room_state", None) d["unsigned"].pop("knock_room_state", None) - if as_client_event: - d = event_format(d) + if config.as_client_event: + d = config.event_format(d) + only_event_fields = config.only_event_fields if only_event_fields: if not isinstance(only_event_fields, list) or not all( isinstance(f, str) for f in only_event_fields @@ -390,18 +400,18 @@ def serialize_event( event: Union[JsonDict, EventBase], time_now: int, *, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, - **kwargs: Any, ) -> JsonDict: """Serializes a single event. Args: event: The event being serialized. time_now: The current time in milliseconds + config: Event serialization config bundle_aggregations: Whether to include the bundled aggregations for this event. Only applies to non-state events. (State events never include bundled aggregations.) - **kwargs: Arguments to pass to `serialize_event` Returns: The serialized event @@ -410,7 +420,7 @@ def serialize_event( if not isinstance(event, EventBase): return event - serialized_event = serialize_event(event, time_now, **kwargs) + serialized_event = serialize_event(event, time_now, config=config) # Check if there are any bundled aggregations to include with the event. if bundle_aggregations: @@ -419,6 +429,7 @@ def serialize_event( self._inject_bundled_aggregations( event, time_now, + config, bundle_aggregations[event.event_id], serialized_event, ) @@ -456,6 +467,7 @@ def _inject_bundled_aggregations( self, event: EventBase, time_now: int, + config: SerializeEventConfig, aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: @@ -466,6 +478,7 @@ def _inject_bundled_aggregations( time_now: The current time in milliseconds aggregations: The bundled aggregation to serialize. serialized_event: The serialized event which may be modified. + config: Event serialization config """ serialized_aggregations = {} @@ -493,8 +506,8 @@ def _inject_bundled_aggregations( thread = aggregations.thread # Don't bundle aggregations as this could recurse forever. - serialized_latest_event = self.serialize_event( - thread.latest_event, time_now, bundle_aggregations=None + serialized_latest_event = serialize_event( + thread.latest_event, time_now, config=config ) # Manually apply an edit, if one exists. if thread.latest_edit: @@ -515,20 +528,34 @@ def _inject_bundled_aggregations( ) def serialize_events( - self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any + self, + events: Iterable[Union[JsonDict, EventBase]], + time_now: int, + *, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, + bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, ) -> List[JsonDict]: """Serializes multiple events. Args: event time_now: The current time in milliseconds - **kwargs: Arguments to pass to `serialize_event` + config: Event serialization config + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) Returns: The list of serialized events """ return [ - self.serialize_event(event, time_now=time_now, **kwargs) for event in events + self.serialize_event( + event, + time_now, + config=config, + bundle_aggregations=bundle_aggregations, + ) + for event in events ] diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 97e75e60c32e..d2ccb5c5d311 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,7 @@ from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase +from synapse.events.utils import SerializeEventConfig from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, UserID @@ -120,7 +121,7 @@ async def get_stream( chunks = self._event_serializer.serialize_events( events, time_now, - as_client_event=as_client_event, + config=SerializeEventConfig(as_client_event=as_client_event), ) chunk = { diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 344f20f37cce..316cfae24ff0 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -18,6 +18,7 @@ from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase +from synapse.events.utils import SerializeEventConfig from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.handlers.receipts import ReceiptEventSource @@ -156,6 +157,8 @@ async def _snapshot_all_rooms( if limit is None: limit = 10 + serializer_options = SerializeEventConfig(as_client_event=as_client_event) + async def handle_room(event: RoomsForUser) -> None: d: JsonDict = { "room_id": event.room_id, @@ -173,7 +176,7 @@ async def handle_room(event: RoomsForUser) -> None: d["invite"] = self._event_serializer.serialize_event( invite_event, time_now, - as_client_event=as_client_event, + config=serializer_options, ) rooms_ret.append(d) @@ -225,7 +228,7 @@ async def handle_room(event: RoomsForUser) -> None: self._event_serializer.serialize_events( messages, time_now=time_now, - as_client_event=as_client_event, + config=serializer_options, ) ), "start": await start_token.to_string(self.store), @@ -235,7 +238,7 @@ async def handle_room(event: RoomsForUser) -> None: d["state"] = self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - as_client_event=as_client_event, + config=serializer_options, ) account_data_events = [] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5c01a426ff32..183fabcfc09e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter +from synapse.events.utils import SerializeEventConfig from synapse.handlers.room import ShutdownRoomResponse from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter @@ -541,13 +542,15 @@ async def get_messages( time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(as_client_event=as_client_event) + chunk = { "chunk": ( self._event_serializer.serialize_events( events, time_now, + config=serialize_options, bundle_aggregations=aggregations, - as_client_event=as_client_event, ) ), "start": await from_token.to_string(self.store), @@ -556,7 +559,7 @@ async def get_messages( if state: chunk["state"] = self._event_serializer.serialize_events( - state, time_now, as_client_event=as_client_event + state, time_now, config=serialize_options ) return chunk diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 20377a9ac628..ff040de6b840 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -16,7 +16,10 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReceiptTypes -from synapse.events.utils import format_event_for_client_v2_without_room_id +from synapse.events.utils import ( + SerializeEventConfig, + format_event_for_client_v2_without_room_id, +) from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -75,7 +78,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._event_serializer.serialize_event( notif_events[pa.event_id], self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, + config=SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id + ), ) ), } diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f3018ff69077..53c385a86cc1 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,24 +14,14 @@ import itertools import logging from collections import defaultdict -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState -from synapse.events import EventBase from synapse.events.utils import ( + SerializeEventConfig, format_event_for_client_v2_without_room_id, format_event_raw, ) @@ -48,7 +38,6 @@ from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import BundledAggregations from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -239,28 +228,31 @@ async def encode_response( else: raise Exception("Unknown event format %s" % (filter.event_format,)) + serialize_options = SerializeEventConfig( + event_format=event_formatter, + token_id=access_token_id, + only_event_fields=filter.event_fields, + ) + stripped_serialize_options = SerializeEventConfig( + event_format=event_formatter, + token_id=access_token_id, + include_stripped_room_state=True, + ) + joined = await self.encode_joined( - sync_result.joined, - time_now, - access_token_id, - filter.event_fields, - event_formatter, + sync_result.joined, time_now, serialize_options ) invited = await self.encode_invited( - sync_result.invited, time_now, access_token_id, event_formatter + sync_result.invited, time_now, stripped_serialize_options ) knocked = await self.encode_knocked( - sync_result.knocked, time_now, access_token_id, event_formatter + sync_result.knocked, time_now, stripped_serialize_options ) archived = await self.encode_archived( - sync_result.archived, - time_now, - access_token_id, - filter.event_fields, - event_formatter, + sync_result.archived, time_now, serialize_options ) logger.debug("building sync response dict") @@ -339,9 +331,7 @@ async def encode_joined( self, rooms: List[JoinedSyncResult], time_now: int, - token_id: Optional[int], - event_fields: List[str], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Encode the joined rooms in a sync result @@ -349,24 +339,14 @@ async def encode_joined( Args: rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing - of transaction IDs - event_fields: List of event fields to include. If empty, - all fields will be returned. - event_formatter: function to convert from federation format - to client format + serialize_options: Event serializer options Returns: The joined rooms list, in our response format """ joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=True, - only_fields=event_fields, - event_formatter=event_formatter, + room, time_now, joined=True, serialize_options=serialize_options ) return joined @@ -376,8 +356,7 @@ async def encode_invited( self, rooms: List[InvitedSyncResult], time_now: int, - token_id: Optional[int], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Encode the invited rooms in a sync result @@ -385,10 +364,7 @@ async def encode_invited( Args: rooms: list of sync results for rooms this user is invited to time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing - of transaction IDs - event_formatter: function to convert from federation format - to client format + serialize_options: Event serializer options Returns: The invited rooms list, in our response format @@ -396,11 +372,7 @@ async def encode_invited( invited = {} for room in rooms: invite = self._event_serializer.serialize_event( - room.invite, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, + room.invite, time_now, config=serialize_options ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned @@ -415,8 +387,7 @@ async def encode_knocked( self, rooms: List[KnockedSyncResult], time_now: int, - token_id: Optional[int], - event_formatter: Callable[[Dict], Dict], + serialize_options: SerializeEventConfig, ) -> Dict[str, Dict[str, Any]]: """ Encode the rooms we've knocked on in a sync result. @@ -424,8 +395,7 @@ async def encode_knocked( Args: rooms: list of sync results for rooms this user is knocking on time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_formatter: function to convert from federation format to client format + serialize_options: Event serializer options Returns: The list of rooms the user has knocked on, in our response format. @@ -433,11 +403,7 @@ async def encode_knocked( knocked = {} for room in rooms: knock = self._event_serializer.serialize_event( - room.knock, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, + room.knock, time_now, config=serialize_options ) # Extract the `unsigned` key from the knock event. @@ -470,9 +436,7 @@ async def encode_archived( self, rooms: List[ArchivedSyncResult], time_now: int, - token_id: Optional[int], - event_fields: List[str], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Encode the archived rooms in a sync result @@ -480,23 +444,14 @@ async def encode_archived( Args: rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing - of transaction IDs - event_fields: List of event fields to include. If empty, - all fields will be returned. - event_formatter: function to convert from federation format to client format + serialize_options: Event serializer options Returns: The archived rooms list, in our response format """ joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=False, - only_fields=event_fields, - event_formatter=event_formatter, + room, time_now, joined=False, serialize_options=serialize_options ) return joined @@ -505,10 +460,8 @@ async def encode_room( self, room: Union[JoinedSyncResult, ArchivedSyncResult], time_now: int, - token_id: Optional[int], joined: bool, - only_fields: Optional[List[str]], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Args: @@ -524,20 +477,6 @@ async def encode_room( Returns: The room, encoded in our response format """ - - def serialize( - events: Iterable[EventBase], - aggregations: Optional[Dict[str, BundledAggregations]] = None, - ) -> List[JsonDict]: - return self._event_serializer.serialize_events( - events, - time_now=time_now, - bundle_aggregations=aggregations, - token_id=token_id, - event_format=event_formatter, - only_event_fields=only_fields, - ) - state_dict = room.state timeline_events = room.timeline.events @@ -554,9 +493,14 @@ def serialize( event.room_id, ) - serialized_state = serialize(state_events) - serialized_timeline = serialize( - timeline_events, room.timeline.bundled_aggregations + serialized_state = self._event_serializer.serialize_events( + state_events, time_now, config=serialize_options + ) + serialized_timeline = self._event_serializer.serialize_events( + timeline_events, + time_now, + config=serialize_options, + bundle_aggregations=room.timeline.bundled_aggregations, ) account_data = room.account_data diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 45e3395b3361..00ad19e446db 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -16,6 +16,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( + SerializeEventConfig, copy_power_levels_contents, prune_event, serialize_event, @@ -392,7 +393,9 @@ def test_member(self): class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): - return serialize_event(ev, 1479807801915, only_event_fields=fields) + return serialize_event( + ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) + ) def test_event_fields_works_with_keys(self): self.assertEqual( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 709f851a3874..53062b41deaa 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -704,10 +704,8 @@ def assert_bundle(event_json: JsonDict) -> None: } }, "event_id": thread_2, - "room_id": self.room, "sender": self.user_id, "type": "m.room.test", - "user_id": self.user_id, }, relations_dict[RelationTypes.THREAD].get("latest_event"), ) From 7e91107be1a4287873266e588a3c5b415279f4c8 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 3 Mar 2022 17:05:44 +0100 Subject: [PATCH 013/230] Add type hints to `tests/rest` (#12146) * Add type hints to `tests/rest` * newsfile * change import from `SigningKey` --- changelog.d/12146.misc | 1 + mypy.ini | 7 +-- tests/rest/key/v2/test_remote_key_resource.py | 44 +++++++++------ tests/rest/media/v1/test_base.py | 4 +- tests/rest/media/v1/test_filepath.py | 48 ++++++++--------- tests/rest/media/v1/test_html_preview.py | 54 +++++++++---------- tests/rest/media/v1/test_oembed.py | 10 ++-- tests/rest/test_health.py | 8 +-- tests/rest/test_well_known.py | 20 +++---- 9 files changed, 104 insertions(+), 92 deletions(-) create mode 100644 changelog.d/12146.misc diff --git a/changelog.d/12146.misc b/changelog.d/12146.misc new file mode 100644 index 000000000000..3ca7c47212fd --- /dev/null +++ b/changelog.d/12146.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest`. diff --git a/mypy.ini b/mypy.ini index 10971b722514..481e8a5366b0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -89,8 +89,6 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_transactions.py - |tests/rest/key/v2/test_remote_key_resource.py - |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py |tests/rest/media/v1/test_url_preview.py |tests/scripts/test_new_matrix_user.py @@ -254,10 +252,7 @@ disallow_untyped_defs = True [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True -[mypy-tests.rest.admin.*] -disallow_untyped_defs = True - -[mypy-tests.rest.client.*] +[mypy-tests.rest.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 4672a6859684..978c252f8482 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -13,19 +13,24 @@ # limitations under the License. import urllib.parse from io import BytesIO, StringIO +from typing import Any, Dict, Optional, Union from unittest.mock import Mock import signedjson.key from canonicaljson import encode_canonical_json -from nacl.signing import SigningKey from signedjson.sign import sign_json +from signedjson.types import SigningKey -from twisted.web.resource import NoResource +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from synapse.util.stringutils import random_string @@ -35,11 +40,11 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock() return self.setup_test_homeserver(federation_http_client=self.http_client) - def create_test_resource(self): + def create_test_resource(self) -> Resource: return create_resource_tree( {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() ) @@ -51,7 +56,12 @@ def expect_outgoing_key_request( Tell the mock http client to expect an outgoing GET request for the given key """ - async def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json( + destination: str, + path: str, + ignore_backoff: bool = False, + **kwargs: Any, + ) -> Union[JsonDict, list]: self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -84,7 +94,8 @@ def make_notary_request(self, server_name: str, key_id: str) -> dict: Checks that the response is a 200 and returns the decoded json body. """ channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(b"") req.requestReceived( b"GET", @@ -97,7 +108,7 @@ def make_notary_request(self, server_name: str, key_id: str) -> dict: resp = channel.json_body return resp - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a remote key""" SERVER_NAME = "remote.server" testkey = signedjson.key.generate_signing_key("ver1") @@ -114,7 +125,7 @@ def test_get_key(self): self.assertIn(SERVER_NAME, keys[0]["signatures"]) self.assertIn(self.hs.hostname, keys[0]["signatures"]) - def test_get_own_key(self): + def test_get_own_key(self) -> None: """Fetch our own key""" testkey = signedjson.key.generate_signing_key("ver1") self.expect_outgoing_key_request(self.hs.hostname, testkey) @@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # replace the signing key with our own @@ -152,7 +163,7 @@ def default_config(self): return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # make a second homeserver, configured to use the first one as a key notary self.http_client2 = Mock() config = default_config(name="keyclient") @@ -175,7 +186,9 @@ def prepare(self, reactor, clock, homeserver): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - async def post_json(destination, path, data): + async def post_json( + destination: str, path: str, data: Optional[JsonDict] = None + ) -> Union[JsonDict, list]: self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, @@ -183,7 +196,8 @@ async def post_json(destination, path, data): ) channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( @@ -198,7 +212,7 @@ async def post_json(destination, path, data): self.http_client2.post_json.side_effect = post_json - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a key belonging to a random server""" # make up a key to be fetched. testkey = signedjson.key.generate_signing_key("abc") @@ -218,7 +232,7 @@ def test_get_key(self): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_key(self): + def test_get_notary_key(self) -> None: """Fetch a key belonging to the notary server""" # make up a key to be fetched. We randomise the keyid to try to get it to # appear before the key server signing key sometimes (otherwise we bail out @@ -240,7 +254,7 @@ def test_get_notary_key(self): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_keyserver_key(self): + def test_get_notary_keyserver_key(self) -> None: """Fetch the notary's keyserver key""" # we expect hs1 to make a regular key request to itself self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key) diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index f761e23f1bf0..c73179151adb 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", } - def tests(self): + def tests(self) -> None: for hdr, expected in self.TEST_CASES.items(): res = get_filename_from_headers({b"Content-Disposition": [hdr]}) self.assertEqual( res, expected, - "expected output for %s to be %s but was %s" % (hdr, expected, res), + f"expected output for {hdr!r} to be {expected} but was {res}", ) diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 913bc530aac1..43e6f0f70aad 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -21,12 +21,12 @@ class MediaFilePathsTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.filepaths = MediaFilePaths("/media_store") - def test_local_media_filepath(self): + def test_local_media_filepath(self) -> None: """Test local media paths""" self.assertEqual( self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -37,7 +37,7 @@ def test_local_media_filepath(self): "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_local_media_thumbnail(self): + def test_local_media_thumbnail(self) -> None: """Test local media thumbnail paths""" self.assertEqual( self.filepaths.local_media_thumbnail_rel( @@ -52,14 +52,14 @@ def test_local_media_thumbnail(self): "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_local_media_thumbnail_dir(self): + def test_local_media_thumbnail_dir(self) -> None: """Test local media thumbnail directory paths""" self.assertEqual( self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_filepath(self): + def test_remote_media_filepath(self) -> None: """Test remote media paths""" self.assertEqual( self.filepaths.remote_media_filepath_rel( @@ -74,7 +74,7 @@ def test_remote_media_filepath(self): "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_thumbnail(self): + def test_remote_media_thumbnail(self) -> None: """Test remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel( @@ -99,7 +99,7 @@ def test_remote_media_thumbnail(self): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_remote_media_thumbnail_legacy(self): + def test_remote_media_thumbnail_legacy(self) -> None: """Test old-style remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel_legacy( @@ -108,7 +108,7 @@ def test_remote_media_thumbnail_legacy(self): "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", ) - def test_remote_media_thumbnail_dir(self): + def test_remote_media_thumbnail_dir(self) -> None: """Test remote media thumbnail directory paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_dir( @@ -117,7 +117,7 @@ def test_remote_media_thumbnail_dir(self): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath(self): + def test_url_cache_filepath(self) -> None: """Test URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), @@ -128,7 +128,7 @@ def test_url_cache_filepath(self): "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_filepath_legacy(self): + def test_url_cache_filepath_legacy(self) -> None: """Test old-style URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -139,7 +139,7 @@ def test_url_cache_filepath_legacy(self): "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath_dirs_to_delete(self): + def test_url_cache_filepath_dirs_to_delete(self) -> None: """Test URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -148,7 +148,7 @@ def test_url_cache_filepath_dirs_to_delete(self): ["/media_store/url_cache/2020-01-02"], ) - def test_url_cache_filepath_dirs_to_delete_legacy(self): + def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -160,7 +160,7 @@ def test_url_cache_filepath_dirs_to_delete_legacy(self): ], ) - def test_url_cache_thumbnail(self): + def test_url_cache_thumbnail(self) -> None: """Test URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -175,7 +175,7 @@ def test_url_cache_thumbnail(self): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_legacy(self): + def test_url_cache_thumbnail_legacy(self) -> None: """Test old-style URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -190,7 +190,7 @@ def test_url_cache_thumbnail_legacy(self): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_directory(self): + def test_url_cache_thumbnail_directory(self) -> None: """Test URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -203,7 +203,7 @@ def test_url_cache_thumbnail_directory(self): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_thumbnail_directory_legacy(self): + def test_url_cache_thumbnail_directory_legacy(self) -> None: """Test old-style URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -216,7 +216,7 @@ def test_url_cache_thumbnail_directory_legacy(self): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_thumbnail_dirs_to_delete(self): + def test_url_cache_thumbnail_dirs_to_delete(self) -> None: """Test URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -228,7 +228,7 @@ def test_url_cache_thumbnail_dirs_to_delete(self): ], ) - def test_url_cache_thumbnail_dirs_to_delete_legacy(self): + def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -241,7 +241,7 @@ def test_url_cache_thumbnail_dirs_to_delete_legacy(self): ], ) - def test_server_name_validation(self): + def test_server_name_validation(self) -> None: """Test validation of server names""" self._test_path_validation( [ @@ -274,7 +274,7 @@ def test_server_name_validation(self): ], ) - def test_file_id_validation(self): + def test_file_id_validation(self) -> None: """Test validation of local, remote and legacy URL cache file / media IDs""" # File / media IDs get split into three parts to form paths, consisting of the # first two characters, next two characters and rest of the ID. @@ -357,7 +357,7 @@ def test_file_id_validation(self): invalid_values=invalid_file_ids, ) - def test_url_cache_media_id_validation(self): + def test_url_cache_media_id_validation(self) -> None: """Test validation of URL cache media IDs""" self._test_path_validation( [ @@ -387,7 +387,7 @@ def test_url_cache_media_id_validation(self): ], ) - def test_content_type_validation(self): + def test_content_type_validation(self) -> None: """Test validation of thumbnail content types""" self._test_path_validation( [ @@ -410,7 +410,7 @@ def test_content_type_validation(self): ], ) - def test_thumbnail_method_validation(self): + def test_thumbnail_method_validation(self) -> None: """Test validation of thumbnail methods""" self._test_path_validation( [ @@ -440,7 +440,7 @@ def _test_path_validation( parameter: str, valid_values: Iterable[str], invalid_values: Iterable[str], - ): + ) -> None: """Test that the specified methods validate the named parameter as expected Args: diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index a4b57e3d1feb..3fb37a2a5970 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -32,7 +32,7 @@ class SummarizeTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_long_summarize(self): + def test_long_summarize(self) -> None: example_paras = [ """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in @@ -90,7 +90,7 @@ def test_long_summarize(self): " Tromsøya had a population of 36,088. Substantial parts of the urban…", ) - def test_short_summarize(self): + def test_short_summarize(self) -> None: example_paras = [ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -117,7 +117,7 @@ def test_short_summarize(self): " most of the year.", ) - def test_small_then_large_summarize(self): + def test_small_then_large_summarize(self) -> None: example_paras = [ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -150,7 +150,7 @@ class CalcOgTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_simple(self): + def test_simple(self) -> None: html = b""" Foo @@ -165,7 +165,7 @@ def test_simple(self): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment(self): + def test_comment(self) -> None: html = b""" Foo @@ -181,7 +181,7 @@ def test_comment(self): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment2(self): + def test_comment2(self) -> None: html = b""" Foo @@ -206,7 +206,7 @@ def test_comment2(self): }, ) - def test_script(self): + def test_script(self) -> None: html = b""" Foo @@ -222,7 +222,7 @@ def test_script(self): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_missing_title(self): + def test_missing_title(self) -> None: html = b""" @@ -236,7 +236,7 @@ def test_missing_title(self): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_h1_as_title(self): + def test_h1_as_title(self) -> None: html = b""" @@ -251,7 +251,7 @@ def test_h1_as_title(self): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) - def test_missing_title_and_broken_h1(self): + def test_missing_title_and_broken_h1(self) -> None: html = b""" @@ -266,19 +266,19 @@ def test_missing_title_and_broken_h1(self): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_empty(self): + def test_empty(self) -> None: """Test a body with no data in it.""" html = b"" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_no_tree(self): + def test_no_tree(self) -> None: """A valid body with no tree in it.""" html = b"\x00" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_xml(self): + def test_xml(self) -> None: """Test decoding XML and ensure it works properly.""" # Note that the strip() call is important to ensure the xml tag starts # at the initial byte. @@ -293,7 +293,7 @@ def test_xml(self): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding(self): + def test_invalid_encoding(self) -> None: """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" html = b""" @@ -307,7 +307,7 @@ def test_invalid_encoding(self): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding2(self): + def test_invalid_encoding2(self) -> None: """A body which doesn't match the sent character encoding.""" # Note that this contains an invalid UTF-8 sequence in the title. html = b""" @@ -322,7 +322,7 @@ def test_invalid_encoding2(self): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) - def test_windows_1252(self): + def test_windows_1252(self) -> None: """A body which uses cp1252, but doesn't declare that.""" html = b""" @@ -338,7 +338,7 @@ def test_windows_1252(self): class MediaEncodingTestCase(unittest.TestCase): - def test_meta_charset(self): + def test_meta_charset(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -363,7 +363,7 @@ def test_meta_charset(self): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_charset_underscores(self): + def test_meta_charset_underscores(self) -> None: """A character encoding contains underscore.""" encodings = _get_html_media_encodings( b""" @@ -376,7 +376,7 @@ def test_meta_charset_underscores(self): ) self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) - def test_xml_encoding(self): + def test_xml_encoding(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -388,7 +388,7 @@ def test_xml_encoding(self): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_xml_encoding(self): + def test_meta_xml_encoding(self) -> None: """Meta tags take precedence over XML encoding.""" encodings = _get_html_media_encodings( b""" @@ -402,7 +402,7 @@ def test_meta_xml_encoding(self): ) self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) - def test_content_type(self): + def test_content_type(self) -> None: """A character encoding is found via the Content-Type header.""" # Test a few variations of the header. headers = ( @@ -417,12 +417,12 @@ def test_content_type(self): encodings = _get_html_media_encodings(b"", header) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_fallback(self): + def test_fallback(self) -> None: """A character encoding cannot be found in the body or header.""" encodings = _get_html_media_encodings(b"", "text/html") self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_duplicates(self): + def test_duplicates(self) -> None: """Ensure each encoding is only attempted once.""" encodings = _get_html_media_encodings( b""" @@ -436,7 +436,7 @@ def test_duplicates(self): ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_unknown_invalid(self): + def test_unknown_invalid(self) -> None: """A character encoding should be ignored if it is unknown or invalid.""" encodings = _get_html_media_encodings( b""" @@ -451,7 +451,7 @@ def test_unknown_invalid(self): class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self): + def test_relative(self) -> None: """Relative URLs should be resolved based on the context of the base URL.""" self.assertEqual( rebase_url("subpage", "https://example.com/foo/"), @@ -466,14 +466,14 @@ def test_relative(self): "https://example.com/bar", ) - def test_absolute(self): + def test_absolute(self) -> None: """Absolute URLs should not be modified.""" self.assertEqual( rebase_url("https://alice.com/a/", "https://example.com/foo/"), "https://alice.com/a/", ) - def test_data(self): + def test_data(self) -> None: """Data URLs should not be modified.""" self.assertEqual( rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index 048d0ca44a95..f38d7225f8f6 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor -from synapse.rest.media.v1.oembed import OEmbedProvider +from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -25,15 +25,15 @@ class OEmbedTests(HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): - self.oembed = OEmbedProvider(homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.oembed = OEmbedProvider(hs) - def parse_response(self, response: JsonDict): + def parse_response(self, response: JsonDict) -> OEmbedResult: return self.oembed.parse_oembed_response( "https://test", json.dumps(response).encode("utf-8") ) - def test_version(self): + def test_version(self) -> None: """Accept versions that are similar to 1.0 as a string or int (or missing).""" for version in ("1.0", 1.0, 1): result = self.parse_response({"version": version, "type": "link"}) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index 01d48c3860d4..da325955f86f 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from synapse.rest.health import HealthResource @@ -19,12 +19,12 @@ class HealthCheckTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> HealthResource: # replace the JsonResource with a HealthResource. return HealthResource() - def test_health(self): + def test_health(self) -> None: channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 118aa93a320d..11f78f52b87a 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + from twisted.web.resource import Resource from synapse.rest.well_known import well_known_resource @@ -19,7 +21,7 @@ class WellKnownTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> Resource: # replace the JsonResource with a Resource wrapping the WellKnownResource res = Resource() res.putChild(b".well-known", well_known_resource(self.hs)) @@ -31,12 +33,12 @@ def create_test_resource(self): "default_identity_server": "https://testis", } ) - def test_client_well_known(self): + def test_client_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, { @@ -50,27 +52,27 @@ def test_client_well_known(self): "public_baseurl": None, } ) - def test_client_well_known_no_public_baseurl(self): + def test_client_well_known_no_public_baseurl(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @unittest.override_config({"serve_server_wellknown": True}) - def test_server_well_known(self): + def test_server_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, {"m.server": "test:443"}, ) - def test_server_well_known_disabled(self): + def test_server_well_known_disabled(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) From 9297d040a72b70c7cc0ec15319afdd99b01ba885 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 3 Mar 2022 17:14:09 +0000 Subject: [PATCH 014/230] Detox, part 2 of N (#12152) I've argued in #11537 that poetry and tox don't cooperate well at the moment. (See also #12119.) Therefore I'm pruning away bits of tox to make the transition to poetry easier. This change removes the commands for coverage. We don't use coverage in anger at the moment. It shouldn't be too hard to add coverage as a dev-dependency and reintroduce this if we really want it. --- changelog.d/12152.misc | 1 + tox.ini | 26 -------------------------- 2 files changed, 1 insertion(+), 26 deletions(-) create mode 100644 changelog.d/12152.misc diff --git a/changelog.d/12152.misc b/changelog.d/12152.misc new file mode 100644 index 000000000000..b9877eaccbee --- /dev/null +++ b/changelog.d/12152.misc @@ -0,0 +1 @@ +Prune unused jobs from `tox` config. \ No newline at end of file diff --git a/tox.ini b/tox.ini index f4829200cca6..04d282a705af 100644 --- a/tox.ini +++ b/tox.ini @@ -158,32 +158,6 @@ commands = extras = lint commands = isort -c --df {[base]lint_targets} -[testenv:combine] -skip_install = true -usedevelop = false -deps = - coverage - pip>=10 -commands= - coverage combine - coverage report - -[testenv:cov-erase] -skip_install = true -usedevelop = false -deps = - coverage -commands= - coverage erase - -[testenv:cov-html] -skip_install = true -usedevelop = false -deps = - coverage -commands= - coverage html - [testenv:mypy] deps = {[base]deps} From fb0ffa96766a4b6f298f53af2d212e4c4d09d9e9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 3 Mar 2022 18:14:09 +0000 Subject: [PATCH 015/230] Rename various ApplicationServices interested methods (#11915) --- changelog.d/11915.misc | 1 + synapse/appservice/__init__.py | 133 +++++++++++++++++++--------- synapse/handlers/appservice.py | 4 +- synapse/handlers/directory.py | 6 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/typing.py | 4 +- tests/appservice/test_appservice.py | 45 +++++++--- tests/handlers/test_appservice.py | 56 +++++++++--- 8 files changed, 175 insertions(+), 76 deletions(-) create mode 100644 changelog.d/11915.misc diff --git a/changelog.d/11915.misc b/changelog.d/11915.misc new file mode 100644 index 000000000000..e3cef1511eb6 --- /dev/null +++ b/changelog.d/11915.misc @@ -0,0 +1 @@ +Simplify the `ApplicationService` class' set of public methods related to interest checking. \ No newline at end of file diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 4d3f8e492384..07ec95f1d67e 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -175,27 +175,14 @@ def _is_exclusive(self, namespace_key: str, test_string: str) -> bool: return namespace.exclusive return False - async def _matches_user(self, event: EventBase, store: "DataStore") -> bool: - if self.is_interested_in_user(event.sender): - return True - - # also check m.room.member state key - if event.type == EventTypes.Member and self.is_interested_in_user( - event.state_key - ): - return True - - does_match = await self.matches_user_in_member_list(event.room_id, store) - return does_match - @cached(num_args=1, cache_context=True) - async def matches_user_in_member_list( + async def _matches_user_in_member_list( self, room_id: str, store: "DataStore", cache_context: _CacheContext, ) -> bool: - """Check if this service is interested a room based upon it's membership + """Check if this service is interested a room based upon its membership Args: room_id: The room to check. @@ -214,47 +201,110 @@ async def matches_user_in_member_list( return True return False - def _matches_room_id(self, event: EventBase) -> bool: - if hasattr(event, "room_id"): - return self.is_interested_in_room(event.room_id) - return False + def is_interested_in_user( + self, + user_id: str, + ) -> bool: + """ + Returns whether the application is interested in a given user ID. + + The appservice is considered to be interested in a user if either: the + user ID is in the appservice's user namespace, or if the user is the + appservice's configured sender_localpart. + + Args: + user_id: The ID of the user to check. + + Returns: + True if the application service is interested in the user, False if not. + """ + return ( + # User is the appservice's sender_localpart user + user_id == self.sender + # User is in the appservice's user namespace + or self.is_user_in_namespace(user_id) + ) + + @cached(num_args=1, cache_context=True) + async def is_interested_in_room( + self, + room_id: str, + store: "DataStore", + cache_context: _CacheContext, + ) -> bool: + """ + Returns whether the application service is interested in a given room ID. + + The appservice is considered to be interested in the room if either: the ID or one + of the aliases of the room is in the appservice's room ID or alias namespace + respectively, or if one of the members of the room fall into the appservice's user + namespace. - async def _matches_aliases(self, event: EventBase, store: "DataStore") -> bool: - alias_list = await store.get_aliases_for_room(event.room_id) + Args: + room_id: The ID of the room to check. + store: The homeserver's datastore class. + + Returns: + True if the application service is interested in the room, False if not. + """ + # Check if we have interest in this room ID + if self.is_room_id_in_namespace(room_id): + return True + + # likewise with the room's aliases (if it has any) + alias_list = await store.get_aliases_for_room(room_id) for alias in alias_list: - if self.is_interested_in_alias(alias): + if self.is_room_alias_in_namespace(alias): return True - return False + # And finally, perform an expensive check on whether any of the + # users in the room match the appservice's user namespace + return await self._matches_user_in_member_list( + room_id, store, on_invalidate=cache_context.invalidate + ) - async def is_interested(self, event: EventBase, store: "DataStore") -> bool: + @cached(num_args=1, cache_context=True) + async def is_interested_in_event( + self, + event_id: str, + event: EventBase, + store: "DataStore", + cache_context: _CacheContext, + ) -> bool: """Check if this service is interested in this event. Args: + event_id: The ID of the event to check. This is purely used for simplifying the + caching of calls to this method. event: The event to check. store: The datastore to query. Returns: - True if this service would like to know about this event. + True if this service would like to know about this event, otherwise False. """ - # Do cheap checks first - if self._matches_room_id(event): + # Check if we're interested in this event's sender by namespace (or if they're the + # sender_localpart user) + if self.is_interested_in_user(event.sender): return True - # This will check the namespaces first before - # checking the store, so should be run before _matches_aliases - if await self._matches_user(event, store): + # additionally, if this is a membership event, perform the same checks on + # the user it references + if event.type == EventTypes.Member and self.is_interested_in_user( + event.state_key + ): return True - # This will check the store, so should be run last - if await self._matches_aliases(event, store): + # This will check the datastore, so should be run last + if await self.is_interested_in_room( + event.room_id, store, on_invalidate=cache_context.invalidate + ): return True return False - @cached(num_args=1) + @cached(num_args=1, cache_context=True) async def is_interested_in_presence( - self, user_id: UserID, store: "DataStore" + self, user_id: UserID, store: "DataStore", cache_context: _CacheContext ) -> bool: """Check if this service is interested a user's presence @@ -272,20 +322,19 @@ async def is_interested_in_presence( # Then find out if the appservice is interested in any of those rooms for room_id in room_ids: - if await self.matches_user_in_member_list(room_id, store): + if await self.is_interested_in_room( + room_id, store, on_invalidate=cache_context.invalidate + ): return True return False - def is_interested_in_user(self, user_id: str) -> bool: - return ( - bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) - or user_id == self.sender - ) + def is_user_in_namespace(self, user_id: str) -> bool: + return bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) - def is_interested_in_alias(self, alias: str) -> bool: + def is_room_alias_in_namespace(self, alias: str) -> bool: return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias)) - def is_interested_in_room(self, room_id: str) -> bool: + def is_room_id_in_namespace(self, room_id: str) -> bool: return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id)) def is_exclusive_user(self, user_id: str) -> bool: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index e6461cc3c980..bd913e524e7b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -571,7 +571,7 @@ async def query_room_alias_exists( room_alias_str = room_alias.to_string() services = self.store.get_app_services() alias_query_services = [ - s for s in services if (s.is_interested_in_alias(room_alias_str)) + s for s in services if (s.is_room_alias_in_namespace(room_alias_str)) ] for alias_service in alias_query_services: is_known_alias = await self.appservice_api.query_alias( @@ -660,7 +660,7 @@ async def _get_services_for_event( # inside of a list comprehension anymore. interested_list = [] for s in services: - if await s.is_interested(event, self.store): + if await s.is_interested_in_event(event.event_id, event, self.store): interested_list.append(s) return interested_list diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index b7064c6624b7..33d827a45b33 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -119,7 +119,7 @@ async def create_association( service = requester.app_service if service: - if not service.is_interested_in_alias(room_alias_str): + if not service.is_room_alias_in_namespace(room_alias_str): raise SynapseError( 400, "This application service has not reserved this kind of alias.", @@ -221,7 +221,7 @@ async def delete_association( async def delete_appservice_association( self, service: ApplicationService, room_alias: RoomAlias ) -> None: - if not service.is_interested_in_alias(room_alias.to_string()): + if not service.is_room_alias_in_namespace(room_alias.to_string()): raise SynapseError( 400, "This application service has not reserved this kind of alias", @@ -376,7 +376,7 @@ def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> b # non-exclusive locks on the alias (or there are no interested services) services = self.store.get_app_services() interested_services = [ - s for s in services if s.is_interested_in_alias(alias.to_string()) + s for s in services if s.is_room_alias_in_namespace(alias.to_string()) ] for service in interested_services: diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index b4132c353ae2..6250bb3bdf2b 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -269,7 +269,7 @@ async def get_new_events_as( # Then filter down to rooms that the AS can read events = [] for room_id, event in rooms_to_events.items(): - if not await service.matches_user_in_member_list(room_id, self.store): + if not await service.is_interested_in_room(room_id, self.store): continue events.append(event) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 843c68eb0fdf..3b8912652856 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -486,9 +486,7 @@ async def get_new_events_as( if handler._room_serials[room_id] <= from_key: continue - if not await service.matches_user_in_member_list( - room_id, self._main_store - ): + if not await service.is_interested_in_room(room_id, self._main_store): continue events.append(self._make_event_for(room_id)) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 9bd6275e92db..edc584d0cf50 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -36,7 +36,10 @@ def setUp(self): hostname="matrix.org", # only used by get_groups_for_user ) self.event = Mock( - type="m.something", room_id="!foo:bar", sender="@someone:somewhere" + event_id="$abc:xyz", + type="m.something", + room_id="!foo:bar", + sender="@someone:somewhere", ) self.store = Mock() @@ -50,7 +53,9 @@ def test_regex_user_id_prefix_match(self): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -62,7 +67,9 @@ def test_regex_user_id_prefix_no_match(self): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -76,7 +83,9 @@ def test_regex_room_member_is_checked(self): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -90,7 +99,9 @@ def test_regex_room_id_match(self): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -104,7 +115,9 @@ def test_regex_room_id_no_match(self): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -121,7 +134,9 @@ def test_regex_alias_match(self): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -174,7 +189,9 @@ def test_regex_alias_no_match(self): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -191,7 +208,9 @@ def test_regex_multiple_matches(self): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -207,7 +226,9 @@ def test_interested_in_self(self): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -225,7 +246,9 @@ def test_member_list_match(self): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(event=self.event, store=self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 072e6bbcdd6e..cead9f90df56 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -59,11 +59,11 @@ def setUp(self): self.event_source = hs.get_event_sources() def test_notify_interested_services(self): - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [ - self._mkservice(is_interested=False), + self._mkservice(is_interested_in_event=False), interested_service, - self._mkservice(is_interested=False), + self._mkservice(is_interested_in_event=False), ] self.mock_as_api.query_user.return_value = make_awaitable(True) @@ -85,7 +85,7 @@ def test_notify_interested_services(self): def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" - services = [self._mkservice(is_interested=True)] + services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services self.mock_store.get_user_by_id.return_value = make_awaitable(None) @@ -102,7 +102,7 @@ def test_query_user_exists_unknown_user(self): def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" - services = [self._mkservice(is_interested=True)] + services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) @@ -127,11 +127,11 @@ def test_query_room_alias_exists(self): room_id = "!alpha:bet" servers = ["aperture"] - interested_service = self._mkservice_alias(is_interested_in_alias=True) + interested_service = self._mkservice_alias(is_room_alias_in_namespace=True) services = [ - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), interested_service, - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), ] self.mock_as_api.query_alias.return_value = make_awaitable(True) @@ -275,7 +275,7 @@ def test_notify_interested_services_ephemeral(self): to be pushed out to interested appservices, and that the stream ID is updated accordingly. """ - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( @@ -304,7 +304,7 @@ def test_notify_interested_services_ephemeral_out_of_order(self): Test sending out of order ephemeral events to the appservice handler are ignored. """ - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services @@ -325,17 +325,45 @@ def test_notify_interested_services_ephemeral_out_of_order(self): interested_service, ephemeral=[] ) - def _mkservice(self, is_interested, protocols=None): + def _mkservice( + self, is_interested_in_event: bool, protocols: Optional[Iterable] = None + ) -> Mock: + """ + Create a new mock representing an ApplicationService. + + Args: + is_interested_in_event: Whether this application service will be considered + interested in all events. + protocols: The third-party protocols that this application service claims to + support. + + Returns: + A mock representing the ApplicationService. + """ service = Mock() - service.is_interested.return_value = make_awaitable(is_interested) + service.is_interested_in_event.return_value = make_awaitable( + is_interested_in_event + ) service.token = "mock_service_token" service.url = "mock_service_url" service.protocols = protocols return service - def _mkservice_alias(self, is_interested_in_alias): + def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock: + """ + Create a new mock representing an ApplicationService that is or is not interested + any given room aliase. + + Args: + is_room_alias_in_namespace: If true, the application service will be interested + in all room aliases that are queried against it. If false, the application + service will not be interested in any room aliases. + + Returns: + A mock representing the ApplicationService. + """ service = Mock() - service.is_interested_in_alias.return_value = is_interested_in_alias + service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace service.token = "mock_service_token" service.url = "mock_service_url" return service From 8533c8b03d8916e3805c7d0e0020226017680147 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 3 Mar 2022 19:58:08 +0000 Subject: [PATCH 016/230] Avoid generating state groups for local out-of-band leaves (#12154) If we locally generate a rejection for an invite received over federation, it is stored as an outlier (because we probably don't have the state for the room). However, currently we still generate a state group for it (even though the state in that state group will be nonsense). By setting the `outlier` param on `create_event`, we avoid the nonsensical state. --- changelog.d/12154.misc | 1 + synapse/handlers/room_member.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12154.misc diff --git a/changelog.d/12154.misc b/changelog.d/12154.misc new file mode 100644 index 000000000000..18d2a4728be9 --- /dev/null +++ b/changelog.d/12154.misc @@ -0,0 +1 @@ +Avoid generating state groups for local out-of-band leaves. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a582837cf0aa..7cbc484b0654 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1736,8 +1736,8 @@ async def _generate_local_out_of_band_leave( txn_id=txn_id, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + outlier=True, ) - event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True result_event = await self.event_creation_handler.handle_new_client_event( From d56202b0383627fdb4e04404d62922dce16868f8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 4 Mar 2022 10:25:18 +0000 Subject: [PATCH 017/230] Fix type of `events` in `StateGroupStorage` and `StateHandler` (#12156) We make multiple passes over this, so a regular iterable won't do. --- changelog.d/12156.misc | 1 + synapse/state/__init__.py | 6 +++--- synapse/storage/state.py | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 changelog.d/12156.misc diff --git a/changelog.d/12156.misc b/changelog.d/12156.misc new file mode 100644 index 000000000000..4818d988d771 --- /dev/null +++ b/changelog.d/12156.misc @@ -0,0 +1 @@ +Fix some type annotations. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6babd5963cc1..21888cc8c561 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -194,7 +194,7 @@ async def get_current_state( } async def get_current_state_ids( - self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None + self, room_id: str, latest_event_ids: Optional[Collection[str]] = None ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room @@ -243,7 +243,7 @@ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: return await self.get_hosts_in_room_at_events(room_id, event_ids) async def get_hosts_in_room_at_events( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> Set[str]: """Get the hosts that were in a room at the given event ids @@ -404,7 +404,7 @@ async def compute_event_context( @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e79ecf64a0ec..86f1a5373bad 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -561,7 +561,7 @@ async def get_state_group_delta( return state_group_delta.prev_group, state_group_delta.delta_ids async def get_state_groups_ids( - self, _room_id: str, event_ids: Iterable[str] + self, _room_id: str, event_ids: Collection[str] ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events @@ -596,7 +596,7 @@ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: return group_to_state[state_group] async def get_state_groups( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> Dict[int, List[EventBase]]: """Get the state groups for the given list of event_ids @@ -648,7 +648,7 @@ def _get_state_groups_from_groups( return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -684,7 +684,7 @@ async def get_state_for_events( return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids From 87c230c27cdeb7e421f61f1271a500c760f1f63b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 4 Mar 2022 10:31:19 +0000 Subject: [PATCH 018/230] Update client-visibility filtering for outlier events (#12155) Avoid trying to get the state for outliers, which isn't a sensible thing to do. --- changelog.d/12155.misc | 1 + synapse/visibility.py | 17 ++++++++- tests/test_visibility.py | 76 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12155.misc diff --git a/changelog.d/12155.misc b/changelog.d/12155.misc new file mode 100644 index 000000000000..9f333e718a86 --- /dev/null +++ b/changelog.d/12155.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/synapse/visibility.py b/synapse/visibility.py index 1b970ce479d0..281cbe4d8877 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -81,8 +81,9 @@ async def filter_events_for_client( types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + # we exclude outliers at this point, and then handle them separately later event_id_to_state = await storage.state.get_state_for_events( - frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events if not e.internal_metadata.outlier), state_filter=StateFilter.from_types(types), ) @@ -154,6 +155,17 @@ def allowed(event: EventBase) -> Optional[EventBase]: if event.event_id in always_include_ids: return event + # we need to handle outliers separately, since we don't have the room state. + if event.internal_metadata.outlier: + # Normally these can't be seen by clients, but we make an exception for + # for out-of-band membership events (eg, incoming invites, or rejections of + # said invite) for the user themselves. + if event.type == EventTypes.Member and event.state_key == user_id: + logger.debug("Returning out-of-band-membership event %s", event) + return event + + return None + state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. @@ -198,6 +210,9 @@ def allowed(event: EventBase) -> Optional[EventBase]: # Always allow the user to see their own leave events, otherwise # they won't see the room disappear if they reject the invite + # + # (Note this doesn't work for out-of-band invite rejections, which don't + # have prev_state populated. They are handled above in the outlier code.) if membership == "leave" and ( prev_membership == "join" or prev_membership == "invite" ): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 219b5660b117..532e3fe9cd92 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,11 +13,12 @@ # limitations under the License. import logging from typing import Optional +from unittest.mock import patch from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase -from synapse.types import JsonDict -from synapse.visibility import filter_events_for_server +from synapse.events import EventBase, make_event_from_dict +from synapse.types import JsonDict, create_requester +from synapse.visibility import filter_events_for_client, filter_events_for_server from tests import unittest from tests.utils import create_room @@ -185,3 +186,72 @@ def _inject_message( self.get_success(self.storage.persistence.persist_event(event, context)) return event + + +class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): + def test_out_of_band_invite_rejection(self): + # this is where we have received an invite event over federation, and then + # rejected it. + invite_pdu = { + "room_id": "!room:id", + "depth": 1, + "auth_events": [], + "prev_events": [], + "origin_server_ts": 1, + "sender": "@someone:" + self.OTHER_SERVER_NAME, + "type": "m.room.member", + "state_key": "@user:test", + "content": {"membership": "invite"}, + } + self.add_hashes_and_signatures(invite_pdu) + invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id + + self.get_success( + self.hs.get_federation_server().on_invite_request( + self.OTHER_SERVER_NAME, + invite_pdu, + "9", + ) + ) + + # stub out do_remotely_reject_invite so that we fall back to a locally- + # generated rejection + with patch.object( + self.hs.get_federation_handler(), + "do_remotely_reject_invite", + side_effect=Exception(), + ): + reject_event_id, _ = self.get_success( + self.hs.get_room_member_handler().remote_reject_invite( + invite_event_id, + txn_id=None, + requester=create_requester("@user:test"), + content={}, + ) + ) + + invite_event, reject_event = self.get_success( + self.hs.get_datastores().main.get_events_as_list( + [invite_event_id, reject_event_id] + ) + ) + + # the invited user should be able to see both the invite and the rejection + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage(), "@user:test", [invite_event, reject_event] + ) + ), + [invite_event, reject_event], + ) + + # other users should see neither + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage(), "@other:test", [invite_event, reject_event] + ) + ), + [], + ) From 423cca9efe06d78aaca5f62fb74ee7e5bceebe49 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Mar 2022 11:48:15 +0000 Subject: [PATCH 019/230] Spread out sending device lists to remote hosts (#12132) --- changelog.d/12132.feature | 1 + synapse/federation/send_queue.py | 2 +- synapse/federation/sender/__init__.py | 26 ++++++---- .../sender/per_destination_queue.py | 10 ++++ synapse/handlers/device.py | 2 +- synapse/replication/tcp/client.py | 2 +- tests/federation/test_federation_sender.py | 52 +++++++++++++++++-- 7 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12132.feature diff --git a/changelog.d/12132.feature b/changelog.d/12132.feature new file mode 100644 index 000000000000..3b8362ad35ed --- /dev/null +++ b/changelog.d/12132.feature @@ -0,0 +1 @@ +Improve performance of logging in for large accounts. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0d7c4f506758..d720b5fd3fe2 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -244,7 +244,7 @@ def send_presence_to_destinations( self.notifier.on_new_replication_data() - def send_device_messages(self, destination: str) -> None: + def send_device_messages(self, destination: str, immediate: bool = False) -> None: """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 6106a486d10a..30e2421efc6d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -118,7 +118,12 @@ def build_and_send_edu( raise NotImplementedError() @abc.abstractmethod - def send_device_messages(self, destination: str) -> None: + def send_device_messages(self, destination: str, immediate: bool = True) -> None: + """Tells the sender that a new device message is ready to be sent to the + destination. The `immediate` flag specifies whether the messages should + be tried to be sent immediately, or whether it can be delayed for a + short while (to aid performance). + """ raise NotImplementedError() @abc.abstractmethod @@ -146,9 +151,8 @@ async def get_replication_rows( @attr.s -class _PresenceQueue: - """A queue of destinations that need to be woken up due to new presence - updates. +class _DestinationWakeupQueue: + """A queue of destinations that need to be woken up due to new updates. Staggers waking up of per destination queues to ensure that we don't attempt to start TLS connections with many hosts all at once, leading to pinned CPU. @@ -175,7 +179,7 @@ def add_to_queue(self, destination: str) -> None: if not self.processing: self._handle() - @wrap_as_background_process("_PresenceQueue.handle") + @wrap_as_background_process("_DestinationWakeupQueue.handle") async def _handle(self) -> None: """Background process to drain the queue.""" @@ -297,7 +301,7 @@ def __init__(self, hs: "HomeServer"): self._external_cache = hs.get_external_cache() - self._presence_queue = _PresenceQueue(self, self.clock) + self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock) def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -614,7 +618,7 @@ def send_presence_to_destinations( states, start_loop=False ) - self._presence_queue.add_to_queue(destination) + self._destination_wakeup_queue.add_to_queue(destination) def build_and_send_edu( self, @@ -667,7 +671,7 @@ def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None: else: queue.send_edu(edu) - def send_device_messages(self, destination: str) -> None: + def send_device_messages(self, destination: str, immediate: bool = False) -> None: if destination == self.server_name: logger.warning("Not sending device update to ourselves") return @@ -677,7 +681,11 @@ def send_device_messages(self, destination: str) -> None: ): return - self._get_per_destination_queue(destination).attempt_new_transaction() + if immediate: + self._get_per_destination_queue(destination).attempt_new_transaction() + else: + self._get_per_destination_queue(destination).mark_new_data() + self._destination_wakeup_queue.add_to_queue(destination) def wake_destination(self, destination: str) -> None: """Called when we want to retry sending transactions to a remote. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index c8768f22bc6b..d80f0ac5e8c3 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -219,6 +219,16 @@ def send_edu(self, edu: Edu) -> None: self._pending_edus.append(edu) self.attempt_new_transaction() + def mark_new_data(self) -> None: + """Marks that the destination has new data to send, without starting a + new transaction. + + If a transaction loop is already in progress then a new transcation will + be attempted when the current one finishes. + """ + + self._new_data_to_send = True + def attempt_new_transaction(self) -> None: """Try to start a new transaction to this destination diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 934b5bd7349b..d90cb259a65c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -506,7 +506,7 @@ async def notify_device_update( "Sending device list update notif for %r to: %r", user_id, hosts ) for host in hosts: - self.federation_sender.send_device_messages(host) + self.federation_sender.send_device_messages(host, immediate=False) log_kv({"message": "sent device update to host", "host": host}) async def notify_user_signature_update( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 1b8479b0b4ec..b8fc1d4db95e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -380,7 +380,7 @@ async def process_replication_rows( # changes. hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: - self.federation_sender.send_device_messages(host) + self.federation_sender.send_device_messages(host, immediate=False) elif stream_name == ToDeviceStream.NAME: # The to_device stream includes stuff to be pushed to both local diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 60e0c31f4384..e90592855ad9 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -201,9 +201,12 @@ def test_send_device_updates(self): self.assertEqual(len(self.edus), 1) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # a second call should produce no new device EDUs self.hs.get_federation_sender().send_device_messages("host2") - self.pump() self.assertEqual(self.edus, []) # a second device @@ -232,6 +235,10 @@ def test_upload_signatures(self): device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect two more edus self.assertEqual(len(self.edus), 2) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) @@ -265,6 +272,10 @@ def test_upload_signatures(self): e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect signing key update edu self.assertEqual(len(self.edus), 2) self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") @@ -284,6 +295,10 @@ def test_upload_signatures(self): ) self.assertEqual(ret["failures"], {}) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect two edus, in one or two transactions. We don't know what order the # devices will be updated. self.assertEqual(len(self.edus), 2) @@ -307,6 +322,10 @@ def test_delete_devices(self): self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect three edus self.assertEqual(len(self.edus), 3) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) @@ -318,6 +337,10 @@ def test_delete_devices(self): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect three edus, in an unknown order self.assertEqual(len(self.edus), 3) for edu in self.edus: @@ -350,12 +373,19 @@ def test_unreachable_server(self): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + self.assertGreaterEqual(mock_send_txn.call_count, 4) # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # for each device, there should be a single update self.assertEqual(len(self.edus), 3) @@ -390,6 +420,10 @@ def test_prune_outbound_device_pokes1(self): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + self.assertGreaterEqual(mock_send_txn.call_count, 4) # run the prune job @@ -401,7 +435,10 @@ def test_prune_outbound_device_pokes1(self): # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # there should be a single update for this user. self.assertEqual(len(self.edus), 1) @@ -435,6 +472,10 @@ def test_prune_outbound_device_pokes2(self): self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # delete them again self.get_success( self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) @@ -451,7 +492,10 @@ def test_prune_outbound_device_pokes2(self): # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # ... and we should get a single update for this user. self.assertEqual(len(self.edus), 1) From 4aeb00ca20a0d9dbb2a104591aca081c723eb6d9 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 4 Mar 2022 11:58:49 +0000 Subject: [PATCH 020/230] Move synctl into `synapse._scripts` and expose as an entrypoint (#12140) --- .dockerignore | 1 - MANIFEST.in | 1 - changelog.d/12140.misc | 1 + docker/Dockerfile | 2 +- docs/postgres.md | 8 ++++---- docs/turn-howto.md | 5 +++-- docs/upgrade.md | 23 ++++++++++++++++++++++- scripts-dev/lint.sh | 2 +- setup.py | 2 +- synctl => synapse/_scripts/synctl.py | 0 tox.ini | 1 - 11 files changed, 33 insertions(+), 13 deletions(-) create mode 100644 changelog.d/12140.misc rename synctl => synapse/_scripts/synctl.py (100%) diff --git a/.dockerignore b/.dockerignore index 617f7015971b..434231fce9fc 100644 --- a/.dockerignore +++ b/.dockerignore @@ -7,6 +7,5 @@ !MANIFEST.in !README.rst !setup.py -!synctl **/__pycache__ diff --git a/MANIFEST.in b/MANIFEST.in index f1e295e5837f..d744c090acde 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,3 @@ -include synctl include LICENSE include VERSION include *.rst diff --git a/changelog.d/12140.misc b/changelog.d/12140.misc new file mode 100644 index 000000000000..33a21a29f0f4 --- /dev/null +++ b/changelog.d/12140.misc @@ -0,0 +1 @@ +Move `synctl` into `synapse._scripts` and expose as an entry point. \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 327275a9cae6..24b5515eb99e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -46,7 +46,7 @@ RUN \ && rm -rf /var/lib/apt/lists/* # Copy just what we need to pip install -COPY MANIFEST.in README.rst setup.py synctl /synapse/ +COPY MANIFEST.in README.rst setup.py /synapse/ COPY synapse/__init__.py /synapse/synapse/__init__.py COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py diff --git a/docs/postgres.md b/docs/postgres.md index 0562021da526..de4e2ba4b756 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -153,9 +153,9 @@ database file (typically `homeserver.db`) to another location. Once the copy is complete, restart synapse. For instance: ```sh -./synctl stop +synctl stop cp homeserver.db homeserver.db.snapshot -./synctl start +synctl start ``` Copy the old config file into a new config file: @@ -192,10 +192,10 @@ Once that has completed, change the synapse config to point at the PostgreSQL database configuration file `homeserver-postgres.yaml`: ```sh -./synctl stop +synctl stop mv homeserver.yaml homeserver-old-sqlite.yaml mv homeserver-postgres.yaml homeserver.yaml -./synctl start +synctl start ``` Synapse should now be running against PostgreSQL. diff --git a/docs/turn-howto.md b/docs/turn-howto.md index eba7ca6124a5..3a2cd04e36a9 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -238,8 +238,9 @@ After updating the homeserver configuration, you must restart synapse: * If you use synctl: ```sh - cd /where/you/run/synapse - ./synctl restart + # Depending on how Synapse is installed, synctl may already be on + # your PATH. If not, you may need to activate a virtual environment. + synctl restart ``` * If you use systemd: ```sh diff --git a/docs/upgrade.md b/docs/upgrade.md index f9be3ac6bc15..0d0bb066ee63 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -47,7 +47,7 @@ this document. 3. Restart Synapse: ```bash - ./synctl restart + synctl restart ``` To check whether your update was successful, you can check the running @@ -85,6 +85,27 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.55.0 + +## `synctl` script has been moved + +The `synctl` script +[has been made](https://github.com/matrix-org/synapse/pull/12140) an +[entry point](https://packaging.python.org/en/latest/specifications/entry-points/) +and no longer exists at the root of Synapse's source tree. If you wish to use +`synctl` to manage your homeserver, you should invoke `synctl` directly, e.g. +`synctl start` instead of `./synctl start` or `/path/to/synctl start`. + +You will need to ensure `synctl` is on your `PATH`. + - This is automatically the case when using + [Debian packages](https://packages.matrix.org/debian/) or + [docker images](https://hub.docker.com/r/matrixdotorg/synapse) + provided by Matrix.org. + - When installing from a wheel, sdist, or PyPI, a `synctl` executable is added + to your Python installation's `bin`. This should be on your `PATH` + automatically, though you might need to activate a virtual environment + depending on how you installed Synapse. + # Upgrading to v1.54.0 ## Legacy structured logging configuration removal diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 2f5f2c356674..c063fafa973b 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -85,7 +85,7 @@ else "synapse" "docker" "tests" # annoyingly, black doesn't find these so we have to list them "scripts-dev" - "contrib" "synctl" "setup.py" "synmark" "stubs" ".ci" + "contrib" "setup.py" "synmark" "stubs" ".ci" ) fi fi diff --git a/setup.py b/setup.py index 318df16766ec..439ed75d7282 100755 --- a/setup.py +++ b/setup.py @@ -155,6 +155,7 @@ def exec_file(path_segments): # Application "synapse_homeserver = synapse.app.homeserver:main", "synapse_worker = synapse.app.generic_worker:main", + "synctl = synapse._scripts.synctl:main", # Scripts "export_signing_key = synapse._scripts.export_signing_key:main", "generate_config = synapse._scripts.generate_config:main", @@ -177,6 +178,5 @@ def exec_file(path_segments): "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", ], - scripts=["synctl"], cmdclass={"test": TestCommand}, ) diff --git a/synctl b/synapse/_scripts/synctl.py similarity index 100% rename from synctl rename to synapse/_scripts/synctl.py diff --git a/tox.ini b/tox.ini index 04d282a705af..f1f96b27ea14 100644 --- a/tox.ini +++ b/tox.ini @@ -42,7 +42,6 @@ lint_targets = scripts-dev stubs contrib - synctl synmark .ci docker From 36071d39f784ddc2271c91a72a55d9ee4f8689bb Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 4 Mar 2022 12:01:51 +0000 Subject: [PATCH 021/230] Changelog (#12153) --- .github/workflows/tests.yml | 1 + changelog.d/12153.misc | 1 + tox.ini | 1 - 3 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12153.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c89c50cd07e2..3bce95b0e057 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/setup-python@v2 - run: pip install -e . - run: scripts-dev/generate_sample_config.sh --check + - run: scripts-dev/config-lint.sh lint: runs-on: ubuntu-latest diff --git a/changelog.d/12153.misc b/changelog.d/12153.misc new file mode 100644 index 000000000000..f02d140f3871 --- /dev/null +++ b/changelog.d/12153.misc @@ -0,0 +1 @@ +Move CI checks out of tox, to facilitate a move to using poetry. \ No newline at end of file diff --git a/tox.ini b/tox.ini index f1f96b27ea14..3ffd2c3e97e8 100644 --- a/tox.ini +++ b/tox.ini @@ -151,7 +151,6 @@ extras = lint commands = python -m black --check --diff {[base]lint_targets} flake8 {[base]lint_targets} {env:PEP8SUFFIX:} - {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort] extras = lint From cd1ae3d0b438ff453b7d4750c4fe901f266fcbb6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Mar 2022 07:10:10 -0500 Subject: [PATCH 022/230] Remove backwards compatibility with RelationPaginationToken. (#12138) --- changelog.d/12138.removal | 1 + synapse/rest/client/relations.py | 55 ++++++---------------- synapse/storage/relations.py | 31 ------------ tests/rest/client/test_relations.py | 73 +---------------------------- 4 files changed, 16 insertions(+), 144 deletions(-) create mode 100644 changelog.d/12138.removal diff --git a/changelog.d/12138.removal b/changelog.d/12138.removal new file mode 100644 index 000000000000..6ed84d476cd9 --- /dev/null +++ b/changelog.d/12138.removal @@ -0,0 +1 @@ +Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 487ea38b55cc..07fa1cdd4c67 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,50 +27,15 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -async def _parse_token( - store: "DataStore", token: Optional[str] -) -> Optional[StreamToken]: - """ - For backwards compatibility support RelationPaginationToken, but new pagination - tokens are generated as full StreamTokens, to be compatible with /sync and /messages. - """ - if not token: - return None - # Luckily the format for StreamToken and RelationPaginationToken differ enough - # that they can easily be separated. An "_" appears in the serialization of - # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses - # "-" only for separators. - if "_" in token: - return await StreamToken.from_string(store, token) - else: - relation_token = RelationPaginationToken.from_string(token) - return StreamToken( - room_key=RoomStreamToken(relation_token.topological, relation_token.stream), - presence_key=0, - typing_key=0, - receipt_key=0, - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=0, - groups_key=0, - ) - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -122,8 +87,12 @@ async def on_GET( pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -317,8 +286,12 @@ async def on_GET( from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 36ca2b827398..fba270150b63 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -54,37 +54,6 @@ async def to_dict(self, store: "DataStore") -> Dict[str, Any]: return d -@attr.s(frozen=True, slots=True, auto_attribs=True) -class RelationPaginationToken: - """Pagination token for relation pagination API. - - As the results are in topological order, we can use the - `topological_ordering` and `stream_ordering` fields of the events at the - boundaries of the chunk as pagination tokens. - - Attributes: - topological: The topological ordering of the boundary event - stream: The stream ordering of the boundary event. - """ - - topological: int - stream: int - - @staticmethod - def from_string(string: str) -> "RelationPaginationToken": - try: - t, s = string.split("-") - return RelationPaginationToken(int(t), int(s)) - except ValueError: - raise SynapseError(400, "Invalid relation pagination token") - - async def to_string(self, store: "DataStore") -> str: - return "%d-%d" % (self.topological, self.stream) - - def as_tuple(self) -> Tuple[Any, ...]: - return attr.astuple(self) - - @attr.s(frozen=True, slots=True, auto_attribs=True) class AggregationPaginationToken: """Pagination token for relation aggregation pagination API. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 53062b41deaa..274f9c44c164 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -24,8 +24,7 @@ from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync from synapse.server import HomeServer -from synapse.storage.relations import RelationPaginationToken -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -281,15 +280,6 @@ def test_basic_paginate_relations(self) -> None: channel.json_body["chunk"][0], ) - def _stream_token_to_relation_token(self, token: str) -> str: - """Convert a StreamToken into a legacy token (RelationPaginationToken).""" - room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key - return self.get_success( - RelationPaginationToken( - topological=room_key.topological, stream=room_key.stream - ).to_string(self.store) - ) - def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -330,34 +320,6 @@ def test_repeated_paginate_relations(self) -> None: found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") @@ -543,39 +505,6 @@ def test_aggregation_pagination_within_group(self) -> None: found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" From 158e0937eb56f14dd851549939037de499526fd2 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 4 Mar 2022 13:10:05 +0000 Subject: [PATCH 023/230] Add test for `ObservableDeferred`'s cancellation behaviour (#12149) Signed-off-by: Sean Quah --- changelog.d/12149.misc | 1 + tests/util/test_async_helpers.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 changelog.d/12149.misc diff --git a/changelog.d/12149.misc b/changelog.d/12149.misc new file mode 100644 index 000000000000..d39af9672365 --- /dev/null +++ b/changelog.d/12149.misc @@ -0,0 +1 @@ +Add test for `ObservableDeferred`'s cancellation behaviour. diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 362014f4cb6f..ff53ce114bd7 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -100,6 +100,34 @@ def check_val(res, idx): self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") + def test_cancellation(self): + """Test that cancelling an observer does not affect other observers.""" + origin_d: "Deferred[int]" = Deferred() + observable = ObservableDeferred(origin_d, consumeErrors=True) + + observer1 = observable.observe() + observer2 = observable.observe() + observer3 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + self.assertFalse(observer3.called) + + # cancel the second observer + observer2.cancel() + self.assertFalse(observer1.called) + self.failureResultOf(observer2, CancelledError) + self.assertFalse(observer3.called) + + # other observers resolve as normal + origin_d.callback(123) + self.assertEqual(observer1.result, 123, "observer 1 callback result") + self.assertEqual(observer3.result, 123, "observer 3 callback result") + + # additional observers resolve as normal + observer4 = observable.observe() + self.assertEqual(observer4.result, 123, "observer 4 callback result") + class TimeoutDeferredTest(TestCase): def setUp(self): From 75574726a766f09d955c05672d400c65cb341810 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 4 Mar 2022 15:37:02 +0000 Subject: [PATCH 024/230] Add type hints for `ObservableDeferred` attributes (#12159) Signed-off-by: Sean Quah --- changelog.d/12159.misc | 1 + synapse/util/async_helpers.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12159.misc diff --git a/changelog.d/12159.misc b/changelog.d/12159.misc new file mode 100644 index 000000000000..30500f2fd95d --- /dev/null +++ b/changelog.d/12159.misc @@ -0,0 +1 @@ +Add type hints for `ObservableDeferred` attributes. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 60c03a66fd16..a9f67dcbac6a 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -40,7 +40,7 @@ ) import attr -from typing_extensions import ContextManager +from typing_extensions import ContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -96,6 +96,10 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): __slots__ = ["_deferred", "_observers", "_result"] + _deferred: "defer.Deferred[_T]" + _observers: Union[List["defer.Deferred[_T]"], Tuple[()]] + _result: Union[None, Tuple[Literal[True], _T], Tuple[Literal[False], Failure]] + def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) @@ -158,12 +162,14 @@ def observe(self) -> "defer.Deferred[_T]": effect the underlying deferred. """ if not self._result: + assert isinstance(self._observers, list) d: "defer.Deferred[_T]" = defer.Deferred() self._observers.append(d) return d + elif self._result[0]: + return defer.succeed(self._result[1]) else: - success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return defer.fail(self._result[1]) def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers @@ -175,6 +181,8 @@ def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True def get_result(self) -> Union[_T, Failure]: + if self._result is None: + raise ValueError(f"{self!r} has no result yet") return self._result[1] def __getattr__(self, name: str) -> Any: From 0752ab7a3621b90073f9332fbfdc8afe16a3be01 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Mar 2022 17:57:27 +0000 Subject: [PATCH 025/230] Reduce to-device queries for /sync. (#12163) --- changelog.d/12163.misc | 1 + synapse/storage/databases/main/deviceinbox.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/12163.misc diff --git a/changelog.d/12163.misc b/changelog.d/12163.misc new file mode 100644 index 000000000000..13de0895f5fa --- /dev/null +++ b/changelog.d/12163.misc @@ -0,0 +1 @@ +Reduce number of DB queries made during processing of `/sync`. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1392363de15a..b4a1b041b1f8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -298,6 +298,9 @@ async def _get_device_messages( # This user has new messages sent to them. Query messages for them user_ids_to_query.add(user_id) + if not user_ids_to_query: + return {}, to_stream_id + def get_device_messages_txn(txn: LoggingTransaction): # Build a query to select messages from any of the given devices that # are between the given stream id bounds. From 0211f18d65b20c9bd77e64f296f8790f4267cf28 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 7 Mar 2022 12:24:06 +0000 Subject: [PATCH 026/230] Switch the `tests-done` job to an Action (#12161) I've factored it out for easier use in other workflows. --- .github/workflows/tests.yml | 30 +++++++++--------------------- changelog.d/12161.misc | 1 + 2 files changed, 10 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12161.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3bce95b0e057..613a773775a2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -388,34 +388,22 @@ jobs: tests-done: if: ${{ always() }} needs: + - check-sampleconfig - lint - lint-crlf - lint-newsfile - trial - trial-olddeps - sytest + - export-data - portdb - complement runs-on: ubuntu-latest steps: - - name: Set build result - env: - NEEDS_CONTEXT: ${{ toJSON(needs) }} - # the `jq` incantation dumps out a series of " " lines. - # we set it to an intermediate variable to avoid a pipe, which makes it - # hard to set $rc. - run: | - rc=0 - results=$(jq -r 'to_entries[] | [.key,.value.result] | join(" ")' <<< $NEEDS_CONTEXT) - while read job result ; do - # The newsfile lint may be skipped on non PR builds - if [ $result == "skipped" ] && [ $job == "lint-newsfile" ]; then - continue - fi - - if [ "$result" != "success" ]; then - echo "::set-failed ::Job $job returned $result" - rc=1 - fi - done <<< $results - exit $rc + - uses: matrix-org/done-action@v2 + with: + needs: ${{ toJSON(needs) }} + + # The newsfile lint may be skipped on non PR builds + skippable: + lint-newsfile diff --git a/changelog.d/12161.misc b/changelog.d/12161.misc new file mode 100644 index 000000000000..43eff08d467e --- /dev/null +++ b/changelog.d/12161.misc @@ -0,0 +1 @@ +Use a prebuilt Action for the `tests-done` CI job. From f63bedef07360216a8de71dc38f00f1aea503903 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Mar 2022 09:00:05 -0500 Subject: [PATCH 027/230] Invalidate caches when an event with a relation is redacted. (#12121) The caches for the target of the relation must be cleared so that the bundled aggregations are re-calculated after the redaction is processed. --- changelog.d/12113.bugfix | 1 + changelog.d/12113.misc | 1 - changelog.d/12121.bugfix | 1 + synapse/storage/databases/main/cache.py | 2 + synapse/storage/databases/main/events.py | 38 ++++- tests/rest/client/test_relations.py | 207 ++++++++++++++++++----- 6 files changed, 202 insertions(+), 48 deletions(-) create mode 100644 changelog.d/12113.bugfix delete mode 100644 changelog.d/12113.misc create mode 100644 changelog.d/12121.bugfix diff --git a/changelog.d/12113.bugfix b/changelog.d/12113.bugfix new file mode 100644 index 000000000000..df9b0dc413dd --- /dev/null +++ b/changelog.d/12113.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12113.misc b/changelog.d/12113.misc deleted file mode 100644 index 102e064053c2..000000000000 --- a/changelog.d/12113.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the tests for event relations. diff --git a/changelog.d/12121.bugfix b/changelog.d/12121.bugfix new file mode 100644 index 000000000000..df9b0dc413dd --- /dev/null +++ b/changelog.d/12121.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c428dd5596af..abd54c7dc703 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -200,6 +200,8 @@ def _invalidate_caches_for_event( self.get_relations_for_event.invalidate((relates_to,)) self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) + self.get_thread_summary.invalidate((relates_to,)) + self.get_thread_participated.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ca2a9ba9d116..1dc83aa5e3a6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1518,7 +1518,7 @@ def _update_metadata_tables_txn( ) # Remove from relations table. - self._handle_redaction(txn, event.redacts) + self._handle_redact_relations(txn, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1943,15 +1943,43 @@ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase): txn.execute(sql, (batch_id,)) - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. + def _handle_redact_relations( + self, txn: LoggingTransaction, redacted_event_id: str + ) -> None: + """Handles receiving a redaction and checking whether the redacted event + has any relations which must be removed from the database. Args: txn - redacted_event_id (str): The event that was redacted. + redacted_event_id: The event that was redacted. """ + # Fetch the current relation of the event being redacted. + redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_relations", + keyvalues={"event_id": redacted_event_id}, + retcol="relates_to_id", + allow_none=True, + ) + # Any relation information for the related event must be cleared. + if redacted_relates_to is not None: + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_applicable_edit, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_summary, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_participated, (redacted_relates_to,) + ) + self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 274f9c44c164..a40a5de3991c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1273,7 +1273,21 @@ def test_background_update(self) -> None: class RelationRedactionTestCase(BaseRelationsTestCase): - """Test the behaviour of relations when the parent or child event is redacted.""" + """ + Test the behaviour of relations when the parent or child event is redacted. + + The behaviour of each relation type is subtly different which causes the tests + to be a bit repetitive, they follow a naming scheme of: + + test_redact_(relation|parent)_{relation_type} + + The first bit of "relation" means that the event with the relation defined + on it (the child event) is to be redacted. A "parent" means that the target + of the relation (the parent event) is to be redacted. + + The relation_type describes which type of relation is under test (i.e. it is + related to the value of rel_type in the event content). + """ def _redact(self, event_id: str) -> None: channel = self.make_request( @@ -1284,9 +1298,53 @@ def _redact(self, event_id: str) -> None: ) self.assertEqual(200, channel.code, channel.json_body) + def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: + """ + Makes requests and ensures they result in a 200 response, returns a + tuple of results: + + 1. `/relations` -> Returns a list of event IDs. + 2. `/event` -> Returns the response's m.relations field (from unsigned), + if it exists. + """ + + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) + + return event_ids, bundled_relations + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + def test_redact_relation_annotation(self) -> None: - """Test that annotations of an event are properly handled after the + """ + Test that annotations of an event are properly handled after the annotation is redacted. + + The redacted relation should not be included in bundled aggregations or + the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(200, channel.code, channel.json_body) @@ -1296,24 +1354,97 @@ def test_redact_relation_annotation(self) -> None: RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + + # Both relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, + ) + + # Both relations appear in the aggregation. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) # Redact one of the reactions. self._redact(to_redact_event_id) - # Ensure that the aggregations are correct. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + # The unredacted aggregation should still exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_relation_thread(self) -> None: + """ + Test that thread replies are properly handled after the thread reply redacted. + + The redacted event should not be included in bundled aggregations or + the response to relations. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + # Note that the *last* event in the thread is redacted, as that gets + # included in the bundled aggregation. + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 2", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + # Both relations exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 2, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event returned is the event that will be redacted. self.assertEqual( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + to_redact_event_id, ) - def test_redact_relation_edit(self) -> None: + # Redact one of the reactions. + self._redact(to_redact_event_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event is now the unredacted event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + unredacted_event_id, + ) + + def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1331,34 +1462,19 @@ def test_redact_relation_edit(self) -> None: self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.REPLACE, relations) # Redact the original event self._redact(self.parent_id) - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 0) + self.assertEqual(relations, {}) - def test_redact_parent(self) -> None: + def test_redact_parent_annotation(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1366,16 +1482,23 @@ def test_redact_parent(self) -> None: channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + # The relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.ANNOTATION, relations) + + # The aggregation should exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) + # Redact the original event. self._redact(self.parent_id) - # Check that aggregations returns zero - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(event_ids, []) + self.assertEqual(relations, {}) - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # There's nothing to aggregate. + chunk = self._get_aggregations() + self.assertEqual(chunk, []) From 26211fec24d8d0a967de33147e148166359ec8cb Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Mar 2022 09:44:33 -0800 Subject: [PATCH 028/230] Fix a bug in background updates wherein background updates are never run using the default batch size (#12157) --- changelog.d/12157.bugfix | 1 + synapse/storage/background_updates.py | 8 +++++--- tests/rest/admin/test_background_updates.py | 18 ++++++++---------- tests/storage/test_background_update.py | 4 ++-- 4 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/12157.bugfix diff --git a/changelog.d/12157.bugfix b/changelog.d/12157.bugfix new file mode 100644 index 000000000000..c3d2e700bb1d --- /dev/null +++ b/changelog.d/12157.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in #4864 whereby background updates are never run with the default background batch size. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d64910aded33..4acc2c997dce 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -102,10 +102,12 @@ def average_items_per_ms(self) -> Optional[float]: Returns: A duration in ms as a float """ - if self.avg_duration_ms == 0: - return 0 - elif self.total_item_count == 0: + # We want to return None if this is the first background update item + if self.total_item_count == 0: return None + # Avoid dividing by zero + elif self.avg_duration_ms == 0: + return 0 else: # Use the exponential moving average so that we can adapt to # changes in how long the update process takes. diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index fb36aa994090..becec84524cb 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -155,10 +155,10 @@ def test_status_bg_update(self) -> None: "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -210,10 +210,10 @@ def test_enabled(self) -> None: "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -239,10 +239,10 @@ def test_enabled(self) -> None: "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -278,11 +278,9 @@ def test_enabled(self) -> None: "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.05263157894736842, "total_duration_ms": 2000.0, - "total_item_count": ( - 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE - ), + "total_item_count": (110), } }, "enabled": True, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 39dcc094bd8b..9fdf54ea31f9 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -66,13 +66,13 @@ async def update(progress, count): self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), - by=0.01, + by=0.02, ) self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update From d8bab6793c75774db4bde8aeec6897b607e08799 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Mar 2022 07:26:05 -0500 Subject: [PATCH 029/230] Fix incorrect type hints for txredis. (#12042) Some properties were marked as RedisProtocol instead of ConnectionHandler, which wraps RedisProtocol instance(s). --- changelog.d/12042.misc | 1 + stubs/txredisapi.pyi | 9 ++++++--- synapse/replication/tcp/external_cache.py | 4 ++-- synapse/replication/tcp/redis.py | 6 +++--- synapse/server.py | 4 ++-- 5 files changed, 14 insertions(+), 10 deletions(-) create mode 100644 changelog.d/12042.misc diff --git a/changelog.d/12042.misc b/changelog.d/12042.misc new file mode 100644 index 000000000000..6ecdc960210c --- /dev/null +++ b/changelog.d/12042.misc @@ -0,0 +1 @@ +Correct type hints for txredis. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 429234d7ae7f..2d8ca018fbfc 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -20,7 +20,7 @@ from twisted.internet import protocol from twisted.internet.defer import Deferred class RedisProtocol(protocol.Protocol): - def publish(self, channel: str, message: bytes): ... + def publish(self, channel: str, message: bytes) -> "Deferred[None]": ... def ping(self) -> "Deferred[None]": ... def set( self, @@ -52,11 +52,14 @@ def lazyConnection( convertNumbers: bool = ..., ) -> RedisProtocol: ... -class ConnectionHandler: ... +# ConnectionHandler doesn't actually inherit from RedisProtocol, but it proxies +# most methods to it via ConnectionHandler.__getattr__. +class ConnectionHandler(RedisProtocol): + def disconnect(self) -> "Deferred[None]": ... class RedisFactory(protocol.ReconnectingClientFactory): continueTrying: bool - handler: RedisProtocol + handler: ConnectionHandler pool: List[RedisProtocol] replyTimeout: Optional[int] def __init__( diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index aaf91e5e0253..bf7d017968f9 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -21,7 +21,7 @@ from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: - from txredisapi import RedisProtocol + from txredisapi import ConnectionHandler from synapse.server import HomeServer @@ -63,7 +63,7 @@ class ExternalCache: def __init__(self, hs: "HomeServer"): if hs.config.redis.redis_enabled: self._redis_connection: Optional[ - "RedisProtocol" + "ConnectionHandler" ] = hs.get_outbound_redis_connection() else: self._redis_connection = None diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 3170f7c59b04..b84e572da136 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -93,7 +93,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): synapse_handler: "ReplicationCommandHandler" synapse_stream_name: str - synapse_outbound_redis_connection: txredisapi.RedisProtocol + synapse_outbound_redis_connection: txredisapi.ConnectionHandler def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -313,7 +313,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): protocol = RedisSubscriber def __init__( - self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol + self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler ): super().__init__( @@ -353,7 +353,7 @@ def lazyConnection( reconnect: bool = True, password: Optional[str] = None, replyTimeout: int = 30, -) -> txredisapi.RedisProtocol: +) -> txredisapi.ConnectionHandler: """Creates a connection to Redis that is lazily set up and reconnects if the connections is lost. """ diff --git a/synapse/server.py b/synapse/server.py index b5e2a319bcef..46a64418ea0c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -145,7 +145,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from txredisapi import RedisProtocol + from txredisapi import ConnectionHandler from synapse.handlers.oidc import OidcHandler from synapse.handlers.saml import SamlHandler @@ -807,7 +807,7 @@ def get_account_handler(self) -> AccountHandler: return AccountHandler(self) @cache_in_self - def get_outbound_redis_connection(self) -> "RedisProtocol": + def get_outbound_redis_connection(self) -> "ConnectionHandler": """ The Redis connection used for replication. From ca9234a9eba4fba02d8d50e5d5eff079bfaf0ebd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Mar 2022 08:09:11 -0500 Subject: [PATCH 030/230] Do not return allowed_room_ids from /hierarchy response. (#12175) This field is only to be used in the Server-Server API, and not the Client-Server API, but was being leaked when a federation response was used in the /hierarchy API. --- changelog.d/12175.bugfix | 1 + synapse/handlers/room_summary.py | 15 +++++++++++++-- tests/handlers/test_room_summary.py | 3 +++ 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12175.bugfix diff --git a/changelog.d/12175.bugfix b/changelog.d/12175.bugfix new file mode 100644 index 000000000000..881cb9b76c20 --- /dev/null +++ b/changelog.d/12175.bugfix @@ -0,0 +1 @@ +Fix a bug where non-standard information was returned from the `/hierarchy` API. Introduced in Synapse v1.41.0. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 3979cbba71bd..486145f48aca 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -295,7 +295,7 @@ async def _get_room_hierarchy( # inaccessible to the requesting user. if room_entry: # Add the room (including the stripped m.space.child events). - rooms_result.append(room_entry.as_json()) + rooms_result.append(room_entry.as_json(for_client=True)) # If this room is not at the max-depth, check if there are any # children to process. @@ -843,14 +843,25 @@ class _RoomEntry: # This may not include all children. children_state_events: Sequence[JsonDict] = () - def as_json(self) -> JsonDict: + def as_json(self, for_client: bool = False) -> JsonDict: """ Returns a JSON dictionary suitable for the room hierarchy endpoint. It returns the room summary including the stripped m.space.child events as a sub-key. + + Args: + for_client: If true, any server-server only fields are stripped from + the result. + """ result = dict(self.room) + + # Before returning to the client, remove the allowed_room_ids key, if it + # exists. + if for_client: + result.pop("allowed_room_ids", False) + result["children_state"] = self.children_state_events return result diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index cff07a8973b3..d37292ce138e 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -172,6 +172,9 @@ def _assert_hierarchy( result_room_ids = [] result_children_ids = [] for result_room in result["rooms"]: + # Ensure federation results are not leaking over the client-server API. + self.assertNotIn("allowed_room_ids", result_room) + result_room_ids.append(result_room["room_id"]) result_children_ids.append( [ From 2ce27a24fe29104ca54e0a879c7ad37d88a3fc69 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 8 Mar 2022 13:23:18 +0000 Subject: [PATCH 031/230] Add experimental environment variable to enable asyncio reactor (#12135) --- changelog.d/12135.feature | 1 + mypy.ini | 3 +++ synapse/__init__.py | 21 +++++++++++++++++++++ 3 files changed, 25 insertions(+) create mode 100644 changelog.d/12135.feature diff --git a/changelog.d/12135.feature b/changelog.d/12135.feature new file mode 100644 index 000000000000..b337f51730e6 --- /dev/null +++ b/changelog.d/12135.feature @@ -0,0 +1 @@ +Add experimental env var `SYNAPSE_ASYNC_IO_REACTOR` that causes Synapse to use the asyncio reactor for Twisted. diff --git a/mypy.ini b/mypy.ini index 481e8a5366b0..c8390ddba96d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -353,3 +353,6 @@ ignore_missing_imports = True [mypy-zope] ignore_missing_imports = True + +[mypy-incremental.*] +ignore_missing_imports = True diff --git a/synapse/__init__.py b/synapse/__init__.py index b21e1ed0f342..674acc713503 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -25,6 +25,27 @@ print("Synapse requires Python 3.7 or above.") sys.exit(1) +# Allow using the asyncio reactor via env var. +if bool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", False)): + try: + from incremental import Version + + import twisted + + # We need a bugfix that is included in Twisted 21.2.0: + # https://twistedmatrix.com/trac/ticket/9787 + if twisted.version < Version("Twisted", 21, 2, 0): + print("Using asyncio reactor requires Twisted>=21.2.0") + sys.exit(1) + + import asyncio + + from twisted.internet import asyncioreactor + + asyncioreactor.install(asyncio.get_event_loop()) + except ImportError: + pass + # Twisted and canonicaljson will fail to import when this file is executed to # get the __version__ during a fresh install. That's OK and subsequent calls to # actually start Synapse will import these libraries fine. From bfa7d6b03588ec3eca488f404d8fcac38f9e6427 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 8 Mar 2022 15:11:50 +0000 Subject: [PATCH 032/230] Fix CI not attaching source distributions and wheels to the GitHub releases. (#12131) --- .github/workflows/release-artifacts.yml | 3 ++- changelog.d/12131.misc | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12131.misc diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 65ea761ad713..ed4fc6179db3 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -112,7 +112,8 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: files: | - python-dist/* + Sdist/* + Wheel/* debs.tar.xz # if it's not already published, keep the release as a draft. draft: true diff --git a/changelog.d/12131.misc b/changelog.d/12131.misc new file mode 100644 index 000000000000..8ef23c22d524 --- /dev/null +++ b/changelog.d/12131.misc @@ -0,0 +1 @@ +Fix CI not attaching source distributions and wheels to the GitHub releases. \ No newline at end of file From 562718278847375636ead2ed3afcc9d9d482ef96 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 8 Mar 2022 15:58:14 +0000 Subject: [PATCH 033/230] Use `ParamSpec` in type hints for `synapse.logging.context` (#12150) Signed-off-by: Sean Quah --- changelog.d/12150.misc | 1 + synapse/handlers/initial_sync.py | 5 +-- synapse/logging/context.py | 44 ++++++++++++----------- synapse/python_dependencies.py | 3 +- synapse/rest/media/v1/storage_provider.py | 9 +++-- 5 files changed, 37 insertions(+), 25 deletions(-) create mode 100644 changelog.d/12150.misc diff --git a/changelog.d/12150.misc b/changelog.d/12150.misc new file mode 100644 index 000000000000..2d2706dac769 --- /dev/null +++ b/changelog.d/12150.misc @@ -0,0 +1 @@ +Use `ParamSpec` in type hints for `synapse.logging.context`. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 316cfae24ff0..a7db8feb57eb 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -153,8 +153,9 @@ async def _snapshot_all_rooms( public_room_ids = await self.store.get_public_room_ids() - limit = pagin_config.limit - if limit is None: + if pagin_config.limit is not None: + limit = pagin_config.limit + else: limit = 10 serializer_options = SerializeEventConfig(as_client_event=as_client_event) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index c31c2960ad95..88cd8a9e1c39 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -29,7 +29,6 @@ from types import TracebackType from typing import ( TYPE_CHECKING, - Any, Awaitable, Callable, Optional, @@ -41,7 +40,7 @@ ) import attr -from typing_extensions import Literal +from typing_extensions import Literal, ParamSpec from twisted.internet import defer, threads from twisted.python.threadpool import ThreadPool @@ -719,32 +718,33 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) +P = ParamSpec("P") R = TypeVar("R") @overload def preserve_fn( # type: ignore[misc] - f: Callable[..., Awaitable[R]], -) -> Callable[..., "defer.Deferred[R]"]: + f: Callable[P, Awaitable[R]], +) -> Callable[P, "defer.Deferred[R]"]: # The `type: ignore[misc]` above suppresses # "Overloaded function signatures 1 and 2 overlap with incompatible return types" ... @overload -def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]: +def preserve_fn(f: Callable[P, R]) -> Callable[P, "defer.Deferred[R]"]: ... def preserve_fn( f: Union[ - Callable[..., R], - Callable[..., Awaitable[R]], + Callable[P, R], + Callable[P, Awaitable[R]], ] -) -> Callable[..., "defer.Deferred[R]"]: +) -> Callable[P, "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" - def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": + def g(*args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[R]": return run_in_background(f, *args, **kwargs) return g @@ -752,7 +752,7 @@ def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": @overload def run_in_background( # type: ignore[misc] - f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any + f: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": # The `type: ignore[misc]` above suppresses # "Overloaded function signatures 1 and 2 overlap with incompatible return types" @@ -761,18 +761,22 @@ def run_in_background( # type: ignore[misc] @overload def run_in_background( - f: Callable[..., R], *args: Any, **kwargs: Any + f: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": ... -def run_in_background( +def run_in_background( # type: ignore[misc] + # The `type: ignore[misc]` above suppresses + # "Overloaded function implementation does not accept all possible arguments of signature 1" + # "Overloaded function implementation does not accept all possible arguments of signature 2" + # which seems like a bug in mypy. f: Union[ - Callable[..., R], - Callable[..., Awaitable[R]], + Callable[P, R], + Callable[P, Awaitable[R]], ], - *args: Any, - **kwargs: Any, + *args: P.args, + **kwargs: P.kwargs, ) -> "defer.Deferred[R]": """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the @@ -872,7 +876,7 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: def defer_to_thread( - reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any + reactor: "ISynapseReactor", f: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": """ Calls the function `f` using a thread from the reactor's default threadpool and @@ -908,9 +912,9 @@ def defer_to_thread( def defer_to_threadpool( reactor: "ISynapseReactor", threadpool: ThreadPool, - f: Callable[..., R], - *args: Any, - **kwargs: Any, + f: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> "defer.Deferred[R]": """ A wrapper for twisted.internet.threads.deferToThreadpool, which handles diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index b40a7bbb76ca..1dd39f06cffb 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -76,7 +76,8 @@ "netaddr>=0.7.18", "Jinja2>=2.9", "bleach>=1.4.3", - "typing-extensions>=3.7.4", + # We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0. + "typing-extensions>=3.10.0", # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 18bf977d3d9f..1c9b71d69c77 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -16,7 +16,7 @@ import logging import os import shutil -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -150,8 +150,13 @@ async def store_file(self, path: str, file_info: FileInfo) -> None: dirname = os.path.dirname(backup_fname) os.makedirs(dirname, exist_ok=True) + # mypy needs help inferring the type of the second parameter, which is generic + shutil_copyfile: Callable[[str, str], str] = shutil.copyfile await defer_to_thread( - self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname + self.hs.get_reactor(), + shutil_copyfile, + primary_fname, + backup_fname, ) async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: From 9a0172d49f3da46c615304c7df3353494500fd49 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Mar 2022 15:02:59 -0500 Subject: [PATCH 034/230] Clean-up demo scripts & documentation (#12143) * Rewrites the demo documentation to be clearer, accurate, and moves it to our documentation tree. * Improvements to the demo scripts: * `clean.sh` now runs `stop.sh` first to avoid zombie processes. * Uses more modern Synapse configuration (and removes some obsolete configuration). * Consistently use the HTTP ports for server name, etc. * Remove the `demo/etc` directory and place everything into the `demo/808x` directories. --- README.rst | 3 ++ changelog.d/12143.doc | 1 + demo/.gitignore | 11 +++---- demo/README | 26 --------------- demo/clean.sh | 3 ++ demo/start.sh | 71 +++++++++++++++++++--------------------- docs/SUMMARY.md | 1 + docs/development/demo.md | 41 +++++++++++++++++++++++ docs/federate.md | 3 +- 9 files changed, 89 insertions(+), 71 deletions(-) create mode 100644 changelog.d/12143.doc delete mode 100644 demo/README create mode 100644 docs/development/demo.md diff --git a/README.rst b/README.rst index 4281c87d1f80..595fb5ff62a6 100644 --- a/README.rst +++ b/README.rst @@ -312,6 +312,9 @@ We recommend using the demo which starts 3 federated instances running on ports (to stop, you can use `./demo/stop.sh`) +See the [demo documentation](https://matrix-org.github.io/synapse/develop/development/demo.html) +for more information. + If you just want to start a single instance of the app and run it directly:: # Create the homeserver.yaml config once diff --git a/changelog.d/12143.doc b/changelog.d/12143.doc new file mode 100644 index 000000000000..4b9db74b1fc9 --- /dev/null +++ b/changelog.d/12143.doc @@ -0,0 +1 @@ +Improve documentation for demo scripts. diff --git a/demo/.gitignore b/demo/.gitignore index 4d1271234312..5663aba8e758 100644 --- a/demo/.gitignore +++ b/demo/.gitignore @@ -1,7 +1,4 @@ -*.db -*.log -*.log.* -*.pid - -/media_store.* -/etc +# Ignore all the temporary files from the demo servers. +8080/ +8081/ +8082/ diff --git a/demo/README b/demo/README deleted file mode 100644 index a5a95bd19666..000000000000 --- a/demo/README +++ /dev/null @@ -1,26 +0,0 @@ -DO NOT USE THESE DEMO SERVERS IN PRODUCTION - -Requires you to have done: - python setup.py develop - - -The demo start.sh will start three synapse servers on ports 8080, 8081 and 8082, with host names localhost:$port. This can be easily changed to `hostname`:$port in start.sh if required. - -To enable the servers to communicate untrusted ssl certs are used. In order to do this the servers do not check the certs -and are configured in a highly insecure way. Do not use these configuration files in production. - -stop.sh will stop the synapse servers and the webclient. - -clean.sh will delete the databases and log files. - -To start a completely new set of servers, run: - - ./demo/stop.sh; ./demo/clean.sh && ./demo/start.sh - - -Logs and sqlitedb will be stored in demo/808{0,1,2}.{log,db} - - - -Also note that when joining a public room on a different HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name. - diff --git a/demo/clean.sh b/demo/clean.sh index e9b440d90dfd..7f1e1920215a 100755 --- a/demo/clean.sh +++ b/demo/clean.sh @@ -4,6 +4,9 @@ set -e DIR="$( cd "$( dirname "$0" )" && pwd )" +# Ensure that the servers are stopped. +$DIR/stop.sh + PID_FILE="$DIR/servers.pid" if [ -f "$PID_FILE" ]; then diff --git a/demo/start.sh b/demo/start.sh index 8ffb14e30add..55e69685e3c2 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -6,8 +6,6 @@ CWD=$(pwd) cd "$DIR/.." || exit -mkdir -p demo/etc - PYTHONPATH=$(readlink -f "$(pwd)") export PYTHONPATH @@ -21,22 +19,26 @@ for port in 8080 8081 8082; do mkdir -p demo/$port pushd demo/$port || exit - #rm $DIR/etc/$port.config + # Generate the configuration for the homeserver at localhost:848x. python3 -m synapse.app.homeserver \ --generate-config \ - -H "localhost:$https_port" \ - --config-path "$DIR/etc/$port.config" \ + --server-name "localhost:$port" \ + --config-path "$port.config" \ --report-stats no - if ! grep -F "Customisation made by demo/start.sh" -q "$DIR/etc/$port.config"; then - # Generate tls keys - openssl req -x509 -newkey rsa:4096 -keyout "$DIR/etc/localhost:$https_port.tls.key" -out "$DIR/etc/localhost:$https_port.tls.crt" -days 365 -nodes -subj "/O=matrix" + if ! grep -F "Customisation made by demo/start.sh" -q "$port.config"; then + # Generate TLS keys. + openssl req -x509 -newkey rsa:4096 \ + -keyout "localhost:$port.tls.key" \ + -out "localhost:$port.tls.crt" \ + -days 365 -nodes -subj "/O=matrix" - # Regenerate configuration + # Add customisations to the configuration. { - printf '\n\n# Customisation made by demo/start.sh\n' + printf '\n\n# Customisation made by demo/start.sh\n\n' echo "public_baseurl: http://localhost:$port/" echo 'enable_registration: true' + echo '' # Warning, this heredoc depends on the interaction of tabs and spaces. # Please don't accidentaly bork me with your fancy settings. @@ -63,38 +65,34 @@ for port in 8080 8081 8082; do echo "${listeners}" - # Disable tls for the servers - printf '\n\n# Disable tls on the servers.' + # Disable TLS for the servers + printf '\n\n# Disable TLS for the servers.' echo '# DO NOT USE IN PRODUCTION' echo 'use_insecure_ssl_client_just_for_testing_do_not_use: true' echo 'federation_verify_certificates: false' - # Set tls paths - echo "tls_certificate_path: \"$DIR/etc/localhost:$https_port.tls.crt\"" - echo "tls_private_key_path: \"$DIR/etc/localhost:$https_port.tls.key\"" + # Set paths for the TLS certificates. + echo "tls_certificate_path: \"$DIR/$port/localhost:$port.tls.crt\"" + echo "tls_private_key_path: \"$DIR/$port/localhost:$port.tls.key\"" # Ignore keys from the trusted keys server echo '# Ignore keys from the trusted keys server' echo 'trusted_key_servers:' echo ' - server_name: "matrix.org"' echo ' accept_keys_insecurely: true' - - # Reduce the blacklist - blacklist=$(cat <<-BLACK - # Set the blacklist so that it doesn't include 127.0.0.1, ::1 - federation_ip_range_blacklist: - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - 'fe80::/64' - - 'fc00::/7' - BLACK + echo '' + + # Allow the servers to communicate over localhost. + allow_list=$(cat <<-ALLOW_LIST + # Allow the servers to communicate over localhost. + ip_range_whitelist: + - '127.0.0.1/8' + - '::1/128' + ALLOW_LIST ) - echo "${blacklist}" - } >> "$DIR/etc/$port.config" + echo "${allow_list}" + } >> "$port.config" fi # Check script parameters @@ -141,19 +139,18 @@ for port in 8080 8081 8082; do burst_count: 1000 RC ) - echo "${ratelimiting}" >> "$DIR/etc/$port.config" + echo "${ratelimiting}" >> "$port.config" fi fi - if ! grep -F "full_twisted_stacktraces" -q "$DIR/etc/$port.config"; then - echo "full_twisted_stacktraces: true" >> "$DIR/etc/$port.config" - fi - if ! grep -F "report_stats" -q "$DIR/etc/$port.config" ; then - echo "report_stats: false" >> "$DIR/etc/$port.config" + # Always disable reporting of stats if the option is not there. + if ! grep -F "report_stats" -q "$port.config" ; then + echo "report_stats: false" >> "$port.config" fi + # Run the homeserver in the background. python3 -m synapse.app.homeserver \ - --config-path "$DIR/etc/$port.config" \ + --config-path "$port.config" \ -D \ popd || exit diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index ef9cabf55524..21f80efc9998 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -82,6 +82,7 @@ - [Release Cycle](development/releases.md) - [Git Usage](development/git.md) - [Testing]() + - [Demo scripts](development/demo.md) - [OpenTracing](opentracing.md) - [Database Schemas](development/database_schema.md) - [Experimental features](development/experimental_features.md) diff --git a/docs/development/demo.md b/docs/development/demo.md new file mode 100644 index 000000000000..4277252ceb60 --- /dev/null +++ b/docs/development/demo.md @@ -0,0 +1,41 @@ +# Synapse demo setup + +**DO NOT USE THESE DEMO SERVERS IN PRODUCTION** + +Requires you to have a [Synapse development environment setup](https://matrix-org.github.io/synapse/develop/development/contributing_guide.html#4-install-the-dependencies). + +The demo setup allows running three federation Synapse servers, with server +names `localhost:8080`, `localhost:8081`, and `localhost:8082`. + +You can access them via any Matrix client over HTTP at `localhost:8080`, +`localhost:8081`, and `localhost:8082` or over HTTPS at `localhost:8480`, +`localhost:8481`, and `localhost:8482`. + +To enable the servers to communicate, self-signed SSL certificates are generated +and the servers are configured in a highly insecure way, including: + +* Not checking certificates over federation. +* Not verifying keys. + +The servers are configured to store their data under `demo/8080`, `demo/8081`, and +`demo/8082`. This includes configuration, logs, SQLite databases, and media. + +Note that when joining a public room on a different HS via "#foo:bar.net", then +you are (in the current impl) joining a room with room_id "foo". This means that +it won't work if your HS already has a room with that name. + +## Using the demo scripts + +There's three main scripts with straightforward purposes: + +* `start.sh` will start the Synapse servers, generating any missing configuration. + * This accepts a single parameter `--no-rate-limit` to "disable" rate limits + (they actually still exist, but are very high). +* `stop.sh` will stop the Synapse servers. +* `clean.sh` will delete the configuration, databases, log files, etc. + +To start a completely new set of servers, run: + +```sh +./demo/stop.sh; ./demo/clean.sh && ./demo/start.sh +``` diff --git a/docs/federate.md b/docs/federate.md index 5107f995be98..df4c87da51e2 100644 --- a/docs/federate.md +++ b/docs/federate.md @@ -63,4 +63,5 @@ release of Synapse. If you want to get up and running quickly with a trio of homeservers in a private federation, there is a script in the `demo` directory. This is mainly -useful just for development purposes. See [demo/README](https://github.com/matrix-org/synapse/tree/develop/demo/). +useful just for development purposes. See +[demo scripts](https://matrix-org.github.io/synapse/develop/development/demo.html). From dc8d825ef26714f610db9c286f2f2517db064b79 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 9 Mar 2022 11:00:48 +0000 Subject: [PATCH 035/230] Skip attempt to get state at backwards-extremities (#12173) We don't *have* the state at a backwards-extremity, so this is never going to do anything useful. --- changelog.d/12173.misc | 1 + synapse/handlers/federation.py | 60 ++-------------------------------- 2 files changed, 4 insertions(+), 57 deletions(-) create mode 100644 changelog.d/12173.misc diff --git a/changelog.d/12173.misc b/changelog.d/12173.misc new file mode 100644 index 000000000000..9f333e718a86 --- /dev/null +++ b/changelog.d/12173.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eb03a5accbac..db39aeabded6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -23,8 +23,6 @@ from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 -from twisted.internet import defer - from synapse import event_auth from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( @@ -45,11 +43,7 @@ from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict -from synapse.logging.context import ( - make_deferred_yieldable, - nested_logging_context, - preserve_fn, -) +from synapse.logging.context import nested_logging_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, @@ -355,56 +349,8 @@ async def try_backfill(domains: List[str]) -> bool: if success: return True - # Huh, well *those* domains didn't work out. Lets try some domains - # from the time. - - tried_domains = set(likely_domains) - tried_domains.add(self.server_name) - - event_ids = list(extremities.keys()) - - logger.debug("calling resolve_state_groups in _maybe_backfill") - resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states_list = await make_deferred_yieldable( - defer.gatherResults( - [resolve(room_id, [e]) for e in event_ids], consumeErrors=True - ) - ) - - # A map from event_id to state map of event_ids. - state_ids: Dict[str, StateMap[str]] = dict( - zip(event_ids, [s.state for s in states_list]) - ) - - state_map = await self.store.get_events( - [e_id for ids in state_ids.values() for e_id in ids.values()], - get_prev_content=False, - ) - - # A map from event_id to state map of events. - state_events: Dict[str, StateMap[EventBase]] = { - key: { - k: state_map[e_id] - for k, e_id in state_dict.items() - if e_id in state_map - } - for key, state_dict in state_ids.items() - } - - for e_id in event_ids: - likely_extremeties_domains = get_domains_from_state(state_events[e_id]) - - success = await try_backfill( - [ - dom - for dom, _ in likely_extremeties_domains - if dom not in tried_domains - ] - ) - if success: - return True - - tried_domains.update(dom for dom, _ in likely_extremeties_domains) + # TODO: we could also try servers which were previously in the room, but + # are no longer. return False From 180d8ff0d4d706344fa984abbd9ed6fa02ca13dc Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 9 Mar 2022 14:53:28 +0000 Subject: [PATCH 036/230] Retry some http replication failures (#12182) This allows for the target process to be down for around a minute which provides time for restarts during synapse upgrades/config updates. Closes: #12178 Signed off by Nick Mills-Barrett nick@beeper.com --- changelog.d/12182.misc | 1 + synapse/replication/http/_base.py | 47 +++++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 11 deletions(-) create mode 100644 changelog.d/12182.misc diff --git a/changelog.d/12182.misc b/changelog.d/12182.misc new file mode 100644 index 000000000000..7e9ad2c75244 --- /dev/null +++ b/changelog.d/12182.misc @@ -0,0 +1 @@ +Retry HTTP replication failures, this should prevent 502's when restarting stateful workers (main, event persisters, stream writers). Contributed by Nick @ Beeper. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2e697c74a6bb..f1abb986534b 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -21,6 +21,7 @@ from prometheus_client import Counter, Gauge +from twisted.internet.error import ConnectError, DNSLookupError from twisted.web.server import Request from synapse.api.errors import HttpResponseException, SynapseError @@ -87,6 +88,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): `_handle_request` must return a Deferred. RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504 is received. + RETRY_ON_CONNECT_ERROR (bool): Whether or not to retry the request when + a connection error is received. + RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when + receiving connection errors, each will backoff exponentially longer. """ NAME: str = abc.abstractproperty() # type: ignore @@ -94,6 +99,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): METHOD = "POST" CACHE = True RETRY_ON_TIMEOUT = True + RETRY_ON_CONNECT_ERROR = True + RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1) def __init__(self, hs: "HomeServer"): if self.CACHE: @@ -236,18 +243,20 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: "/".join(url_args), ) + headers: Dict[bytes, List[bytes]] = {} + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [b"Bearer " + replication_secret] + opentracing.inject_header_dict(headers, check_destination=False) + try: + # Keep track of attempts made so we can bail if we don't manage to + # connect to the target after N tries. + attempts = 0 # We keep retrying the same request for timeouts. This is so that we # have a good idea that the request has either succeeded or failed # on the master, and so whether we should clean up or not. while True: - headers: Dict[bytes, List[bytes]] = {} - # Add an authorization header, if configured. - if replication_secret: - headers[b"Authorization"] = [ - b"Bearer " + replication_secret - ] - opentracing.inject_header_dict(headers, check_destination=False) try: result = await request_func(uri, data, headers=headers) break @@ -255,11 +264,27 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: if not cls.RETRY_ON_TIMEOUT: raise - logger.warning("%s request timed out; retrying", cls.NAME) + logger.warning("%s request timed out; retrying", cls.NAME) + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + await clock.sleep(1) + except (ConnectError, DNSLookupError) as e: + if not cls.RETRY_ON_CONNECT_ERROR: + raise + if attempts > cls.RETRY_ON_CONNECT_ERROR_ATTEMPTS: + raise + + delay = 2 ** attempts + logger.warning( + "%s request connection failed; retrying in %ds: %r", + cls.NAME, + delay, + e, + ) - # If we timed out we probably don't need to worry about backing - # off too much, but lets just wait a little anyway. - await clock.sleep(1) + await clock.sleep(delay) + attempts += 1 except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the main process that we should send to the client. (And From 032688854babeea832cbb4f762fc70fe31e73cc6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 10:29:39 -0500 Subject: [PATCH 037/230] Remove some unused variables/parameters. (#12187) --- changelog.d/12187.misc | 1 + synapse/storage/databases/main/roommember.py | 14 +++++--------- 2 files changed, 6 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12187.misc diff --git a/changelog.d/12187.misc b/changelog.d/12187.misc new file mode 100644 index 000000000000..c53e68faa508 --- /dev/null +++ b/changelog.d/12187.misc @@ -0,0 +1 @@ +Remove unused variables. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e48ec5f495af..bef675b8453c 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -46,7 +46,7 @@ ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id +from synapse.types import PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -273,7 +273,7 @@ def _get_room_summary_txn(txn): txn.execute(sql, (room_id,)) res = {} for count, membership in txn: - summary = res.setdefault(membership, MemberSummary([], count)) + res.setdefault(membership, MemberSummary([], count)) # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent @@ -839,18 +839,14 @@ async def get_joined_hosts(self, room_id: str, state_entry): with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry + room_id, state_group, state_entry=state_entry ) @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, - room_id: str, - state_group: int, - current_state_ids: StateMap[str], - state_entry: "_StateCacheEntry", + self, room_id: str, state_group: int, state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: - # We don't use `state_group`, its there so that we can cache based on + # We don't use `state_group`, it's there so that we can cache based on # it. However, its important that its never None, since two # current_state's with a state_group of None are likely to be different. # From 690cb4f3b32938f5ced5590abe9429733040a129 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 13:07:41 -0500 Subject: [PATCH 038/230] Allow for ignoring some arguments when caching. (#12189) * `@cached` can now take an `uncached_args` which is an iterable of names to not use in the cache key. * Requires `@cached`, @cachedList` and `@lru_cache` to use keyword arguments for clarity. * Asserts that keyword-only arguments in cached functions are not accepted. (I tested this briefly and I don't believe this works properly.) --- changelog.d/12189.misc | 1 + .../storage/databases/main/events_worker.py | 4 +- synapse/util/caches/descriptors.py | 74 ++++++++++++---- tests/util/caches/test_descriptors.py | 84 ++++++++++++++++++- 4 files changed, 142 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12189.misc diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc new file mode 100644 index 000000000000..015e808e63c7 --- /dev/null +++ b/changelog.d/12189.misc @@ -0,0 +1 @@ +Support skipping some arguments when generating cache keys. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 26784f755e40..59454a47dfdd 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1286,7 +1286,7 @@ async def have_seen_events( ) return {eid for ((_rid, eid), have_event) in res.items() if have_event} - @cachedList("have_seen_event", "keys") + @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( self, keys: Iterable[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: @@ -1954,7 +1954,7 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: get_event_id_for_timestamp_txn, ) - @cachedList("is_partial_state_event", list_name="event_ids") + @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] ) -> Dict[str, bool]: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1cdead02f14b..c3c5c16db96e 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -20,6 +20,7 @@ Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, @@ -69,6 +70,7 @@ def __init__( self, orig: Callable[..., Any], num_args: Optional[int], + uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, ): self.orig = orig @@ -76,6 +78,13 @@ def __init__( arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args + # There's no reason that keyword-only arguments couldn't be supported, + # but right now they're buggy so do not allow them. + if arg_spec.kwonlyargs: + raise ValueError( + "_CacheDescriptorBase does not support keyword-only arguments." + ) + if "cache_context" in all_args: if not cache_context: raise ValueError( @@ -88,6 +97,9 @@ def __init__( " named `cache_context`" ) + if num_args is not None and uncached_args is not None: + raise ValueError("Cannot provide both num_args and uncached_args") + if num_args is None: num_args = len(all_args) - 1 if cache_context: @@ -105,6 +117,12 @@ def __init__( # list of the names of the args used as the cache key self.arg_names = all_args[1 : num_args + 1] + # If there are args to not cache on, filter them out (and fix the size of num_args). + if uncached_args is not None: + include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names] + else: + include_arg_in_cache_key = [True] * len(self.arg_names) + # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: @@ -119,8 +137,8 @@ def __init__( self.add_cache_context = cache_context - self.cache_key_builder = get_cache_key_builder( - self.arg_names, self.arg_defaults + self.cache_key_builder = _get_cache_key_builder( + self.arg_names, include_arg_in_cache_key, self.arg_defaults ) @@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, - cache_context: bool = False, + *, max_entries: int = 1000, cache_context: bool = False ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -186,7 +203,9 @@ def __init__( max_entries: int = 1000, cache_context: bool = False, ): - super().__init__(orig, num_args=None, cache_context=cache_context) + super().__init__( + orig, num_args=None, uncached_args=None, cache_context=cache_context + ) self.max_entries = max_entries def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: @@ -260,6 +279,9 @@ def foo(self, key, cache_context): num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. + uncached_args: a list of argument names to not use as the cache key. + (``self`` and ``cache_context`` are always ignored.) Cannot be used + with num_args. tree: cache_context: iterable: @@ -273,12 +295,18 @@ def __init__( orig: Callable[..., Any], max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) + super().__init__( + orig, + num_args=num_args, + uncached_args=uncached_args, + cache_context=cache_context, + ) if tree and self.num_args < 2: raise RuntimeError( @@ -369,7 +397,7 @@ def __init__( but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args) + super().__init__(orig, num_args=num_args, uncached_args=None) self.list_name = list_name @@ -530,8 +558,10 @@ def get_instance( def cached( + *, max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, @@ -541,6 +571,7 @@ def cached( orig, max_entries=max_entries, num_args=num_args, + uncached_args=uncached_args, tree=tree, cache_context=cache_context, iterable=iterable, @@ -551,7 +582,7 @@ def cached( def cachedList( - cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. @@ -590,13 +621,16 @@ def batch_do_something(self, first_arg, second_args): return cast(Callable[[F], _CachedFunction[F]], func) -def get_cache_key_builder( - param_names: Sequence[str], param_defaults: Mapping[str, Any] +def _get_cache_key_builder( + param_names: Sequence[str], + include_params: Sequence[bool], + param_defaults: Mapping[str, Any], ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: """Construct a function which will build cache keys suitable for a cached function Args: param_names: list of formal parameter names for the cached function + include_params: list of bools of whether to include the parameter name in the cache key param_defaults: a mapping from parameter name to default value for that param Returns: @@ -608,6 +642,7 @@ def get_cache_key_builder( if len(param_names) == 1: nm = param_names[0] + assert include_params[0] is True def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: if nm in kwargs: @@ -620,13 +655,18 @@ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: else: def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: - return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + return tuple( + _get_cache_key_gen( + param_names, include_params, param_defaults, args, kwargs + ) + ) return get_cache_key def _get_cache_key_gen( param_names: Iterable[str], + include_params: Iterable[bool], param_defaults: Mapping[str, Any], args: Sequence[Any], kwargs: Mapping[str, Any], @@ -637,16 +677,18 @@ def _get_cache_key_gen( This is essentially the same operation as `inspect.getcallargs`, but optimised so that we don't need to inspect the target function for each call. """ - # We loop through each arg name, looking up if its in the `kwargs`, # otherwise using the next argument in `args`. If there are no more # args then we try looking the arg name up in the defaults. pos = 0 - for nm in param_names: + for nm, inc in zip(param_names, include_params): if nm in kwargs: - yield kwargs[nm] + if inc: + yield kwargs[nm] elif pos < len(args): - yield args[pos] + if inc: + yield args[pos] pos += 1 else: - yield param_defaults[nm] + if inc: + yield param_defaults[nm] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 19741ffcdaf1..6a4b17527a7f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -141,6 +141,84 @@ def fn(self, arg1, arg2): self.assertEqual(r, "chips") obj.mock.assert_not_called() + @defer.inlineCallbacks + def test_cache_uncached_args(self): + """ + Only the arguments not named in uncached_args should matter to the cache + + Note that this is identical to test_cache_num_args, but provides the + arguments differently. + """ + + class Cls: + # Note that it is important that this is not the last argument to + # test behaviour of skipping arguments properly. + @descriptors.cached(uncached_args=("arg2",)) + def fn(self, arg1, arg2, arg3): + return self.mock(arg1, arg2, arg3) + + def __init__(self): + self.mock = mock.Mock() + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, 2, 3) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2, 3) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(2, 3, 4) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(2, 3, 4) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4, 3) + self.assertEqual(r, "fish") + r = yield obj.fn(2, 5, 4) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + @defer.inlineCallbacks + def test_cache_kwargs(self): + """Test that keyword arguments are treated properly""" + + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, kwarg1=2): + return self.mock(arg1, kwarg1=kwarg1) + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, kwarg1=2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(1, kwarg1=3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, kwarg1=3) + obj.mock.reset_mock() + + # the values should now be cached. + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + # We should be able to not provide kwarg1 and get the cached value back. + r = yield obj.fn(1) + self.assertEqual(r, "fish") + # Keyword arguments can be in any order. + r = yield obj.fn(kwarg1=2, arg1=1) + self.assertEqual(r, "fish") + obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): """If the wrapped function throws synchronously, things should continue to work""" @@ -656,7 +734,7 @@ def __init__(self): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): assert current_context().name == "c1" # we want this to behave like an asynchronous function @@ -715,7 +793,7 @@ def __init__(self): def fn(self, arg1): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") def list_fn(self, args1) -> "Deferred[dict]": return self.mock(args1) @@ -758,7 +836,7 @@ def __init__(self): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function await run_on_reactor() From 15382b1afad65366df13c3b9040b6fdfb1eccfca Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Wed, 9 Mar 2022 18:23:57 +0000 Subject: [PATCH 039/230] Add third_party module callbacks to check if a user can delete a room and deactivate a user (#12028) * Add check_can_deactivate_user * Add check_can_shutdown_rooms * Documentation * callbacks, not functions * Various suggested tweaks * Add tests for test_check_can_shutdown_room and test_check_can_deactivate_user * Update check_can_deactivate_user to not take a Requester * Fix check_can_shutdown_room docs * Renegade and use `by_admin` instead of `admin_user_id` * fix lint * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier Co-authored-by: Brendan Abolivier --- changelog.d/12028.feature | 1 + docs/modules/third_party_rules_callbacks.md | 43 +++++++ synapse/events/third_party_rules.py | 55 +++++++++ synapse/handlers/deactivate_account.py | 12 +- synapse/handlers/room.py | 8 ++ synapse/module_api/__init__.py | 6 + synapse/rest/admin/rooms.py | 9 ++ tests/rest/client/test_third_party_rules.py | 121 ++++++++++++++++++++ 8 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12028.feature diff --git a/changelog.d/12028.feature b/changelog.d/12028.feature new file mode 100644 index 000000000000..5549c8f6fcf6 --- /dev/null +++ b/changelog.d/12028.feature @@ -0,0 +1 @@ +Add third-party rules rules callbacks `check_can_shutdown_room` and `check_can_deactivate_user`. diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index 09ac838107b8..1d3c39967faa 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -148,6 +148,49 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c If multiple modules implement this callback, Synapse runs them all in order. +### `check_can_shutdown_room` + +_First introduced in Synapse v1.55.0_ + +```python +async def check_can_shutdown_room( + user_id: str, room_id: str, +) -> bool: +``` + +Called when an admin user requests the shutdown of a room. The module must return a +boolean indicating whether the shutdown can go through. If the callback returns `False`, +the shutdown will not proceed and the caller will see a `M_FORBIDDEN` error. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. + +### `check_can_deactivate_user` + +_First introduced in Synapse v1.55.0_ + +```python +async def check_can_deactivate_user( + user_id: str, by_admin: bool, +) -> bool: +``` + +Called when the deactivation of a user is requested. User deactivation can be +performed by an admin or the user themselves, so developers are encouraged to check the +requester when implementing this callback. The module must return a +boolean indicating whether the deactivation can go through. If the callback returns `False`, +the deactivation will not proceed and the caller will see a `M_FORBIDDEN` error. + +The module is passed two parameters, `user_id` which is the ID of the user being deactivated, and `by_admin` which is `True` if the request is made by a serve admin, and `False` otherwise. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. + + ### `on_profile_update` _First introduced in Synapse v1.54.0_ diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index ede72ee87631..bfca454f510d 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -38,6 +38,8 @@ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] +CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] +CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] @@ -157,6 +159,12 @@ def __init__(self, hs: "HomeServer"): CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] + self._check_can_shutdown_room_callbacks: List[ + CHECK_CAN_SHUTDOWN_ROOM_CALLBACK + ] = [] + self._check_can_deactivate_user_callbacks: List[ + CHECK_CAN_DEACTIVATE_USER_CALLBACK + ] = [] self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = [] self._on_user_deactivation_status_changed_callbacks: List[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -173,6 +181,8 @@ def register_third_party_rules_callbacks( CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, + check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -198,6 +208,11 @@ def register_third_party_rules_callbacks( if on_new_event is not None: self._on_new_event_callbacks.append(on_new_event) + if check_can_shutdown_room is not None: + self._check_can_shutdown_room_callbacks.append(check_can_shutdown_room) + + if check_can_deactivate_user is not None: + self._check_can_deactivate_user_callbacks.append(check_can_deactivate_user) if on_profile_update is not None: self._on_profile_update_callbacks.append(on_profile_update) @@ -369,6 +384,46 @@ async def on_new_event(self, event_id: str) -> None: "Failed to run module API callback %s: %s", callback, e ) + async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool: + """Intercept requests to shutdown a room. If `False` is returned, the + room must not be shut down. + + Args: + requester: The ID of the user requesting the shutdown. + room_id: The ID of the room. + """ + for callback in self._check_can_shutdown_room_callbacks: + try: + if await callback(user_id, room_id) is False: + return False + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + return True + + async def check_can_deactivate_user( + self, + user_id: str, + by_admin: bool, + ) -> bool: + """Intercept requests to deactivate a user. If `False` is returned, the + user should not be deactivated. + + Args: + requester + user_id: The ID of the room. + """ + for callback in self._check_can_deactivate_user_callbacks: + try: + if await callback(user_id, by_admin) is False: + return False + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + return True + async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 76ae768e6ef5..816e1a6d79c8 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -17,7 +17,7 @@ from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import Requester, UserID, create_requester +from synapse.types import Codes, Requester, UserID, create_requester if TYPE_CHECKING: from synapse.server import HomeServer @@ -42,6 +42,7 @@ def __init__(self, hs: "HomeServer"): # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False + self._third_party_rules = hs.get_third_party_event_rules() # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). @@ -74,6 +75,15 @@ async def deactivate_account( Returns: True if identity server supports removing threepids, otherwise False. """ + + # Check if this user can be deactivated + if not await self._third_party_rules.check_can_deactivate_user( + user_id, by_admin + ): + raise SynapseError( + 403, "Deactivation of this user is forbidden", Codes.FORBIDDEN + ) + # FIXME: Theoretically there is a race here wherein user resets # password using threepid. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7b965b4b962a..b9735631fcd3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1475,6 +1475,7 @@ def __init__(self, hs: "HomeServer"): self.room_member_handler = hs.get_room_member_handler() self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() + self._third_party_rules = hs.get_third_party_event_rules() self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main @@ -1548,6 +1549,13 @@ async def shutdown_room( if not RoomID.is_valid(room_id): raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + if not await self._third_party_rules.check_can_shutdown_room( + requester_user_id, room_id + ): + raise SynapseError( + 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN + ) + # Action the block first (even if the room doesn't exist yet) if block: # This will work even if the room is already blocked, but that is diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c42eeedd87ae..d735c1d4616e 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -54,6 +54,8 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK, ) from synapse.events.third_party_rules import ( + CHECK_CAN_DEACTIVATE_USER_CALLBACK, + CHECK_CAN_SHUTDOWN_ROOM_CALLBACK, CHECK_EVENT_ALLOWED_CALLBACK, CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, @@ -283,6 +285,8 @@ def register_third_party_rules_callbacks( CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, + check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -298,6 +302,8 @@ def register_third_party_rules_callbacks( check_threepid_can_be_invited=check_threepid_can_be_invited, check_visibility_can_be_modified=check_visibility_can_be_modified, on_new_event=on_new_event, + check_can_shutdown_room=check_can_shutdown_room, + check_can_deactivate_user=check_can_deactivate_user, on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f4736a3dad83..356d6f74d7ef 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -67,6 +67,7 @@ def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() + self._third_party_rules = hs.get_third_party_event_rules() async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -106,6 +107,14 @@ async def on_DELETE( HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) + # Check this here, as otherwise we'll only fail after the background job has been started. + if not await self._third_party_rules.check_can_shutdown_room( + requester.user.to_string(), room_id + ): + raise SynapseError( + 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN + ) + delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 58f1ea11b7da..e7de67e3a3b0 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -775,3 +775,124 @@ def test_on_user_deactivation_status_changed_admin(self) -> None: self.assertEqual(args[0], user_id) self.assertFalse(args[1]) self.assertTrue(args[2]) + + def test_check_can_deactivate_user(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the request was not made by an admin + self.assertEqual(args[1], False) + + def test_check_can_deactivate_user_admin(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the mock was made by an admin + self.assertEqual(args[1], True) + + def test_check_can_shutdown_room(self) -> None: + """Tests that the check_can_shutdown_room module callback is called + correctly when processing an admin's shutdown room request. + """ + # Register a mocked callback. + shutdown_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_shutdown_room_callbacks.append( + shutdown_mock, + ) + + # Register an admin user. + admin_user_id = self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Shutdown the room. + channel = self.make_request( + "DELETE", + "/_synapse/admin/v2/rooms/%s" % self.room_id, + {}, + access_token=admin_tok, + ) + + # Check that the shutdown was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + shutdown_mock.assert_called_once() + args = shutdown_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], admin_user_id) + + # Check that the mock was called with the right room ID + self.assertEqual(args[1], self.room_id) From a4c1fdb44a16471964ed6a347be6a191102f5c07 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 9 Mar 2022 18:45:21 +0000 Subject: [PATCH 040/230] Remove dead code in `tests/storage/test_database.py` (#12197) Signed-off-by: Sean Quah --- changelog.d/12197.misc | 1 + tests/storage/test_database.py | 16 ---------------- 2 files changed, 1 insertion(+), 16 deletions(-) create mode 100644 changelog.d/12197.misc diff --git a/changelog.d/12197.misc b/changelog.d/12197.misc new file mode 100644 index 000000000000..7d0e9b6bbf4c --- /dev/null +++ b/changelog.d/12197.misc @@ -0,0 +1 @@ +Remove some dead code. diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 6fbac0ab1466..85978675634b 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -13,26 +13,10 @@ # limitations under the License. from synapse.storage.database import make_tuple_comparison_clause -from synapse.storage.engines import BaseDatabaseEngine from tests import unittest -def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: - # returns a DatabaseEngine, circumventing the abc mechanism - # any kwargs are set as attributes on the class before instantiating it - t = type( - "TestBaseDatabaseEngine", - (BaseDatabaseEngine,), - dict(BaseDatabaseEngine.__dict__), - ) - # defeat the abc mechanism - t.__abstractmethods__ = set() - for k, v in kwargs.items(): - setattr(t, k, v) - return t(None, None) - - class TupleComparisonClauseTestCase(unittest.TestCase): def test_native_tuple_comparison(self): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) From 3e4af36bc8515504721b3c1b1d64d4f45359bf88 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 08:01:56 -0500 Subject: [PATCH 041/230] Rename get_tcp_replication to get_replication_command_handler. (#12192) Since the object it returns is a ReplicationCommandHandler. This is clean-up from adding support to Redis where the command handler was added as an additional layer of abstraction from the TCP protocol. --- changelog.d/12192.misc | 1 + synapse/app/generic_worker.py | 2 +- synapse/app/homeserver.py | 2 +- synapse/federation/transport/server/_base.py | 2 +- synapse/handlers/presence.py | 4 ++-- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/replication/tcp/client.py | 4 +++- synapse/replication/tcp/handler.py | 4 +--- synapse/replication/tcp/redis.py | 2 +- synapse/replication/tcp/resource.py | 4 ++-- synapse/server.py | 2 +- tests/replication/_base.py | 4 ++-- tests/replication/tcp/streams/test_events.py | 2 +- tests/replication/tcp/streams/test_typing.py | 2 +- tests/replication/test_federation_ack.py | 2 +- 15 files changed, 20 insertions(+), 19 deletions(-) create mode 100644 changelog.d/12192.misc diff --git a/changelog.d/12192.misc b/changelog.d/12192.misc new file mode 100644 index 000000000000..bdfe8dad98a6 --- /dev/null +++ b/changelog.d/12192.misc @@ -0,0 +1 @@ +Rename `HomeServer.get_tcp_replication` to `get_replication_command_handler`. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1536a4272333..a10a63b06c7e 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -417,7 +417,7 @@ def start_listening(self) -> None: else: logger.warning("Unsupported listener type: %s", listener.type) - self.get_tcp_replication().start_replication(self) + self.get_replication_command_handler().start_replication(self) def start(config_options: List[str]) -> None: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index a6789a840ede..e4dc04c0b40f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -273,7 +273,7 @@ def start_listening(self) -> None: # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client # rather than a server). - self.get_tcp_replication().start_replication(self) + self.get_replication_command_handler().start_replication(self) for listener in self.config.server.listeners: if listener.type == "http": diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 87e99c7ddf5b..2529dee613aa 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -63,7 +63,7 @@ def __init__(self, hs: "HomeServer"): self.replication_client = None if hs.config.worker.worker_app: - self.replication_client = hs.get_tcp_replication() + self.replication_client = hs.get_replication_command_handler() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c155098beeb8..9927a30e6ed5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -424,13 +424,13 @@ def __init__(self, hs: "HomeServer"): async def _on_shutdown(self) -> None: if self._presence_enabled: - self.hs.get_tcp_replication().send_command( + self.hs.get_replication_command_handler().send_command( ClearUserSyncsCommand(self.instance_id) ) def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: if self._presence_enabled: - self.hs.get_tcp_replication().send_user_sync( + self.hs.get_replication_command_handler().send_user_sync( self.instance_id, user_id, is_syncing, last_sync_ms ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b5b84c09ae41..14706a081755 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -54,6 +54,6 @@ async def insert_client_ip( self.client_ip_last_seen.set(key, now) - self.hs.get_tcp_replication().send_user_ip( + self.hs.get_replication_command_handler().send_user_ip( user_id, access_token, ip, user_agent, device_id, now ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b8fc1d4db95e..deeaaec4e66c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -462,6 +462,8 @@ async def _save_and_send_ack(self) -> None: # We ACK this token over replication so that the master can drop # its in memory queues - self._hs.get_tcp_replication().send_federation_ack(current_position) + self._hs.get_replication_command_handler().send_federation_ack( + current_position + ) except Exception: logger.exception("Error updating federation stream position") diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0d2013a3cfc5..d51f045f229a 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -295,9 +295,7 @@ async def _process_command( raise Exception("Unrecognised command %s in stream queue", cmd.NAME) def start_replication(self, hs: "HomeServer") -> None: - """Helper method to start a replication connection to the remote server - using TCP. - """ + """Helper method to start replication.""" if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index b84e572da136..989c5be0327e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -325,7 +325,7 @@ def __init__( password=hs.config.redis.redis_password, ) - self.synapse_handler = hs.get_tcp_replication() + self.synapse_handler = hs.get_replication_command_handler() self.synapse_stream_name = hs.hostname self.synapse_outbound_redis_connection = outbound_redis_connection diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 494e42a2be8f..ab829040cde9 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -44,7 +44,7 @@ class ReplicationStreamProtocolFactory(ServerFactory): """Factory for new replication connections.""" def __init__(self, hs: "HomeServer"): - self.command_handler = hs.get_tcp_replication() + self.command_handler = hs.get_replication_command_handler() self.clock = hs.get_clock() self.server_name = hs.config.server.server_name @@ -85,7 +85,7 @@ def __init__(self, hs: "HomeServer"): self.is_looping = False self.pending_updates = False - self.command_handler = hs.get_tcp_replication() + self.command_handler = hs.get_replication_command_handler() # Set of streams to replicate. self.streams = self.command_handler.get_streams_to_replicate() diff --git a/synapse/server.py b/synapse/server.py index 46a64418ea0c..1270abb5a335 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -639,7 +639,7 @@ def get_read_marker_handler(self) -> ReadMarkerHandler: return ReadMarkerHandler(self) @cache_in_self - def get_tcp_replication(self) -> ReplicationCommandHandler: + def get_replication_command_handler(self) -> ReplicationCommandHandler: return ReplicationCommandHandler(self) @cache_in_self diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a7a05a564fe9..9c5df266bd1f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -251,7 +251,7 @@ def setUp(self): self.connect_any_redis_attempts, ) - self.hs.get_tcp_replication().start_replication(self.hs) + self.hs.get_replication_command_handler().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't @@ -375,7 +375,7 @@ def make_worker_hs( ) if worker_hs.config.redis.redis_enabled: - worker_hs.get_tcp_replication().start_replication(worker_hs) + worker_hs.get_replication_command_handler().start_replication(worker_hs) return worker_hs diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f9d5da723cce..641a94133b1d 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -420,7 +420,7 @@ def test_backwards_stream_id(self): # Manually send an old RDATA command, which should get dropped. This # re-uses the row from above, but with an earlier stream token. - self.hs.get_tcp_replication().send_command( + self.hs.get_replication_command_handler().send_command( RdataCommand("events", "master", 1, row) ) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 3ff5afc6e501..9a229dd23f4a 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -118,7 +118,7 @@ def test_reset(self): # Reset the typing handler self.hs.get_replication_streams()["typing"].last_token = 0 - self.hs.get_tcp_replication()._streams["typing"].last_token = 0 + self.hs.get_replication_command_handler()._streams["typing"].last_token = 0 typing._latest_room_serial = 0 typing._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", typing._latest_room_serial diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 1b6a4bf4b0b1..26b8bd512a7f 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -48,7 +48,7 @@ def test_federation_ack_sent(self): transport, rather than assuming that the implementation has a ReplicationCommandHandler. """ - rch = self.hs.get_tcp_replication() + rch = self.hs.get_replication_command_handler() # wire up the ReplicationCommandHandler to a mock connection, which needs # to implement IReplicationConnection. (Note that Mock doesn't understand From 88cd6f937807e64c05458cec86ef0ba0c1c656b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 09:03:59 -0500 Subject: [PATCH 042/230] Allow retrieving the relations of a redacted event. (#12130) This is allowed per MSC2675, although the original implementation did not allow for it and would return an empty chunk / not bundle aggregations. The main thing to improve is that the various caches get cleared properly when an event is redacted, and that edits must not leak if the original event is redacted (as that would presumably leak something similar to the original event content). --- changelog.d/12130.bugfix | 1 + changelog.d/12189.bugfix | 1 + changelog.d/12189.misc | 1 - synapse/rest/client/relations.py | 82 ++++++++++----------- synapse/storage/databases/main/cache.py | 4 + synapse/storage/databases/main/events.py | 11 +-- synapse/storage/databases/main/relations.py | 60 ++++++++------- tests/rest/client/test_relations.py | 45 +++++++++-- 8 files changed, 122 insertions(+), 83 deletions(-) create mode 100644 changelog.d/12130.bugfix create mode 100644 changelog.d/12189.bugfix delete mode 100644 changelog.d/12189.misc diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix new file mode 100644 index 000000000000..df9b0dc413dd --- /dev/null +++ b/changelog.d/12130.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix new file mode 100644 index 000000000000..df9b0dc413dd --- /dev/null +++ b/changelog.d/12189.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc deleted file mode 100644 index 015e808e63c7..000000000000 --- a/changelog.d/12189.misc +++ /dev/null @@ -1 +0,0 @@ -Support skipping some arguments when generating cache keys. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 07fa1cdd4c67..d9a6be43f793 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,7 +27,7 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -82,28 +82,25 @@ async def on_GET( from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - pagination_chunk = await self.store.get_relations_for_event( - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) + + pagination_chunk = await self.store.get_relations_for_event( + event_id=parent_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] @@ -193,27 +190,23 @@ async def on_GET( from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_aggregation_groups_for_event( + event_id=parent_id, + room_id=room_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) return 200, await pagination_chunk.to_dict(self.store) @@ -295,6 +288,7 @@ async def on_GET( result = await self.store.get_relations_for_event( event_id=parent_id, + event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index abd54c7dc703..d6a2df1afeb6 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -191,6 +191,10 @@ def _invalidate_caches_for_event( if redacts: self._invalidate_get_event_cache(redacts) + # Caches which might leak edits must be invalidated for the event being + # redacted. + self.get_relations_for_event.invalidate((redacts,)) + self.get_applicable_edit.invalidate((redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1dc83aa5e3a6..1a322882bf3f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1619,9 +1619,12 @@ def prefill(): txn.call_after(prefill) - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: + # Invalidate the caches for the redacted event, note that these caches + # are also cleared as part of event replication in _invalidate_caches_for_event. txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) + txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) self.db_pool.simple_upsert_txn( txn, @@ -1812,9 +1815,7 @@ def _handle_event_relations( txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) if rel_type == RelationTypes.THREAD: - txn.call_after( - self.store.get_thread_summary.invalidate, (parent_id, event.room_id) - ) + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 36aa1092f602..be1500092b5b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -91,10 +91,11 @@ def __init__( self._msc3440_enabled = hs.config.experimental.msc3440_enabled - @cached(tree=True) + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, event_id: str, + event: EventBase, room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -108,6 +109,7 @@ async def get_relations_for_event( Args: event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -122,9 +124,13 @@ async def get_relations_for_event( List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ + # We don't use `event_id`, it's there so that we can cache based on + # it. The `event_id` must match the `event.event_id`. + assert event.event_id == event_id where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event_id, room_id] + where_args: List[Union[str, int]] = [event.event_id, room_id] + is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: where_clause.append("relation_type = ?") @@ -157,7 +163,7 @@ async def get_relations_for_event( order = "ASC" sql = """ - SELECT event_id, topological_ordering, stream_ordering + SELECT event_id, relation_type, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -178,9 +184,12 @@ def _get_recent_references_for_event_txn( last_stream_id = None events = [] for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] + # Do not include edits for redacted events as they leak event + # content. + if not is_redacted or row[1] != RelationTypes.REPLACE: + events.append({"event_id": row[0]}) + last_topo_id = row[2] + last_stream_id = row[3] # If there are more events, generate the next pagination key. next_token = None @@ -776,7 +785,7 @@ async def _get_bundled_aggregation_for_event( ) references = await self.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations.references = await references.to_dict(cast("DataStore", self)) @@ -797,41 +806,36 @@ async def get_bundled_aggregations( A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ - # The already processed event IDs. Tracked separately from the result - # since the result omits events which do not have bundled aggregations. - seen_event_ids = set() - - # State events and redacted events do not get bundled aggregations. - events = [ - event - for event in events - if not event.is_state() and not event.internal_metadata.is_redacted() - ] + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} # Fetch other relations per event. - for event in events: - # De-duplicate events by ID to handle the same event requested multiple - # times. The caches that _get_bundled_aggregation_for_event use should - # capture this, but best to reduce work. - if event.event_id in seen_event_ids: - continue - seen_event_ids.add(event.event_id) - + for event in events_by_id.values(): event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result - # Fetch any edits. - edits = await self._get_applicable_edits(seen_event_ids) + # Fetch any edits (but not for redacted events). + edits = await self._get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. if self._msc3440_enabled: - summaries = await self._get_thread_summaries(seen_event_ids) + summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. participated = await self._get_threads_participated( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a40a5de3991c..f9ae6e663f95 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1475,12 +1475,13 @@ def test_redact_parent_edit(self) -> None: self.assertEqual(relations, {}) def test_redact_parent_annotation(self) -> None: - """Test that annotations of an event are redacted when the original event + """Test that annotations of an event are viewable when the original event is redacted. """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids, relations = self._make_relation_requests() @@ -1494,11 +1495,45 @@ def test_redact_parent_annotation(self) -> None: # Redact the original event. self._redact(self.parent_id) - # The relations are not returned. + # The relations are returned. event_ids, relations = self._make_relation_requests() - self.assertEqual(event_ids, []) - self.assertEqual(relations, {}) + self.assertEquals(event_ids, [related_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + ) # There's nothing to aggregate. chunk = self._get_aggregations() - self.assertEqual(chunk, []) + self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_parent_thread(self) -> None: + """ + Test that thread replies are still available when the root event is redacted. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] + + # Redact one of the reactions. + self._redact(self.parent_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(len(event_ids), 1) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + related_event_id, + ) From 52a947dc4603e1bd14916efb8822c4fe58f0d200 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 10 Mar 2022 15:18:31 +0000 Subject: [PATCH 043/230] Updates to the Room DAG concepts development document (#12179) Some stuff that came up while we were talking about #12173. --- changelog.d/12179.doc | 1 + docs/development/room-dag-concepts.md | 71 ++++++++++++++++++++------- 2 files changed, 54 insertions(+), 18 deletions(-) create mode 100644 changelog.d/12179.doc diff --git a/changelog.d/12179.doc b/changelog.d/12179.doc new file mode 100644 index 000000000000..55d8caa45a8c --- /dev/null +++ b/changelog.d/12179.doc @@ -0,0 +1 @@ +Updates to the Room DAG concepts development document. diff --git a/docs/development/room-dag-concepts.md b/docs/development/room-dag-concepts.md index cbc7cf29491c..3eb4d5acc463 100644 --- a/docs/development/room-dag-concepts.md +++ b/docs/development/room-dag-concepts.md @@ -30,37 +30,72 @@ rather than skipping any that arrived late; whereas if you're looking at a historical section of timeline (i.e. `/messages`), you want to see the best representation of the state of the room as others were seeing it at the time. +## Outliers -## Forward extremity +We mark an event as an `outlier` when we haven't figured out the state for the +room at that point in the DAG yet. They are "floating" events that we haven't +yet correlated to the DAG. -Most-recent-in-time events in the DAG which are not referenced by any other events' `prev_events` yet. +Outliers typically arise when we fetch the auth chain or state for a given +event. When that happens, we just grab the events in the state/auth chain, +without calculating the state at those events, or backfilling their +`prev_events`. -The forward extremities of a room are used as the `prev_events` when the next event is sent. +So, typically, we won't have the `prev_events` of an `outlier` in the database, +(though it's entirely possible that we *might* have them for some other +reason). Other things that make outliers different from regular events: + * We don't have state for them, so there should be no entry in + `event_to_state_groups` for an outlier. (In practice this isn't always + the case, though I'm not sure why: see https://github.com/matrix-org/synapse/issues/12201). -## Backward extremity + * We don't record entries for them in the `event_edges`, + `event_forward_extremeties` or `event_backward_extremities` tables. -The current marker of where we have backfilled up to and will generally be the -`prev_events` of the oldest-in-time events we have in the DAG. This gives a starting point when -backfilling history. +Since outliers are not tied into the DAG, they do not normally form part of the +timeline sent down to clients via `/sync` or `/messages`; however there is an +exception: -When we persist a non-outlier event, we clear it as a backward extremity and set -all of its `prev_events` as the new backward extremities if they aren't already -persisted in the `events` table. +### Out-of-band membership events +A special case of outlier events are some membership events for federated rooms +that we aren't full members of. For example: -## Outliers + * invites received over federation, before we join the room + * *rejections* for said invites + * knock events for rooms that we would like to join but have not yet joined. -We mark an event as an `outlier` when we haven't figured out the state for the -room at that point in the DAG yet. +In all the above cases, we don't have the state for the room, which is why they +are treated as outliers. They are a bit special though, in that they are +proactively sent to clients via `/sync`. -We won't *necessarily* have the `prev_events` of an `outlier` in the database, -but it's entirely possible that we *might*. +## Forward extremity + +Most-recent-in-time events in the DAG which are not referenced by any other +events' `prev_events` yet. (In this definition, outliers, rejected events, and +soft-failed events don't count.) + +The forward extremities of a room (or at least, a subset of them, if there are +more than ten) are used as the `prev_events` when the next event is sent. + +The "current state" of a room (ie: the state which would be used if we +generated a new event) is, therefore, the resolution of the room states +at each of the forward extremities. + +## Backward extremity + +The current marker of where we have backfilled up to and will generally be the +`prev_events` of the oldest-in-time events we have in the DAG. This gives a starting point when +backfilling history. -For example, when we fetch the event auth chain or state for a given event, we -mark all of those claimed auth events as outliers because we haven't done the -state calculation ourself. +Note that, unlike forward extremities, we typically don't have any backward +extremity events themselves in the database - or, if we do, they will be "outliers" (see +above). Either way, we don't expect to have the room state at a backward extremity. +When we persist a non-outlier event, if it was previously a backward extremity, +we clear it as a backward extremity and set all of its `prev_events` as the new +backward extremities if they aren't already persisted as non-outliers. This +therefore keeps the backward extremities up-to-date. ## State groups From ea27528b5d177dcfc5a4e38b463baeace916dc8e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 10:36:13 -0500 Subject: [PATCH 044/230] Support stable identifiers for MSC3440: Threading (#12151) The unstable identifiers are still supported if the experimental configuration flag is enabled. The unstable identifiers will be removed in a future release. --- changelog.d/12151.feature | 1 + synapse/api/constants.py | 4 +- synapse/api/filtering.py | 23 +++--- synapse/events/utils.py | 9 ++- synapse/handlers/message.py | 5 +- synapse/rest/client/versions.py | 1 + synapse/server.py | 2 +- synapse/storage/databases/main/events.py | 5 +- synapse/storage/databases/main/relations.py | 77 +++++++++++++-------- synapse/storage/databases/main/stream.py | 18 ++--- tests/rest/client/test_relations.py | 7 +- tests/rest/client/test_rooms.py | 18 +++-- tests/storage/test_stream.py | 20 +++--- 13 files changed, 109 insertions(+), 81 deletions(-) create mode 100644 changelog.d/12151.feature diff --git a/changelog.d/12151.feature b/changelog.d/12151.feature new file mode 100644 index 000000000000..18432b2da9a5 --- /dev/null +++ b/changelog.d/12151.feature @@ -0,0 +1 @@ +Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 36ace7c6134f..b0c08a074d83 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -178,7 +178,9 @@ class RelationTypes: ANNOTATION: Final = "m.annotation" REPLACE: Final = "m.replace" REFERENCE: Final = "m.reference" - THREAD: Final = "io.element.thread" + THREAD: Final = "m.thread" + # TODO Remove this in Synapse >= v1.57.0. + UNSTABLE_THREAD: Final = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cb532d723828..27e97d6f372d 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -88,7 +88,9 @@ "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, # MSC3440, filtering by event relations. + "related_by_senders": {"type": "array", "items": {"type": "string"}}, "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "related_by_rel_types": {"type": "array", "items": {"type": "string"}}, "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -318,19 +320,18 @@ def __init__(self, hs: "HomeServer", filter_json: JsonDict): self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - # Ideally these would be rejected at the endpoint if they were provided - # and not supported, but that would involve modifying the JSON schema - # based on the homeserver configuration. + self.related_by_senders = self.filter_json.get("related_by_senders", None) + self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + + # Fallback to the unstable prefix if the stable version is not given. if hs.config.experimental.msc3440_enabled: - self.relation_senders = self.filter_json.get( + self.related_by_senders = self.related_by_senders or self.filter_json.get( "io.element.relation_senders", None ) - self.relation_types = self.filter_json.get( - "io.element.relation_types", None + self.related_by_rel_types = ( + self.related_by_rel_types + or self.filter_json.get("io.element.relation_types", None) ) - else: - self.relation_senders = None - self.relation_types = None def filters_all_types(self) -> bool: return "*" in self.not_types @@ -461,7 +462,7 @@ async def _check_event_relations( event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( - event_ids, self.relation_senders, self.relation_types + event_ids, self.related_by_senders, self.related_by_rel_types ) ) @@ -474,7 +475,7 @@ async def _check_event_relations( async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: result = [event for event in events if self._check(event)] - if self.relation_senders or self.relation_types: + if self.related_by_senders or self.related_by_rel_types: return await self._check_event_relations(result) return result diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ee34cb46e437..b2a237c1e04a 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,6 +38,7 @@ from . import EventBase if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.storage.databases.main.relations import BundledAggregations @@ -395,6 +396,9 @@ class EventClientSerializer: clients. """ + def __init__(self, hs: "HomeServer"): + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -515,11 +519,14 @@ def _inject_bundled_aggregations( thread.latest_event, serialized_latest_event, thread.latest_edit ) - serialized_aggregations[RelationTypes.THREAD] = { + thread_summary = { "latest_event": serialized_latest_event, "count": thread.count, "current_user_participated": thread.current_user_participated, } + serialized_aggregations[RelationTypes.THREAD] = thread_summary + if self._msc3440_enabled: + serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary # Include the bundled aggregations in the event. if serialized_aggregations: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0799ec9a84df..f9544fe7fb83 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1079,7 +1079,10 @@ async def _validate_event_relation(self, event: EventBase) -> None: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: + elif ( + relation_type == RelationTypes.THREAD + or relation_type == RelationTypes.UNSTABLE_THREAD + ): if await self.store.event_includes_relation(relates_to): raise SynapseError( 400, "Cannot start threads from an event with a relation" diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2e5d0e4e2258..9a65aa484360 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,6 +101,7 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]: "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440": self.config.experimental.msc3440_enabled, + "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 1270abb5a335..7741ff29dc3f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -754,7 +754,7 @@ def get_oidc_handler(self) -> "OidcHandler": @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1a322882bf3f..1f60aef180d0 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1814,7 +1814,10 @@ def _handle_event_relations( if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - if rel_type == RelationTypes.THREAD: + if ( + rel_type == RelationTypes.THREAD + or rel_type == RelationTypes.UNSTABLE_THREAD + ): txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index be1500092b5b..c4869d64e663 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -508,7 +508,7 @@ def _get_thread_summaries_txn( AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC """ else: @@ -523,16 +523,22 @@ def _get_thread_summaries_txn( AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY child.topological_ordering DESC, child.stream_ordering DESC """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) latest_event_ids = {} for parent_event_id, child_event_id in txn: # Only consider the latest threaded reply (by topological ordering). @@ -552,7 +558,7 @@ def _get_thread_summaries_txn( AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s GROUP BY parent.event_id """ @@ -561,9 +567,15 @@ def _get_thread_summaries_txn( clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", latest_event_ids.keys() ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) return counts, latest_event_ids @@ -626,16 +638,24 @@ def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]: AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s AND child.sender = ? """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.extend((RelationTypes.THREAD, user_id)) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + args.append(user_id) + + txn.execute(sql % (clause, relations_clause), args) return {row[0] for row in txn.fetchall()} participated_threads = await self.db_pool.runInteraction( @@ -834,26 +854,23 @@ async def get_bundled_aggregations( results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. - if self._msc3440_enabled: - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - summaries.keys(), user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) + summaries = await self._get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated(summaries.keys(), user_id) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a898f847e7d5..39e1efe37348 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args.extend(event_filter.labels) # Filter on relation_senders / relation types from the joined tables. - if event_filter.relation_senders: + if event_filter.related_by_senders: clauses.append( "(%s)" % " OR ".join( - "related_event.sender = ?" for _ in event_filter.relation_senders + "related_event.sender = ?" for _ in event_filter.related_by_senders ) ) - args.extend(event_filter.relation_senders) + args.extend(event_filter.related_by_senders) - if event_filter.relation_types: + if event_filter.related_by_rel_types: clauses.append( "(%s)" - % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + % " OR ".join( + "relation_type = ?" for _ in event_filter.related_by_rel_types + ) ) - args.extend(event_filter.relation_types) + args.extend(event_filter.related_by_rel_types) return " AND ".join(clauses), args @@ -1203,7 +1205,7 @@ def _paginate_room_events_txn( # If there is a filter on relation_senders and relation_types join to the # relations table. if event_filter and ( - event_filter.relation_senders or event_filter.relation_types + event_filter.related_by_senders or event_filter.related_by_rel_types ): # Filtering by relations could cause the same event to appear multiple # times (since there's no limit on the number of relations to an event). @@ -1211,7 +1213,7 @@ def _paginate_room_events_txn( join_clause += """ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) """ - if event_filter.relation_senders: + if event_filter.related_by_senders: join_clause += """ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f9ae6e663f95..0cbe6c0cf754 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -547,9 +547,7 @@ def test_aggregation_must_be_annotation(self) -> None: ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config( - {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -758,7 +756,6 @@ def test_aggregation_get_event_for_thread(self) -> None: }, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -1065,7 +1062,6 @@ def test_edit_reply(self) -> None: {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_edit_thread(self) -> None: """Test that editing a thread works.""" @@ -1383,7 +1379,6 @@ def test_redact_relation_annotation(self) -> None: chunk = self._get_aggregations() self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 37866ee330f3..3a9617d6da8e 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2141,21 +2141,19 @@ def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -2164,20 +2162,20 @@ def test_filter_relation_senders(self) -> None: def test_filter_relation_type(self) -> None: # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -2191,8 +2189,8 @@ def test_filter_relation_type(self) -> None: def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 6a1cf3305455..eaa0d7d749b7 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -129,21 +129,19 @@ def _filter_messages(self, filter: JsonDict) -> List[EventBase]: def test_filter_relation_senders(self): # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -152,20 +150,20 @@ def test_filter_relation_senders(self): def test_filter_relation_type(self): # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -179,8 +177,8 @@ def test_filter_relation_type(self): def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) @@ -201,7 +199,7 @@ def test_duplicate_relation(self): tok=self.second_tok, ) - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) From 72e7f1c420b879a0a1ef1430771698b868693ab0 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 10 Mar 2022 15:53:23 +0000 Subject: [PATCH 045/230] Remove workaround introduced in Synapse v1.50.0rc1 for Mjolnir compatibility. Breaks compatibility with Mjolnir v1.3.1 and earlier. (#11700) --- changelog.d/11700.removal | 1 + docs/upgrade.md | 8 ++++++++ synapse/util/__init__.py | 7 ------- 3 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 changelog.d/11700.removal diff --git a/changelog.d/11700.removal b/changelog.d/11700.removal new file mode 100644 index 000000000000..d3d3c48f0fc4 --- /dev/null +++ b/changelog.d/11700.removal @@ -0,0 +1 @@ +Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. diff --git a/docs/upgrade.md b/docs/upgrade.md index 0d0bb066ee63..95005962dc49 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -106,6 +106,14 @@ You will need to ensure `synctl` is on your `PATH`. automatically, though you might need to activate a virtual environment depending on how you installed Synapse. + +## Compatibility dropped for Mjolnir 1.3.1 and earlier + +Synapse v1.55.0 drops support for Mjolnir 1.3.1 and earlier. +If you use the Mjolnir module to moderate your homeserver, +please upgrade Mjolnir to version 1.3.2 or later before upgrading Synapse. + + # Upgrading to v1.54.0 ## Legacy structured logging configuration removal diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 58b4220ff355..d8046b7553cb 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -31,13 +31,6 @@ if typing.TYPE_CHECKING: pass -# FIXME Mjolnir imports glob_to_regex from this file, but it was moved to -# matrix_common. -# As a temporary workaround, we import glob_to_regex here for -# compatibility with current versions of Mjolnir. -# See https://github.com/matrix-org/mjolnir/pull/174 -from matrix_common.regex import glob_to_regex # noqa - logger = logging.getLogger(__name__) From ed9aea42fa991428406be96a67c311a8f9cec544 Mon Sep 17 00:00:00 2001 From: Shay Date: Thu, 10 Mar 2022 09:40:07 -0800 Subject: [PATCH 046/230] fix misleading comment in `check_events_for_spam` (#12203) --- changelog.d/12203.misc | 1 + synapse/events/spamcheck.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12203.misc diff --git a/changelog.d/12203.misc b/changelog.d/12203.misc new file mode 100644 index 000000000000..892dc5bfb7e3 --- /dev/null +++ b/changelog.d/12203.misc @@ -0,0 +1 @@ +Fix a misleading comment in the function `check_event_for_spam`. diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 04afd48274e1..60904a55f5c0 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -245,8 +245,8 @@ async def check_event_for_spam( """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if - sent by a local user. If it is sent by a user on another server, then - users receive a blank event. + sent by a local user. If it is sent by a user on another server, the + event is soft-failed. Args: event: the event to be checked From 7577894bec78a063f3e85ec7d386a58a0c60fb11 Mon Sep 17 00:00:00 2001 From: ~creme Date: Thu, 10 Mar 2022 19:15:19 +0100 Subject: [PATCH 047/230] Document that most streams can only have a single writer. (#12196) This includes the `typing`, `to_device`, `account_data`, `receipts`, and `presence` streams (really anything except the `events` stream). --- changelog.d/12196.doc | 1 + docs/workers.md | 31 +++++++++++++++++-------------- 2 files changed, 18 insertions(+), 14 deletions(-) create mode 100644 changelog.d/12196.doc diff --git a/changelog.d/12196.doc b/changelog.d/12196.doc new file mode 100644 index 000000000000..269f06aa3386 --- /dev/null +++ b/changelog.d/12196.doc @@ -0,0 +1 @@ +Document that the `typing`, `to_device`, `account_data`, `receipts`, and `presence` stream writer can only be used on a single worker. \ No newline at end of file diff --git a/docs/workers.md b/docs/workers.md index b0f8599ef062..8751134e654d 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -351,8 +351,11 @@ is only supported with Redis-based replication.) To enable this, the worker must have a HTTP replication listener configured, have a `worker_name` and be listed in the `instance_map` config. The same worker -can handle multiple streams. For example, to move event persistence off to a -dedicated worker, the shared configuration would include: +can handle multiple streams, but unless otherwise documented, each stream can only +have a single writer. + +For example, to move event persistence off to a dedicated worker, the shared +configuration would include: ```yaml instance_map: @@ -370,8 +373,8 @@ streams and the endpoints associated with them: ##### The `events` stream -The `events` stream also experimentally supports having multiple writers, where -work is sharded between them by room ID. Note that you *must* restart all worker +The `events` stream experimentally supports having multiple writers, where work +is sharded between them by room ID. Note that you *must* restart all worker instances when adding or removing event persisters. An example `stream_writers` configuration with multiple writers: @@ -384,38 +387,38 @@ stream_writers: ##### The `typing` stream -The following endpoints should be routed directly to the workers configured as -stream writers for the `typing` stream: +The following endpoints should be routed directly to the worker configured as +the stream writer for the `typing` stream: ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing ##### The `to_device` stream -The following endpoints should be routed directly to the workers configured as -stream writers for the `to_device` stream: +The following endpoints should be routed directly to the worker configured as +the stream writer for the `to_device` stream: ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ ##### The `account_data` stream -The following endpoints should be routed directly to the workers configured as -stream writers for the `account_data` stream: +The following endpoints should be routed directly to the worker configured as +the stream writer for the `account_data` stream: ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data ##### The `receipts` stream -The following endpoints should be routed directly to the workers configured as -stream writers for the `receipts` stream: +The following endpoints should be routed directly to the worker configured as +the stream writer for the `receipts` stream: ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers ##### The `presence` stream -The following endpoints should be routed directly to the workers configured as -stream writers for the `presence` stream: +The following endpoints should be routed directly to the worker configured as +the stream writer for the `presence` stream: ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ From 483f2aa2eca98500046847364ede04b034530aac Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 11 Mar 2022 10:33:49 +0000 Subject: [PATCH 048/230] Retention test: avoid relying on state at purged events (#12202) This test was relying on poking events which weren't in the database into filter_events_for_client. --- changelog.d/12202.misc | 1 + tests/rest/client/test_retention.py | 29 +++++++++++++++++------------ 2 files changed, 18 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12202.misc diff --git a/changelog.d/12202.misc b/changelog.d/12202.misc new file mode 100644 index 000000000000..9f333e718a86 --- /dev/null +++ b/changelog.d/12202.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index f3bf8d0934e6..7b8fe6d02522 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -24,6 +24,7 @@ from synapse.visibility import filter_events_for_client from tests import unittest +from tests.unittest import override_config one_hour_ms = 3600000 one_day_ms = one_hour_ms * 24 @@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["retention"] = { + + # merge this default retention config with anything that was specified in + # @override_config + retention_config = { "enabled": True, "default_policy": { "min_lifetime": one_day_ms, @@ -47,6 +51,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: "allowed_lifetime_min": one_day_ms, "allowed_lifetime_max": one_day_ms * 3, } + retention_config.update(config.get("retention", {})) + config["retention"] = retention_config self.hs = self.setup_test_homeserver(config=config) @@ -115,22 +121,20 @@ def test_retention_event_purged_without_state_event(self) -> None: self._test_retention_event_purged(room_id, one_day_ms * 2) + @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}}) def test_visibility(self) -> None: """Tests that synapse.visibility.filter_events_for_client correctly filters out - outdated events + outdated events, even if the purge job hasn't got to them yet. + + We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) - events = [] # Send a first event, which should be filtered out at the end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) - - # Get the event from the store so that we end up with a FrozenEvent that we can - # give to filter_events_for_client. We need to do this now because the event won't - # be in the database anymore after it has expired. - events.append(self.get_success(store.get_event(resp.get("event_id")))) + first_event_id = resp.get("event_id") # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -138,16 +142,17 @@ def test_visibility(self) -> None: # Send another event, which shouldn't get filtered out. resp = self.helper.send(room_id=room_id, body="2", tok=self.token) - valid_event_id = resp.get("event_id") - events.append(self.get_success(store.get_event(valid_event_id))) - # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) - # Run filter_events_for_client with our list of FrozenEvents. + # Fetch the events, and run filter_events_for_client on them + events = self.get_success( + store.get_events_as_list([first_event_id, valid_event_id]) + ) + self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( filter_events_for_client(storage, self.user_id, events) ) From 3b12f6d61b3b10b57e7b3a45f1d7a96f9790d674 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 11 Mar 2022 11:10:20 +0000 Subject: [PATCH 049/230] Note that contributors can sign off privately (#12204) Co-authored-by: Patrick Cloke --- changelog.d/12204.doc | 1 + docs/development/contributing_guide.md | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 changelog.d/12204.doc diff --git a/changelog.d/12204.doc b/changelog.d/12204.doc new file mode 100644 index 000000000000..c4b2805bb112 --- /dev/null +++ b/changelog.d/12204.doc @@ -0,0 +1 @@ +Document that contributors can sign off privately by email. diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md index 8448685952dc..071202e1965f 100644 --- a/docs/development/contributing_guide.md +++ b/docs/development/contributing_guide.md @@ -458,6 +458,17 @@ Git allows you to add this signoff automatically when using the `-s` flag to `git commit`, which uses the name and email set in your `user.name` and `user.email` git configs. +### Private Sign off + +If you would like to provide your legal name privately to the Matrix.org +Foundation (instead of in a public commit or comment), you can do so +by emailing your legal name and a link to the pull request to +[dco@matrix.org](mailto:dco@matrix.org?subject=Private%20sign%20off). +It helps to include "sign off" or similar in the subject line. You will then +be instructed further. + +Once private sign off is complete, doing so for future contributions will not +be required. # 10. Turn feedback into better code. From bc9dff1d9597251a15a15475cb8e8194b2d14910 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 11 Mar 2022 07:06:21 -0500 Subject: [PATCH 050/230] Remove unnecessary pass statements. (#12206) --- changelog.d/12206.misc | 1 + synapse/handlers/device.py | 2 -- synapse/handlers/presence.py | 2 -- synapse/http/matrixfederationclient.py | 2 -- synapse/http/server.py | 1 - synapse/rest/media/v1/_base.py | 1 - synapse/server.py | 1 - synapse/storage/databases/main/registration.py | 2 -- synapse/storage/schema/main/delta/30/as_users.py | 1 - synapse/util/caches/treecache.py | 2 -- tests/handlers/test_password_providers.py | 1 - 11 files changed, 1 insertion(+), 15 deletions(-) create mode 100644 changelog.d/12206.misc diff --git a/changelog.d/12206.misc b/changelog.d/12206.misc new file mode 100644 index 000000000000..df59bb56cdb8 --- /dev/null +++ b/changelog.d/12206.misc @@ -0,0 +1 @@ +Remove unnecessary `pass` statements. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d90cb259a65c..d5ccaa0c37cc 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -371,7 +371,6 @@ async def delete_device(self, user_id: str, device_id: str) -> None: log_kv( {"reason": "User doesn't have device id.", "device_id": device_id} ) - pass else: raise @@ -414,7 +413,6 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: # no match set_tag("error", True) set_tag("reason", "User doesn't have that device id.") - pass else: raise diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9927a30e6ed5..34d9411bbf61 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -267,7 +267,6 @@ async def update_external_syncs_row( is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ - pass async def update_external_syncs_clear(self, process_id: str) -> None: """Marks all users that had been marked as syncing by a given process @@ -277,7 +276,6 @@ async def update_external_syncs_clear(self, process_id: str) -> None: This is a no-op when presence is handled by a different worker. """ - pass async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: list diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 40bf1e06d602..6b98d865f5bb 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -120,7 +120,6 @@ def finish(self) -> T: """Called when response has finished streaming and the parser should return the final result (or error). """ - pass @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -601,7 +600,6 @@ async def _send_request( response.code, response_phrase, ) - pass else: logger.info( "{%s} [%s] Got response headers: %d %s", diff --git a/synapse/http/server.py b/synapse/http/server.py index 09b412548968..31ca84188975 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -233,7 +233,6 @@ def register_paths( servlet_classname (str): The name of the handler to be used in prometheus and opentracing logs. """ - pass class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 9b40fd8a6c23..c35d42fab89d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -298,7 +298,6 @@ def write_to_consumer(self, consumer: IConsumer) -> Awaitable: Returns: Resolves once the response has finished being written """ - pass def __enter__(self) -> None: pass diff --git a/synapse/server.py b/synapse/server.py index 7741ff29dc3f..2fcf18a7a69a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -328,7 +328,6 @@ def start_listening(self) -> None: Does nothing in this base class; overridden in derived classes to start the appropriate listeners. """ - pass def setup_background_tasks(self) -> None: """ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index dc6665237abf..a698d10cc535 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -48,8 +48,6 @@ class ExternalIDReuseException(Exception): """Exception if writing an external id for a user fails, because this external id is given to an other user.""" - pass - @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py index 22a7901e15df..4b4b166e37a6 100644 --- a/synapse/storage/schema/main/delta/30/as_users.py +++ b/synapse/storage/schema/main/delta/30/as_users.py @@ -36,7 +36,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): config_files = config.appservice.app_service_config_files except AttributeError: logger.warning("Could not get app_service_config_files from config") - pass appservices = load_appservices(config.server.server_name, config_files) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 563845f86769..e78305f78746 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -22,8 +22,6 @@ class TreeCacheNode(dict): leaves. """ - pass - class TreeCache: """ diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 49d832de814d..d401fda93855 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -124,7 +124,6 @@ def __init__(self, config, api: ModuleApi): ("m.login.password", ("password",)): self.check_auth, } ) - pass def check_auth(self, *args): return mock_password_provider.check_auth(*args) From e10a2fe0c28ec9206c0e2275df492f61ff5025f2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 11 Mar 2022 07:07:15 -0500 Subject: [PATCH 051/230] Add some type hints to the tests.handlers module. (#12207) --- changelog.d/12108.misc | 2 +- changelog.d/12146.misc | 2 +- changelog.d/12207.misc | 1 + tests/handlers/test_admin.py | 18 +++--- tests/handlers/test_auth.py | 34 +++++----- tests/handlers/test_deactivate_account.py | 2 +- tests/handlers/test_device.py | 76 ++++++++++++----------- 7 files changed, 74 insertions(+), 61 deletions(-) create mode 100644 changelog.d/12207.misc diff --git a/changelog.d/12108.misc b/changelog.d/12108.misc index 0360dbd61edc..b67a701dbb52 100644 --- a/changelog.d/12108.misc +++ b/changelog.d/12108.misc @@ -1 +1 @@ -Add type hints to `tests/rest/client`. +Add type hints to tests files. diff --git a/changelog.d/12146.misc b/changelog.d/12146.misc index 3ca7c47212fd..b67a701dbb52 100644 --- a/changelog.d/12146.misc +++ b/changelog.d/12146.misc @@ -1 +1 @@ -Add type hints to `tests/rest`. +Add type hints to tests files. diff --git a/changelog.d/12207.misc b/changelog.d/12207.misc new file mode 100644 index 000000000000..b67a701dbb52 --- /dev/null +++ b/changelog.d/12207.misc @@ -0,0 +1 @@ +Add type hints to tests files. diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index abf2a0fe0dc5..c1579dac610f 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -15,11 +15,15 @@ from collections import Counter from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin import synapse.storage from synapse.api.constants import EventTypes, JoinRules from synapse.api.room_versions import RoomVersions from synapse.rest.client import knock, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -32,7 +36,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): knock.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_handler = hs.get_admin_handler() self.user1 = self.register_user("user1", "password") @@ -41,7 +45,7 @@ def prepare(self, reactor, clock, hs): self.user2 = self.register_user("user2", "password") self.token2 = self.login("user2", "password") - def test_single_public_joined_room(self): + def test_single_public_joined_room(self) -> None: """Test that we write *all* events for a public room""" room_id = self.helper.create_room_as( self.user1, tok=self.token1, is_public=True @@ -74,7 +78,7 @@ def test_single_public_joined_room(self): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) - def test_single_private_joined_room(self): + def test_single_private_joined_room(self) -> None: """Tests that we correctly write state when we can't see all events in a room. """ @@ -112,7 +116,7 @@ def test_single_private_joined_room(self): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) - def test_single_left_room(self): + def test_single_left_room(self) -> None: """Tests that we don't see events in the room after we leave.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) @@ -144,7 +148,7 @@ def test_single_left_room(self): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 2) - def test_single_left_rejoined_private_room(self): + def test_single_left_rejoined_private_room(self) -> None: """Tests that see the correct events in private rooms when we repeatedly join and leave. """ @@ -185,7 +189,7 @@ def test_single_left_rejoined_private_room(self): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 3) - def test_invite(self): + def test_invite(self) -> None: """Tests that pending invites get handled correctly.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) @@ -204,7 +208,7 @@ def test_invite(self): self.assertEqual(args[1].content["membership"], "invite") self.assertTrue(args[2]) # Assert there is at least one bit of state - def test_knock(self): + def test_knock(self) -> None: """Tests that knock get handled correctly.""" # create a knockable v7 room room_id = self.helper.create_room_as( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 0c6e55e72592..67a7829769b6 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -15,8 +15,12 @@ import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -27,7 +31,7 @@ class AuthTestCase(unittest.HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.auth_handler = hs.get_auth_handler() self.macaroon_generator = hs.get_macaroon_generator() @@ -42,23 +46,23 @@ def prepare(self, reactor, clock, hs): self.user1 = self.register_user("a_user", "pass") - def test_macaroon_caveats(self): + def test_macaroon_caveats(self) -> None: token = self.macaroon_generator.generate_guest_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) - def verify_gen(caveat): + def verify_gen(caveat: str) -> bool: return caveat == "gen = 1" - def verify_user(caveat): + def verify_user(caveat: str) -> bool: return caveat == "user_id = a_user" - def verify_type(caveat): + def verify_type(caveat: str) -> bool: return caveat == "type = access" - def verify_nonce(caveat): + def verify_nonce(caveat: str) -> bool: return caveat.startswith("nonce =") - def verify_guest(caveat): + def verify_guest(caveat: str) -> bool: return caveat == "guest = true" v = pymacaroons.Verifier() @@ -69,7 +73,7 @@ def verify_guest(caveat): v.satisfy_general(verify_guest) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - def test_short_term_login_token_gives_user_id(self): + def test_short_term_login_token_gives_user_id(self) -> None: token = self.macaroon_generator.generate_short_term_login_token( self.user1, "", duration_in_ms=5000 ) @@ -84,7 +88,7 @@ def test_short_term_login_token_gives_user_id(self): AuthError, ) - def test_short_term_login_token_gives_auth_provider(self): + def test_short_term_login_token_gives_auth_provider(self) -> None: token = self.macaroon_generator.generate_short_term_login_token( self.user1, auth_provider_id="my_idp" ) @@ -92,7 +96,7 @@ def test_short_term_login_token_gives_auth_provider(self): self.assertEqual(self.user1, res.user_id) self.assertEqual("my_idp", res.auth_provider_id) - def test_short_term_login_token_cannot_replace_user_id(self): + def test_short_term_login_token_cannot_replace_user_id(self) -> None: token = self.macaroon_generator.generate_short_term_login_token( self.user1, "", duration_in_ms=5000 ) @@ -112,7 +116,7 @@ def test_short_term_login_token_cannot_replace_user_id(self): AuthError, ) - def test_mau_limits_disabled(self): + def test_mau_limits_disabled(self) -> None: self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception self.get_success( @@ -127,7 +131,7 @@ def test_mau_limits_disabled(self): ) ) - def test_mau_limits_exceeded_large(self): + def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) @@ -150,7 +154,7 @@ def test_mau_limits_exceeded_large(self): ResourceLimitError, ) - def test_mau_limits_parity(self): + def test_mau_limits_parity(self) -> None: # Ensure we're not at the unix epoch. self.reactor.advance(1) self.auth_blocking._limit_usage_by_mau = True @@ -189,7 +193,7 @@ def test_mau_limits_parity(self): ) ) - def test_mau_limits_not_exceeded(self): + def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( @@ -211,7 +215,7 @@ def test_mau_limits_not_exceeded(self): ) ) - def _get_macaroon(self): + def _get_macaroon(self) -> pymacaroons.Macaroon: token = self.macaroon_generator.generate_short_term_login_token( self.user1, "", duration_in_ms=5000 ) diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index ddda36c5a93b..3a107912265e 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -39,7 +39,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") - def _deactivate_my_account(self): + def _deactivate_my_account(self) -> None: """ Deactivates the account `self.user` using `self.token` and asserts that it returns a 200 success code. diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 683677fd0770..01ea7d2a4281 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -14,9 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse.api.errors -import synapse.handlers.device -import synapse.storage +from typing import Optional + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -25,28 +30,27 @@ class DeviceTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastores().main return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # These tests assume that it starts 1000 seconds in. self.reactor.advance(1000) - def test_device_is_created_with_invalid_name(self): + def test_device_is_created_with_invalid_name(self) -> None: self.get_failure( self.handler.check_device_registered( user_id="@boris:foo", device_id="foo", - initial_device_display_name="a" - * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1), + initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1), ), - synapse.api.errors.SynapseError, + SynapseError, ) - def test_device_is_created_if_doesnt_exist(self): + def test_device_is_created_if_doesnt_exist(self) -> None: res = self.get_success( self.handler.check_device_registered( user_id="@boris:foo", @@ -59,7 +63,7 @@ def test_device_is_created_if_doesnt_exist(self): dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - def test_device_is_preserved_if_exists(self): + def test_device_is_preserved_if_exists(self) -> None: res1 = self.get_success( self.handler.check_device_registered( user_id="@boris:foo", @@ -81,7 +85,7 @@ def test_device_is_preserved_if_exists(self): dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - def test_device_id_is_made_up_if_unspecified(self): + def test_device_id_is_made_up_if_unspecified(self) -> None: device_id = self.get_success( self.handler.check_device_registered( user_id="@theresa:foo", @@ -93,7 +97,7 @@ def test_device_id_is_made_up_if_unspecified(self): dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) self.assertEqual(dev["display_name"], "display") - def test_get_devices_by_user(self): + def test_get_devices_by_user(self) -> None: self._record_users() res = self.get_success(self.handler.get_devices_by_user(user1)) @@ -131,7 +135,7 @@ def test_get_devices_by_user(self): device_map["abc"], ) - def test_get_device(self): + def test_get_device(self) -> None: self._record_users() res = self.get_success(self.handler.get_device(user1, "abc")) @@ -146,21 +150,19 @@ def test_get_device(self): res, ) - def test_delete_device(self): + def test_delete_device(self) -> None: self._record_users() # delete the device self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - self.get_failure( - self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError - ) + self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError) # we'd like to check the access token was invalidated, but that's a # bit of a PITA. - def test_delete_device_and_device_inbox(self): + def test_delete_device_and_device_inbox(self) -> None: self._record_users() # add an device_inbox @@ -191,7 +193,7 @@ def test_delete_device_and_device_inbox(self): ) self.assertIsNone(res) - def test_update_device(self): + def test_update_device(self) -> None: self._record_users() update = {"display_name": "new display"} @@ -200,32 +202,29 @@ def test_update_device(self): res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") - def test_update_device_too_long_display_name(self): + def test_update_device_too_long_display_name(self) -> None: """Update a device with a display name that is invalid (too long).""" self._record_users() # Request to update a device display name with a new value that is longer than allowed. - update = { - "display_name": "a" - * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) - } + update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)} self.get_failure( self.handler.update_device(user1, "abc", update), - synapse.api.errors.SynapseError, + SynapseError, ) # Ensure the display name was not updated. res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "display 2") - def test_update_unknown_device(self): + def test_update_unknown_device(self) -> None: update = {"display_name": "new_display"} self.get_failure( self.handler.update_device("user_id", "unknown_device_id", update), - synapse.api.errors.NotFoundError, + NotFoundError, ) - def _record_users(self): + def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. self._record_user(user1, "xyz", "display 0") @@ -238,8 +237,13 @@ def _record_users(self): self.reactor.advance(10000) def _record_user( - self, user_id, device_id, display_name, access_token=None, ip=None - ): + self, + user_id: str, + device_id: str, + display_name: str, + access_token: Optional[str] = None, + ip: Optional[str] = None, + ) -> None: device_id = self.get_success( self.handler.check_device_registered( user_id=user_id, @@ -248,7 +252,7 @@ def _record_user( ) ) - if ip is not None: + if access_token is not None and ip is not None: self.get_success( self.store.insert_client_ip( user_id, access_token, ip, "user_agent", device_id @@ -258,7 +262,7 @@ def _record_user( class DehydrationTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() @@ -266,7 +270,7 @@ def make_homeserver(self, reactor, clock): self.store = hs.get_datastores().main return hs - def test_dehydrate_and_rehydrate_device(self): + def test_dehydrate_and_rehydrate_device(self) -> None: user_id = "@boris:dehydration" self.get_success(self.store.register_user(user_id, "foobar")) @@ -303,7 +307,7 @@ def test_dehydrate_and_rehydrate_device(self): access_token=access_token, device_id="not the right device ID", ), - synapse.api.errors.NotFoundError, + NotFoundError, ) # dehydrating the right devices should succeed and change our device ID @@ -331,7 +335,7 @@ def test_dehydrate_and_rehydrate_device(self): # make sure that the device ID that we were initially assigned no longer exists self.get_failure( self.handler.get_device(user_id, device_id), - synapse.api.errors.NotFoundError, + NotFoundError, ) # make sure that there's no device available for dehydrating now From 32c828d0f760492711a98b11376e229d795fd1b3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 11 Mar 2022 13:42:22 +0100 Subject: [PATCH 052/230] Add type hints to `tests/rest`. (#12208) Co-authored-by: Patrick Cloke --- changelog.d/12208.misc | 1 + mypy.ini | 1 - tests/rest/client/test_transactions.py | 19 +++- tests/rest/media/v1/test_media_storage.py | 110 +++++++++++++--------- tests/rest/media/v1/test_url_preview.py | 83 ++++++++-------- 5 files changed, 129 insertions(+), 85 deletions(-) create mode 100644 changelog.d/12208.misc diff --git a/changelog.d/12208.misc b/changelog.d/12208.misc new file mode 100644 index 000000000000..c5b635679931 --- /dev/null +++ b/changelog.d/12208.misc @@ -0,0 +1 @@ +Add type hints to tests files. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index c8390ddba96d..f9c39fcaaee3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,7 +90,6 @@ exclude = (?x) |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_transactions.py |tests/rest/media/v1/test_media_storage.py - |tests/rest/media/v1/test_url_preview.py |tests/scripts/test_new_matrix_user.py |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 3b5747cb12b8..8d8251b2ac99 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -1,3 +1,18 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus from unittest.mock import Mock, call from twisted.internet import defer, reactor @@ -11,14 +26,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() self.hs = Mock() self.hs.get_clock = Mock(return_value=self.clock) self.hs.get_auth = Mock() self.cache = HttpTransactionCache(self.hs) - self.mock_http_response = (200, "GOOD JOB!") + self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!") self.mock_key = "foo" @defer.inlineCallbacks diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index cba9be17c4ca..7204b2dfe075 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -16,7 +16,7 @@ import tempfile from binascii import unhexlify from io import BytesIO -from typing import Optional +from typing import Any, BinaryIO, Dict, List, Optional, Union from unittest.mock import Mock from urllib import parse @@ -26,18 +26,24 @@ from twisted.internet import defer from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor +from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable +from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths -from synapse.rest.media.v1.media_storage import MediaStorage +from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend +from synapse.server import HomeServer +from synapse.types import RoomAlias +from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request +from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import SMALL_PNG from tests.utils import default_config @@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): needs_threadpool = True - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") self.addCleanup(shutil.rmtree, self.test_dir) @@ -62,7 +68,7 @@ def prepare(self, reactor, clock, hs): hs, self.primary_base_path, self.filepaths, storage_providers ) - def test_ensure_media_is_in_local_cache(self): + def test_ensure_media_is_in_local_cache(self) -> None: media_id = "some_media_id" test_body = "Test\n" @@ -105,7 +111,7 @@ def test_ensure_media_is_in_local_cache(self): self.assertEqual(test_body, body) -@attr.s(slots=True, frozen=True) +@attr.s(auto_attribs=True, slots=True, frozen=True) class _TestImage: """An image for testing thumbnailing with the expected results @@ -121,18 +127,18 @@ class _TestImage: a 404 is expected. """ - data = attr.ib(type=bytes) - content_type = attr.ib(type=bytes) - extension = attr.ib(type=bytes) - expected_cropped = attr.ib(type=Optional[bytes], default=None) - expected_scaled = attr.ib(type=Optional[bytes], default=None) - expected_found = attr.ib(default=True, type=bool) + data: bytes + content_type: bytes + extension: bytes + expected_cropped: Optional[bytes] = None + expected_scaled: Optional[bytes] = None + expected_found: bool = True @parameterized_class( ("test_image",), [ - # smoll png + # small png ( _TestImage( SMALL_PNG, @@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + def get_file( + destination: str, + path: str, + output_stream: BinaryIO, + args: Optional[Dict[str, Union[str, List[str]]]] = None, + max_size: Optional[int] = None, + ) -> Deferred: """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -238,7 +250,7 @@ def write_to(r): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] @@ -248,8 +260,9 @@ def prepare(self, reactor, clock, hs): self.media_id = "example.com/12345" - def _req(self, content_disposition, include_content_type=True): - + def _req( + self, content_disposition: Optional[bytes], include_content_type: bool = True + ) -> FakeChannel: channel = make_request( self.reactor, FakeSite(self.download_resource, self.reactor), @@ -288,7 +301,7 @@ def _req(self, content_disposition, include_content_type=True): return channel - def test_handle_missing_content_type(self): + def test_handle_missing_content_type(self) -> None: channel = self._req( b"inline; filename=out" + self.test_image.extension, include_content_type=False, @@ -299,7 +312,7 @@ def test_handle_missing_content_type(self): headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"] ) - def test_disposition_filename_ascii(self): + def test_disposition_filename_ascii(self) -> None: """ If the filename is filename= then Synapse will decode it as an ASCII string, and use filename= in the response. @@ -315,7 +328,7 @@ def test_disposition_filename_ascii(self): [b"inline; filename=out" + self.test_image.extension], ) - def test_disposition_filenamestar_utf8escaped(self): + def test_disposition_filenamestar_utf8escaped(self) -> None: """ If the filename is filename=*utf8'' then Synapse will correctly decode it as the UTF-8 string, and use filename* in the @@ -335,7 +348,7 @@ def test_disposition_filenamestar_utf8escaped(self): [b"inline; filename*=utf-8''" + filename + self.test_image.extension], ) - def test_disposition_none(self): + def test_disposition_none(self) -> None: """ If there is no filename, one isn't passed on in the Content-Disposition of the request. @@ -348,26 +361,26 @@ def test_disposition_none(self): ) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) - def test_thumbnail_crop(self): + def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) - def test_thumbnail_scale(self): + def test_thumbnail_scale(self) -> None: """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) - def test_invalid_type(self): + def test_invalid_type(self) -> None: """An invalid thumbnail type is never available.""" self._test_thumbnail("invalid", None, False) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} ) - def test_no_thumbnail_crop(self): + def test_no_thumbnail_crop(self) -> None: """ Override the config to generate only scaled thumbnails, but request a cropped one. """ @@ -376,13 +389,13 @@ def test_no_thumbnail_crop(self): @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} ) - def test_no_thumbnail_scale(self): + def test_no_thumbnail_scale(self) -> None: """ Override the config to generate only cropped thumbnails, but request a scaled one. """ self._test_thumbnail("scale", None, False) - def test_thumbnail_repeated_thumbnail(self): + def test_thumbnail_repeated_thumbnail(self) -> None: """Test that fetching the same thumbnail works, and deleting the on disk thumbnail regenerates it. """ @@ -443,7 +456,9 @@ def test_thumbnail_repeated_thumbnail(self): channel.result["body"], ) - def _test_thumbnail(self, method, expected_body, expected_found): + def _test_thumbnail( + self, method: str, expected_body: Optional[bytes], expected_found: bool + ) -> None: params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, @@ -485,7 +500,7 @@ def _test_thumbnail(self, method, expected_body, expected_found): ) @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) - def test_same_quality(self, method, desired_size): + def test_same_quality(self, method: str, desired_size: int) -> None: """Test that choosing between thumbnails with the same quality rating succeeds. We are not particular about which thumbnail is chosen.""" @@ -521,7 +536,7 @@ def test_same_quality(self, method, desired_size): ) ) - def test_x_robots_tag_header(self): + def test_x_robots_tag_header(self) -> None: """ Tests that the `X-Robots-Tag` header is present, which informs web crawlers to not index, archive, or follow links in media. @@ -540,29 +555,38 @@ class TestSpamChecker: `evil`. """ - def __init__(self, config, api): + def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None: self.config = config self.api = api - def parse_config(config): + def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config - async def check_event_for_spam(self, foo): + async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]: return False # allow all events - async def user_may_invite(self, inviter_userid, invitee_userid, room_id): + async def user_may_invite( + self, + inviter_userid: str, + invitee_userid: str, + room_id: str, + ) -> bool: return True # allow all invites - async def user_may_create_room(self, userid): + async def user_may_create_room(self, userid: str) -> bool: return True # allow all room creations - async def user_may_create_room_alias(self, userid, room_alias): + async def user_may_create_room_alias( + self, userid: str, room_alias: RoomAlias + ) -> bool: return True # allow all room aliases - async def user_may_publish_room(self, userid, room_id): + async def user_may_publish_room(self, userid: str, room_id: str) -> bool: return True # allow publishing of all rooms - async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool: + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> bool: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write) @@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user = self.register_user("user", "pass") self.tok = self.login("user", "pass") @@ -586,7 +610,7 @@ def prepare(self, reactor, clock, hs): load_legacy_spam_checkers(hs) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = default_config("test") config.update( @@ -602,13 +626,13 @@ def default_config(self): return config - def test_upload_innocent(self): + def test_upload_innocent(self) -> None: """Attempt to upload some innocent data that should be allowed.""" self.helper.upload_media( self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 ) - def test_upload_ban(self): + def test_upload_ban(self) -> None: """Attempt to upload some data that includes bytes "evil", which should get rejected by the spam checker. """ diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index da2c53326019..5148c39874e2 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -16,16 +16,21 @@ import json import os import re +from typing import Any, Dict, Optional, Sequence, Tuple, Type from urllib.parse import urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError -from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.internet.interfaces import IAddress, IResolutionReceiver +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from synapse.config.oembed import OEmbedEndpointConfig +from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest @@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"" ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["url_preview_enabled"] = True @@ -113,22 +118,22 @@ def make_homeserver(self, reactor, clock): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] - self.lookups = {} + self.lookups: Dict[str, Any] = {} class Resolver: def resolveHostName( _self, - resolutionReceiver, - hostName, - portNumber=0, - addressTypes=None, - transportSemantics="TCP", - ): + resolutionReceiver: IResolutionReceiver, + hostName: str, + portNumber: int = 0, + addressTypes: Optional[Sequence[Type[IAddress]]] = None, + transportSemantics: str = "TCP", + ) -> IResolutionReceiver: resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) @@ -140,9 +145,9 @@ def resolveHostName( resolutionReceiver.resolutionComplete() return resolutionReceiver - self.reactor.nameResolver = Resolver() + self.reactor.nameResolver = Resolver() # type: ignore[assignment] - def create_test_resource(self): + def create_test_resource(self) -> MediaRepositoryResource: return self.hs.get_media_repository_resource() def _assert_small_png(self, json_body: JsonDict) -> None: @@ -153,7 +158,7 @@ def _assert_small_png(self, json_body: JsonDict) -> None: self.assertEqual(json_body["og:image:type"], "image/png") self.assertEqual(json_body["matrix:image:size"], 67) - def test_cache_returns_correct_type(self): + def test_cache_returns_correct_type(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] channel = self.make_request( @@ -207,7 +212,7 @@ def test_cache_returns_correct_type(self): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_non_ascii_preview_httpequiv(self): + def test_non_ascii_preview_httpequiv(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -243,7 +248,7 @@ def test_non_ascii_preview_httpequiv(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - def test_video_rejected(self): + def test_video_rejected(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = b"anything" @@ -279,7 +284,7 @@ def test_video_rejected(self): }, ) - def test_audio_rejected(self): + def test_audio_rejected(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = b"anything" @@ -315,7 +320,7 @@ def test_audio_rejected(self): }, ) - def test_non_ascii_preview_content_type(self): + def test_non_ascii_preview_content_type(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -350,7 +355,7 @@ def test_non_ascii_preview_content_type(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - def test_overlong_title(self): + def test_overlong_title(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -387,7 +392,7 @@ def test_overlong_title(self): # We should only see the `og:description` field, as `title` is too long and should be stripped out self.assertCountEqual(["og:description"], res.keys()) - def test_ipaddr(self): + def test_ipaddr(self) -> None: """ IP addresses can be previewed directly. """ @@ -417,7 +422,7 @@ def test_ipaddr(self): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_specific(self): + def test_blacklisted_ip_specific(self) -> None: """ Blacklisted IP addresses, found via DNS, are not spidered. """ @@ -438,7 +443,7 @@ def test_blacklisted_ip_specific(self): }, ) - def test_blacklisted_ip_range(self): + def test_blacklisted_ip_range(self) -> None: """ Blacklisted IP ranges, IPs found over DNS, are not spidered. """ @@ -457,7 +462,7 @@ def test_blacklisted_ip_range(self): }, ) - def test_blacklisted_ip_specific_direct(self): + def test_blacklisted_ip_specific_direct(self) -> None: """ Blacklisted IP addresses, accessed directly, are not spidered. """ @@ -476,7 +481,7 @@ def test_blacklisted_ip_specific_direct(self): ) self.assertEqual(channel.code, 403) - def test_blacklisted_ip_range_direct(self): + def test_blacklisted_ip_range_direct(self) -> None: """ Blacklisted IP ranges, accessed directly, are not spidered. """ @@ -493,7 +498,7 @@ def test_blacklisted_ip_range_direct(self): }, ) - def test_blacklisted_ip_range_whitelisted_ip(self): + def test_blacklisted_ip_range_whitelisted_ip(self) -> None: """ Blacklisted but then subsequently whitelisted IP addresses can be spidered. @@ -526,7 +531,7 @@ def test_blacklisted_ip_range_whitelisted_ip(self): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_with_external_ip(self): + def test_blacklisted_ip_with_external_ip(self) -> None: """ If a hostname resolves a blacklisted IP, even if there's a non-blacklisted one, it will be rejected. @@ -549,7 +554,7 @@ def test_blacklisted_ip_with_external_ip(self): }, ) - def test_blacklisted_ipv6_specific(self): + def test_blacklisted_ipv6_specific(self) -> None: """ Blacklisted IP addresses, found via DNS, are not spidered. """ @@ -572,7 +577,7 @@ def test_blacklisted_ipv6_specific(self): }, ) - def test_blacklisted_ipv6_range(self): + def test_blacklisted_ipv6_range(self) -> None: """ Blacklisted IP ranges, IPs found over DNS, are not spidered. """ @@ -591,7 +596,7 @@ def test_blacklisted_ipv6_range(self): }, ) - def test_OPTIONS(self): + def test_OPTIONS(self) -> None: """ OPTIONS returns the OPTIONS. """ @@ -601,7 +606,7 @@ def test_OPTIONS(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) - def test_accept_language_config_option(self): + def test_accept_language_config_option(self) -> None: """ Accept-Language header is sent to the remote server """ @@ -652,7 +657,7 @@ def test_accept_language_config_option(self): server.data, ) - def test_data_url(self): + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. """ @@ -675,7 +680,7 @@ def test_data_url(self): self.assertEqual(channel.code, 500) - def test_inline_data_url(self): + def test_inline_data_url(self) -> None: """ An inline image (as a data URL) should be parsed properly. """ @@ -712,7 +717,7 @@ def test_inline_data_url(self): self.assertEqual(channel.code, 200) self._assert_small_png(channel.json_body) - def test_oembed_photo(self): + def test_oembed_photo(self) -> None: """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -771,7 +776,7 @@ def test_oembed_photo(self): self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") self._assert_small_png(body) - def test_oembed_rich(self): + def test_oembed_rich(self) -> None: """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -817,7 +822,7 @@ def test_oembed_rich(self): }, ) - def test_oembed_format(self): + def test_oembed_format(self) -> None: """Test an oEmbed endpoint which requires the format in the URL.""" self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")] @@ -866,7 +871,7 @@ def test_oembed_format(self): }, ) - def test_oembed_autodiscovery(self): + def test_oembed_autodiscovery(self) -> None: """ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. 1. Request a preview of a URL which is not known to the oEmbed code. @@ -962,7 +967,7 @@ def test_oembed_autodiscovery(self): ) self._assert_small_png(body) - def _download_image(self): + def _download_image(self) -> Tuple[str, str]: """Downloads an image into the URL cache. Returns: A (host, media_id) tuple representing the MXC URI of the image. @@ -995,7 +1000,7 @@ def _download_image(self): self.assertIsNone(_port) return host, media_id - def test_storage_providers_exclude_files(self): + def test_storage_providers_exclude_files(self) -> None: """Test that files are not stored in or fetched from storage providers.""" host, media_id = self._download_image() @@ -1037,7 +1042,7 @@ def test_storage_providers_exclude_files(self): "URL cache file was unexpectedly retrieved from a storage provider", ) - def test_storage_providers_exclude_thumbnails(self): + def test_storage_providers_exclude_thumbnails(self) -> None: """Test that thumbnails are not stored in or fetched from storage providers.""" host, media_id = self._download_image() @@ -1090,7 +1095,7 @@ def test_storage_providers_exclude_thumbnails(self): "URL cache thumbnail was unexpectedly retrieved from a storage provider", ) - def test_cache_expiry(self): + def test_cache_expiry(self) -> None: """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" self.preview_url.clock = MockClock() From 003cc6910af177fec86ae7f43683d146975c7f4b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 11 Mar 2022 14:20:00 +0100 Subject: [PATCH 053/230] Update the SSO username picker template to comply with SIWA guidelines (#12210) Fixes https://github.com/matrix-org/synapse/issues/12205 --- changelog.d/12210.misc | 1 + docs/sample_config.yaml | 9 +++++++-- docs/templates.md | 7 +++++-- synapse/config/oidc.py | 9 +++++++-- synapse/handlers/oidc.py | 12 +++++++++++- synapse/handlers/sso.py | 8 +++++--- synapse/res/templates/sso_auth_account_details.html | 6 +++--- synapse/rest/synapse/client/pick_username.py | 8 ++++++++ 8 files changed, 47 insertions(+), 13 deletions(-) create mode 100644 changelog.d/12210.misc diff --git a/changelog.d/12210.misc b/changelog.d/12210.misc new file mode 100644 index 000000000000..3f6a8747c256 --- /dev/null +++ b/changelog.d/12210.misc @@ -0,0 +1 @@ +Update the SSO username picker template to comply with SIWA guidelines. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6f3623c88ab9..ef25a3175f98 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1947,8 +1947,13 @@ saml2_config: # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their -# own username (see 'sso_auth_account_details.html' in the 'sso' -# section of this file). +# own username (see the documentation for the +# 'sso_auth_account_details.html' template). +# +# confirm_localpart: Whether to prompt the user to validate (or +# change) the generated localpart (see the documentation for the +# 'sso_auth_account_details.html' template), instead of +# registering the account right away. # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. diff --git a/docs/templates.md b/docs/templates.md index 2b66e9d86294..b251d05cb9eb 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -176,8 +176,11 @@ Below are the templates Synapse will look for when generating pages related to S for the brand of the IdP * `user_attributes`: an object containing details about the user that we received from the IdP. May have the following attributes: - * display_name: the user's display_name - * emails: a list of email addresses + * `display_name`: the user's display name + * `emails`: a list of email addresses + * `localpart`: the local part of the Matrix user ID to register, + if `localpart_template` is set in the mapping provider configuration (empty + string if not) The template should render a form which submits the following fields: * `username`: the localpart of the user's chosen user id * `sso_new_user_consent.html`: HTML page allowing the user to consent to the diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index f7e4f9ef22b1..fc95912d9b3f 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -182,8 +182,13 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their - # own username (see 'sso_auth_account_details.html' in the 'sso' - # section of this file). + # own username (see the documentation for the + # 'sso_auth_account_details.html' template). + # + # confirm_localpart: Whether to prompt the user to validate (or + # change) the generated localpart (see the documentation for the + # 'sso_auth_account_details.html' template), instead of + # registering the account right away. # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 593a2aac6691..d98659edc761 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1228,6 +1228,7 @@ class OidcSessionData: class UserAttributeDict(TypedDict): localpart: Optional[str] + confirm_localpart: bool display_name: Optional[str] emails: List[str] @@ -1316,6 +1317,7 @@ class JinjaOidcMappingConfig: display_name_template: Optional[Template] email_template: Optional[Template] extra_attributes: Dict[str, Template] + confirm_localpart: bool = False class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1357,12 +1359,17 @@ def parse_template_config(option_name: str) -> Optional[Template]: "invalid jinja template", path=["extra_attributes", key] ) from e + confirm_localpart = config.get("confirm_localpart") or False + if not isinstance(confirm_localpart, bool): + raise ConfigError("must be a bool", path=["confirm_localpart"]) + return JinjaOidcMappingConfig( subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, extra_attributes=extra_attributes, + confirm_localpart=confirm_localpart, ) def get_remote_user_id(self, userinfo: UserInfo) -> str: @@ -1398,7 +1405,10 @@ def render_template_field(template: Optional[Template]) -> Optional[str]: emails.append(email) return UserAttributeDict( - localpart=localpart, display_name=display_name, emails=emails + localpart=localpart, + display_name=display_name, + emails=emails, + confirm_localpart=self._config.confirm_localpart, ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ff5b5169cac1..4f02a060d953 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -132,6 +132,7 @@ class UserAttributes: # if `None`, the mapper has not picked a userid, and the user should be prompted to # enter one. localpart: Optional[str] + confirm_localpart: bool = False display_name: Optional[str] = None emails: Collection[str] = attr.Factory(list) @@ -561,9 +562,10 @@ def _get_url_for_next_new_user_step( # Must provide either attributes or session, not both assert (attributes is not None) != (session is not None) - if (attributes and attributes.localpart is None) or ( - session and session.chosen_localpart is None - ): + if ( + attributes + and (attributes.localpart is None or attributes.confirm_localpart is True) + ) or (session and session.chosen_localpart is None): return b"/_synapse/client/pick_username/account_details" elif self._consent_at_registration and not ( session and session.terms_accepted_version diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index 00e1dcdbb866..41315e4fd4da 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -130,15 +130,15 @@
-

Your account is nearly ready

-

Check your details before creating an account on {{ server_name }}

+

Choose your user name

+

This is required to create your account on {{ server_name }}, and you can't change this later.

@
- +
:{{ server_name }}
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 28ae08349705..6338fbaaa96d 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -92,12 +92,20 @@ async def _async_render_GET(self, request: Request) -> None: self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) return + # The configuration might mandate going through this step to validate an + # automatically generated localpart, so session.chosen_localpart might already + # be set. + localpart = "" + if session.chosen_localpart is not None: + localpart = session.chosen_localpart + idp_id = session.auth_provider_id template_params = { "idp": self._sso_handler.get_identity_providers()[idp_id], "user_attributes": { "display_name": session.display_name, "emails": session.emails, + "localpart": localpart, }, } From 735e89bd3a0755883ef0a19649adf84192b5d9fc Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 11 Mar 2022 14:45:26 +0100 Subject: [PATCH 054/230] Add an additional HTTP pusher + push rule tests. (#12188) And rename the field used for caching from _id to _cache_key. --- changelog.d/12188.misc | 1 + synapse/push/baserules.py | 38 ++++++------- synapse/push/bulk_push_rule_evaluator.py | 10 ++-- synapse/push/clientformat.py | 2 +- tests/push/test_http.py | 72 +++++++++++++++++++++++- 5 files changed, 95 insertions(+), 28 deletions(-) create mode 100644 changelog.d/12188.misc diff --git a/changelog.d/12188.misc b/changelog.d/12188.misc new file mode 100644 index 000000000000..403158481cee --- /dev/null +++ b/changelog.d/12188.misc @@ -0,0 +1 @@ +Add combined test for HTTP pusher and push rule. Contributed by Nick @ Beeper. diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 832eaa34e9ef..f42f605f2383 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -169,7 +169,7 @@ def make_base_prepend_rules( "kind": "event_match", "key": "content.msgtype", "pattern": "m.notice", - "_id": "_suppress_notices", + "_cache_key": "_suppress_notices", } ], "actions": ["dont_notify"], @@ -183,13 +183,13 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "m.room.member", - "_id": "_member", + "_cache_key": "_member", }, { "kind": "event_match", "key": "content.membership", "pattern": "invite", - "_id": "_invite_member", + "_cache_key": "_invite_member", }, {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, ], @@ -212,7 +212,7 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "m.room.member", - "_id": "_member", + "_cache_key": "_member", } ], "actions": ["dont_notify"], @@ -237,12 +237,12 @@ def make_base_prepend_rules( "kind": "event_match", "key": "content.body", "pattern": "@room", - "_id": "_roomnotif_content", + "_cache_key": "_roomnotif_content", }, { "kind": "sender_notification_permission", "key": "room", - "_id": "_roomnotif_pl", + "_cache_key": "_roomnotif_pl", }, ], "actions": ["notify", {"set_tweak": "highlight", "value": True}], @@ -254,13 +254,13 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "m.room.tombstone", - "_id": "_tombstone", + "_cache_key": "_tombstone", }, { "kind": "event_match", "key": "state_key", "pattern": "", - "_id": "_tombstone_statekey", + "_cache_key": "_tombstone_statekey", }, ], "actions": ["notify", {"set_tweak": "highlight", "value": True}], @@ -272,7 +272,7 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "m.reaction", - "_id": "_reaction", + "_cache_key": "_reaction", } ], "actions": ["dont_notify"], @@ -288,7 +288,7 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "m.call.invite", - "_id": "_call", + "_cache_key": "_call", } ], "actions": [ @@ -302,12 +302,12 @@ def make_base_prepend_rules( { "rule_id": "global/underride/.m.rule.room_one_to_one", "conditions": [ - {"kind": "room_member_count", "is": "2", "_id": "member_count"}, + {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"}, { "kind": "event_match", "key": "type", "pattern": "m.room.message", - "_id": "_message", + "_cache_key": "_message", }, ], "actions": [ @@ -321,12 +321,12 @@ def make_base_prepend_rules( { "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one", "conditions": [ - {"kind": "room_member_count", "is": "2", "_id": "member_count"}, + {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"}, { "kind": "event_match", "key": "type", "pattern": "m.room.encrypted", - "_id": "_encrypted", + "_cache_key": "_encrypted", }, ], "actions": [ @@ -342,7 +342,7 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "m.room.message", - "_id": "_message", + "_cache_key": "_message", } ], "actions": ["notify", {"set_tweak": "highlight", "value": False}], @@ -356,7 +356,7 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "m.room.encrypted", - "_id": "_encrypted", + "_cache_key": "_encrypted", } ], "actions": ["notify", {"set_tweak": "highlight", "value": False}], @@ -368,19 +368,19 @@ def make_base_prepend_rules( "kind": "event_match", "key": "type", "pattern": "im.vector.modular.widgets", - "_id": "_type_modular_widgets", + "_cache_key": "_type_modular_widgets", }, { "kind": "event_match", "key": "content.type", "pattern": "jitsi", - "_id": "_content_type_jitsi", + "_cache_key": "_content_type_jitsi", }, { "kind": "event_match", "key": "state_key", "pattern": "*", - "_id": "_is_state_event", + "_cache_key": "_is_state_event", }, ], "actions": ["notify", {"set_tweak": "highlight", "value": False}], diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index fecf86034eb9..8140afcb6b37 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -274,17 +274,17 @@ def _condition_checker( cache: Dict[str, bool], ) -> bool: for cond in conditions: - _id = cond.get("_id", None) - if _id: - res = cache.get(_id, None) + _cache_key = cond.get("_cache_key", None) + if _cache_key: + res = cache.get(_cache_key, None) if res is False: return False elif res is True: continue res = evaluator.matches(cond, uid, display_name) - if _id: - cache[_id] = bool(res) + if _cache_key: + cache[_cache_key] = bool(res) if not res: return False diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index c5708cd8885b..63b22d50aea9 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -40,7 +40,7 @@ def format_push_rules_for_user( # Remove internal stuff. for c in r["conditions"]: - c.pop("_id", None) + c.pop("_cache_key", None) pattern_type = c.pop("pattern_type", None) if pattern_type == "user_id": diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c284beb37ce0..6691e0712896 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -18,7 +19,7 @@ import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException -from synapse.rest.client import login, receipts, room +from synapse.rest.client import login, push_rule, receipts, room from tests.unittest import HomeserverTestCase, override_config @@ -29,6 +30,7 @@ class HTTPPusherTests(HomeserverTestCase): room.register_servlets, login.register_servlets, receipts.register_servlets, + push_rule.register_servlets, ] user_id = True hijack_auth = False @@ -39,12 +41,12 @@ def default_config(self): return config def make_homeserver(self, reactor, clock): - self.push_attempts = [] + self.push_attempts: List[tuple[Deferred, str, dict]] = [] m = Mock() def post_json_get_json(url, body): - d = Deferred() + d: Deferred = Deferred() self.push_attempts.append((d, url, body)) return make_deferred_yieldable(d) @@ -719,3 +721,67 @@ def _send_read_request(self, access_token, message_event_id, room_id): access_token=access_token, ) self.assertEqual(channel.code, 200, channel.json_body) + + def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + user_id = self.register_user(username, "pass") + access_token = self.login(username, "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "http://example.com/_matrix/push/v1/notify"}, + ) + ) + + return user_id, access_token + + def test_dont_notify_rule_overrides_message(self): + """ + The override push rule will suppress notification + """ + + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Disable user notifications for this room -> user + body = { + "conditions": [{"kind": "event_match", "key": "room_id", "pattern": room}], + "actions": ["dont_notify"], + } + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend", + body, + access_token=access_token, + ) + self.assertEqual(channel.code, 200) + + # Check we start with no pushes + self.assertEqual(len(self.push_attempts), 0) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends a message (ignored by dont_notify push rule set above) + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # The user sends a message back (sends a notification) + self.helper.send(room, body="Hello", tok=access_token) + self.assertEqual(len(self.push_attempts), 1) From 4a53f357379c2dc407617a3d39e6da4790dec9aa Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 11 Mar 2022 14:00:15 +0000 Subject: [PATCH 055/230] Improve code documentation for the typing stream over replication. (#12211) --- changelog.d/12211.misc | 1 + synapse/handlers/typing.py | 5 +++-- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/resource.py | 6 +++--- synapse/replication/tcp/streams/_base.py | 12 ++++++++++++ 5 files changed, 20 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12211.misc diff --git a/changelog.d/12211.misc b/changelog.d/12211.misc new file mode 100644 index 000000000000..d11634a1ee0f --- /dev/null +++ b/changelog.d/12211.misc @@ -0,0 +1 @@ +Improve code documentation for the typing stream over replication. \ No newline at end of file diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3b8912652856..6854428b7ca5 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -160,8 +160,9 @@ def process_replication_rows( """Should be called whenever we receive updates for typing stream.""" if self._latest_room_serial > token: - # The master has gone backwards. To prevent inconsistent data, just - # clear everything. + # The typing worker has gone backwards (e.g. it may have restarted). + # To prevent inconsistent data, just clear everything. + logger.info("Typing handler stream went backwards; resetting") self._reset() # Set the latest serial token to whatever the server gave us. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d51f045f229a..b217c35f995d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -709,7 +709,7 @@ def send_remote_server_up(self, server: str) -> None: self.send_command(RemoteServerUpCommand(server)) def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None: - """Called when a new update is available to stream to clients. + """Called when a new update is available to stream to Redis subscribers. We need to check if the client is interested in the stream or not """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ab829040cde9..c6870df8f954 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -67,8 +67,8 @@ def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol: class ReplicationStreamer: """Handles replication connections. - This needs to be poked when new replication data may be available. When new - data is available it will propagate to all connected clients. + This needs to be poked when new replication data may be available. + When new data is available it will propagate to all Redis subscribers. """ def __init__(self, hs: "HomeServer"): @@ -109,7 +109,7 @@ def __init__(self, hs: "HomeServer"): def on_notifier_poke(self) -> None: """Checks if there is actually any new data and sends it to the - connections if there are. + Redis subscribers if there are. This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 23d631a76944..495f2f0285ba 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -316,7 +316,19 @@ def __init__(self, hs: "HomeServer"): class TypingStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) class TypingStreamRow: + """ + An entry in the typing stream. + Describes all the users that are 'typing' right now in one room. + + When a user stops typing, it will be streamed as a new update with that + user absent; you can think of the `user_ids` list as overwriting the + entire list that was there previously. + """ + + # The room that this update is for. room_id: str + + # All the users that are 'typing' right now in the specified room. user_ids: List[str] NAME = "typing" From e6a106fd5ebbf30a7c84f8ba09dc903d20213be3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 11 Mar 2022 16:15:11 +0100 Subject: [PATCH 056/230] Implement a Jinja2 filter to extract localparts from email addresses (#12212) --- changelog.d/12212.feature | 1 + docs/sample_config.yaml | 3 ++- docs/templates.md | 7 +++++++ synapse/config/oidc.py | 3 ++- synapse/handlers/oidc.py | 6 ++++++ synapse/util/templates.py | 5 +++++ 6 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12212.feature diff --git a/changelog.d/12212.feature b/changelog.d/12212.feature new file mode 100644 index 000000000000..fe337ff99057 --- /dev/null +++ b/changelog.d/12212.feature @@ -0,0 +1 @@ +Add a new Jinja2 template filter to extract the local part of an email address. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ef25a3175f98..d634fd8ff54c 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1948,7 +1948,8 @@ saml2_config: # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their # own username (see the documentation for the -# 'sso_auth_account_details.html' template). +# 'sso_auth_account_details.html' template). This template can +# use the 'localpart_from_email' filter. # # confirm_localpart: Whether to prompt the user to validate (or # change) the generated localpart (see the documentation for the diff --git a/docs/templates.md b/docs/templates.md index b251d05cb9eb..f87692a4538d 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -36,6 +36,13 @@ Turns a `mxc://` URL for media content into an HTTP(S) one using the homeserver' Example: `message.sender_avatar_url|mxc_to_http(32,32)` +```python +localpart_from_email(address: str) -> str +``` + +Returns the local part of an email address (e.g. `alice` in `alice@example.com`). + +Example: `user.email_address|localpart_from_email` ## Email templates diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index fc95912d9b3f..5d571651cbc9 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -183,7 +183,8 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their # own username (see the documentation for the - # 'sso_auth_account_details.html' template). + # 'sso_auth_account_details.html' template). This template can + # use the 'localpart_from_email' filter. # # confirm_localpart: Whether to prompt the user to validate (or # change) the generated localpart (see the documentation for the diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d98659edc761..724b9cfcb4bb 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -45,6 +45,7 @@ from synapse.util import Clock, json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry +from synapse.util.templates import _localpart_from_email_filter if TYPE_CHECKING: from synapse.server import HomeServer @@ -1308,6 +1309,11 @@ def jinja_finalize(thing: Any) -> Any: env = Environment(finalize=jinja_finalize) +env.filters.update( + { + "localpart_from_email": _localpart_from_email_filter, + } +) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 12941065ca60..fb758b7180f2 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -64,6 +64,7 @@ def build_jinja_env( { "format_ts": _format_ts_filter, "mxc_to_http": _create_mxc_to_http_filter(config.server.public_baseurl), + "localpart_from_email": _localpart_from_email_filter, } ) @@ -112,3 +113,7 @@ def mxc_to_http_filter( def _format_ts_filter(value: int, format: str) -> str: return time.strftime(format, time.localtime(value / 1000)) + + +def _localpart_from_email_filter(address: str) -> str: + return address.rsplit("@", 1)[0] From ef3619e61d84493d98470eb2a69131d15eb1166b Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 11 Mar 2022 10:46:45 -0800 Subject: [PATCH 057/230] Add config settings for background update parameters (#11980) --- changelog.d/11980.misc | 1 + docs/sample_config.yaml | 32 +++ synapse/config/_base.pyi | 2 + synapse/config/background_updates.py | 68 ++++++ synapse/config/homeserver.py | 2 + synapse/storage/background_updates.py | 39 +-- tests/config/test_background_update.py | 58 +++++ tests/rest/admin/test_background_updates.py | 9 +- tests/storage/test_background_update.py | 253 ++++++++++++++++++-- 9 files changed, 430 insertions(+), 34 deletions(-) create mode 100644 changelog.d/11980.misc create mode 100644 synapse/config/background_updates.py create mode 100644 tests/config/test_background_update.py diff --git a/changelog.d/11980.misc b/changelog.d/11980.misc new file mode 100644 index 000000000000..36e992e645a3 --- /dev/null +++ b/changelog.d/11980.misc @@ -0,0 +1 @@ +Add config settings for background update parameters. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d634fd8ff54c..36c6c56e58f7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2735,3 +2735,35 @@ redis: # Optional password if configured on the Redis instance # #password: + + +## Background Updates ## + +# Background updates are database updates that are run in the background in batches. +# The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to +# sleep can all be configured. This is helpful to speed up or slow down the updates. +# +background_updates: + # How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set + # a time to change the default. + # + #background_update_duration_ms: 500 + + # Whether to sleep between updates. Defaults to True. Uncomment to change the default. + # + #sleep_enabled: false + + # If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment + # and set a duration to change the default. + # + #sleep_duration_ms: 300 + + # Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and + # set a size to change the default. + # + #min_batch_size: 10 + + # The batch size to use for the first iteration of a new background update. The default is 100. + # Uncomment and set a size to change the default. + # + #default_batch_size: 50 diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 1eb5f5a68cf8..363d8b45545e 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -19,6 +19,7 @@ from synapse.config import ( api, appservice, auth, + background_updates, cache, captcha, cas, @@ -113,6 +114,7 @@ class RootConfig: caches: cache.CacheConfig federation: federation.FederationConfig retention: retention.RetentionConfig + background_updates: background_updates.BackgroundUpdateConfig config_classes: List[Type["Config"]] = ... def __init__(self) -> None: ... diff --git a/synapse/config/background_updates.py b/synapse/config/background_updates.py new file mode 100644 index 000000000000..f6cdeacc4b19 --- /dev/null +++ b/synapse/config/background_updates.py @@ -0,0 +1,68 @@ +# Copyright 2022 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class BackgroundUpdateConfig(Config): + section = "background_updates" + + def generate_config_section(self, **kwargs) -> str: + return """\ + ## Background Updates ## + + # Background updates are database updates that are run in the background in batches. + # The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to + # sleep can all be configured. This is helpful to speed up or slow down the updates. + # + background_updates: + # How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set + # a time to change the default. + # + #background_update_duration_ms: 500 + + # Whether to sleep between updates. Defaults to True. Uncomment to change the default. + # + #sleep_enabled: false + + # If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment + # and set a duration to change the default. + # + #sleep_duration_ms: 300 + + # Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and + # set a size to change the default. + # + #min_batch_size: 10 + + # The batch size to use for the first iteration of a new background update. The default is 100. + # Uncomment and set a size to change the default. + # + #default_batch_size: 50 + """ + + def read_config(self, config, **kwargs) -> None: + bg_update_config = config.get("background_updates") or {} + + self.update_duration_ms = bg_update_config.get( + "background_update_duration_ms", 100 + ) + + self.sleep_enabled = bg_update_config.get("sleep_enabled", True) + + self.sleep_duration_ms = bg_update_config.get("sleep_duration_ms", 1000) + + self.min_batch_size = bg_update_config.get("min_batch_size", 1) + + self.default_batch_size = bg_update_config.get("default_batch_size", 100) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 001605c265fb..a4ec70690802 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -16,6 +16,7 @@ from .api import ApiConfig from .appservice import AppServiceConfig from .auth import AuthConfig +from .background_updates import BackgroundUpdateConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig @@ -99,4 +100,5 @@ class HomeServerConfig(RootConfig): WorkerConfig, RedisConfig, ExperimentalConfig, + BackgroundUpdateConfig, ] diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 4acc2c997dce..08c6eabc6d1a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -60,18 +60,19 @@ class _BackgroundUpdateHandler: class _BackgroundUpdateContextManager: - BACKGROUND_UPDATE_INTERVAL_MS = 1000 - BACKGROUND_UPDATE_DURATION_MS = 100 - - def __init__(self, sleep: bool, clock: Clock): + def __init__( + self, sleep: bool, clock: Clock, sleep_duration_ms: int, update_duration: int + ): self._sleep = sleep self._clock = clock + self._sleep_duration_ms = sleep_duration_ms + self._update_duration_ms = update_duration async def __aenter__(self) -> int: if self._sleep: - await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000) + await self._clock.sleep(self._sleep_duration_ms / 1000) - return self.BACKGROUND_UPDATE_DURATION_MS + return self._update_duration_ms async def __aexit__(self, *exc) -> None: pass @@ -133,9 +134,6 @@ class BackgroundUpdater: process and autotuning the batch size. """ - MINIMUM_BACKGROUND_BATCH_SIZE = 1 - DEFAULT_BACKGROUND_BATCH_SIZE = 100 - def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database @@ -160,6 +158,14 @@ def __init__(self, hs: "HomeServer", database: "DatabasePool"): # enable/disable background updates via the admin API. self.enabled = True + self.minimum_background_batch_size = hs.config.background_updates.min_batch_size + self.default_background_batch_size = ( + hs.config.background_updates.default_batch_size + ) + self.update_duration_ms = hs.config.background_updates.update_duration_ms + self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms + self.sleep_enabled = hs.config.background_updates.sleep_enabled + def register_update_controller_callbacks( self, on_update: ON_UPDATE_CALLBACK, @@ -216,7 +222,9 @@ def _get_context_manager_for_update( if self._on_update_callback is not None: return self._on_update_callback(update_name, database_name, oneshot) - return _BackgroundUpdateContextManager(sleep, self._clock) + return _BackgroundUpdateContextManager( + sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms + ) async def _default_batch_size(self, update_name: str, database_name: str) -> int: """The batch size to use for the first iteration of a new background @@ -225,7 +233,7 @@ async def _default_batch_size(self, update_name: str, database_name: str) -> int if self._default_batch_size_callback is not None: return await self._default_batch_size_callback(update_name, database_name) - return self.DEFAULT_BACKGROUND_BATCH_SIZE + return self.default_background_batch_size async def _min_batch_size(self, update_name: str, database_name: str) -> int: """A lower bound on the batch size of a new background update. @@ -235,7 +243,7 @@ async def _min_batch_size(self, update_name: str, database_name: str) -> int: if self._min_batch_size_callback is not None: return await self._min_batch_size_callback(update_name, database_name) - return self.MINIMUM_BACKGROUND_BATCH_SIZE + return self.minimum_background_batch_size def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: """Returns the current background update, if any.""" @@ -254,9 +262,12 @@ def start_doing_background_updates(self) -> None: if self.enabled: # if we start a new background update, not all updates are done. self._all_done = False - run_as_background_process("background_updates", self.run_background_updates) + sleep = self.sleep_enabled + run_as_background_process( + "background_updates", self.run_background_updates, sleep + ) - async def run_background_updates(self, sleep: bool = True) -> None: + async def run_background_updates(self, sleep: bool) -> None: if self._running or not self.enabled: return diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py new file mode 100644 index 000000000000..0c32c1ca299e --- /dev/null +++ b/tests/config/test_background_update.py @@ -0,0 +1,58 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml + +from synapse.storage.background_updates import BackgroundUpdater + +from tests.unittest import HomeserverTestCase, override_config + + +class BackgroundUpdateConfigTestCase(HomeserverTestCase): + # Tests that the default values in the config are correctly loaded. Note that the default + # values are loaded when the corresponding config options are commented out, which is why there isn't + # a config specified here. + def test_default_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 1) + self.assertEqual(background_updater.default_background_batch_size, 100) + self.assertEqual(background_updater.sleep_enabled, True) + self.assertEqual(background_updater.sleep_duration_ms, 1000) + self.assertEqual(background_updater.update_duration_ms, 100) + + # Tests that non-default values for the config options are properly picked up and passed on. + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 1000 + sleep_enabled: false + sleep_duration_ms: 600 + min_batch_size: 5 + default_batch_size: 50 + """ + ) + ) + def test_custom_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 5) + self.assertEqual(background_updater.default_background_batch_size, 50) + self.assertEqual(background_updater.sleep_enabled, False) + self.assertEqual(background_updater.sleep_duration_ms, 600) + self.assertEqual(background_updater.update_duration_ms, 1000) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index becec84524cb..6cf56b1e352f 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -39,6 +39,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") + self.updater = BackgroundUpdater(hs, self.store.db_pool) @parameterized.expand( [ @@ -135,10 +136,10 @@ def test_status_bg_update(self) -> None: """Test the status API works with a background update.""" # Create a new background update - self._register_bg_update() self.store.db_pool.updates.start_doing_background_updates() + self.reactor.pump([1.0, 1.0, 1.0]) channel = self.make_request( @@ -158,7 +159,7 @@ def test_status_bg_update(self) -> None: "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -213,7 +214,7 @@ def test_enabled(self) -> None: "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -242,7 +243,7 @@ def test_enabled(self) -> None: "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fdf54ea31f9..5cf18b690e48 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -14,12 +14,15 @@ from unittest.mock import Mock +import yaml + from twisted.internet.defer import Deferred, ensureDeferred from synapse.storage.background_updates import BackgroundUpdater from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config class BackgroundUpdateTestCase(unittest.HomeserverTestCase): @@ -34,6 +37,19 @@ def prepare(self, reactor, clock, homeserver): self.updates.register_background_update_handler( "test_update", self.update_handler ) + self.store = self.hs.get_datastores().main + + async def update(self, progress, count): + duration_ms = 10 + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count def test_do_background_update(self): # the time we claim it takes to update one item when running the update @@ -42,27 +58,14 @@ def test_do_background_update(self): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastores().main self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) ) - # first step: make a bit of progress - async def update(progress, count): - await self.clock.sleep((count * duration_ms) / 1000) - progress = {"my_key": progress["my_key"] + 1} - await store.db_pool.runInteraction( - "update_progress", - self.updates._background_update_progress_txn, - "test_update", - progress, - ) - return count - - self.update_handler.side_effect = update + self.update_handler.side_effect = self.update self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), @@ -72,7 +75,7 @@ async def update(progress, count): # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.default_background_batch_size ) # second step: complete the update @@ -99,6 +102,224 @@ async def update(progress, count): self.assertTrue(result) self.assertFalse(self.update_handler.called) + @override_config( + yaml.safe_load( + """ + background_updates: + default_batch_size: 20 + """ + ) + ) + def test_background_update_default_batch_set_by_config(self): + """ + Test that the background update is run with the default_batch_size set by the config + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.01, + ) + self.assertFalse(res) + + # on the first call, we should get run with the default background update size specified in the config + self.update_handler.assert_called_once_with({"my_key": 1}, 20) + + def test_background_update_default_sleep_behavior(self): + """ + Test default background update behavior, which is to sleep + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor less than the default sleep duration (1000ms) + self.reactor.pump([0.5]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past default sleep duration + self.reactor.pump([1]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_duration_ms: 500 + """ + ) + ) + def test_background_update_sleep_set_in_config(self): + """ + Test that changing the sleep time in the config changes how long it sleeps + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor less than the configured sleep duration (500ms) + self.reactor.pump([0.45]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past config sleep duration but less than default duration + self.reactor.pump([0.75]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_enabled: false + """ + ) + ) + def test_disabling_background_update_sleep(self): + """ + Test that disabling sleep in the config results in bg update not sleeping + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor very little + self.reactor.pump([0.025]) + # check that an update has run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 500 + """ + ) + ) + def test_background_update_duration_set_in_config(self): + """ + Test that the desired duration set in the config is used in determining batch size + """ + # Duration of one background update item + duration_ms = 10 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.02, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with 500ms as the + # desired duration + async def update(progress, count): + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, + 500 / duration_ms, + places=0, + ) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update + self.get_success(self.updates.do_next_background_update(False)) + + @override_config( + yaml.safe_load( + """ + background_updates: + min_batch_size: 5 + """ + ) + ) + def test_background_update_min_batch_set_in_config(self): + """ + Test that the minimum batch size set in the config is used + """ + # a very long-running individual update + duration_ms = 50 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + # Run the update with the long-running update item + async def update(progress, count): + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count + + self.update_handler.side_effect = update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=1, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with minimum batch size + # as the first items took a very long time + async def update(progress, count): + self.assertEqual(progress, {"my_key": 2}) + self.assertEqual(count, 5) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update + self.get_success(self.updates.do_next_background_update(False)) + class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): From 54f674f7a9107d3dccd6c126c3e99337314a12c2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Sat, 12 Mar 2022 13:23:37 -0500 Subject: [PATCH 058/230] Deprecate the groups/communities endpoints and add an experimental configuration flag. (#12200) --- changelog.d/12200.removal | 1 + docs/upgrade.md | 14 ++++++++++++++ synapse/app/generic_worker.py | 3 ++- synapse/config/experimental.py | 3 +++ synapse/federation/transport/server/__init__.py | 15 +++++++++++---- synapse/rest/__init__.py | 3 ++- synapse/rest/admin/__init__.py | 3 ++- 7 files changed, 35 insertions(+), 7 deletions(-) create mode 100644 changelog.d/12200.removal diff --git a/changelog.d/12200.removal b/changelog.d/12200.removal new file mode 100644 index 000000000000..312c7ae32597 --- /dev/null +++ b/changelog.d/12200.removal @@ -0,0 +1 @@ +The groups/communities feature in Synapse has been deprecated. diff --git a/docs/upgrade.md b/docs/upgrade.md index 95005962dc49..f9ac605e7b29 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,20 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.56.0 + +## Groups/communities feature has been deprecated + +The non-standard groups/communities feature in Synapse has been deprecated and will +be disabled by default in Synapse v1.58.0. + +You can test disabling it by adding the following to your homeserver configuration: + +```yaml +experimental_features: + groups_enabled: false +``` + # Upgrading to v1.55.0 ## `synctl` script has been moved diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index a10a63b06c7e..b6f510ed3058 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -322,7 +322,8 @@ def _listen_http(self, listener_config: ListenerConfig) -> None: presence.register_servlets(self, resource) - groups.register_servlets(self, resource) + if self.config.experimental.groups_enabled: + groups.register_servlets(self, resource) resources.update({CLIENT_API_PREFIX: resource}) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 41338b39df21..064db4487c85 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -74,3 +74,6 @@ def read_config(self, config: JsonDict, **kwargs): # MSC3720 (Account status endpoint) self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) + + # The deprecated groups feature. + self.groups_enabled: bool = experimental.get("groups_enabled", True) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 67a634790712..71b2f90eb920 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -289,7 +289,7 @@ async def on_GET( return 200, {"sub": user_id} -DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { +SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { "federation": FEDERATION_SERVLET_CLASSES, "room_list": (PublicRoomList,), "group_server": GROUP_SERVER_SERVLET_CLASSES, @@ -298,6 +298,10 @@ async def on_GET( "openid": (OpenIdUserInfo,), } +DEFAULT_SERVLET_GROUPS = ("federation", "room_list", "openid") + +GROUP_SERVLET_GROUPS = ("group_server", "group_local", "group_attestation") + def register_servlets( hs: "HomeServer", @@ -320,16 +324,19 @@ def register_servlets( Defaults to ``DEFAULT_SERVLET_GROUPS``. """ if not servlet_groups: - servlet_groups = DEFAULT_SERVLET_GROUPS.keys() + servlet_groups = DEFAULT_SERVLET_GROUPS + # Only allow the groups servlets if the deprecated groups feature is enabled. + if hs.config.experimental.groups_enabled: + servlet_groups = servlet_groups + GROUP_SERVLET_GROUPS for servlet_group in servlet_groups: # Skip unknown servlet groups. - if servlet_group not in DEFAULT_SERVLET_GROUPS: + if servlet_group not in SERVLET_GROUPS: raise RuntimeError( f"Attempting to register unknown federation servlet: '{servlet_group}'" ) - for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: + for servletclass in SERVLET_GROUPS[servlet_group]: # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled if ( servletclass == FederationTimestampLookupServlet diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index cebdeecb8127..762808a5717b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -118,7 +118,8 @@ def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None: thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource) - groups.register_servlets(hs, client_resource) + if hs.config.experimental.groups_enabled: + groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) room_batch.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6de302f81352..cb4d55c89d78 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -293,7 +293,8 @@ def register_servlets_for_client_rest_resource( ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) - DeleteGroupAdminRestServlet(hs).register(http_server) + if hs.config.experimental.groups_enabled: + DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) # Load the media repo ones if we're using them. Otherwise load the servlets which From 90b2327066d2343faa86c464a182b6f3c4422ecd Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 17:52:15 +0000 Subject: [PATCH 059/230] Add `delay_cancellation` utility function (#12180) `delay_cancellation` behaves like `stop_cancellation`, except it delays `CancelledError`s until the original `Deferred` resolves. This is handy for unifying cleanup paths and ensuring that uncancelled coroutines don't use finished logcontexts. Signed-off-by: Sean Quah --- changelog.d/12180.misc | 1 + synapse/util/async_helpers.py | 48 ++++++++++-- tests/util/test_async_helpers.py | 124 +++++++++++++++++++++++++++++-- 3 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12180.misc diff --git a/changelog.d/12180.misc b/changelog.d/12180.misc new file mode 100644 index 000000000000..7a347352fd91 --- /dev/null +++ b/changelog.d/12180.misc @@ -0,0 +1 @@ +Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a9f67dcbac6a..69c8c1baa9fc 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -686,12 +686,48 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": Synapse logcontext rules. Returns: - A new `Deferred`, which will contain the result of the original `Deferred`, - but will not propagate cancellation through to the original. When cancelled, - the new `Deferred` will fail with a `CancelledError` and will not follow the - Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap - the new `Deferred`. + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will fail with a `CancelledError`. + + The new `Deferred` will not follow the Synapse logcontext rules and should be + wrapped with `make_deferred_yieldable`. + """ + new_deferred: "defer.Deferred[T]" = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred + + +def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Delay cancellation of a `Deferred` until it resolves. + + Has the same effect as `stop_cancellation`, but the returned `Deferred` will not + resolve with a `CancelledError` until the original `Deferred` resolves. + + Args: + deferred: The `Deferred` to protect against cancellation. May optionally follow + the Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will wait until the original `Deferred` + resolves before failing with a `CancelledError`. + + The new `Deferred` will follow the Synapse logcontext rules if `deferred` + follows the Synapse logcontext rules. Otherwise the new `Deferred` should be + wrapped with `make_deferred_yieldable`. """ - new_deferred: defer.Deferred[T] = defer.Deferred() + + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: + # before the new deferred is cancelled, we `pause` it to stop the cancellation + # propagating. we then `unpause` it once the wrapped deferred completes, to + # propagate the exception. + new_deferred.pause() + new_deferred.errback(Failure(CancelledError())) + + deferred.addBoth(lambda _: new_deferred.unpause()) + + new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel) deferred.chainDeferred(new_deferred) return new_deferred diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ff53ce114bd7..e5bc416de12f 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -13,6 +13,8 @@ # limitations under the License. import traceback +from parameterized import parameterized_class + from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock @@ -23,10 +25,12 @@ LoggingContext, PreserveLoggingContext, current_context, + make_deferred_yieldable, ) from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, + delay_cancellation, stop_cancellation, timeout_deferred, ) @@ -313,13 +317,27 @@ async def caller(): self.successResultOf(d2) -class StopCancellationTests(TestCase): - """Tests for the `stop_cancellation` function.""" +@parameterized_class( + ("wrapper",), + [("stop_cancellation",), ("delay_cancellation",)], +) +class CancellationWrapperTests(TestCase): + """Common tests for the `stop_cancellation` and `delay_cancellation` functions.""" + + wrapper: str + + def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]": + if self.wrapper == "stop_cancellation": + return stop_cancellation(deferred) + elif self.wrapper == "delay_cancellation": + return delay_cancellation(deferred) + else: + raise ValueError(f"Unsupported wrapper type: {self.wrapper}") def test_succeed(self): """Test that the new `Deferred` receives the result.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Success should propagate through. deferred.callback("success") @@ -329,7 +347,7 @@ def test_succeed(self): def test_failure(self): """Test that the new `Deferred` receives the `Failure`.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Failure should propagate through. deferred.errback(ValueError("abc")) @@ -337,6 +355,10 @@ def test_failure(self): self.failureResultOf(wrapper_deferred, ValueError) self.assertIsNone(deferred.result, "`Failure` was not consumed") + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + def test_cancellation(self): """Test that cancellation of the new `Deferred` leaves the original running.""" deferred: "Deferred[str]" = Deferred() @@ -347,11 +369,101 @@ def test_cancellation(self): self.assertTrue(wrapper_deferred.called) self.failureResultOf(wrapper_deferred, CancelledError) self.assertFalse( - deferred.called, "Original `Deferred` was unexpectedly cancelled." + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + +class DelayCancellationTests(TestCase): + """Tests for the `delay_cancellation` function.""" + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` waits for the original.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_suppresses_second_cancellation(self): + """Test that a second cancellation is suppressed. + + Identical to `test_cancellation` except the new `Deferred` is cancelled twice. + """ + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`, twice. + wrapper_deferred.cancel() + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" ) - # Now make the inner `Deferred` fail. + # Now make the original `Deferred` fail. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed # in logs. deferred.errback(ValueError("abc")) self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_propagates_cancelled_error(self): + """Test that a `CancelledError` from the original `Deferred` gets propagated.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Fail the original `Deferred` with a `CancelledError`. + cancelled_error = CancelledError() + deferred.errback(cancelled_error) + + # The new `Deferred` should fail with exactly the same `CancelledError`. + self.assertTrue(wrapper_deferred.called) + self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) + + def test_preserves_logcontext(self): + """Test that logging contexts are preserved.""" + blocking_d: "Deferred[None]" = Deferred() + + async def inner(): + await make_deferred_yieldable(blocking_d) + + async def outer(): + with LoggingContext("c") as c: + try: + await delay_cancellation(defer.ensureDeferred(inner())) + self.fail("`CancelledError` was not raised") + except CancelledError: + self.assertEqual(c, current_context()) + # Succeed with no error, unless the logging context is wrong. + + # Run and block inside `inner()`. + d = defer.ensureDeferred(outer()) + self.assertEqual(SENTINEL_CONTEXT, current_context()) + + d.cancel() + + # Now unblock. `outer()` will consume the `CancelledError` and check the + # logging context. + blocking_d.callback(None) + self.successResultOf(d) From 8e5706d14448c0fe8d1c55eaca38a672c701d7a9 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 17:52:58 +0000 Subject: [PATCH 060/230] Fix broken background updates when using sqlite with `enable_search` off (#12215) Signed-off-by: Sean Quah --- changelog.d/12215.bugfix | 1 + synapse/storage/databases/main/search.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12215.bugfix diff --git a/changelog.d/12215.bugfix b/changelog.d/12215.bugfix new file mode 100644 index 000000000000..593b12556beb --- /dev/null +++ b/changelog.d/12215.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.54.0 that broke background updates on sqlite homeservers while search was disabled. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e23b1190726b..c5e9010c83b5 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -125,9 +125,6 @@ def __init__( ): super().__init__(database, db_conn, hs) - if not hs.config.server.enable_search: - return - self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) @@ -243,9 +240,13 @@ def reindex_search_txn(txn): return len(event_search_rows) - result = await self.db_pool.runInteraction( - self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn - ) + if self.hs.config.server.enable_search: + result = await self.db_pool.runInteraction( + self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn + ) + else: + # Don't index anything if search is not enabled. + result = 0 if not result: await self.db_pool.updates._end_background_update( From 605d161d7d585847fd1bb98d14d5281daeac8e86 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 18:49:07 +0000 Subject: [PATCH 061/230] Add cancellation support to `ReadWriteLock` (#12120) Also convert `ReadWriteLock` to use async context managers. Signed-off-by: Sean Quah --- changelog.d/12120.misc | 1 + synapse/handlers/pagination.py | 8 +- synapse/util/async_helpers.py | 71 +++--- tests/util/test_rwlock.py | 395 ++++++++++++++++++++++++++++----- 4 files changed, 382 insertions(+), 93 deletions(-) create mode 100644 changelog.d/12120.misc diff --git a/changelog.d/12120.misc b/changelog.d/12120.misc new file mode 100644 index 000000000000..360309650032 --- /dev/null +++ b/changelog.d/12120.misc @@ -0,0 +1 @@ +Add support for cancellation to `ReadWriteLock`. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 183fabcfc09e..60059fec3e0f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -350,7 +350,7 @@ async def _purge_history( """ self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) @@ -406,7 +406,7 @@ async def purge_room(self, room_id: str, force: bool = False) -> None: room_id: room to be purged force: set true to skip checking for joined users. """ - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) @@ -448,7 +448,7 @@ async def get_messages( room_token = from_token.room_key - with await self.pagination_lock.read(room_id): + async with self.pagination_lock.read(room_id): ( membership, member_event_id, @@ -615,7 +615,7 @@ async def _shutdown_and_purge_room( self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[ delete_id diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 69c8c1baa9fc..6a8e844d6365 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -18,9 +18,10 @@ import inspect import itertools import logging -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from typing import ( Any, + AsyncIterator, Awaitable, Callable, Collection, @@ -40,7 +41,7 @@ ) import attr -from typing_extensions import ContextManager, Literal +from typing_extensions import AsyncContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -491,7 +492,7 @@ class ReadWriteLock: Example: - with await read_write_lock.read("test_key"): + async with read_write_lock.read("test_key"): # do some work """ @@ -514,22 +515,24 @@ def __init__(self) -> None: # Latest writer queued self.key_to_current_writer: Dict[str, defer.Deferred] = {} - async def read(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def read(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - curr_readers.add(new_defer) + curr_readers.add(new_defer) - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - if curr_writer: - await make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager() -> Iterator[None]: try: + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + # May raise a `CancelledError` if the `Deferred` wrapping us is + # cancelled. The `Deferred` we are waiting on must not be cancelled, + # since we do not own it. + if curr_writer: + await make_deferred_yieldable(stop_cancellation(curr_writer)) yield finally: with PreserveLoggingContext(): @@ -538,29 +541,35 @@ def _ctx_manager() -> Iterator[None]: return _ctx_manager() - async def write(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def write(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer + # We can clear the list of current readers since `new_defer` waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer - await make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager() -> Iterator[None]: + to_wait_on_defer = defer.gatherResults(to_wait_on) try: + # Wait for all current readers and the latest writer to finish. + # May raise a `CancelledError` immediately after the wait if the + # `Deferred` wrapping us is cancelled. We must only release the lock + # once we have acquired it, hence the use of `delay_cancellation` + # rather than `stop_cancellation`. + await make_deferred_yieldable(delay_cancellation(to_wait_on_defer)) yield finally: + # Release the lock. with PreserveLoggingContext(): new_defer.callback(None) # `self.key_to_current_writer[key]` may be missing if there was another diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 0774625b8551..0c842261971f 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import AsyncContextManager, Callable, Sequence, Tuple + from twisted.internet import defer -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.util.async_helpers import ReadWriteLock @@ -21,87 +23,187 @@ class ReadWriteLockTestCase(unittest.TestCase): - def _assert_called_before_not_after(self, lst, first_false): - for i, d in enumerate(lst[:first_false]): - self.assertTrue(d.called, msg="%d was unexpectedly false" % i) + def _start_reader_or_writer( + self, + read_or_write: Callable[[str], AsyncContextManager], + key: str, + return_value: str, + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader or writer which acquires the lock, blocks, then completes. + + Args: + read_or_write: A function returning a context manager for a lock. + Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`. + key: The key to read or write. + return_value: A string that the reader or writer will resolve with when + done. + + Returns: + A tuple of three `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader or writer + completes successfully. + * A `Deferred` that resolves once the reader or writer acquires the lock. + * A `Deferred` that blocks the reader or writer. Must be resolved by the + caller to allow the reader or writer to release the lock and complete. + """ + acquired_d: "Deferred[None]" = Deferred() + unblock_d: "Deferred[None]" = Deferred() + + async def reader_or_writer(): + async with read_or_write(key): + acquired_d.callback(None) + await unblock_d + return return_value + + d = defer.ensureDeferred(reader_or_writer()) + return d, acquired_d, unblock_d + + def _start_blocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.read, key, return_value) + + def _start_blocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a writer which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.write, key, return_value) + + def _start_nonblocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a reader which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader completes + successfully. + * A `Deferred` that resolves once the reader acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.read, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _start_nonblocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a writer which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the writer completes + successfully. + * A `Deferred` that resolves once the writer acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.write, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _assert_first_n_resolved( + self, deferreds: Sequence["defer.Deferred[None]"], n: int + ) -> None: + """Assert that exactly the first n `Deferred`s in the given list are resolved. - for i, d in enumerate(lst[first_false:]): + Args: + deferreds: The list of `Deferred`s to be checked. + n: The number of `Deferred`s at the start of `deferreds` that should be + resolved. + """ + for i, d in enumerate(deferreds[:n]): + self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i) + + for i, d in enumerate(deferreds[n:]): self.assertFalse( - d.called, msg="%d was unexpectedly true" % (i + first_false) + d.called, msg="deferred %d was unexpectedly resolved" % (i + n) ) def test_rwlock(self): rwlock = ReadWriteLock() - - key = object() + key = "key" ds = [ - rwlock.read(key), # 0 - rwlock.read(key), # 1 - rwlock.write(key), # 2 - rwlock.write(key), # 3 - rwlock.read(key), # 4 - rwlock.read(key), # 5 - rwlock.write(key), # 6 + self._start_blocking_reader(rwlock, key, "0"), + self._start_blocking_reader(rwlock, key, "1"), + self._start_blocking_writer(rwlock, key, "2"), + self._start_blocking_writer(rwlock, key, "3"), + self._start_blocking_reader(rwlock, key, "4"), + self._start_blocking_reader(rwlock, key, "5"), + self._start_blocking_writer(rwlock, key, "6"), ] - ds = [defer.ensureDeferred(d) for d in ds] + # `Deferred`s that resolve when each reader or writer acquires the lock. + acquired_ds = [acquired_d for _, acquired_d, _ in ds] + # `Deferred`s that will trigger the release of locks when resolved. + release_ds = [release_d for _, _, release_d in ds] - self._assert_called_before_not_after(ds, 2) + # The first two readers should acquire their locks. + self._assert_first_n_resolved(acquired_ds, 2) - with ds[0].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 2) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[0].callback(None) + self._assert_first_n_resolved(acquired_ds, 2) - with ds[1].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 3) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[1].callback(None) + self._assert_first_n_resolved(acquired_ds, 3) - with ds[2].result: - self._assert_called_before_not_after(ds, 3) - self._assert_called_before_not_after(ds, 4) + # Release the write lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 3) + release_ds[2].callback(None) + self._assert_first_n_resolved(acquired_ds, 4) - with ds[3].result: - self._assert_called_before_not_after(ds, 4) - self._assert_called_before_not_after(ds, 6) + # Release the write lock. The next two readers should acquire locks. + self._assert_first_n_resolved(acquired_ds, 4) + release_ds[3].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[5].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 6) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[5].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[4].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 7) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[4].callback(None) + self._assert_first_n_resolved(acquired_ds, 7) - with ds[6].result: - pass + # Release the write lock. + release_ds[6].callback(None) - d = defer.ensureDeferred(rwlock.write(key)) - self.assertTrue(d.called) - with d.result: - pass + # Acquire and release the write and read locks one last time for good measure. + _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer") + self.assertTrue(acquired_d.called) - d = defer.ensureDeferred(rwlock.read(key)) - self.assertTrue(d.called) - with d.result: - pass + _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") + self.assertTrue(acquired_d.called) def test_lock_handoff_to_nonblocking_writer(self): """Test a writer handing the lock to another writer that completes instantly.""" rwlock = ReadWriteLock() key = "key" - unblock: "Deferred[None]" = Deferred() - - async def blocking_write(): - with await rwlock.write(key): - await unblock - - async def nonblocking_write(): - with await rwlock.write(key): - pass - - d1 = defer.ensureDeferred(blocking_write()) - d2 = defer.ensureDeferred(nonblocking_write()) + d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed") + d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") self.assertFalse(d1.called) self.assertFalse(d2.called) @@ -111,5 +213,182 @@ async def nonblocking_write(): self.assertTrue(d2.called) # The `ReadWriteLock` should operate as normal. - d3 = defer.ensureDeferred(nonblocking_write()) + d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") self.assertTrue(d3.called) + + def test_cancellation_while_holding_read_lock(self): + """Test cancellation while holding a read lock. + + A waiting writer should be given the lock when the reader holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed") + + # 2. A writer waits for the reader to complete. + writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed") + self.assertFalse(writer_d.called) + + # 3. The reader is cancelled. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + + # 4. The writer should take the lock and complete. + self.assertTrue( + writer_d.called, "Writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write completed", self.successResultOf(writer_d)) + + def test_cancellation_while_holding_write_lock(self): + """Test cancellation while holding a write lock. + + A waiting reader should be given the lock when the writer holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed") + + # 2. A reader waits for the writer to complete. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. The writer is cancelled. + writer_d.cancel() + self.failureResultOf(writer_d, CancelledError) + + # 4. The reader should take the lock and complete. + self.assertTrue( + reader_d.called, "Reader is stuck waiting for a cancelled writer" + ) + self.assertEqual("read completed", self.successResultOf(reader_d)) + + def test_cancellation_while_waiting_for_read_lock(self): + """Test cancellation while waiting for a read lock. + + Tests that cancelling a waiting reader: + * does not cancel the writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier writer + has finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 2. A reader waits for the first writer to complete. + # This reader will be cancelled later. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. A second writer waits for both the first writer and the reader to complete. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. The waiting reader is cancelled. + # Neither of the writers should be cancelled. + # The second writer should still be waiting, but only on the first writer. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer2_d.called, + "Second writer was unexpectedly cancelled or given the lock before the " + "first writer finished", + ) + + # 5. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 6. The second writer should take the lock and complete. + self.assertTrue( + writer2_d.called, "Second writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) + + def test_cancellation_while_waiting_for_write_lock(self): + """Test cancellation while waiting for a write lock. + + Tests that cancelling a waiting writer: + * does not cancel the reader or writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier reader + and writer have finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, unblock_reader = self._start_blocking_reader( + rwlock, key, "read completed" + ) + + # 2. A writer waits for the reader to complete. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 3. A second writer waits for both the reader and first writer to complete. + # This writer will be cancelled later. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. A third writer waits for the second writer to complete. + writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") + self.assertFalse(writer3_d.called) + + # 5. The second writer is cancelled, but continues waiting for the lock. + # The reader, first writer and third writer should not be cancelled. + # The first writer should still be waiting on the reader. + # The third writer should still be waiting on the second writer. + writer2_d.cancel() + self.assertNoResult(writer2_d) + self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled") + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly cancelled or given the lock before the first " + "writer finished", + ) + + # 6. Unblock the reader, which should complete. + # The first writer should be given the lock and block. + # The third writer should still be waiting on the second writer. + unblock_reader.callback(None) + self.assertEqual("read completed", self.successResultOf(reader_d)) + self.assertNoResult(writer2_d) + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly given the lock before the first writer " + "finished", + ) + + # 7. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 8. The second writer should take the lock and release it immediately, since it + # has been cancelled. + self.failureResultOf(writer2_d, CancelledError) + + # 9. The third writer should take the lock and complete. + self.assertTrue( + writer3_d.called, "Third writer is stuck waiting for a cancelled writer" + ) + self.assertEqual("write 3 completed", self.successResultOf(writer3_d)) From 2fcf4b3f6cd2a0be6597622664636d2219957c2a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 19:04:29 +0000 Subject: [PATCH 062/230] Add cancellation support to `@cached` and `@cachedList` decorators (#12183) These decorators mostly support cancellation already. Add cancellation tests and fix use of finished logging contexts by delaying cancellation, as suggested by @erikjohnston. Signed-off-by: Sean Quah --- changelog.d/12183.misc | 1 + synapse/util/caches/descriptors.py | 11 ++ tests/util/caches/test_descriptors.py | 147 +++++++++++++++++++++++++- 3 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12183.misc diff --git a/changelog.d/12183.misc b/changelog.d/12183.misc new file mode 100644 index 000000000000..dd441bb64ff7 --- /dev/null +++ b/changelog.d/12183.misc @@ -0,0 +1 @@ +Add cancellation support to `@cached` and `@cachedList` decorators. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index c3c5c16db96e..eda92d864dea 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -41,6 +41,7 @@ from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError +from synapse.util.async_helpers import delay_cancellation from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache @@ -350,6 +351,11 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any: ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) ret = cache.set(cache_key, ret, callback=invalidate_callback) + # We started a new call to `self.orig`, so we must always wait for it to + # complete. Otherwise we might mark our current logging context as + # finished while `self.orig` is still using it in the background. + ret = delay_cancellation(ret) + return make_deferred_yieldable(ret) wrapped = cast(_CachedFunction, _wrapped) @@ -510,6 +516,11 @@ def errback_all(f: Failure) -> None: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( lambda _: results, unwrapFirstError ) + if missing: + # We started a new call to `self.orig`, so we must always wait for it to + # complete. Otherwise we might mark our current logging context as + # finished while `self.orig` is still using it in the background. + d = delay_cancellation(d) return make_deferred_yieldable(d) else: return defer.succeed(results) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 6a4b17527a7f..48e616ac7419 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -17,7 +17,7 @@ from unittest import mock from twisted.internet import defer, reactor -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -28,7 +28,7 @@ make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, lru_cache +from synapse.util.caches.descriptors import cached, cachedList, lru_cache from tests import unittest from tests.test_utils import get_awaitable_result @@ -493,6 +493,74 @@ def func3(self, key, cache_context): obj.invalidate() top_invalidate.assert_called_once() + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + async def fn(self, arg1): + await complete_lookup + return str(arg1) + + obj = Cls() + + d1 = obj.fn(123) + d2 = obj.fn(123) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Cancel `d1`, which is the lookup that caused `fn` to run. + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, "123") + + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + async def fn(self, arg1): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return str(arg1) + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.fn(123) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached @@ -865,3 +933,78 @@ async def list_fn(self, args1, arg2): obj.fn.invalidate((10, 2)) invalidate0.assert_called_once() invalidate1.assert_called_once() + + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + def fn(self, arg1): + pass + + @cachedList(cached_method_name="fn", list_name="args") + async def list_fn(self, args): + await complete_lookup + return {arg: str(arg) for arg in args} + + obj = Cls() + + d1 = obj.list_fn([123, 456]) + d2 = obj.list_fn([123, 456, 789]) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) + + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + def fn(self, arg1): + pass + + @cachedList(cached_method_name="fn", list_name="args") + async def list_fn(self, args): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return {arg: str(arg) for arg in args} + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.list_fn([123]) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) From d1130a249b4f462a3e457b783b483d5a6c7486f0 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 15 Mar 2022 11:00:01 +0000 Subject: [PATCH 063/230] 1.55.0rc1 --- CHANGES.md | 82 +++++++++++++++++++++++++++++++++++++++ changelog.d/11700.removal | 1 - changelog.d/11915.misc | 1 - changelog.d/11980.misc | 1 - changelog.d/11998.doc | 1 - changelog.d/12028.feature | 1 - changelog.d/12042.misc | 1 - changelog.d/12090.bugfix | 1 - changelog.d/12101.misc | 1 - changelog.d/12108.misc | 1 - changelog.d/12113.bugfix | 1 - changelog.d/12118.misc | 1 - changelog.d/12120.misc | 1 - changelog.d/12121.bugfix | 1 - changelog.d/12128.misc | 1 - changelog.d/12130.bugfix | 1 - changelog.d/12131.misc | 1 - changelog.d/12132.feature | 1 - changelog.d/12135.feature | 1 - changelog.d/12136.misc | 1 - changelog.d/12137.misc | 1 - changelog.d/12138.removal | 1 - changelog.d/12140.misc | 1 - changelog.d/12142.misc | 1 - changelog.d/12143.doc | 1 - changelog.d/12144.misc | 1 - changelog.d/12145.misc | 1 - changelog.d/12146.misc | 1 - changelog.d/12149.misc | 1 - changelog.d/12150.misc | 1 - changelog.d/12151.feature | 1 - changelog.d/12152.misc | 1 - changelog.d/12153.misc | 1 - changelog.d/12154.misc | 1 - changelog.d/12155.misc | 1 - changelog.d/12156.misc | 1 - changelog.d/12157.bugfix | 1 - changelog.d/12159.misc | 1 - changelog.d/12161.misc | 1 - changelog.d/12163.misc | 1 - changelog.d/12173.misc | 1 - changelog.d/12175.bugfix | 1 - changelog.d/12179.doc | 1 - changelog.d/12180.misc | 1 - changelog.d/12182.misc | 1 - changelog.d/12183.misc | 1 - changelog.d/12187.misc | 1 - changelog.d/12188.misc | 1 - changelog.d/12189.bugfix | 1 - changelog.d/12192.misc | 1 - changelog.d/12196.doc | 1 - changelog.d/12197.misc | 1 - changelog.d/12200.removal | 1 - changelog.d/12202.misc | 1 - changelog.d/12203.misc | 1 - changelog.d/12204.doc | 1 - changelog.d/12206.misc | 1 - changelog.d/12207.misc | 1 - changelog.d/12208.misc | 1 - changelog.d/12210.misc | 1 - changelog.d/12211.misc | 1 - changelog.d/12212.feature | 1 - changelog.d/12215.bugfix | 1 - debian/changelog | 6 +++ synapse/__init__.py | 2 +- 65 files changed, 89 insertions(+), 63 deletions(-) delete mode 100644 changelog.d/11700.removal delete mode 100644 changelog.d/11915.misc delete mode 100644 changelog.d/11980.misc delete mode 100644 changelog.d/11998.doc delete mode 100644 changelog.d/12028.feature delete mode 100644 changelog.d/12042.misc delete mode 100644 changelog.d/12090.bugfix delete mode 100644 changelog.d/12101.misc delete mode 100644 changelog.d/12108.misc delete mode 100644 changelog.d/12113.bugfix delete mode 100644 changelog.d/12118.misc delete mode 100644 changelog.d/12120.misc delete mode 100644 changelog.d/12121.bugfix delete mode 100644 changelog.d/12128.misc delete mode 100644 changelog.d/12130.bugfix delete mode 100644 changelog.d/12131.misc delete mode 100644 changelog.d/12132.feature delete mode 100644 changelog.d/12135.feature delete mode 100644 changelog.d/12136.misc delete mode 100644 changelog.d/12137.misc delete mode 100644 changelog.d/12138.removal delete mode 100644 changelog.d/12140.misc delete mode 100644 changelog.d/12142.misc delete mode 100644 changelog.d/12143.doc delete mode 100644 changelog.d/12144.misc delete mode 100644 changelog.d/12145.misc delete mode 100644 changelog.d/12146.misc delete mode 100644 changelog.d/12149.misc delete mode 100644 changelog.d/12150.misc delete mode 100644 changelog.d/12151.feature delete mode 100644 changelog.d/12152.misc delete mode 100644 changelog.d/12153.misc delete mode 100644 changelog.d/12154.misc delete mode 100644 changelog.d/12155.misc delete mode 100644 changelog.d/12156.misc delete mode 100644 changelog.d/12157.bugfix delete mode 100644 changelog.d/12159.misc delete mode 100644 changelog.d/12161.misc delete mode 100644 changelog.d/12163.misc delete mode 100644 changelog.d/12173.misc delete mode 100644 changelog.d/12175.bugfix delete mode 100644 changelog.d/12179.doc delete mode 100644 changelog.d/12180.misc delete mode 100644 changelog.d/12182.misc delete mode 100644 changelog.d/12183.misc delete mode 100644 changelog.d/12187.misc delete mode 100644 changelog.d/12188.misc delete mode 100644 changelog.d/12189.bugfix delete mode 100644 changelog.d/12192.misc delete mode 100644 changelog.d/12196.doc delete mode 100644 changelog.d/12197.misc delete mode 100644 changelog.d/12200.removal delete mode 100644 changelog.d/12202.misc delete mode 100644 changelog.d/12203.misc delete mode 100644 changelog.d/12204.doc delete mode 100644 changelog.d/12206.misc delete mode 100644 changelog.d/12207.misc delete mode 100644 changelog.d/12208.misc delete mode 100644 changelog.d/12210.misc delete mode 100644 changelog.d/12211.misc delete mode 100644 changelog.d/12212.feature delete mode 100644 changelog.d/12215.bugfix diff --git a/CHANGES.md b/CHANGES.md index ef671e73f178..b0311f73bf84 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,85 @@ +Synapse 1.55.0rc1 (2022-03-15) +============================== + +Features +-------- + +- Add third-party rules rules callbacks `check_can_shutdown_room` and `check_can_deactivate_user`. ([\#12028](https://github.com/matrix-org/synapse/issues/12028)) +- Improve performance of logging in for large accounts. ([\#12132](https://github.com/matrix-org/synapse/issues/12132)) +- Add experimental env var `SYNAPSE_ASYNC_IO_REACTOR` that causes Synapse to use the asyncio reactor for Twisted. ([\#12135](https://github.com/matrix-org/synapse/issues/12135)) +- Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. ([\#12151](https://github.com/matrix-org/synapse/issues/12151)) +- Add a new Jinja2 template filter to extract the local part of an email address. ([\#12212](https://github.com/matrix-org/synapse/issues/12212)) + + +Bugfixes +-------- + +- Use the proper serialization format for bundled thread aggregations. The bug has existed since Synapse v1.48.0. ([\#12090](https://github.com/matrix-org/synapse/issues/12090)) +- Fix a long-standing bug when redacting events with relations. ([\#12113](https://github.com/matrix-org/synapse/issues/12113), [\#12121](https://github.com/matrix-org/synapse/issues/12121), [\#12130](https://github.com/matrix-org/synapse/issues/12130), [\#12189](https://github.com/matrix-org/synapse/issues/12189)) +- Fix a bug introduced in #4864 whereby background updates are never run with the default background batch size. ([\#12157](https://github.com/matrix-org/synapse/issues/12157)) +- Fix a bug where non-standard information was returned from the `/hierarchy` API. Introduced in Synapse v1.41.0. ([\#12175](https://github.com/matrix-org/synapse/issues/12175)) +- Fix a bug introduced in 1.54.0 that broke background updates on sqlite homeservers while search was disabled. ([\#12215](https://github.com/matrix-org/synapse/issues/12215)) + + +Improved Documentation +---------------------- + +- Fix complexity checking config example in [Resource Constrained Devices](https://matrix-org.github.io/synapse/v1.54/other/running_synapse_on_single_board_computers.html) docs page. ([\#11998](https://github.com/matrix-org/synapse/issues/11998)) +- Improve documentation for demo scripts. ([\#12143](https://github.com/matrix-org/synapse/issues/12143)) +- Updates to the Room DAG concepts development document. ([\#12179](https://github.com/matrix-org/synapse/issues/12179)) +- Document that the `typing`, `to_device`, `account_data`, `receipts`, and `presence` stream writer can only be used on a single worker. ([\#12196](https://github.com/matrix-org/synapse/issues/12196)) +- Document that contributors can sign off privately by email. ([\#12204](https://github.com/matrix-org/synapse/issues/12204)) + + +Deprecations and Removals +------------------------- + +- Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700)) +- Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. ([\#12138](https://github.com/matrix-org/synapse/issues/12138)) +- The groups/communities feature in Synapse has been deprecated. ([\#12200](https://github.com/matrix-org/synapse/issues/12200)) + + +Internal Changes +---------------- + +- Simplify the `ApplicationService` class' set of public methods related to interest checking. ([\#11915](https://github.com/matrix-org/synapse/issues/11915)) +- Add config settings for background update parameters. ([\#11980](https://github.com/matrix-org/synapse/issues/11980)) +- Correct type hints for txredis. ([\#12042](https://github.com/matrix-org/synapse/issues/12042)) +- Limit the size of `aggregation_key` on annotations. ([\#12101](https://github.com/matrix-org/synapse/issues/12101)) +- Add type hints to tests files. ([\#12108](https://github.com/matrix-org/synapse/issues/12108), [\#12146](https://github.com/matrix-org/synapse/issues/12146), [\#12207](https://github.com/matrix-org/synapse/issues/12207), [\#12208](https://github.com/matrix-org/synapse/issues/12208)) +- Move scripts to Synapse package and expose as setuptools entry points. ([\#12118](https://github.com/matrix-org/synapse/issues/12118)) +- Add support for cancellation to `ReadWriteLock`. ([\#12120](https://github.com/matrix-org/synapse/issues/12120)) +- Fix data validation to compare to lists, not sequences. ([\#12128](https://github.com/matrix-org/synapse/issues/12128)) +- Fix CI not attaching source distributions and wheels to the GitHub releases. ([\#12131](https://github.com/matrix-org/synapse/issues/12131)) +- Remove unused mocks from `test_typing`. ([\#12136](https://github.com/matrix-org/synapse/issues/12136)) +- Give `scripts-dev` scripts suffixes for neater CI config. ([\#12137](https://github.com/matrix-org/synapse/issues/12137)) +- Move `synctl` into `synapse._scripts` and expose as an entry point. ([\#12140](https://github.com/matrix-org/synapse/issues/12140)) +- Move the snapcraft configuration file to `contrib`. ([\#12142](https://github.com/matrix-org/synapse/issues/12142)) +- Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI. ([\#12144](https://github.com/matrix-org/synapse/issues/12144)) +- Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI. ([\#12145](https://github.com/matrix-org/synapse/issues/12145)) +- Add test for `ObservableDeferred`'s cancellation behaviour. ([\#12149](https://github.com/matrix-org/synapse/issues/12149)) +- Use `ParamSpec` in type hints for `synapse.logging.context`. ([\#12150](https://github.com/matrix-org/synapse/issues/12150)) +- Prune unused jobs from `tox` config. ([\#12152](https://github.com/matrix-org/synapse/issues/12152)) +- Move CI checks out of tox, to facilitate a move to using poetry. ([\#12153](https://github.com/matrix-org/synapse/issues/12153)) +- Avoid generating state groups for local out-of-band leaves. ([\#12154](https://github.com/matrix-org/synapse/issues/12154)) +- Avoid trying to calculate the state at outlier events. ([\#12155](https://github.com/matrix-org/synapse/issues/12155), [\#12173](https://github.com/matrix-org/synapse/issues/12173), [\#12202](https://github.com/matrix-org/synapse/issues/12202)) +- Fix some type annotations. ([\#12156](https://github.com/matrix-org/synapse/issues/12156)) +- Add type hints for `ObservableDeferred` attributes. ([\#12159](https://github.com/matrix-org/synapse/issues/12159)) +- Use a prebuilt Action for the `tests-done` CI job. ([\#12161](https://github.com/matrix-org/synapse/issues/12161)) +- Reduce number of DB queries made during processing of `/sync`. ([\#12163](https://github.com/matrix-org/synapse/issues/12163)) +- Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`. ([\#12180](https://github.com/matrix-org/synapse/issues/12180)) +- Retry HTTP replication failures, this should prevent 502's when restarting stateful workers (main, event persisters, stream writers). Contributed by Nick @ Beeper. ([\#12182](https://github.com/matrix-org/synapse/issues/12182)) +- Add cancellation support to `@cached` and `@cachedList` decorators. ([\#12183](https://github.com/matrix-org/synapse/issues/12183)) +- Remove unused variables. ([\#12187](https://github.com/matrix-org/synapse/issues/12187)) +- Add combined test for HTTP pusher and push rule. Contributed by Nick @ Beeper. ([\#12188](https://github.com/matrix-org/synapse/issues/12188)) +- Rename `HomeServer.get_tcp_replication` to `get_replication_command_handler`. ([\#12192](https://github.com/matrix-org/synapse/issues/12192)) +- Remove some dead code. ([\#12197](https://github.com/matrix-org/synapse/issues/12197)) +- Fix a misleading comment in the function `check_event_for_spam`. ([\#12203](https://github.com/matrix-org/synapse/issues/12203)) +- Remove unnecessary `pass` statements. ([\#12206](https://github.com/matrix-org/synapse/issues/12206)) +- Update the SSO username picker template to comply with SIWA guidelines. ([\#12210](https://github.com/matrix-org/synapse/issues/12210)) +- Improve code documentation for the typing stream over replication. ([\#12211](https://github.com/matrix-org/synapse/issues/12211)) + + Synapse 1.54.0 (2022-03-08) =========================== diff --git a/changelog.d/11700.removal b/changelog.d/11700.removal deleted file mode 100644 index d3d3c48f0fc4..000000000000 --- a/changelog.d/11700.removal +++ /dev/null @@ -1 +0,0 @@ -Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. diff --git a/changelog.d/11915.misc b/changelog.d/11915.misc deleted file mode 100644 index e3cef1511eb6..000000000000 --- a/changelog.d/11915.misc +++ /dev/null @@ -1 +0,0 @@ -Simplify the `ApplicationService` class' set of public methods related to interest checking. \ No newline at end of file diff --git a/changelog.d/11980.misc b/changelog.d/11980.misc deleted file mode 100644 index 36e992e645a3..000000000000 --- a/changelog.d/11980.misc +++ /dev/null @@ -1 +0,0 @@ -Add config settings for background update parameters. \ No newline at end of file diff --git a/changelog.d/11998.doc b/changelog.d/11998.doc deleted file mode 100644 index 33ab7b7880be..000000000000 --- a/changelog.d/11998.doc +++ /dev/null @@ -1 +0,0 @@ -Fix complexity checking config example in [Resource Constrained Devices](https://matrix-org.github.io/synapse/v1.54/other/running_synapse_on_single_board_computers.html) docs page. \ No newline at end of file diff --git a/changelog.d/12028.feature b/changelog.d/12028.feature deleted file mode 100644 index 5549c8f6fcf6..000000000000 --- a/changelog.d/12028.feature +++ /dev/null @@ -1 +0,0 @@ -Add third-party rules rules callbacks `check_can_shutdown_room` and `check_can_deactivate_user`. diff --git a/changelog.d/12042.misc b/changelog.d/12042.misc deleted file mode 100644 index 6ecdc960210c..000000000000 --- a/changelog.d/12042.misc +++ /dev/null @@ -1 +0,0 @@ -Correct type hints for txredis. diff --git a/changelog.d/12090.bugfix b/changelog.d/12090.bugfix deleted file mode 100644 index 087065dcb1cd..000000000000 --- a/changelog.d/12090.bugfix +++ /dev/null @@ -1 +0,0 @@ -Use the proper serialization format for bundled thread aggregations. The bug has existed since Synapse v1.48.0. diff --git a/changelog.d/12101.misc b/changelog.d/12101.misc deleted file mode 100644 index d165f73d13e8..000000000000 --- a/changelog.d/12101.misc +++ /dev/null @@ -1 +0,0 @@ -Limit the size of `aggregation_key` on annotations. diff --git a/changelog.d/12108.misc b/changelog.d/12108.misc deleted file mode 100644 index b67a701dbb52..000000000000 --- a/changelog.d/12108.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to tests files. diff --git a/changelog.d/12113.bugfix b/changelog.d/12113.bugfix deleted file mode 100644 index df9b0dc413dd..000000000000 --- a/changelog.d/12113.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12118.misc b/changelog.d/12118.misc deleted file mode 100644 index a2c397d90755..000000000000 --- a/changelog.d/12118.misc +++ /dev/null @@ -1 +0,0 @@ -Move scripts to Synapse package and expose as setuptools entry points. diff --git a/changelog.d/12120.misc b/changelog.d/12120.misc deleted file mode 100644 index 360309650032..000000000000 --- a/changelog.d/12120.misc +++ /dev/null @@ -1 +0,0 @@ -Add support for cancellation to `ReadWriteLock`. diff --git a/changelog.d/12121.bugfix b/changelog.d/12121.bugfix deleted file mode 100644 index df9b0dc413dd..000000000000 --- a/changelog.d/12121.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12128.misc b/changelog.d/12128.misc deleted file mode 100644 index 0570a8e3272f..000000000000 --- a/changelog.d/12128.misc +++ /dev/null @@ -1 +0,0 @@ -Fix data validation to compare to lists, not sequences. diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix deleted file mode 100644 index df9b0dc413dd..000000000000 --- a/changelog.d/12130.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12131.misc b/changelog.d/12131.misc deleted file mode 100644 index 8ef23c22d524..000000000000 --- a/changelog.d/12131.misc +++ /dev/null @@ -1 +0,0 @@ -Fix CI not attaching source distributions and wheels to the GitHub releases. \ No newline at end of file diff --git a/changelog.d/12132.feature b/changelog.d/12132.feature deleted file mode 100644 index 3b8362ad35ed..000000000000 --- a/changelog.d/12132.feature +++ /dev/null @@ -1 +0,0 @@ -Improve performance of logging in for large accounts. diff --git a/changelog.d/12135.feature b/changelog.d/12135.feature deleted file mode 100644 index b337f51730e6..000000000000 --- a/changelog.d/12135.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental env var `SYNAPSE_ASYNC_IO_REACTOR` that causes Synapse to use the asyncio reactor for Twisted. diff --git a/changelog.d/12136.misc b/changelog.d/12136.misc deleted file mode 100644 index 98b1c1c9d8ac..000000000000 --- a/changelog.d/12136.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unused mocks from `test_typing`. \ No newline at end of file diff --git a/changelog.d/12137.misc b/changelog.d/12137.misc deleted file mode 100644 index 118ff77a91c6..000000000000 --- a/changelog.d/12137.misc +++ /dev/null @@ -1 +0,0 @@ -Give `scripts-dev` scripts suffixes for neater CI config. \ No newline at end of file diff --git a/changelog.d/12138.removal b/changelog.d/12138.removal deleted file mode 100644 index 6ed84d476cd9..000000000000 --- a/changelog.d/12138.removal +++ /dev/null @@ -1 +0,0 @@ -Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. diff --git a/changelog.d/12140.misc b/changelog.d/12140.misc deleted file mode 100644 index 33a21a29f0f4..000000000000 --- a/changelog.d/12140.misc +++ /dev/null @@ -1 +0,0 @@ -Move `synctl` into `synapse._scripts` and expose as an entry point. \ No newline at end of file diff --git a/changelog.d/12142.misc b/changelog.d/12142.misc deleted file mode 100644 index 5d09f90b5244..000000000000 --- a/changelog.d/12142.misc +++ /dev/null @@ -1 +0,0 @@ -Move the snapcraft configuration file to `contrib`. \ No newline at end of file diff --git a/changelog.d/12143.doc b/changelog.d/12143.doc deleted file mode 100644 index 4b9db74b1fc9..000000000000 --- a/changelog.d/12143.doc +++ /dev/null @@ -1 +0,0 @@ -Improve documentation for demo scripts. diff --git a/changelog.d/12144.misc b/changelog.d/12144.misc deleted file mode 100644 index d8f71bb203eb..000000000000 --- a/changelog.d/12144.misc +++ /dev/null @@ -1 +0,0 @@ -Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI. diff --git a/changelog.d/12145.misc b/changelog.d/12145.misc deleted file mode 100644 index 4092a2d66e45..000000000000 --- a/changelog.d/12145.misc +++ /dev/null @@ -1 +0,0 @@ -Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI. diff --git a/changelog.d/12146.misc b/changelog.d/12146.misc deleted file mode 100644 index b67a701dbb52..000000000000 --- a/changelog.d/12146.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to tests files. diff --git a/changelog.d/12149.misc b/changelog.d/12149.misc deleted file mode 100644 index d39af9672365..000000000000 --- a/changelog.d/12149.misc +++ /dev/null @@ -1 +0,0 @@ -Add test for `ObservableDeferred`'s cancellation behaviour. diff --git a/changelog.d/12150.misc b/changelog.d/12150.misc deleted file mode 100644 index 2d2706dac769..000000000000 --- a/changelog.d/12150.misc +++ /dev/null @@ -1 +0,0 @@ -Use `ParamSpec` in type hints for `synapse.logging.context`. diff --git a/changelog.d/12151.feature b/changelog.d/12151.feature deleted file mode 100644 index 18432b2da9a5..000000000000 --- a/changelog.d/12151.feature +++ /dev/null @@ -1 +0,0 @@ -Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. diff --git a/changelog.d/12152.misc b/changelog.d/12152.misc deleted file mode 100644 index b9877eaccbee..000000000000 --- a/changelog.d/12152.misc +++ /dev/null @@ -1 +0,0 @@ -Prune unused jobs from `tox` config. \ No newline at end of file diff --git a/changelog.d/12153.misc b/changelog.d/12153.misc deleted file mode 100644 index f02d140f3871..000000000000 --- a/changelog.d/12153.misc +++ /dev/null @@ -1 +0,0 @@ -Move CI checks out of tox, to facilitate a move to using poetry. \ No newline at end of file diff --git a/changelog.d/12154.misc b/changelog.d/12154.misc deleted file mode 100644 index 18d2a4728be9..000000000000 --- a/changelog.d/12154.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid generating state groups for local out-of-band leaves. diff --git a/changelog.d/12155.misc b/changelog.d/12155.misc deleted file mode 100644 index 9f333e718a86..000000000000 --- a/changelog.d/12155.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid trying to calculate the state at outlier events. diff --git a/changelog.d/12156.misc b/changelog.d/12156.misc deleted file mode 100644 index 4818d988d771..000000000000 --- a/changelog.d/12156.misc +++ /dev/null @@ -1 +0,0 @@ -Fix some type annotations. diff --git a/changelog.d/12157.bugfix b/changelog.d/12157.bugfix deleted file mode 100644 index c3d2e700bb1d..000000000000 --- a/changelog.d/12157.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in #4864 whereby background updates are never run with the default background batch size. diff --git a/changelog.d/12159.misc b/changelog.d/12159.misc deleted file mode 100644 index 30500f2fd95d..000000000000 --- a/changelog.d/12159.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints for `ObservableDeferred` attributes. diff --git a/changelog.d/12161.misc b/changelog.d/12161.misc deleted file mode 100644 index 43eff08d467e..000000000000 --- a/changelog.d/12161.misc +++ /dev/null @@ -1 +0,0 @@ -Use a prebuilt Action for the `tests-done` CI job. diff --git a/changelog.d/12163.misc b/changelog.d/12163.misc deleted file mode 100644 index 13de0895f5fa..000000000000 --- a/changelog.d/12163.misc +++ /dev/null @@ -1 +0,0 @@ -Reduce number of DB queries made during processing of `/sync`. diff --git a/changelog.d/12173.misc b/changelog.d/12173.misc deleted file mode 100644 index 9f333e718a86..000000000000 --- a/changelog.d/12173.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid trying to calculate the state at outlier events. diff --git a/changelog.d/12175.bugfix b/changelog.d/12175.bugfix deleted file mode 100644 index 881cb9b76c20..000000000000 --- a/changelog.d/12175.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where non-standard information was returned from the `/hierarchy` API. Introduced in Synapse v1.41.0. diff --git a/changelog.d/12179.doc b/changelog.d/12179.doc deleted file mode 100644 index 55d8caa45a8c..000000000000 --- a/changelog.d/12179.doc +++ /dev/null @@ -1 +0,0 @@ -Updates to the Room DAG concepts development document. diff --git a/changelog.d/12180.misc b/changelog.d/12180.misc deleted file mode 100644 index 7a347352fd91..000000000000 --- a/changelog.d/12180.misc +++ /dev/null @@ -1 +0,0 @@ -Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`. diff --git a/changelog.d/12182.misc b/changelog.d/12182.misc deleted file mode 100644 index 7e9ad2c75244..000000000000 --- a/changelog.d/12182.misc +++ /dev/null @@ -1 +0,0 @@ -Retry HTTP replication failures, this should prevent 502's when restarting stateful workers (main, event persisters, stream writers). Contributed by Nick @ Beeper. diff --git a/changelog.d/12183.misc b/changelog.d/12183.misc deleted file mode 100644 index dd441bb64ff7..000000000000 --- a/changelog.d/12183.misc +++ /dev/null @@ -1 +0,0 @@ -Add cancellation support to `@cached` and `@cachedList` decorators. diff --git a/changelog.d/12187.misc b/changelog.d/12187.misc deleted file mode 100644 index c53e68faa508..000000000000 --- a/changelog.d/12187.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unused variables. diff --git a/changelog.d/12188.misc b/changelog.d/12188.misc deleted file mode 100644 index 403158481cee..000000000000 --- a/changelog.d/12188.misc +++ /dev/null @@ -1 +0,0 @@ -Add combined test for HTTP pusher and push rule. Contributed by Nick @ Beeper. diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix deleted file mode 100644 index df9b0dc413dd..000000000000 --- a/changelog.d/12189.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12192.misc b/changelog.d/12192.misc deleted file mode 100644 index bdfe8dad98a6..000000000000 --- a/changelog.d/12192.misc +++ /dev/null @@ -1 +0,0 @@ -Rename `HomeServer.get_tcp_replication` to `get_replication_command_handler`. diff --git a/changelog.d/12196.doc b/changelog.d/12196.doc deleted file mode 100644 index 269f06aa3386..000000000000 --- a/changelog.d/12196.doc +++ /dev/null @@ -1 +0,0 @@ -Document that the `typing`, `to_device`, `account_data`, `receipts`, and `presence` stream writer can only be used on a single worker. \ No newline at end of file diff --git a/changelog.d/12197.misc b/changelog.d/12197.misc deleted file mode 100644 index 7d0e9b6bbf4c..000000000000 --- a/changelog.d/12197.misc +++ /dev/null @@ -1 +0,0 @@ -Remove some dead code. diff --git a/changelog.d/12200.removal b/changelog.d/12200.removal deleted file mode 100644 index 312c7ae32597..000000000000 --- a/changelog.d/12200.removal +++ /dev/null @@ -1 +0,0 @@ -The groups/communities feature in Synapse has been deprecated. diff --git a/changelog.d/12202.misc b/changelog.d/12202.misc deleted file mode 100644 index 9f333e718a86..000000000000 --- a/changelog.d/12202.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid trying to calculate the state at outlier events. diff --git a/changelog.d/12203.misc b/changelog.d/12203.misc deleted file mode 100644 index 892dc5bfb7e3..000000000000 --- a/changelog.d/12203.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a misleading comment in the function `check_event_for_spam`. diff --git a/changelog.d/12204.doc b/changelog.d/12204.doc deleted file mode 100644 index c4b2805bb112..000000000000 --- a/changelog.d/12204.doc +++ /dev/null @@ -1 +0,0 @@ -Document that contributors can sign off privately by email. diff --git a/changelog.d/12206.misc b/changelog.d/12206.misc deleted file mode 100644 index df59bb56cdb8..000000000000 --- a/changelog.d/12206.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unnecessary `pass` statements. diff --git a/changelog.d/12207.misc b/changelog.d/12207.misc deleted file mode 100644 index b67a701dbb52..000000000000 --- a/changelog.d/12207.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to tests files. diff --git a/changelog.d/12208.misc b/changelog.d/12208.misc deleted file mode 100644 index c5b635679931..000000000000 --- a/changelog.d/12208.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to tests files. \ No newline at end of file diff --git a/changelog.d/12210.misc b/changelog.d/12210.misc deleted file mode 100644 index 3f6a8747c256..000000000000 --- a/changelog.d/12210.misc +++ /dev/null @@ -1 +0,0 @@ -Update the SSO username picker template to comply with SIWA guidelines. diff --git a/changelog.d/12211.misc b/changelog.d/12211.misc deleted file mode 100644 index d11634a1ee0f..000000000000 --- a/changelog.d/12211.misc +++ /dev/null @@ -1 +0,0 @@ -Improve code documentation for the typing stream over replication. \ No newline at end of file diff --git a/changelog.d/12212.feature b/changelog.d/12212.feature deleted file mode 100644 index fe337ff99057..000000000000 --- a/changelog.d/12212.feature +++ /dev/null @@ -1 +0,0 @@ -Add a new Jinja2 template filter to extract the local part of an email address. diff --git a/changelog.d/12215.bugfix b/changelog.d/12215.bugfix deleted file mode 100644 index 593b12556beb..000000000000 --- a/changelog.d/12215.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in 1.54.0 that broke background updates on sqlite homeservers while search was disabled. diff --git a/debian/changelog b/debian/changelog index 02136a0d606f..09ef24ebb0d3 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.55.0~rc1) stable; urgency=medium + + * New synapse release 1.55.0~rc1. + + -- Synapse Packaging team Tue, 15 Mar 2022 10:59:31 +0000 + matrix-synapse-py3 (1.54.0) stable; urgency=medium * New synapse release 1.54.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index 4b0056597692..870707f47643 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -68,7 +68,7 @@ except ImportError: pass -__version__ = "1.54.0" +__version__ = "1.55.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when From 9e90d643e6947c2fa4286c21a6351061cf510026 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 15 Mar 2022 11:16:36 +0000 Subject: [PATCH 064/230] Changelog tweaks --- CHANGES.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b0311f73bf84..60e7ecb1b9c1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,10 +1,12 @@ Synapse 1.55.0rc1 (2022-03-15) ============================== +This release removes a workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. **This breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700))**; Mjolnir users should upgrade Mjolnir before upgrading Synapse to this version. + Features -------- -- Add third-party rules rules callbacks `check_can_shutdown_room` and `check_can_deactivate_user`. ([\#12028](https://github.com/matrix-org/synapse/issues/12028)) +- Add third-party rules callbacks `check_can_shutdown_room` and `check_can_deactivate_user`. ([\#12028](https://github.com/matrix-org/synapse/issues/12028)) - Improve performance of logging in for large accounts. ([\#12132](https://github.com/matrix-org/synapse/issues/12132)) - Add experimental env var `SYNAPSE_ASYNC_IO_REACTOR` that causes Synapse to use the asyncio reactor for Twisted. ([\#12135](https://github.com/matrix-org/synapse/issues/12135)) - Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. ([\#12151](https://github.com/matrix-org/synapse/issues/12151)) @@ -16,9 +18,9 @@ Bugfixes - Use the proper serialization format for bundled thread aggregations. The bug has existed since Synapse v1.48.0. ([\#12090](https://github.com/matrix-org/synapse/issues/12090)) - Fix a long-standing bug when redacting events with relations. ([\#12113](https://github.com/matrix-org/synapse/issues/12113), [\#12121](https://github.com/matrix-org/synapse/issues/12121), [\#12130](https://github.com/matrix-org/synapse/issues/12130), [\#12189](https://github.com/matrix-org/synapse/issues/12189)) -- Fix a bug introduced in #4864 whereby background updates are never run with the default background batch size. ([\#12157](https://github.com/matrix-org/synapse/issues/12157)) +- Fix a bug introduced in Synapse 1.7.2 whereby background updates are never run with the default background batch size. ([\#12157](https://github.com/matrix-org/synapse/issues/12157)) - Fix a bug where non-standard information was returned from the `/hierarchy` API. Introduced in Synapse v1.41.0. ([\#12175](https://github.com/matrix-org/synapse/issues/12175)) -- Fix a bug introduced in 1.54.0 that broke background updates on sqlite homeservers while search was disabled. ([\#12215](https://github.com/matrix-org/synapse/issues/12215)) +- Fix a bug introduced in Synapse 1.54.0 that broke background updates on sqlite homeservers while search was disabled. ([\#12215](https://github.com/matrix-org/synapse/issues/12215)) Improved Documentation @@ -34,7 +36,7 @@ Improved Documentation Deprecations and Removals ------------------------- -- Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700)) +- **Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700))** - Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. ([\#12138](https://github.com/matrix-org/synapse/issues/12138)) - The groups/communities feature in Synapse has been deprecated. ([\#12200](https://github.com/matrix-org/synapse/issues/12200)) From 5dd949bee6158a8b651db9f2ae417a62c8184bfd Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 15 Mar 2022 14:16:37 +0100 Subject: [PATCH 065/230] Add type hints to some tests/handlers files. (#12224) --- changelog.d/12224.misc | 1 + mypy.ini | 5 -- tests/handlers/test_directory.py | 84 +++++++++++++++------------- tests/handlers/test_e2e_keys.py | 36 ++++++------ tests/handlers/test_oidc.py | 94 +++++++++++++++++--------------- tests/handlers/test_profile.py | 43 ++++++++------- tests/handlers/test_saml.py | 24 ++++---- 7 files changed, 156 insertions(+), 131 deletions(-) create mode 100644 changelog.d/12224.misc diff --git a/changelog.d/12224.misc b/changelog.d/12224.misc new file mode 100644 index 000000000000..b67a701dbb52 --- /dev/null +++ b/changelog.d/12224.misc @@ -0,0 +1 @@ +Add type hints to tests files. diff --git a/mypy.ini b/mypy.ini index f9c39fcaaee3..fe31bfb8bb37 100644 --- a/mypy.ini +++ b/mypy.ini @@ -67,13 +67,8 @@ exclude = (?x) |tests/federation/transport/test_knocking.py |tests/federation/transport/test_server.py |tests/handlers/test_cas.py - |tests/handlers/test_directory.py - |tests/handlers/test_e2e_keys.py |tests/handlers/test_federation.py - |tests/handlers/test_oidc.py |tests/handlers/test_presence.py - |tests/handlers/test_profile.py - |tests/handlers/test_saml.py |tests/handlers/test_typing.py |tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_srv_resolver.py diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 6e403a87c5d0..11ad44223d39 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -12,14 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import directory, login, room -from synapse.types import RoomAlias, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, RoomAlias, create_requester +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -28,13 +32,15 @@ class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -54,7 +60,7 @@ def register_query_handler(query_type, handler): return hs - def test_get_local_association(self): + def test_get_local_association(self) -> None: self.get_success( self.store.create_room_alias_association( self.my_room, "!8765qwer:test", ["test"] @@ -65,7 +71,7 @@ def test_get_local_association(self): self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - def test_get_remote_association(self): + def test_get_remote_association(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) @@ -83,7 +89,7 @@ def test_get_remote_association(self): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success( self.store.create_room_alias_association( self.your_room, "!8765asdf:test", ["test"] @@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_directory_handler() # Create user @@ -125,7 +131,7 @@ def prepare(self, reactor, clock, hs): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def test_create_alias_joined_room(self): + def test_create_alias_joined_room(self) -> None: """A user can create an alias for a room they're in.""" self.get_success( self.handler.create_association( @@ -135,7 +141,7 @@ def test_create_alias_joined_room(self): ) ) - def test_create_alias_other_room(self): + def test_create_alias_other_room(self) -> None: """A user cannot create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok @@ -150,7 +156,7 @@ def test_create_alias_other_room(self): synapse.api.errors.SynapseError, ) - def test_create_alias_admin(self): + def test_create_alias_admin(self) -> None: """An admin can create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.test_user, tok=self.test_user_tok @@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -195,7 +201,7 @@ def prepare(self, reactor, clock, hs): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def _create_alias(self, user): + def _create_alias(self, user) -> None: # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( @@ -203,7 +209,7 @@ def _create_alias(self, user): ) ) - def test_delete_alias_not_allowed(self): + def test_delete_alias_not_allowed(self) -> None: """A user that doesn't meet the expected guidelines cannot delete an alias.""" self._create_alias(self.admin_user) self.get_failure( @@ -213,7 +219,7 @@ def test_delete_alias_not_allowed(self): synapse.api.errors.AuthError, ) - def test_delete_alias_creator(self): + def test_delete_alias_creator(self) -> None: """An alias creator can delete their own alias.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -232,7 +238,7 @@ def test_delete_alias_creator(self): synapse.api.errors.SynapseError, ) - def test_delete_alias_admin(self): + def test_delete_alias_admin(self) -> None: """A server admin can delete an alias created by another user.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -251,7 +257,7 @@ def test_delete_alias_admin(self): synapse.api.errors.SynapseError, ) - def test_delete_alias_sufficient_power(self): + def test_delete_alias_sufficient_power(self) -> None: """A user with a sufficient power level should be able to delete an alias.""" self._create_alias(self.admin_user) @@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -317,7 +323,7 @@ def _add_alias(self, alias: str) -> RoomAlias: ) return room_alias - def _set_canonical_alias(self, content): + def _set_canonical_alias(self, content) -> None: """Configure the canonical alias state on the room.""" self.helper.send_state( self.room_id, @@ -334,7 +340,7 @@ def _get_canonical_alias(self): ) ) - def test_remove_alias(self): + def test_remove_alias(self) -> None: """Removing an alias that is the canonical alias should remove it there too.""" # Set this new alias as the canonical alias for this room self._set_canonical_alias( @@ -356,7 +362,7 @@ def test_remove_alias(self): self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) - def test_remove_other_alias(self): + def test_remove_other_alias(self) -> None: """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" @@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom alias creation rules to the config. @@ -403,7 +409,7 @@ def default_config(self): return config - def test_denied(self): + def test_denied(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -413,7 +419,7 @@ def test_denied(self): ) self.assertEqual(403, channel.code, channel.result) - def test_allowed(self): + def test_allowed(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -423,7 +429,7 @@ def test_allowed(self): ) self.assertEqual(200, channel.code, channel.result) - def test_denied_during_creation(self): + def test_denied_during_creation(self) -> None: """A room alias that is not allowed should be rejected during creation.""" # Invalid room alias. self.helper.create_room_as( @@ -432,7 +438,7 @@ def test_denied_during_creation(self): extra_content={"room_alias_name": "foo"}, ) - def test_allowed_during_creation(self): + def test_allowed_during_creation(self) -> None: """A valid room alias should be allowed during creation.""" room_id = self.helper.create_room_as( self.user_id, @@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): data = {"room_alias_name": "unofficial_test"} allowed_localpart = "allowed" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom room list publication rules to the config. @@ -474,7 +480,9 @@ def default_config(self): return config - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass") @@ -483,7 +491,7 @@ def prepare(self, reactor, clock, hs): return hs - def test_denied_without_publication_permission(self): + def test_denied_without_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user without permission to publish rooms. @@ -497,7 +505,7 @@ def test_denied_without_publication_permission(self): expect_code=403, ) - def test_allowed_when_creating_private_room(self): + def test_allowed_when_creating_private_room(self) -> None: """ Try to create a room, register an alias for it, and NOT publish it, as a user without permission to publish rooms. @@ -511,7 +519,7 @@ def test_allowed_when_creating_private_room(self): expect_code=200, ) - def test_allowed_with_publication_permission(self): + def test_allowed_with_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -525,7 +533,7 @@ def test_allowed_with_publication_permission(self): expect_code=200, ) - def test_denied_publication_with_invalid_alias(self): + def test_denied_publication_with_invalid_alias(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -538,7 +546,7 @@ def test_denied_publication_with_invalid_alias(self): expect_code=403, ) - def test_can_create_as_private_room_after_rejection(self): + def test_can_create_as_private_room_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as the same user, but without publishing the room. @@ -549,7 +557,7 @@ def test_can_create_as_private_room_after_rejection(self): self.test_denied_without_publication_permission() self.test_allowed_when_creating_private_room() - def test_can_create_with_permission_after_rejection(self): + def test_can_create_with_permission_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as someone with permission, using the same alias. @@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -579,7 +589,7 @@ def prepare(self, reactor, clock, hs): return hs - def test_disabling_room_list(self): + def test_disabling_room_list(self) -> None: self.room_list_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 9338ab92e98e..ac21a28c4331 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -20,33 +20,37 @@ from signedjson import key as key, sign as sign from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(federation_client=mock.Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastores().main - def test_query_local_devices_no_devices(self): + def test_query_local_devices_no_devices(self) -> None: """If the user has no devices, we expect an empty list.""" local_user = "@boris:" + self.hs.hostname res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_reupload_one_time_keys(self): + def test_reupload_one_time_keys(self) -> None: """we should be able to re-upload the same keys""" local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = { + keys: JsonDict = { "alg1:k1": "key1", "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, "alg2:k3": {"key": "key3"}, @@ -74,7 +78,7 @@ def test_reupload_one_time_keys(self): res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} ) - def test_change_one_time_keys(self): + def test_change_one_time_keys(self) -> None: """attempts to change one-time-keys should be rejected""" local_user = "@boris:" + self.hs.hostname @@ -134,7 +138,7 @@ def test_change_one_time_keys(self): SynapseError, ) - def test_claim_one_time_key(self): + def test_claim_one_time_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" keys = {"alg1:k1": "key1"} @@ -161,7 +165,7 @@ def test_claim_one_time_key(self): }, ) - def test_fallback_key(self): + def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" fallback_key = {"alg1:k1": "fallback_key1"} @@ -294,7 +298,7 @@ def test_fallback_key(self): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) - def test_replace_master_key(self): + def test_replace_master_key(self) -> None: """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -328,7 +332,7 @@ def test_replace_master_key(self): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - def test_reupload_signatures(self): + def test_reupload_signatures(self) -> None: """re-uploading a signature should not fail""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -433,7 +437,7 @@ def test_reupload_signatures(self): self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) - def test_self_signing_key_doesnt_show_up_as_device(self): + def test_self_signing_key_doesnt_show_up_as_device(self) -> None: """signing keys should be hidden when fetching a user's devices""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -462,7 +466,7 @@ def test_self_signing_key_doesnt_show_up_as_device(self): res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_upload_signatures(self): + def test_upload_signatures(self) -> None: """should check signatures that are uploaded""" # set up a user with cross-signing keys and a device. This user will # try uploading signatures @@ -686,7 +690,7 @@ def test_upload_signatures(self): other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) - def test_query_devices_remote_no_sync(self): + def test_query_devices_remote_no_sync(self) -> None: """Tests that querying keys for a remote user that we don't share a room with returns the cross signing keys correctly. """ @@ -759,7 +763,7 @@ def test_query_devices_remote_no_sync(self): }, ) - def test_query_devices_remote_sync(self): + def test_query_devices_remote_sync(self) -> None: """Tests that querying keys for a remote user that we share a room with, but haven't yet fetched the keys for, returns the cross signing keys correctly. @@ -845,7 +849,7 @@ def test_query_devices_remote_sync(self): (["device_1", "device_2"],), ] ) - def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None: """Test that requests for all of a remote user's devices are cached. We do this by asserting that only one call over federation was made, and that @@ -853,7 +857,7 @@ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): """ local_user_id = "@test:test" remote_user_id = "@test:other" - request_body = {"device_keys": {remote_user_id: []}} + request_body: JsonDict = {"device_keys": {remote_user_id: []}} response_devices = [ { diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e8418b6638e4..014815db6ee4 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,14 +13,18 @@ # limitations under the License. import json import os +from typing import Any, Dict from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.sso import MappingException from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock @@ -98,7 +102,7 @@ async def map_user_attributes(self, userinfo, token, failures): } -async def get_json(url): +async def get_json(url: str) -> JsonDict: # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: # Minimal discovery document, as defined in OpenID.Discovery @@ -116,6 +120,8 @@ async def get_json(url): elif url == JWKS_URI: return {"keys": []} + return {} + def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase): if not HAS_OIDC: skip = "requires OIDC" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock(spec=["get_json"]) self.http_client.get_json.side_effect = get_json self.http_client.user_agent = b"Synapse Test" @@ -164,7 +170,7 @@ def make_homeserver(self, reactor, clock): sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error + sso_handler.render_error = self.render_error # type: ignore[assignment] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 @@ -193,14 +199,14 @@ def assertRenderedError(self, error, error_description=None): return args @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_config(self): + def test_config(self) -> None: """Basic config correctly sets up the callback URL and client auth correctly.""" self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) - def test_discovery(self): + def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) @@ -219,13 +225,13 @@ def test_discovery(self): self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_no_discovery(self): + def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_load_jwks(self): + def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) @@ -253,7 +259,7 @@ async def patched_load_metadata(): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_validate_config(self): + def test_validate_config(self) -> None: """Provider metadatas are extensively validated.""" h = self.provider @@ -336,14 +342,14 @@ async def force_load(): force_load_metadata() @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) - def test_skip_verification(self): + def test_skip_verification(self) -> None: """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw get_awaitable_result(self.provider.load_metadata()) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_redirect_request(self): + def test_redirect_request(self) -> None: """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["cookies"]) req.cookies = [] @@ -387,7 +393,7 @@ def test_redirect_request(self): self.assertEqual(redirect, "http://client/redirect") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_error(self): + def test_callback_error(self) -> None: """Errors from the provider returned in the callback are displayed.""" request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] @@ -399,7 +405,7 @@ def test_callback_error(self): self.assertRenderedError("invalid_client", "some description") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback(self): + def test_callback(self) -> None: """Code callback works and display errors if something went wrong. A lot of scenarios are tested here: @@ -428,9 +434,9 @@ def test_callback(self): "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -468,7 +474,7 @@ def test_callback(self): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") @@ -483,7 +489,7 @@ def test_callback(self): "type": "bearer", "access_token": "access_token", } - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( @@ -510,8 +516,8 @@ def test_callback(self): id_token = { "sid": "abcdefgh", } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] auth_handler.complete_sso_login.reset_mock() self.provider._fetch_userinfo.reset_mock() self.get_success(self.handler.handle_oidc_callback(request)) @@ -531,21 +537,21 @@ def test_callback(self): self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure from synapse.handlers.oidc import OidcError - self.provider._exchange_code = simple_async_mock( + self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_session(self): + def test_callback_session(self) -> None: """The callback verifies the session presence and validity""" request = Mock(spec=["args", "getCookie", "cookies"]) @@ -590,7 +596,7 @@ def test_callback_session(self): @override_config( {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} ) - def test_exchange_code(self): + def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} token_json = json.dumps(token).encode("utf-8") @@ -686,7 +692,7 @@ def test_exchange_code(self): } } ) - def test_exchange_code_jwt_key(self): + def test_exchange_code_jwt_key(self) -> None: """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt @@ -741,7 +747,7 @@ def test_exchange_code_jwt_key(self): } } ) - def test_exchange_code_no_auth(self): + def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" token = {"type": "bearer"} self.http_client.request = simple_async_mock( @@ -776,7 +782,7 @@ def test_exchange_code_no_auth(self): } } ) - def test_extra_attributes(self): + def test_extra_attributes(self) -> None: """ Login while using a mapping provider that implements get_extra_attributes. """ @@ -790,8 +796,8 @@ def test_extra_attributes(self): "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -817,12 +823,12 @@ def test_extra_attributes(self): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_user(self): + def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - userinfo = { + userinfo: dict = { "sub": "test_user", "username": "test_user", } @@ -870,7 +876,7 @@ def test_map_userinfo_to_user(self): ) @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) - def test_map_userinfo_to_existing_user(self): + def test_map_userinfo_to_existing_user(self) -> None: """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") @@ -974,7 +980,7 @@ def test_map_userinfo_to_existing_user(self): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_invalid_localpart(self): + def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" self.get_success( _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) @@ -991,7 +997,7 @@ def test_map_userinfo_to_invalid_localpart(self): } } ) - def test_map_userinfo_to_user_retries(self): + def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1039,7 +1045,7 @@ def test_map_userinfo_to_user_retries(self): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_empty_localpart(self): + def test_empty_localpart(self) -> None: """Attempts to map onto an empty localpart should be rejected.""" userinfo = { "sub": "tester", @@ -1058,7 +1064,7 @@ def test_empty_localpart(self): } } ) - def test_null_localpart(self): + def test_null_localpart(self) -> None: """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" userinfo = { "sub": "tester", @@ -1075,7 +1081,7 @@ def test_null_localpart(self): } } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1115,7 +1121,7 @@ def test_attribute_requirements(self): } } ) - def test_attribute_requirements_contains(self): + def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1146,7 +1152,7 @@ def test_attribute_requirements_contains(self): } } ) - def test_attribute_requirements_mismatch(self): + def test_attribute_requirements_mismatch(self) -> None: """ Test that auth fails if attributes exist but don't match, or are non-string values. @@ -1154,7 +1160,7 @@ def test_attribute_requirements_mismatch(self): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail - userinfo = { + userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", @@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) - provider._parse_id_token = simple_async_mock(return_value=userinfo) - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] + provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] state = "state" session = handler._token_generator.generate_oidc_session_token( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 972cbac6e496..08733a9f2d42 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.types from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -47,7 +52,7 @@ def register_query_handler(query_type, handler): ) return hs - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") @@ -58,7 +63,7 @@ def prepare(self, reactor, clock, hs: HomeServer): self.handler = hs.get_profile_handler() - def test_get_my_name(self): + def test_get_my_name(self) -> None: self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) @@ -67,7 +72,7 @@ def test_get_my_name(self): self.assertEqual("Frank", displayname) - def test_set_my_name(self): + def test_set_my_name(self) -> None: self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." @@ -110,7 +115,7 @@ def test_set_my_name(self): self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) - def test_set_my_name_if_disabled(self): + def test_set_my_name_if_disabled(self) -> None: self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed @@ -135,7 +140,7 @@ def test_set_my_name_if_disabled(self): SynapseError, ) - def test_set_my_name_noauth(self): + def test_set_my_name_noauth(self) -> None: self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.bob), "Frank Jr." @@ -143,7 +148,7 @@ def test_set_my_name_noauth(self): AuthError, ) - def test_get_other_name(self): + def test_get_other_name(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) @@ -158,7 +163,7 @@ def test_get_other_name(self): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success(self.store.create_profile("caroline")) self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) @@ -174,7 +179,7 @@ def test_incoming_fed_query(self): self.assertEqual({"displayname": "Caroline"}, response) - def test_get_my_avatar(self): + def test_get_my_avatar(self) -> None: self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" @@ -184,7 +189,7 @@ def test_get_my_avatar(self): self.assertEqual("http://my.server/me.png", avatar_url) - def test_set_my_avatar(self): + def test_set_my_avatar(self) -> None: self.get_success( self.handler.set_avatar_url( self.frank, @@ -225,7 +230,7 @@ def test_set_my_avatar(self): (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), ) - def test_set_my_avatar_if_disabled(self): + def test_set_my_avatar_if_disabled(self) -> None: self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed @@ -250,7 +255,7 @@ def test_set_my_avatar_if_disabled(self): SynapseError, ) - def test_avatar_constraints_no_config(self): + def test_avatar_constraints_no_config(self) -> None: """Tests that the method to check an avatar against configured constraints skips all of its check if no constraint is configured. """ @@ -263,7 +268,7 @@ def test_avatar_constraints_no_config(self): self.assertTrue(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_missing(self): + def test_avatar_constraints_missing(self) -> None: """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't be found. """ @@ -273,7 +278,7 @@ def test_avatar_constraints_missing(self): self.assertFalse(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_file_size(self): + def test_avatar_constraints_file_size(self) -> None: """Tests that a file that's above the allowed file size is forbidden but one that's below it is allowed. """ @@ -295,7 +300,7 @@ def test_avatar_constraints_file_size(self): self.assertFalse(res) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_constraint_mime_type(self): + def test_avatar_constraint_mime_type(self) -> None: """Tests that a file with an unauthorised MIME type is forbidden but one with an authorised content type is allowed. """ diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 23941abed852..8d4404eda10d 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional from unittest.mock import Mock import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import RedirectException +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -81,10 +85,10 @@ def saml_response_to_user_attributes( class SamlHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL - saml_config = { + saml_config: Dict[str, Any] = { "sp_config": {"metadata": {}}, # Disable grandfathering. "grandfathered_mxid_source_attribute": None, @@ -98,7 +102,7 @@ def default_config(self): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_saml_handler() @@ -114,7 +118,7 @@ def make_homeserver(self, reactor, clock): elif not has_xmlsec1: skip = "Requires xmlsec1" - def test_map_saml_response_to_user(self): + def test_map_saml_response_to_user(self) -> None: """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -140,7 +144,7 @@ def test_map_saml_response_to_user(self): ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) - def test_map_saml_response_to_existing_user(self): + def test_map_saml_response_to_existing_user(self) -> None: """Existing users can log in with SAML account.""" store = self.hs.get_datastores().main self.get_success( @@ -186,7 +190,7 @@ def test_map_saml_response_to_existing_user(self): auth_provider_session_id=None, ) - def test_map_saml_response_to_invalid_localpart(self): + def test_map_saml_response_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" # stub out the auth handler @@ -207,7 +211,7 @@ def test_map_saml_response_to_invalid_localpart(self): ) auth_handler.complete_sso_login.assert_not_called() - def test_map_saml_response_to_user_retries(self): + def test_map_saml_response_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" # stub out the auth handler and error renderer @@ -271,7 +275,7 @@ def test_map_saml_response_to_user_retries(self): } } ) - def test_map_saml_response_redirect(self): + def test_map_saml_response_redirect(self) -> None: """Test a mapping provider that raises a RedirectException""" saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) @@ -292,7 +296,7 @@ def test_map_saml_response_redirect(self): }, } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the SAML response.""" # stub out the auth handler From dea577998f221297d3ff30bdf904f7147f3c3d8a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 15 Mar 2022 15:40:34 +0000 Subject: [PATCH 066/230] Add tests for database transaction callbacks (#12198) Signed-off-by: Sean Quah --- changelog.d/12198.misc | 1 + tests/storage/test_database.py | 104 ++++++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12198.misc diff --git a/changelog.d/12198.misc b/changelog.d/12198.misc new file mode 100644 index 000000000000..6b184a9053f8 --- /dev/null +++ b/changelog.d/12198.misc @@ -0,0 +1 @@ +Add tests for database transaction callbacks. diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 85978675634b..ae13bed08621 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import make_tuple_comparison_clause +from typing import Callable, Tuple +from unittest.mock import Mock, call + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.util import Clock from tests import unittest @@ -22,3 +33,94 @@ def test_native_tuple_comparison(self): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) + + +class CallbacksTestCase(unittest.HomeserverTestCase): + """Tests for transaction callbacks.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def _run_interaction( + self, func: Callable[[LoggingTransaction], object] + ) -> Tuple[Mock, Mock]: + """Run the given function in a database transaction, with callbacks registered. + + Args: + func: The function to be run in a transaction. The transaction will be + retried if `func` raises an `OperationalError`. + + Returns: + Two mocks, which were registered as an `after_callback` and an + `exception_callback` respectively, on every transaction attempt. + """ + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + func(txn) + + try: + self.get_success_or_raise( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + except Exception: + pass + + return after_callback, exception_callback + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + after_callback, exception_callback = self._run_interaction(lambda txn: None) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + _test_txn = Mock(side_effect=ZeroDivisionError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_called_once_with(987, 654, extra=321) + + def test_failed_retry(self) -> None: + """Test that the exception callback is called for every failed attempt.""" + # Always raise an `OperationalError`. + _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls + + def test_successful_retry(self) -> None: + """Test callbacks for a failed transaction followed by a successful attempt.""" + # Raise an `OperationalError` on the first attempt only. + _test_txn = Mock( + side_effect=[self.db_pool.engine.module.OperationalError, None] + ) + after_callback, exception_callback = self._run_interaction(_test_txn) + + # Calling both `after_callback`s when the first attempt failed is rather + # surprising (#12184). Let's document the behaviour in a test. + after_callback.assert_has_calls( + [ + call(123, 456, extra=789), + call(123, 456, extra=789), + ] + ) + self.assertEqual(after_callback.call_count, 2) # no additional calls + exception_callback.assert_not_called() From dda9b7fc4d2e6ca84a1a994a7ff1943b590e71df Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 15 Mar 2022 14:06:05 -0400 Subject: [PATCH 067/230] Use the ignored_users table to test event visibility & sync. (#12225) Instead of fetching the raw account data and re-parsing it. The ignored_users table is a denormalised version of the account data for quick searching. --- changelog.d/12225.misc | 1 + synapse/handlers/sync.py | 30 +------------- synapse/push/bulk_push_rule_evaluator.py | 2 +- .../storage/databases/main/account_data.py | 41 +++++++++++++++++-- synapse/visibility.py | 18 ++------ tests/storage/test_account_data.py | 17 ++++++++ 6 files changed, 62 insertions(+), 47 deletions(-) create mode 100644 changelog.d/12225.misc diff --git a/changelog.d/12225.misc b/changelog.d/12225.misc new file mode 100644 index 000000000000..23105c727c3c --- /dev/null +++ b/changelog.d/12225.misc @@ -0,0 +1 @@ +Use the `ignored_users` table in additional places instead of re-parsing the account data. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0aa3052fd6ac..c9d6a18bd700 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1601,7 +1601,7 @@ async def _generate_sync_entry_for_rooms( return set(), set(), set(), set() # 3. Work out which rooms need reporting in the sync response. - ignored_users = await self._get_ignored_users(user_id) + ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1627,7 +1627,6 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( sync_result_builder, - ignored_users, room_entry, ephemeral=ephemeral_by_room.get(room_entry.room_id, []), tags=tags_by_room.get(room_entry.room_id), @@ -1657,29 +1656,6 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: newly_left_users, ) - async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: - """Retrieve the users ignored by the given user from their global account_data. - - Returns an empty set if - - there is no global account_data entry for ignored_users - - there is such an entry, but it's not a JSON object. - """ - # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - return ignored_users - async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: @@ -2022,7 +1998,6 @@ async def _get_all_rooms( async def _generate_room_entry( self, sync_result_builder: "SyncResultBuilder", - ignored_users: FrozenSet[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], @@ -2051,7 +2026,6 @@ async def _generate_room_entry( Args: sync_result_builder - ignored_users: Set of users ignored by user. room_builder ephemeral: List of new ephemeral events for room tags: List of *all* tags for room, or None if there has been diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8140afcb6b37..030898e4d0cc 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -213,7 +213,7 @@ async def action_for_event_by_user( if not event.is_state(): ignorers = await self.store.ignored_by(event.sender) else: - ignorers = set() + ignorers = frozenset() for uid, rules in rules_by_user.items(): if event.sender == uid: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 52146aacc8c0..9af9f4f18e19 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, + cast, +) from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -365,7 +375,7 @@ def get_updated_account_data_for_user_txn( ) @cached(max_entries=5000, iterable=True) - async def ignored_by(self, user_id: str) -> Set[str]: + async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ Get users which ignore the given user. @@ -375,7 +385,7 @@ async def ignored_by(self, user_id: str) -> Set[str]: Return: The user IDs which ignore the given user. """ - return set( + return frozenset( await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignored_user_id": user_id}, @@ -384,6 +394,26 @@ async def ignored_by(self, user_id: str) -> Set[str]: ) ) + @cached(max_entries=5000, iterable=True) + async def ignored_users(self, user_id: str) -> FrozenSet[str]: + """ + Get users which the given user ignores. + + Params: + user_id: The user ID which is making the request. + + Return: + The user IDs which are ignored by the given user. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + desc="ignored_users", + ) + ) + def process_replication_rows( self, stream_name: str, @@ -529,6 +559,10 @@ def _add_account_data_for_user( else: currently_ignored_users = set() + # If the data has not changed, nothing to do. + if previously_ignored_users == currently_ignored_users: + return + # Delete entries which are no longer ignored. self.db_pool.simple_delete_many_txn( txn, @@ -551,6 +585,7 @@ def _add_account_data_for_user( # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) async def purge_account_data_for_user(self, user_id: str) -> None: """ diff --git a/synapse/visibility.py b/synapse/visibility.py index 281cbe4d8877..49519eb8f51e 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -14,12 +14,7 @@ import logging from typing import Dict, FrozenSet, List, Optional -from synapse.api.constants import ( - AccountDataTypes, - EventTypes, - HistoryVisibility, - Membership, -) +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event from synapse.storage import Storage @@ -87,15 +82,8 @@ async def filter_events_for_client( state_filter=StateFilter.from_types(types), ) - ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - user_id, AccountDataTypes.IGNORED_USER_LIST - ) - - ignore_list: FrozenSet[str] = frozenset() - if ignore_dict_content: - ignored_users_dict = ignore_dict_content.get("ignored_users", {}) - if isinstance(ignored_users_dict, dict): - ignore_list = frozenset(ignored_users_dict.keys()) + # Get the users who are ignored by the requesting user. + ignore_list = await storage.main.ignored_users(user_id) erased_senders = await storage.main.are_users_erased(e.sender for e in events) diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 272cd3540220..72bf5b3d311c 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -47,9 +47,18 @@ def assert_ignorers( expected_ignorer_user_ids, ) + def assert_ignored( + self, ignorer_user_id: str, expected_ignored_user_ids: Set[str] + ) -> None: + self.assertEqual( + self.get_success(self.store.ignored_users(ignorer_user_id)), + expected_ignored_user_ids, + ) + def test_ignoring_users(self): """Basic adding/removing of users from the ignore list.""" self._update_ignore_list("@other:test", "@another:remote") + self.assert_ignored(self.user, {"@other:test", "@another:remote"}) # Check a user which no one ignores. self.assert_ignorers("@user:test", set()) @@ -62,6 +71,7 @@ def test_ignoring_users(self): # Add one user, remove one user, and leave one user. self._update_ignore_list("@foo:test", "@another:remote") + self.assert_ignored(self.user, {"@foo:test", "@another:remote"}) # Check the removed user. self.assert_ignorers("@other:test", set()) @@ -76,20 +86,24 @@ def test_caching(self): """Ensure that caching works properly between different users.""" # The first user ignores a user. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # The second user ignores them. self._update_ignore_list("@other:test", ignorer_user_id="@second:test") + self.assert_ignored("@second:test", {"@other:test"}) self.assert_ignorers("@other:test", {self.user, "@second:test"}) # The first user un-ignores them. self._update_ignore_list() + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", {"@second:test"}) def test_invalid_data(self): """Invalid data ends up clearing out the ignored users list.""" # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # No ignored_users key. @@ -102,10 +116,12 @@ def test_invalid_data(self): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # Invalid data. @@ -118,4 +134,5 @@ def test_invalid_data(self): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) From 4587b35929d22731644a11120a9e7d6a9c3bc304 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 07:21:36 -0400 Subject: [PATCH 068/230] Clean-up logic for rebasing URLs during URL preview. (#12219) By using urljoin from the standard library and reducing the number of places URLs are rebased. --- changelog.d/12219.misc | 1 + synapse/rest/media/v1/preview_html.py | 39 +------------- synapse/rest/media/v1/preview_url_resource.py | 23 ++++---- tests/rest/media/v1/test_html_preview.py | 54 ++++--------------- 4 files changed, 26 insertions(+), 91 deletions(-) create mode 100644 changelog.d/12219.misc diff --git a/changelog.d/12219.misc b/changelog.d/12219.misc new file mode 100644 index 000000000000..607941409295 --- /dev/null +++ b/changelog.d/12219.misc @@ -0,0 +1 @@ +Clean-up logic around rebasing URLs for URL image previews. diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 872a9e72e818..4cc9c66fbe68 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -16,7 +16,6 @@ import logging import re from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union -from urllib import parse as urlparse if TYPE_CHECKING: from lxml import etree @@ -144,9 +143,7 @@ def decode_body( return etree.fromstring(body, parser) -def parse_html_to_open_graph( - tree: "etree.Element", media_uri: str -) -> Dict[str, Optional[str]]: +def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: """ Parse the HTML document into an Open Graph response. @@ -155,7 +152,6 @@ def parse_html_to_open_graph( Args: tree: The parsed HTML document. - media_url: The URI used to download the body. Returns: The Open Graph response as a dictionary. @@ -209,7 +205,7 @@ def parse_html_to_open_graph( "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" ) if meta_image: - og["og:image"] = rebase_url(meta_image[0], media_uri) + og["og:image"] = meta_image[0] else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") @@ -320,37 +316,6 @@ def _iterate_over_text( ) -def rebase_url(url: str, base: str) -> str: - """ - Resolves a potentially relative `url` against an absolute `base` URL. - - For example: - - >>> rebase_url("subpage", "https://example.com/foo/") - 'https://example.com/foo/subpage' - >>> rebase_url("sibling", "https://example.com/foo") - 'https://example.com/sibling' - >>> rebase_url("/bar", "https://example.com/foo/") - 'https://example.com/bar' - >>> rebase_url("https://alice.com/a/", "https://example.com/foo/") - 'https://alice.com/a' - """ - base_parts = urlparse.urlparse(base) - # Convert the parsed URL to a list for (potential) modification. - url_parts = list(urlparse.urlparse(url)) - # Add a scheme, if one does not exist. - if not url_parts[0]: - url_parts[0] = base_parts.scheme or "http" - # Fix up the hostname, if this is not a data URL. - if url_parts[0] != "data" and not url_parts[1]: - url_parts[1] = base_parts.netloc - # If the path does not start with a /, nest it under the base path's last - # directory. - if not url_parts[2].startswith("/"): - url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2] - return urlparse.urlunparse(url_parts) - - def summarize_paragraphs( text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 ) -> Optional[str]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 14ea88b24052..d47af8ead6b7 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -22,7 +22,7 @@ import sys import traceback from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple -from urllib import parse as urlparse +from urllib.parse import urljoin, urlparse, urlsplit from urllib.request import urlopen import attr @@ -44,11 +44,7 @@ from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import ( - decode_body, - parse_html_to_open_graph, - rebase_url, -) +from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred @@ -187,7 +183,7 @@ async def _async_render_GET(self, request: SynapseRequest) -> None: ts = self.clock.time_msec() # XXX: we could move this into _do_preview if we wanted. - url_tuple = urlparse.urlsplit(url) + url_tuple = urlsplit(url) for entry in self.url_preview_url_blacklist: match = True for attrib in entry: @@ -322,7 +318,7 @@ async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes: # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. - og_from_html = parse_html_to_open_graph(tree, media_info.uri) + og_from_html = parse_html_to_open_graph(tree) # Compile the Open Graph response by using the scraped # information from the HTML and overlaying any information @@ -588,12 +584,17 @@ async def _precache_image_url( if "og:image" not in og or not og["og:image"]: return + # The image URL from the HTML might be relative to the previewed page, + # convert it to an URL which can be requested directly. + image_url = og["og:image"] + url_parts = urlparse(image_url) + if url_parts.scheme != "data": + image_url = urljoin(media_info.uri, image_url) + # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._handle_url( - rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True - ) + image_info = await self._handle_url(image_url, user, allow_data_urls=True) if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 3fb37a2a5970..62e308814d2f 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -16,7 +16,6 @@ _get_html_media_encodings, decode_body, parse_html_to_open_graph, - rebase_url, summarize_paragraphs, ) @@ -161,7 +160,7 @@ def test_simple(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -177,7 +176,7 @@ def test_comment(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -196,7 +195,7 @@ def test_comment2(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual( og, @@ -218,7 +217,7 @@ def test_script(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -232,7 +231,7 @@ def test_missing_title(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -247,7 +246,7 @@ def test_h1_as_title(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -262,7 +261,7 @@ def test_missing_title_and_broken_h1(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -290,7 +289,7 @@ def test_xml(self) -> None: FooSome text. """.strip() tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding(self) -> None: @@ -304,7 +303,7 @@ def test_invalid_encoding(self) -> None: """ tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding2(self) -> None: @@ -319,7 +318,7 @@ def test_invalid_encoding2(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) def test_windows_1252(self) -> None: @@ -333,7 +332,7 @@ def test_windows_1252(self) -> None: """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) @@ -448,34 +447,3 @@ def test_unknown_invalid(self) -> None: 'text/html; charset="invalid"', ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - -class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self) -> None: - """Relative URLs should be resolved based on the context of the base URL.""" - self.assertEqual( - rebase_url("subpage", "https://example.com/foo/"), - "https://example.com/foo/subpage", - ) - self.assertEqual( - rebase_url("sibling", "https://example.com/foo"), - "https://example.com/sibling", - ) - self.assertEqual( - rebase_url("/bar", "https://example.com/foo/"), - "https://example.com/bar", - ) - - def test_absolute(self) -> None: - """Absolute URLs should not be modified.""" - self.assertEqual( - rebase_url("https://alice.com/a/", "https://example.com/foo/"), - "https://alice.com/a/", - ) - - def test_data(self) -> None: - """Data URLs should not be modified.""" - self.assertEqual( - rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), - "data:,Hello%2C%20World%21", - ) From 1da0f79d5455b594f2aa989106a672786f5b990f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 09:20:57 -0400 Subject: [PATCH 069/230] Refactor relations tests (#12232) * Moves the relation pagination tests to a separate class. * Move the assertion of the response code into the `_send_relation` helper. * Moves some helpers into the base-class. --- changelog.d/12232.misc | 1 + tests/rest/client/test_relations.py | 1389 +++++++++++++-------------- 2 files changed, 674 insertions(+), 716 deletions(-) create mode 100644 changelog.d/12232.misc diff --git a/changelog.d/12232.misc b/changelog.d/12232.misc new file mode 100644 index 000000000000..4a4132edff2c --- /dev/null +++ b/changelog.d/12232.misc @@ -0,0 +1 @@ +Refactor relations tests to improve code re-use. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 0cbe6c0cf754..3dbd1304a87d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -79,6 +79,7 @@ def _send_relation( content: Optional[dict] = None, access_token: Optional[str] = None, parent_id: Optional[str] = None, + expected_response_code: int = 200, ) -> FakeChannel: """Helper function to send a relation pointing at `self.parent_id` @@ -115,16 +116,50 @@ def _send_relation( content, access_token=access_token, ) + self.assertEqual(expected_response_code, channel.code, channel.json_body) return channel + def _get_related_events(self) -> List[str]: + """ + Requests /relations on the parent ID and returns a list of event IDs. + """ + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return [ev["event_id"] for ev in channel.json_body["chunk"]] + + def _get_bundled_aggregations(self) -> JsonDict: + """ + Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists. + """ + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return channel.json_body["unsigned"].get("m.relations", {}) + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: """Tests that sending a relation works.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) - event_id = channel.json_body["event_id"] channel = self.make_request( @@ -151,13 +186,13 @@ def test_send_relation(self) -> None: def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( @@ -171,13 +206,12 @@ def test_deny_invalid_event(self) -> None: desc="test_deny_invalid_event", ) ) - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" @@ -187,18 +221,20 @@ def test_deny_invalid_room(self) -> None: parent_id = res["event_id"] # Attempt to send an annotation to that event. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + parent_id=parent_id, + key="A", + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(400, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400 + ) def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" @@ -208,461 +244,160 @@ def test_deny_forked_thread(self) -> None: content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, "m.room.message", content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) - def test_basic_paginate_relations(self) -> None: - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - first_annotation_id = channel.json_body["event_id"] + def test_aggregation(self) -> None: + """Test that annotations get correctly aggregated.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - second_annotation_id = channel.json_body["event_id"] + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self.make_request( "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - # We expect to get back a single pagination result, which is the latest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( + self.assertEqual( + channel.json_body, { - "event_id": second_annotation_id, - "sender": self.user_id, - "type": "m.reaction", + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] }, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id ) - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) + def test_aggregation_must_be_annotation(self) -> None: + """Test that aggregations must be annotations.""" - # Request the relations again, but with a different direction. channel = self.make_request( "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations" + f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", access_token=self.user_token, ) - self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - # We expect to get back a single pagination result, which is the earliest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": first_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_bundled_aggregations(self) -> None: + """ + Test that annotations, references, and threads get correctly bundled. - def test_repeated_paginate_relations(self) -> None: - """Test that if we paginate using a limit and tokens then we get the - expected events. + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + + See test_edit for a similar test for edits. """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - expected_event_ids = [] - for idx in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] - prev_token = "" - found_event_ids: List[str] = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + self._send_relation(RelationTypes.THREAD, "m.room.test") - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") - if not prev_token: - break + # Ensure the fields are as expected. + self.assertCountEqual( + relations_dict.keys(), + ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.THREAD, + ), + ) - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) + # Check the values of each field. + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + relations_dict[RelationTypes.ANNOTATION], + ) - def test_pagination_from_sync_and_messages(self) -> None: - """Pagination tokens from /sync and /messages can be used to paginate /relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEqual(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - # Send an event after the relation events. - self.helper.send(self.room, body="Latest event", tok=self.user_token) + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + relations_dict[RelationTypes.REFERENCE], + ) - # Request /sync, limiting it such that only the latest event is returned - # (and not the relation). - filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') + self.assertEqual( + 2, + relations_dict[RelationTypes.THREAD].get("count"), + ) + self.assertTrue( + relations_dict[RelationTypes.THREAD].get("current_user_participated") + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + relations_dict[RelationTypes.THREAD].get("latest_event"), + ) + + # Request the event directly. channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - sync_prev_batch = room_timeline["prev_batch"] - self.assertIsNotNone(sync_prev_batch) - # Ensure the relation event is not in the batch returned from /sync. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in room_timeline["events"]] - ) + assert_bundle(channel.json_body) - # Request /messages, limiting it such that only the latest event is - # returned (and not the relation). + # Request the room messages. channel = self.make_request( "GET", - f"/rooms/{self.room}/messages?dir=b&limit=1", + f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - messages_end = channel.json_body["end"] - self.assertIsNotNone(messages_end) - # Ensure the relation event is not in the chunk returned from /messages. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, ) - - # Request /relations with the pagination tokens received from both the - # /sync and /messages responses above, in turn. - # - # This is a tiny bit silly since the client wouldn't know the parent ID - # from the requests above; consider the parent ID to be known from a - # previous /sync. - for from_token in (sync_prev_batch, messages_end): - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # The relation should be in the returned chunk. - self.assertIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] - ) - - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - - def test_aggregation(self) -> None: - """Test that annotations get correctly aggregated.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_must_be_annotation(self) -> None: - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations" - f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(400, channel.code, channel.json_body) - - @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_2 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) @@ -693,14 +428,12 @@ def test_aggregation_get_event_for_annotation(self) -> None: when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Annotate the annotation. - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -713,14 +446,12 @@ def test_aggregation_get_event_for_annotation(self) -> None: def test_aggregation_get_event_for_thread(self) -> None: """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] # Annotate the annotation. - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -877,8 +608,6 @@ def test_edit(self) -> None: "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -954,7 +683,7 @@ def test_multi_edit(self) -> None: shouldn't be allowed, are correctly handled. """ - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -963,309 +692,575 @@ def test_multi_edit(self) -> None: "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + edit_event_id = channel.json_body["event_id"] + + self._send_relation( + RelationTypes.REPLACE, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_edit_reply(self) -> None: + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{reply}", + access_token=self.user_token, + ) self.assertEqual(200, channel.code, channel.json_body) + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_edit_thread(self) -> None: + """Test that editing a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + threaded_event_id = channel.json_body["event_id"] + new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + + # Fetch the thread root, to get the bundled aggregation for the thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + + def test_edit_edit(self) -> None: + """Test that an edit cannot be edited.""" + new_body = {"msgtype": "m.text", "body": "Initial edit"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": new_body, + }, + ) + edit_event_id = channel.json_body["event_id"] + + # Edit the edit event. + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "foo", + "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, + }, + parent_id=edit_event_id, + ) + + # Request the original event. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + # The edit to the edit should be ignored. + self.assertEqual(channel.json_body["content"], new_body) + + # The relations information should not include the edit to the edit. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_unknown_relations(self) -> None: + """Unknown relations should be accepted.""" + channel = self._send_relation("m.relation.test", "m.room.test") + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEqual( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # When bundling the unknown relation is not included. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + + # But unknown relations can be directly queried. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + + def test_background_update(self) -> None: + """Test the event_arbitrary_relations background update.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + annotation_event_id_good = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") + annotation_event_id_bad = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_event_id = channel.json_body["event_id"] + + # Clean-up the table as if the inserts did not happen during event creation. + self.get_success( + self.store.db_pool.simple_delete_many( + table="event_relations", + column="event_id", + iterable=(annotation_event_id_bad, thread_event_id), + keyvalues={}, + desc="RelationsTestCase.test_background_update", + ) + ) + + # Only the "good" annotation should be found. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", + access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good], + ) - edit_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message.WRONG_TYPE", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, - }, + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "event_arbitrary_relations", "progress_json": "{}"}, + ) ) - self.assertEqual(200, channel.code, channel.json_body) + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + self.wait_for_background_updates() + + # The "good" annotation and the thread should be found, but not the "bad" + # annotation. channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{self.parent_id}", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) + self.assertCountEqual( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good, thread_event_id], + ) - self.assertEqual(channel.json_body["content"], new_body) - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) +class RelationPaginationTestCase(BaseRelationsTestCase): + def test_basic_paginate_relations(self) -> None: + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + first_annotation_id = channel.json_body["event_id"] - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + second_annotation_id = channel.json_body["event_id"] - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, ) + self.assertEqual(200, channel.code, channel.json_body) - def test_edit_reply(self) -> None: - """Test that editing a reply works.""" - - # Create a reply to edit. - channel = self._send_relation( - RelationTypes.REFERENCE, - "m.room.message", - content={"msgtype": "m.text", "body": "A reply!"}, + # We expect to get back a single pagination result, which is the latest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": second_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], ) - self.assertEqual(200, channel.code, channel.json_body) - reply = channel.json_body["event_id"] - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=reply, + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEqual( + channel.json_body["original_event"]["event_id"], self.parent_id ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + # Request the relations again, but with a different direction. channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{reply}", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - # We expect to see the new body in the dict, as well as the reference - # metadata sill intact. - self.assertDictContainsSubset(new_body, channel.json_body["content"]) - self.assertDictContainsSubset( + # We expect to get back a single pagination result, which is the earliest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": "m.reference", - } + "event_id": first_annotation_id, + "sender": self.user_id, + "type": "m.reaction", }, - channel.json_body["content"], + channel.json_body["chunk"][0], ) - # We expect that the edit relation appears in the unsigned relations - # section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + def test_repeated_paginate_relations(self) -> None: + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + expected_event_ids = [] + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) + expected_event_ids.append(channel.json_body["event_id"]) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) + prev_token = "" + found_event_ids: List[str] = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token - def test_edit_thread(self) -> None: - """Test that editing a thread works.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - # Create a thread and edit the last event. - channel = self._send_relation( - RelationTypes.THREAD, - "m.room.message", - content={"msgtype": "m.text", "body": "A threaded reply!"}, - ) - self.assertEqual(200, channel.code, channel.json_body) - threaded_event_id = channel.json_body["event_id"] + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=threaded_event_id, + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + def test_pagination_from_sync_and_messages(self) -> None: + """Pagination tokens from /sync and /messages can be used to paginate /relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") + annotation_id = channel.json_body["event_id"] + # Send an event after the relation events. + self.helper.send(self.room, body="Latest event", tok=self.user_token) + + # Request /sync, limiting it such that only the latest event is returned + # (and not the relation). + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token ) self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + sync_prev_batch = room_timeline["prev_batch"] + self.assertIsNotNone(sync_prev_batch) + # Ensure the relation event is not in the batch returned from /sync. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in room_timeline["events"]] + ) - # Fetch the thread root, to get the bundled aggregation for the thread. + # Request /messages, limiting it such that only the latest event is + # returned (and not the relation). channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{self.parent_id}", + f"/rooms/{self.room}/messages?dir=b&limit=1", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) + messages_end = channel.json_body["end"] + self.assertIsNotNone(messages_end) + # Ensure the relation event is not in the chunk returned from /messages. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) - # We expect that the edit message appears in the thread summary in the - # unsigned relations section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.THREAD, relations_dict) + # Request /relations with the pagination tokens received from both the + # /sync and /messages responses above, in turn. + # + # This is a tiny bit silly since the client wouldn't know the parent ID + # from the requests above; consider the parent ID to be known from a + # previous /sync. + for from_token in (sync_prev_batch, messages_end): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - thread_summary = relations_dict[RelationTypes.THREAD] - self.assertIn("latest_event", thread_summary) - latest_event_in_thread = thread_summary["latest_event"] - self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + # The relation should be in the returned chunk. + self.assertIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + def test_aggregation_pagination_groups(self) -> None: + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + + idx += 1 + idx %= len(access_tokens) - def test_edit_edit(self) -> None: - """Test that an edit cannot be edited.""" - new_body = {"msgtype": "m.text", "body": "Initial edit"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": new_body, - }, - ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] + prev_token: Optional[str] = None + found_groups: Dict[str, int] = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token - # Edit the edit event. - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "foo", - "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, - }, - parent_id=edit_event_id, - ) - self.assertEqual(200, channel.code, channel.json_body) + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - # Request the original event. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - # The edit to the edit should be ignored. - self.assertEqual(channel.json_body["content"], new_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - # The relations information should not include the edit to the edit. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) + found_groups[groups["key"]] = groups["count"] - def test_unknown_relations(self) -> None: - """Unknown relations should be accepted.""" - channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - event_id = channel.json_body["event_id"] + next_batch = channel.json_body.get("next_batch") - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch - # We expect to get back a single pagination result, which is the full - # relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, - channel.json_body["chunk"][0], - ) + if not prev_token: + break - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id - ) + self.assertEqual(sent_groups, found_groups) - # When bundling the unknown relation is not included. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_aggregation_pagination_within_group(self) -> None: + """Test that we can paginate within an annotation group.""" - # But unknown relations can be directly queried. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) - raise AssertionError(f"Event {self.parent_id} not found in chunk") + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="👍", + access_token=access_tokens[idx], + ) + expected_event_ids.append(channel.json_body["event_id"]) - def test_background_update(self) -> None: - """Test the event_arbitrary_relations background update.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) - annotation_event_id_good = channel.json_body["event_id"] + idx += 1 - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEqual(200, channel.code, channel.json_body) - annotation_event_id_bad = channel.json_body["event_id"] + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_event_id = channel.json_body["event_id"] + prev_token = "" + found_event_ids: List[str] = [] + encoded_key = urllib.parse.quote_plus("👍".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token - # Clean-up the table as if the inserts did not happen during event creation. - self.get_success( - self.store.db_pool.simple_delete_many( - table="event_relations", - column="event_id", - iterable=(annotation_event_id_bad, thread_event_id), - keyvalues={}, - desc="RelationsTestCase.test_background_update", + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", + access_token=self.user_token, ) - ) + self.assertEqual(200, channel.code, channel.json_body) - # Only the "good" annotation should be found. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - [ev["event_id"] for ev in channel.json_body["chunk"]], - [annotation_event_id_good], - ) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - # Insert and run the background update. - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "event_arbitrary_relations", "progress_json": "{}"}, - ) - ) + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - # Ugh, have to reset this flag - self.store.db_pool.updates._all_done = False - self.wait_for_background_updates() + next_batch = channel.json_body.get("next_batch") - # The "good" annotation and the thread should be found, but not the "bad" - # annotation. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertCountEqual( - [ev["event_id"] for ev in channel.json_body["chunk"]], - [annotation_event_id_good, thread_event_id], - ) + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) class RelationRedactionTestCase(BaseRelationsTestCase): @@ -1294,46 +1289,6 @@ def _redact(self, event_id: str) -> None: ) self.assertEqual(200, channel.code, channel.json_body) - def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: - """ - Makes requests and ensures they result in a 200 response, returns a - tuple of results: - - 1. `/relations` -> Returns a list of event IDs. - 2. `/event` -> Returns the response's m.relations field (from unsigned), - if it exists. - """ - - # Request the relations of the event. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] - - # Fetch the bundled aggregations of the event. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) - - return event_ids, bundled_relations - - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] - def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1343,17 +1298,16 @@ def test_redact_relation_annotation(self) -> None: the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Both relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1368,7 +1322,8 @@ def test_redact_relation_annotation(self) -> None: self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1391,7 +1346,6 @@ def test_redact_relation_thread(self) -> None: EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Note that the *last* event in the thread is redacted, as that gets @@ -1401,11 +1355,11 @@ def test_redact_relation_thread(self) -> None: EventTypes.Message, content={"body": "reply 2", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] # Both relations exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( { @@ -1424,7 +1378,8 @@ def test_redact_relation_thread(self) -> None: self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( { @@ -1444,7 +1399,7 @@ def test_redact_parent_edit(self) -> None: is redacted. """ # Add a relation - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", parent_id=self.parent_id, @@ -1454,10 +1409,10 @@ def test_redact_parent_edit(self) -> None: "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.REPLACE, relations) @@ -1465,7 +1420,8 @@ def test_redact_parent_edit(self) -> None: self._redact(self.parent_id) # The relations are not returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 0) self.assertEqual(relations, {}) @@ -1475,11 +1431,11 @@ def test_redact_parent_annotation(self) -> None: """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # The relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) @@ -1491,7 +1447,8 @@ def test_redact_parent_annotation(self) -> None: self._redact(self.parent_id) # The relations are returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [related_event_id]) self.assertEquals( relations["m.annotation"], @@ -1512,14 +1469,14 @@ def test_redact_parent_thread(self) -> None: EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # Redact one of the reactions. self._redact(self.parent_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(len(event_ids), 1) self.assertDictContainsSubset( { From 86965605a4688d80dc0a74ed4993a52f282e970a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 16 Mar 2022 13:52:59 +0000 Subject: [PATCH 070/230] Fix dead link in spam checker warning (#12231) --- changelog.d/12231.doc | 1 + synapse/config/spam_checker.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12231.doc diff --git a/changelog.d/12231.doc b/changelog.d/12231.doc new file mode 100644 index 000000000000..16593d2b9226 --- /dev/null +++ b/changelog.d/12231.doc @@ -0,0 +1 @@ +Fix the link to the module documentation in the legacy spam checker warning message. diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index a233a9ce0388..4c52103b1c29 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -25,8 +25,8 @@ LEGACY_SPAM_CHECKER_WARNING = """ This server is using a spam checker module that is implementing the deprecated spam checker interface. Please check with the module's maintainer to see if a new version -supporting Synapse's generic modules system is available. -For more information, please see https://matrix-org.github.io/synapse/latest/modules.html +supporting Synapse's generic modules system is available. For more information, please +see https://matrix-org.github.io/synapse/latest/modules/index.html ---------------------------------------------------------------------------------------""" From c486fa5fd9082643e40a55ffa59d902aa6db4c2b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 10:37:04 -0400 Subject: [PATCH 071/230] Add some missing type hints to cache datastore. (#12216) --- changelog.d/12216.misc | 1 + synapse/storage/databases/main/cache.py | 57 ++++++++++++++++--------- 2 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12216.misc diff --git a/changelog.d/12216.misc b/changelog.d/12216.misc new file mode 100644 index 000000000000..dc398ac1e098 --- /dev/null +++ b/changelog.d/12216.misc @@ -0,0 +1 @@ +Add missing type hints for cache storage. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index d6a2df1afeb6..2d7511d61391 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, + EventsStreamRow, ) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -31,6 +32,7 @@ LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -82,7 +84,9 @@ async def get_all_updated_caches( if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -107,7 +111,9 @@ def get_all_updated_caches_txn(txn): "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) @@ -142,10 +148,11 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token, row): + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: + assert isinstance(data, EventsStreamEventRow) self._invalidate_caches_for_event( token, data.event_id, @@ -157,9 +164,8 @@ def _process_event_stream_row(self, token, row): backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( @@ -170,15 +176,15 @@ def _process_event_stream_row(self, token, row): def _invalidate_caches_for_event( self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): + stream_ordering: int, + event_id: str, + room_id: str, + etype: str, + state_key: Optional[str], + redacts: Optional[str], + relates_to: Optional[str], + backfilled: bool, + ) -> None: self._invalidate_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) @@ -207,7 +213,9 @@ def _invalidate_caches_for_event( self.get_thread_summary.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[Any, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -227,7 +235,12 @@ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, .. keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, + txn: LoggingTransaction, + cache_func: _CachedFunction, + keys: Tuple[Any, ...], + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -238,7 +251,9 @@ def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -279,8 +294,8 @@ def _invalidate_state_caches_and_stream( ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -315,7 +330,7 @@ def _send_invalidation_to_replication( "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, ) From fc9bd620ce94b64af46737e25a524336967782a1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 10:39:15 -0400 Subject: [PATCH 072/230] Add a relations handler to avoid duplication. (#12227) Adds a handler layer between the REST and datastore layers for relations. --- changelog.d/12227.misc | 1 + synapse/handlers/pagination.py | 5 +- synapse/handlers/relations.py | 117 +++++++++++++++++++++++++++++++ synapse/rest/client/relations.py | 75 +++----------------- synapse/server.py | 5 ++ 5 files changed, 134 insertions(+), 69 deletions(-) create mode 100644 changelog.d/12227.misc create mode 100644 synapse/handlers/relations.py diff --git a/changelog.d/12227.misc b/changelog.d/12227.misc new file mode 100644 index 000000000000..41c9dcbd37f6 --- /dev/null +++ b/changelog.d/12227.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 60059fec3e0f..41679f7f866b 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set import attr @@ -422,7 +422,7 @@ async def get_messages( pagin_config: PaginationConfig, as_client_event: bool = True, event_filter: Optional[Filter] = None, - ) -> Dict[str, Any]: + ) -> JsonDict: """Get messages in a room. Args: @@ -431,6 +431,7 @@ async def get_messages( pagin_config: The pagination config rules to apply, if any. as_client_event: True to get events in client-server format. event_filter: Filter to apply to results or None + Returns: Pagination API results """ diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py new file mode 100644 index 000000000000..8e475475ad02 --- /dev/null +++ b/synapse/handlers/relations.py @@ -0,0 +1,117 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Optional + +from synapse.api.errors import SynapseError +from synapse.types import JsonDict, Requester, StreamToken + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class RelationsHandler: + def __init__(self, hs: "HomeServer"): + self._main_store = hs.get_datastores().main + self._auth = hs.get_auth() + self._clock = hs.get_clock() + self._event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() + + async def get_relations( + self, + requester: Requester, + event_id: str, + room_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + TODO Accept a PaginationConfig instead of individual pagination parameters. + + Args: + requester: The user requesting the relations. + event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + await self._auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) + + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = await self._event_handler.get_event(requester.user, room_id, event_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") + + pagination_chunk = await self._main_store.get_relations_for_event( + event_id=event_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=aggregation_key, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) + + events = await self._main_store.get_events_as_list( + [c["event_id"] for c in pagination_chunk.chunk] + ) + + now = self._clock.time_msec() + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. + original_event = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None + ) + # The relations returned for the requested event do include their + # bundled aggregations. + aggregations = await self._main_store.get_bundled_aggregations( + events, requester.user.to_string() + ) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value = await pagination_chunk.to_dict(self._main_store) + return_value["chunk"] = serialized_events + return_value["original_event"] = original_event + + return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d9a6be43f793..c16078b187ee 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -51,9 +51,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -65,16 +63,6 @@ async def on_GET( ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True - ) - - # This gets the original event and checks that a) the event exists and - # b) the user is allowed to view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - limit = parse_integer(request, "limit", default=5) direction = parse_string( request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] @@ -90,9 +78,9 @@ async def on_GET( if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - pagination_chunk = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -102,30 +90,7 @@ async def on_GET( to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) - - now = self.clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) - # The relations returned for the requested event do include their - # bundled aggregations. - aggregations = await self.store.get_bundled_aggregations( - events, requester.user.to_string() - ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - - return_value = await pagination_chunk.to_dict(self.store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event - - return 200, return_value + return 200, result class RelationAggregationPaginationServlet(RestServlet): @@ -245,9 +210,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -260,18 +223,6 @@ async def on_GET( ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -286,9 +237,9 @@ async def on_GET( if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - result = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -298,17 +249,7 @@ async def on_GET( to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] - ) - - now = self.clock.time_msec() - serialized_events = self._event_serializer.serialize_events(events, now) - - return_value = await result.to_dict(self.store) - return_value["chunk"] = serialized_events - - return 200, return_value + return 200, result def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/server.py b/synapse/server.py index 2fcf18a7a69a..380369db923e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -94,6 +94,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler +from synapse.handlers.relations import RelationsHandler from synapse.handlers.room import ( RoomContextHandler, RoomCreationHandler, @@ -719,6 +720,10 @@ def get_message_handler(self) -> MessageHandler: def get_pagination_handler(self) -> PaginationHandler: return PaginationHandler(self) + @cache_in_self + def get_relations_handler(self) -> RelationsHandler: + return RelationsHandler(self) + @cache_in_self def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) From 61210567405b1ac7efaa23d5513cc0b443da0a3a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 16 Mar 2022 15:07:41 +0000 Subject: [PATCH 073/230] Handle cancellation in `DatabasePool.runInteraction()` (#12199) To handle cancellation, we ensure that `after_callback`s and `exception_callback`s are always run, since the transaction will complete on another thread regardless of cancellation. We also wait until everything is done before releasing the `CancelledError`, so that logging contexts won't get used after they have been finished. Signed-off-by: Sean Quah --- changelog.d/12199.misc | 1 + synapse/storage/database.py | 61 +++++++++++++++++++++------------- tests/storage/test_database.py | 58 ++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 24 deletions(-) create mode 100644 changelog.d/12199.misc diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc new file mode 100644 index 000000000000..16dec1d26d24 --- /dev/null +++ b/changelog.d/12199.misc @@ -0,0 +1 @@ +Handle cancellation in `DatabasePool.runInteraction()`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 99802228c9f7..9749f0c06e86 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,6 +41,7 @@ from typing_extensions import Literal from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -55,6 +56,7 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -732,34 +734,45 @@ async def runInteraction( Returns: The result of func """ - after_callbacks: List[_CallbackListEntry] = [] - exception_callbacks: List[_CallbackListEntry] = [] - if not current_context(): - logger.warning("Starting db txn '%s' from sentinel context", desc) + async def _runInteraction() -> R: + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] - try: - with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - db_autocommit=db_autocommit, - isolation_level=isolation_level, - **kwargs, - ) + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise + try: + with opentracing.start_active_span(f"db.{desc}"): + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + db_autocommit=db_autocommit, + isolation_level=isolation_level, + **kwargs, + ) - return cast(R, result) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + + return cast(R, result) + except Exception: + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + # To handle cancellation, we ensure that `after_callback`s and + # `exception_callback`s are always run, since the transaction will complete + # on another thread regardless of cancellation. + # + # We also wait until everything above is done before releasing the + # `CancelledError`, so that logging contexts won't get used after they have been + # finished. + return await delay_cancellation(defer.ensureDeferred(_runInteraction())) async def runWithConnection( self, diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index ae13bed08621..a40fc20ef990 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -15,6 +15,8 @@ from typing import Callable, Tuple from unittest.mock import Mock, call +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer @@ -124,3 +126,59 @@ def test_successful_retry(self) -> None: ) self.assertEqual(after_callback.call_count, 2) # no additional calls exception_callback.assert_not_called() + + +class CancellationTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + # Simulate a retryable failure on every attempt. + raise self.db_pool.engine.module.OperationalError() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls From 96274565ff0dbb7d21b02b04fcef115330426707 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 12:17:39 -0400 Subject: [PATCH 074/230] Fix bundling aggregations if unsigned is not a returned event field. (#12234) An error occured if a filter was supplied with `event_fields` which did not include `unsigned`. In that case, bundled aggregations are still added as the spec states it is allowed for servers to add additional fields. --- changelog.d/12234.bugfix | 1 + synapse/events/utils.py | 9 ++++++--- tests/rest/client/test_relations.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12234.bugfix diff --git a/changelog.d/12234.bugfix b/changelog.d/12234.bugfix new file mode 100644 index 000000000000..dbb77f36ff32 --- /dev/null +++ b/changelog.d/12234.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when a `filter` argument with `event_fields` supplied but not including the `unsigned` field could result in a 500 error on `/sync`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b2a237c1e04a..a0520068e0f9 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -530,9 +530,12 @@ def _inject_bundled_aggregations( # Include the bundled aggregations in the event. if serialized_aggregations: - serialized_event["unsigned"].setdefault("m.relations", {}).update( - serialized_aggregations - ) + # There is likely already an "unsigned" field, but a filter might + # have stripped it off (via the event_fields option). The server is + # allowed to return additional fields, so add it back. + serialized_event.setdefault("unsigned", {}).setdefault( + "m.relations", {} + ).update(serialized_aggregations) def serialize_events( self, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 0cbe6c0cf754..171f4e97c8fd 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1267,6 +1267,34 @@ def test_background_update(self) -> None: [annotation_event_id_good, thread_event_id], ) + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + class RelationRedactionTestCase(BaseRelationsTestCase): """ From f70afbd565f34cdc093e083b92376a1154b007a7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 12:20:05 -0400 Subject: [PATCH 075/230] Re-generate changelog. --- CHANGES.md | 1 + changelog.d/12234.bugfix | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 changelog.d/12234.bugfix diff --git a/CHANGES.md b/CHANGES.md index 60e7ecb1b9c1..78498c10b551 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -21,6 +21,7 @@ Bugfixes - Fix a bug introduced in Synapse 1.7.2 whereby background updates are never run with the default background batch size. ([\#12157](https://github.com/matrix-org/synapse/issues/12157)) - Fix a bug where non-standard information was returned from the `/hierarchy` API. Introduced in Synapse v1.41.0. ([\#12175](https://github.com/matrix-org/synapse/issues/12175)) - Fix a bug introduced in Synapse 1.54.0 that broke background updates on sqlite homeservers while search was disabled. ([\#12215](https://github.com/matrix-org/synapse/issues/12215)) +- Fix a long-standing bug when a `filter` argument with `event_fields` which did not include the `unsigned` field could result in a 500 error on `/sync`. ([\#12234](https://github.com/matrix-org/synapse/issues/12234)) Improved Documentation diff --git a/changelog.d/12234.bugfix b/changelog.d/12234.bugfix deleted file mode 100644 index dbb77f36ff32..000000000000 --- a/changelog.d/12234.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug when a `filter` argument with `event_fields` supplied but not including the `unsigned` field could result in a 500 error on `/sync`. From 9e06e220649cc0139749c388a894bee0d65d5f4e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 17 Mar 2022 12:25:50 +0100 Subject: [PATCH 076/230] Add type hints to more tests files. (#12240) --- changelog.d/12240.misc | 1 + mypy.ini | 4 ---- tests/handlers/test_cas.py | 19 +++++++++------ tests/handlers/test_federation.py | 36 ++++++++++++++++------------ tests/handlers/test_presence.py | 13 ++++++---- tests/push/test_http.py | 40 ++++++++++++++++++------------- 6 files changed, 66 insertions(+), 47 deletions(-) create mode 100644 changelog.d/12240.misc diff --git a/changelog.d/12240.misc b/changelog.d/12240.misc new file mode 100644 index 000000000000..c5b635679931 --- /dev/null +++ b/changelog.d/12240.misc @@ -0,0 +1 @@ +Add type hints to tests files. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index fe31bfb8bb37..51f47ff5be8f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -66,9 +66,6 @@ exclude = (?x) |tests/federation/test_federation_server.py |tests/federation/transport/test_knocking.py |tests/federation/transport/test_server.py - |tests/handlers/test_cas.py - |tests/handlers/test_federation.py - |tests/handlers/test_presence.py |tests/handlers/test_typing.py |tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_srv_resolver.py @@ -80,7 +77,6 @@ exclude = (?x) |tests/logging/test_terse_json.py |tests/module_api/test_api.py |tests/push/test_email.py - |tests/push/test_http.py |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_transactions.py diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index a2672288465a..a54aa29cf177 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.cas import CasResponse +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -24,7 +29,7 @@ class CasHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL cas_config = { @@ -40,7 +45,7 @@ def default_config(self): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_cas_handler() @@ -51,7 +56,7 @@ def make_homeserver(self, reactor, clock): return hs - def test_map_cas_user_to_user(self): + def test_map_cas_user_to_user(self) -> None: """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -75,7 +80,7 @@ def test_map_cas_user_to_user(self): auth_provider_session_id=None, ) - def test_map_cas_user_to_existing_user(self): + def test_map_cas_user_to_existing_user(self) -> None: """Existing users can log in with CAS account.""" store = self.hs.get_datastores().main self.get_success( @@ -119,7 +124,7 @@ def test_map_cas_user_to_existing_user(self): auth_provider_session_id=None, ) - def test_map_cas_user_to_invalid_localpart(self): + def test_map_cas_user_to_invalid_localpart(self) -> None: """CAS automaps invalid characters to base-64 encoding.""" # stub out the auth handler @@ -150,7 +155,7 @@ def test_map_cas_user_to_invalid_localpart(self): } } ) - def test_required_attributes(self): + def test_required_attributes(self) -> None: """The required attributes must be met from the CAS response.""" # stub out the auth handler @@ -166,7 +171,7 @@ def test_required_attributes(self): auth_handler.complete_sso_login.assert_not_called() # The response doesn't have any department. - cas_response = CasResponse("test_user", {"userGroup": "staff"}) + cas_response = CasResponse("test_user", {"userGroup": ["staff"]}) request.reset_mock() self.get_success( self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e8b4e39d1a32..89078fc6374c 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import List, cast from unittest import TestCase +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions @@ -23,7 +25,9 @@ from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main @@ -50,7 +54,7 @@ def make_homeserver(self, reactor, clock): self._event_auth_handler = hs.get_event_auth_handler() return hs - def test_exchange_revoked_invite(self): + def test_exchange_revoked_invite(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -96,7 +100,7 @@ def test_exchange_revoked_invite(self): self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) self.assertEqual(failure.msg, "You are not invited to this room.") - def test_rejected_message_event_state(self): + def test_rejected_message_event_state(self) -> None: """ Check that we store the state group correctly for rejected non-state events. @@ -126,7 +130,7 @@ def test_rejected_message_event_state(self): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -149,7 +153,7 @@ def test_rejected_message_event_state(self): self.assertEqual(sg, sg2) - def test_rejected_state_event_state(self): + def test_rejected_state_event_state(self) -> None: """ Check that we store the state group correctly for rejected state events. @@ -180,7 +184,7 @@ def test_rejected_state_event_state(self): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -203,7 +207,7 @@ def test_rejected_state_event_state(self): self.assertEqual(sg, sg2) - def test_backfill_with_many_backward_extremities(self): + def test_backfill_with_many_backward_extremities(self) -> None: """ Check that we can backfill with many backward extremities. The goal is to make sure that when we only use a portion @@ -262,7 +266,7 @@ def test_backfill_with_many_backward_extremities(self): ) self.get_success(d) - def test_backfill_floating_outlier_membership_auth(self): + def test_backfill_floating_outlier_membership_auth(self) -> None: """ As the local homeserver, check that we can properly process a federated event from the OTHER_SERVER with auth_events that include a floating @@ -377,7 +381,7 @@ async def get_event_auth( for ae in auth_events ] - self.handler.federation_client.get_event_auth = get_event_auth + self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment] with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver @@ -397,7 +401,7 @@ async def get_event_auth( @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invite_by_user_ratelimit(self): + def test_invite_by_user_ratelimit(self) -> None: """Tests that invites from federation to a particular user are actually rate-limited. """ @@ -446,7 +450,9 @@ def create_invite(): exc=LimitExceededError, ) - def _build_and_send_join_event(self, other_server, other_user, room_id): + def _build_and_send_join_event( + self, other_server: str, other_user: str, room_id: str + ) -> EventBase: join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) ) @@ -469,7 +475,7 @@ def _build_and_send_join_event(self, other_server, other_user, room_id): class EventFromPduTestCase(TestCase): - def test_valid_json(self): + def test_valid_json(self) -> None: """Valid JSON should be turned into an event.""" ev = event_from_pdu_json( { @@ -487,7 +493,7 @@ def test_valid_json(self): self.assertIsInstance(ev, EventBase) - def test_invalid_numbers(self): + def test_invalid_numbers(self) -> None: """Invalid values for an integer should be rejected, all floats should be rejected.""" for value in [ -(2 ** 53), @@ -512,7 +518,7 @@ def test_invalid_numbers(self): RoomVersions.V6, ) - def test_invalid_nested(self): + def test_invalid_nested(self) -> None: """List and dictionaries are recursively searched.""" with self.assertRaises(SynapseError): event_from_pdu_json( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 6ddec9ecf1fe..b2ed9cbe3775 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -331,11 +331,11 @@ def test_persisting_presence_updates(self): # Extract presence update user ID and state information into lists of tuples db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] - presence_states = [(ps.user_id, ps.state) for ps in presence_states] + presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] # Compare what we put into the storage with what we got out. # They should be identical. - self.assertEqual(presence_states, db_presence_states) + self.assertEqual(presence_states_compare, db_presence_states) class PresenceTimeoutTestCase(unittest.TestCase): @@ -357,6 +357,7 @@ def test_idle_timer(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.status_msg, status_msg) @@ -380,6 +381,7 @@ def test_busy_no_idle(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.status_msg, status_msg) @@ -399,6 +401,7 @@ def test_sync_timeout(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -420,6 +423,7 @@ def test_sync_online(self): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -477,6 +481,7 @@ def test_federation_timeout(self): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -653,13 +658,13 @@ def test_set_presence_with_status_msg_none(self): self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) def _set_presencestate_with_status_msg( - self, user_id: str, state: PresenceState, status_msg: Optional[str] + self, user_id: str, state: str, status_msg: Optional[str] ): """Set a PresenceState and status_msg and check the result. Args: user_id: User for that the status is to be set. - PresenceState: The new PresenceState. + state: The new PresenceState. status_msg: Status message that is to be set. """ self.get_success( diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 6691e0712896..ba158f5d93e4 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException from synapse.rest.client import login, push_rule, receipts, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config @@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase): user_id = True hijack_auth = False - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["start_pushers"] = True return config - def make_homeserver(self, reactor, clock): - self.push_attempts: List[tuple[Deferred, str, dict]] = [] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.push_attempts: List[Tuple[Deferred, str, dict]] = [] m = Mock() @@ -56,7 +60,7 @@ def post_json_get_json(url, body): return hs - def test_invalid_configuration(self): + def test_invalid_configuration(self) -> None: """Invalid push configurations should be rejected.""" # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -68,7 +72,7 @@ def test_invalid_configuration(self): ) token_id = user_tuple.token_id - def test_data(data): + def test_data(data: Optional[JsonDict]) -> None: self.get_failure( self.hs.get_pusherpool().add_pusher( user_id=user_id, @@ -95,7 +99,7 @@ def test_data(data): # A url with an incorrect path isn't accepted. test_data({"url": "http://example.com/foo"}) - def test_sends_http(self): + def test_sends_http(self) -> None: """ The HTTP pusher will send pushes for each message to a HTTP endpoint when configured to do so. @@ -200,7 +204,7 @@ def test_sends_http(self): self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) - def test_sends_high_priority_for_encrypted(self): + def test_sends_high_priority_for_encrypted(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to an encrypted message. @@ -321,7 +325,7 @@ def test_sends_high_priority_for_encrypted(self): ) self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") - def test_sends_high_priority_for_one_to_one_only(self): + def test_sends_high_priority_for_one_to_one_only(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message in a one-to-one room. @@ -404,7 +408,7 @@ def test_sends_high_priority_for_one_to_one_only(self): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_mention(self): + def test_sends_high_priority_for_mention(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message containing the user's display name. @@ -480,7 +484,7 @@ def test_sends_high_priority_for_mention(self): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_atroom(self): + def test_sends_high_priority_for_atroom(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message that contains @room. @@ -563,7 +567,7 @@ def test_sends_high_priority_for_atroom(self): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_push_unread_count_group_by_room(self): + def test_push_unread_count_group_by_room(self) -> None: """ The HTTP pusher will group unread count by number of unread rooms. """ @@ -576,7 +580,7 @@ def test_push_unread_count_group_by_room(self): self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) - def test_push_unread_count_message_count(self): + def test_push_unread_count_message_count(self) -> None: """ The HTTP pusher will send the total unread message count. """ @@ -589,7 +593,7 @@ def test_push_unread_count_message_count(self): # last read receipt self._check_push_attempt(6, 3) - def _test_push_unread_count(self): + def _test_push_unread_count(self) -> None: """ Tests that the correct unread count appears in sent push notifications @@ -681,7 +685,7 @@ def _test_push_unread_count(self): self.helper.send(room_id, body="HELLO???", tok=other_access_token) - def _advance_time_and_make_push_succeed(self, expected_push_attempts): + def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None: self.pump() self.push_attempts[expected_push_attempts - 1][0].callback({}) @@ -708,7 +712,9 @@ def _check_push_attempt( expected_unread_count_last_push, ) - def _send_read_request(self, access_token, message_event_id, room_id): + def _send_read_request( + self, access_token: str, message_event_id: str, room_id: str + ) -> None: # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -748,7 +754,7 @@ def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: return user_id, access_token - def test_dont_notify_rule_overrides_message(self): + def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification """ From 12d1f82db213603972d60be3f46f6a36c3c2330f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 17 Mar 2022 13:46:05 +0000 Subject: [PATCH 077/230] Generate announcement links in release script (#12242) --- changelog.d/12242.misc | 1 + scripts-dev/release.py | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12242.misc diff --git a/changelog.d/12242.misc b/changelog.d/12242.misc new file mode 100644 index 000000000000..38e7e0f7d1e8 --- /dev/null +++ b/changelog.d/12242.misc @@ -0,0 +1 @@ +Generate announcement links in the release script. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 046453e65ff8..685fa32b03f4 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -66,11 +66,15 @@ def cli(): ./scripts-dev/release.py tag - # ... wait for asssets to build ... + # ... wait for assets to build ... ./scripts-dev/release.py publish ./scripts-dev/release.py upload + # Optional: generate some nice links for the announcement + + ./scripts-dev/release.py upload + If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the `tag`/`publish` command, then a new draft release will be created/published. """ @@ -415,6 +419,41 @@ def upload(): ) +@cli.command() +def announce(): + """Generate markdown to announce the release.""" + + current_version, _, _ = parse_version_from_module() + tag_name = f"v{current_version}" + + click.echo( + f""" +Hi everyone. Synapse {current_version} has just been released. + +[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\ +[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \ +[debs](https://packages.matrix.org/debian/) | \ +[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)""" + ) + + if "rc" in tag_name: + click.echo( + """ +Announce the RC in +- #homeowners:matrix.org (Synapse Announcements) +- #synapse-dev:matrix.org""" + ) + else: + click.echo( + """ +Announce the release in +- #homeowners:matrix.org (Synapse Announcements), bumping the version in the topic +- #synapse:matrix.org (Synapse Admins), bumping the version in the topic +- #synapse-dev:matrix.org +- #synapse-package-maintainers:matrix.org""" + ) + + def parse_version_from_module() -> Tuple[ version.Version, redbaron.RedBaron, redbaron.Node ]: From 872dbb0181714e201be082c4e8bd9b727c73f177 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 18 Mar 2022 13:51:41 +0000 Subject: [PATCH 078/230] Correct `check_username_for_spam` annotations and docs (#12246) * Formally type the UserProfile in user searches * export UserProfile in synapse.module_api * Update docs Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12246.doc | 1 + docs/modules/spam_checker_callbacks.md | 10 ++++---- synapse/events/spamcheck.py | 7 +++--- synapse/handlers/user_directory.py | 4 ++-- synapse/module_api/__init__.py | 2 ++ synapse/rest/client/user_directory.py | 4 ++-- .../storage/databases/main/user_directory.py | 23 +++++++++++++++---- synapse/types.py | 11 +++++++++ 8 files changed, 46 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12246.doc diff --git a/changelog.d/12246.doc b/changelog.d/12246.doc new file mode 100644 index 000000000000..e7fcc1b99c78 --- /dev/null +++ b/changelog.d/12246.doc @@ -0,0 +1 @@ +Correct `check_username_for_spam` annotations and docs. \ No newline at end of file diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2b672b78f9ae..472d95718087 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -172,7 +172,7 @@ any of the subsequent implementations of this callback. _First introduced in Synapse v1.37.0_ ```python -async def check_username_for_spam(user_profile: Dict[str, str]) -> bool +async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool ``` Called when computing search results in the user directory. The module must return a @@ -182,9 +182,11 @@ search results; otherwise return `False`. The profile is represented as a dictionary with the following keys: -* `user_id`: The Matrix ID for this user. -* `display_name`: The user's display name. -* `avatar_url`: The `mxc://` URL to the user's avatar. +* `user_id: str`. The Matrix ID for this user. +* `display_name: Optional[str]`. The user's display name, or `None` if this user + has not set a display name. +* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None` + if this user has not set an avatar. The module is given a copy of the original dictionary, so modifying it from within the module cannot modify a user's profile when included in user directory search results. diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 60904a55f5c0..cd80fcf9d13a 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -21,7 +21,6 @@ Awaitable, Callable, Collection, - Dict, List, Optional, Tuple, @@ -31,7 +30,7 @@ from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserProfile from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: @@ -50,7 +49,7 @@ USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] -CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] +CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]] LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ Optional[dict], @@ -383,7 +382,7 @@ async def user_may_publish_room(self, userid: str, room_id: str) -> bool: return True - async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: UserProfile) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d27ed2be6a99..048fd4bb8225 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -19,8 +19,8 @@ from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.roommember import ProfileInfo -from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -78,7 +78,7 @@ def __init__(self, hs: "HomeServer"): async def search_users( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d735c1d4616e..aa8256b36f1d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,6 +111,7 @@ StateMap, UserID, UserInfo, + UserProfile, create_requester, ) from synapse.util import Clock @@ -150,6 +151,7 @@ "EventBase", "StateMap", "ProfileInfo", + "UserProfile", ] logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index a47d9bd01da5..116c982ce637 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -19,7 +19,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict +from synapse.types import JsonMapping from ._base import client_patterns @@ -38,7 +38,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]: """Searches for users in directory Returns: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index e7fddd24262a..55cc9178f0b8 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,8 @@ cast, ) +from typing_extensions import TypedDict + from synapse.api.errors import StoreError if TYPE_CHECKING: @@ -40,7 +42,12 @@ from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id +from synapse.types import ( + JsonDict, + UserProfile, + get_domain_from_id, + get_localpart_from_id, +) from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -591,6 +598,11 @@ async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> No ) +class SearchResult(TypedDict): + limited: bool + results: List[UserProfile] + + class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -777,7 +789,7 @@ async def get_user_directory_stream_pos(self) -> Optional[int]: async def search_user_dir( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: @@ -910,8 +922,11 @@ async def search_user_dir( # This should be unreachable. raise Exception("Unrecognized database engine") - results = await self.db_pool.execute( - "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + results = cast( + List[UserProfile], + await self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + ), ) limited = len(results) > limit diff --git a/synapse/types.py b/synapse/types.py index 53be3583a013..5ce2a5b0a5ee 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -34,6 +34,7 @@ import attr from frozendict import frozendict from signedjson.key import decode_verify_key_bytes +from typing_extensions import TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -63,6 +64,10 @@ # JSON types. These could be made stronger, but will do for now. # A JSON-serialisable dict. JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] # A JSON-serialisable object. JsonSerializable = object @@ -791,3 +796,9 @@ class UserInfo: is_deactivated: bool is_guest: bool is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] From c46065fa3d6ad000f5da6e196c769371e0e76ec5 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 18 Mar 2022 16:24:18 +0100 Subject: [PATCH 079/230] Add some type hints to datastore (#12248) * inherit `MonthlyActiveUsersStore` from `RegistrationWorkerStore` Co-authored-by: Patrick Cloke --- changelog.d/12248.misc | 1 + mypy.ini | 6 - .../storage/databases/main/group_server.py | 156 +++++++++++------- .../databases/main/monthly_active_users.py | 38 +++-- 4 files changed, 117 insertions(+), 84 deletions(-) create mode 100644 changelog.d/12248.misc diff --git a/changelog.d/12248.misc b/changelog.d/12248.misc new file mode 100644 index 000000000000..2b1290d1e143 --- /dev/null +++ b/changelog.d/12248.misc @@ -0,0 +1 @@ +Add missing type hints for storage. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 51f47ff5be8f..d8b3b3f9e588 100644 --- a/mypy.ini +++ b/mypy.ini @@ -42,9 +42,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/event_federation.py - |synapse/storage/databases/main/group_server.py - |synapse/storage/databases/main/metrics.py - |synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py @@ -87,9 +84,6 @@ exclude = (?x) |tests/state/test_v2.py |tests/storage/test_background_update.py |tests/storage/test_base.py - |tests/storage/test_client_ips.py - |tests/storage/test_database.py - |tests/storage/test_event_federation.py |tests/storage/test_id_generators.py |tests/storage/test_roommember.py |tests/test_metrics.py diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 3f6086050bb2..0aef121d8348 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.types import JsonDict from synapse.util import json_encoder @@ -75,7 +79,7 @@ async def get_users_in_group( ) -> List[Dict[str, Any]]: # TODO: Pagination - keyvalues = {"group_id": group_id} + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -117,7 +121,7 @@ async def get_rooms_in_group( # TODO: Pagination - def _get_rooms_in_group_txn(txn): + def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]: sql = """ SELECT room_id, is_public FROM group_rooms WHERE group_id = ? @@ -176,8 +180,10 @@ async def get_rooms_for_summary_by_category( * "order": int, the sort order of rooms in this category """ - def _get_rooms_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_rooms_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -241,7 +247,7 @@ def _get_rooms_for_summary_txn(txn): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - async def get_group_categories(self, group_id): + async def get_group_categories(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, @@ -257,7 +263,7 @@ async def get_group_categories(self, group_id): for row in rows } - async def get_group_category(self, group_id, category_id): + async def get_group_category(self, group_id: str, category_id: str) -> JsonDict: category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -269,7 +275,7 @@ async def get_group_category(self, group_id, category_id): return category - async def get_group_roles(self, group_id): + async def get_group_roles(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, @@ -285,7 +291,7 @@ async def get_group_roles(self, group_id): for row in rows } - async def get_group_role(self, group_id, role_id): + async def get_group_role(self, group_id: str, role_id: str) -> JsonDict: role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -311,15 +317,19 @@ async def get_local_groups_for_room(self, room_id: str) -> List[str]: desc="get_local_groups_for_room", ) - async def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role( + self, group_id: str, include_private: bool = False + ) -> Tuple[List[JsonDict], JsonDict]: """Get the users and roles that should be included in a summary request Returns: ([users], [roles]) """ - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_users_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], JsonDict]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -406,7 +416,9 @@ async def is_user_invited_to_local_group( allow_none=True, ) - async def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group( + self, group_id: str, user_id: str + ) -> JsonDict: """Get a dict describing the membership of a user in a group. Example if joined: @@ -421,7 +433,7 @@ async def get_users_membership_info_in_group(self, group_id, user_id): An empty dict if the user is not join/invite/etc """ - def _get_users_membership_in_group_txn(txn): + def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict: row = self.db_pool.simple_select_one_txn( txn, table="group_users", @@ -463,10 +475,14 @@ async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: desc="get_publicised_groups_for_user", ) - async def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals( + self, valid_until_ms: int + ) -> List[Dict[str, Any]]: """Get all attestations that need to be renewed until givent time""" - def _get_attestations_need_renewals_txn(txn): + def _get_attestations_need_renewals_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: sql = """ SELECT group_id, user_id FROM group_attestations_renewals WHERE valid_until_ms <= ? @@ -478,7 +494,9 @@ def _get_attestations_need_renewals_txn(txn): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - async def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation( + self, group_id: str, user_id: str + ) -> Optional[JsonDict]: """Get the attestation that proves the remote agrees that the user is in the group. """ @@ -504,8 +522,8 @@ async def get_joined_groups(self, user_id: str) -> List[str]: desc="get_joined_groups", ) - async def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): + async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: + def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, type, membership, u.content FROM local_group_updates AS u @@ -528,15 +546,16 @@ def _get_all_groups_for_user_txn(txn): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - async def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( + async def get_groups_changes_for_user( + self, user_id: str, from_token: int, to_token: int + ) -> List[JsonDict]: + has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined] user_id, from_token ) if not has_changed: return [] - def _get_groups_changes_for_user_txn(txn): + def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, membership, type, u.content FROM local_group_updates AS u @@ -583,12 +602,14 @@ async def get_all_groups_changes( """ last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined] if not has_changed: return [], current_id, False - def _get_all_groups_changes_txn(txn): + def _get_all_groups_changes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, group_id, user_id, type, content FROM local_group_updates @@ -596,10 +617,13 @@ def _get_all_groups_changes_txn(txn): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ] + updates = cast( + List[Tuple[int, tuple]], + [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) + for stream_id, group_id, user_id, gtype, content_json in txn + ], + ) limited = False upto_token = current_id @@ -633,8 +657,8 @@ async def add_room_to_summary( self, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -661,11 +685,11 @@ async def add_room_to_summary( def _add_room_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -750,7 +774,7 @@ def _add_room_to_summary_txn( WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -766,7 +790,7 @@ def _add_room_to_summary_txn( "category_id": category_id, "room_id": room_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -785,7 +809,7 @@ def _add_room_to_summary_txn( ) async def remove_room_from_summary( - self, group_id: str, room_id: str, category_id: str + self, group_id: str, room_id: str, category_id: Optional[str] ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID @@ -808,8 +832,8 @@ async def upsert_group_category( is_public: Optional[bool], ) -> None: """Add/update room category for group""" - insertion_values = {} - update_values = {"category_id": category_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"category_id": category_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -844,8 +868,8 @@ async def upsert_group_role( is_public: Optional[bool], ) -> None: """Add/remove user role""" - insertion_values = {} - update_values = {"role_id": role_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"role_id": role_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -876,8 +900,8 @@ async def add_user_to_summary( self, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) user's entry in summary. @@ -904,13 +928,13 @@ async def add_user_to_summary( def _add_user_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], - ): + ) -> None: """Add (or update) user's entry in summary. Args: @@ -989,7 +1013,7 @@ def _add_user_to_summary_txn( WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -1005,7 +1029,7 @@ def _add_user_to_summary_txn( "role_id": role_id, "user_id": user_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -1024,7 +1048,7 @@ def _add_user_to_summary_txn( ) async def remove_user_from_summary( - self, group_id: str, user_id: str, role_id: str + self, group_id: str, user_id: str, role_id: Optional[str] ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID @@ -1065,7 +1089,7 @@ async def add_user_to_group( Optional if the user and group are on the same server """ - def _add_user_to_group_txn(txn): + def _add_user_to_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_insert_txn( txn, table="group_users", @@ -1108,7 +1132,7 @@ def _add_user_to_group_txn(txn): await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) async def remove_user_from_group(self, group_id: str, user_id: str) -> None: - def _remove_user_from_group_txn(txn): + def _remove_user_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_users", @@ -1159,7 +1183,7 @@ async def update_room_in_group_visibility( ) async def remove_room_from_group(self, group_id: str, room_id: str) -> None: - def _remove_room_from_group_txn(txn): + def _remove_room_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_rooms", @@ -1216,7 +1240,9 @@ async def register_user_group_membership( content = content or {} - def _register_user_group_membership_txn(txn, next_id): + def _register_user_group_membership_txn( + txn: LoggingTransaction, next_id: int + ) -> int: # TODO: Upsert? self.db_pool.simple_delete_txn( txn, @@ -1249,7 +1275,7 @@ def _register_user_group_membership_txn(txn, next_id): ), }, ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined] # TODO: Insert profile to ensure it comes down stream if its a join. @@ -1289,7 +1315,7 @@ def _register_user_group_membership_txn(txn, next_id): return next_id - async with self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined] res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, @@ -1298,7 +1324,13 @@ def _register_user_group_membership_txn(txn, next_id): return res async def create_group( - self, group_id, user_id, name, avatar_url, short_description, long_description + self, + group_id: str, + user_id: str, + name: str, + avatar_url: str, + short_description: str, + long_description: str, ) -> None: await self.db_pool.simple_insert( table="groups", @@ -1313,7 +1345,7 @@ async def create_group( desc="create_group", ) - async def update_group_profile(self, group_id, profile): + async def update_group_profile(self, group_id: str, profile: JsonDict) -> None: await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, @@ -1361,8 +1393,8 @@ async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: desc="remove_attestation_renewal", ) - def get_group_stream_token(self): - return self._group_updates_id_gen.get_current_token() + def get_group_stream_token(self) -> int: + return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined] async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. @@ -1371,7 +1403,7 @@ async def delete_group(self, group_id: str) -> None: group_id: The group ID to delete. """ - def _delete_group_txn(txn): + def _delete_group_txn(txn: LoggingTransaction) -> None: tables = [ "groups", "group_users", diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e9a0cdc6be94..216622964aa7 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, + LoggingTransaction, make_in_list_sql_clause, ) +from synapse.storage.databases.main.registration import RegistrationWorkerStore from synapse.util.caches.descriptors import cached from synapse.util.threepids import canonicalise_email @@ -56,7 +58,7 @@ async def get_monthly_active_count(self) -> int: Number of current monthly active users """ - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: # Exclude app service users sql = """ SELECT COUNT(*) @@ -66,7 +68,7 @@ def _count_users(txn): WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); """ txn.execute(sql) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_users", _count_users) @@ -84,7 +86,7 @@ async def get_monthly_active_count_by_service(self) -> Dict[str, int]: """ - def _count_users_by_service(txn): + def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]: sql = """ SELECT COALESCE(appservice_id, 'native'), COUNT(*) FROM monthly_active_users @@ -93,7 +95,7 @@ def _count_users_by_service(txn): """ txn.execute(sql) - result = txn.fetchall() + result = cast(List[Tuple[str, int]], txn.fetchall()) return dict(result) return await self.db_pool.runInteraction( @@ -141,12 +143,12 @@ async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]: ) @wrap_as_background_process("reap_monthly_active_users") - async def reap_monthly_active_users(self): + async def reap_monthly_active_users(self) -> None: """Cleans out monthly active user table to ensure that no stale entries exist. """ - def _reap_users(txn, reserved_users): + def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: reserved_users (tuple): reserved users to preserve @@ -210,10 +212,10 @@ def _reap_users(txn, reserved_users): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._invalidate_all_cache_and_stream( + self._invalidate_all_cache_and_stream( # type: ignore[attr-defined] txn, self.user_last_seen_monthly_active ) - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined] reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( @@ -221,7 +223,7 @@ def _reap_users(txn, reserved_users): ) -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore): def __init__( self, database: DatabasePool, @@ -242,13 +244,15 @@ def __init__( hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) - def _initialise_reserved_users(self, txn, threepids): + def _initialise_reserved_users( + self, txn: LoggingTransaction, threepids: List[dict] + ) -> None: """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve + txn: + threepids: List of threepid dicts to reserve """ # XXX what is this function trying to achieve? It upserts into @@ -299,7 +303,9 @@ async def upsert_monthly_active_user(self, user_id: str) -> None: "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - def upsert_monthly_active_user_txn(self, txn, user_id): + def upsert_monthly_active_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> None: """Updates or inserts monthly active user member We consciously do not call is_support_txn from this method because it @@ -336,7 +342,7 @@ def upsert_monthly_active_user_txn(self, txn, user_id): txn, self.user_last_seen_monthly_active, (user_id,) ) - async def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id: str) -> None: """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -345,7 +351,7 @@ async def populate_monthly_active_users(self, user_id): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = await self.is_guest(user_id) + is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] if is_guest: return is_trial = await self.is_trial_user(user_id) From 2177e356bc4b62e7a84c6fbee1d48de1abc940b5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 12:51:27 -0400 Subject: [PATCH 080/230] Sync more worker regexes in the documentation. (#12243) --- changelog.d/12243.doc | 1 + docs/workers.md | 30 ++++++++++++++---------------- 2 files changed, 15 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12243.doc diff --git a/changelog.d/12243.doc b/changelog.d/12243.doc new file mode 100644 index 000000000000..b2031f0a4039 --- /dev/null +++ b/changelog.d/12243.doc @@ -0,0 +1 @@ +Remove incorrect prefixes in the worker documentation for some endpoints. diff --git a/docs/workers.md b/docs/workers.md index 8751134e654d..9eb4194e4dcd 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -185,8 +185,8 @@ worker: refer to the [stream writers](#stream-writers) section below for further information. # Sync requests - ^/_matrix/client/(v2_alpha|r0|v3)/sync$ - ^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$ + ^/_matrix/client/(r0|v3)/sync$ + ^/_matrix/client/(api/v1|r0|v3)/events$ ^/_matrix/client/(api/v1|r0|v3)/initialSync$ ^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$ @@ -200,13 +200,9 @@ information. ^/_matrix/federation/v1/query/ ^/_matrix/federation/v1/make_join/ ^/_matrix/federation/v1/make_leave/ - ^/_matrix/federation/v1/send_join/ - ^/_matrix/federation/v2/send_join/ - ^/_matrix/federation/v1/send_leave/ - ^/_matrix/federation/v2/send_leave/ - ^/_matrix/federation/v1/invite/ - ^/_matrix/federation/v2/invite/ - ^/_matrix/federation/v1/query_auth/ + ^/_matrix/federation/(v1|v2)/send_join/ + ^/_matrix/federation/(v1|v2)/send_leave/ + ^/_matrix/federation/(v1|v2)/invite/ ^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ @@ -274,6 +270,8 @@ information. Additionally, the following REST endpoints can be handled for GET requests: ^/_matrix/federation/v1/groups/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/ + ^/_matrix/client/(r0|v3|unstable)/groups/ Pagination requests can also be handled, but all requests for a given room must be routed to the same instance. Additionally, care must be taken to @@ -397,23 +395,23 @@ the stream writer for the `typing` stream: The following endpoints should be routed directly to the worker configured as the stream writer for the `to_device` stream: - ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ + ^/_matrix/client/(r0|v3|unstable)/sendToDevice/ ##### The `account_data` stream The following endpoints should be routed directly to the worker configured as the stream writer for the `account_data` stream: - ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags - ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data + ^/_matrix/client/(r0|v3|unstable)/.*/tags + ^/_matrix/client/(r0|v3|unstable)/.*/account_data ##### The `receipts` stream The following endpoints should be routed directly to the worker configured as the stream writer for the `receipts` stream: - ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt - ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers ##### The `presence` stream @@ -528,7 +526,7 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for Handles searches in the user directory. It can handle REST endpoints matching the following regular expressions: - ^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$ + ^/_matrix/client/(r0|v3|unstable)/user_directory/search$ When using this worker you must also set `update_user_directory: False` in the shared configuration file to stop the main synapse running background @@ -540,7 +538,7 @@ Proxies some frequently-requested client endpoints to add caching and remove load from the main synapse. It can handle REST endpoints matching the following regular expressions: - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload + ^/_matrix/client/(r0|v3|unstable)/keys/upload If `use_presence` is False in the homeserver config, it can also handle REST endpoints matching the following regular expressions: From 80e0e1f35e6b1cdfa0267f9c40a6f212b7d774de Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:15:45 -0400 Subject: [PATCH 081/230] Only fetch thread participation for events with threads. (#12228) We fetch the thread summary in two phases: 1. The summary that is shared by all users (count of messages and latest event). 2. Whether the requesting user has participated in the thread. There's no use in attempting step 2 for events which did not return a summary from step 1. --- changelog.d/12228.bugfix | 1 + synapse/storage/databases/main/relations.py | 4 +- tests/rest/client/test_relations.py | 509 +++++++++++--------- tests/server.py | 20 +- 4 files changed, 289 insertions(+), 245 deletions(-) create mode 100644 changelog.d/12228.bugfix diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix new file mode 100644 index 000000000000..47557771399e --- /dev/null +++ b/changelog.d/12228.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c4869d64e663..af2334a65ea0 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -857,7 +857,9 @@ async def get_bundled_aggregations( summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. - participated = await self._get_threads_participated(summaries.keys(), user_id) + participated = await self._get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) for event_id, summary in summaries.items(): if summary: thread_count, latest_thread_event, edit = summary diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f3741b300185..329690f8f7c8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -155,6 +155,16 @@ def _get_aggregations(self) -> List[JsonDict]: self.assertEqual(200, channel.code, channel.json_body) return channel.json_body["chunk"] + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: @@ -291,202 +301,6 @@ def test_aggregation_must_be_annotation(self) -> None: ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_2 = channel.json_body["event_id"] - - self._send_relation(RelationTypes.THREAD, "m.room.test") - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) - - # Request search. - channel = self.make_request( - "POST", - "/search", - # Search term matches the parent message. - content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - chunk = [ - result["result"] - for result in channel.json_body["search_categories"]["room_events"][ - "results" - ] - ] - assert_bundle(self._find_event_in_chunk(chunk)) - - def test_aggregation_get_event_for_annotation(self) -> None: - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - - def test_aggregation_get_event_for_thread(self) -> None: - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEqual( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -796,7 +610,7 @@ def test_edit_thread(self) -> None: threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, @@ -836,7 +650,7 @@ def test_edit_edit(self) -> None: edit_event_id = channel.json_body["event_id"] # Edit the edit event. - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -912,16 +726,6 @@ def test_unknown_relations(self) -> None: self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event - - raise AssertionError(f"Event {self.parent_id} not found in chunk") - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -981,34 +785,6 @@ def test_background_update(self) -> None: [annotation_event_id_good, thread_event_id], ) - def test_bundled_aggregations_with_filter(self) -> None: - """ - If "unsigned" is an omitted field (due to filtering), adding the bundled - aggregations should not break. - - Note that the spec allows for a server to return additional fields beyond - what is specified. - """ - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - - # Note that the sync filter does not include "unsigned" as a field. - filter = urllib.parse.quote_plus( - b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' - ) - channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Ensure the timeline is limited, find the parent event. - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - parent_event = self._find_event_in_chunk(room_timeline["events"]) - - # Ensure there's bundled aggregations on it. - self.assertIn("unsigned", parent_event) - self.assertIn("m.relations", parent_event["unsigned"]) - class RelationPaginationTestCase(BaseRelationsTestCase): def test_basic_paginate_relations(self) -> None: @@ -1255,7 +1031,7 @@ def test_aggregation_pagination_within_group(self) -> None: idx += 1 # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") prev_token = "" found_event_ids: List[str] = [] @@ -1291,6 +1067,263 @@ def test_aggregation_pagination_within_group(self) -> None: self.assertEqual(found_event_ids, expected_event_ids) +class BundledAggregationsTestCase(BaseRelationsTestCase): + """ + See RelationsTestCase.test_edit for a similar test for edits. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + """ + + def _test_bundled_aggregations( + self, + relation_type: str, + assertion_callable: Callable[[JsonDict], None], + expected_db_txn_for_event: int, + ) -> None: + """ + Makes requests to various endpoints which should include bundled aggregations + and then calls an assertion function on the bundled aggregations. + + Args: + relation_type: The field to search for in the `m.relations` field in unsigned. + assertion_callable: Called with the contents of unsigned["m.relations"][relation_type] + for relation-specific assertions. + expected_db_txn_for_event: The number of database transactions which + are expected for a call to /event/. + """ + + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + + # Ensure the fields are as expected. + self.assertCountEqual(relations_dict.keys(), (relation_type,)) + assertion_callable(relations_dict[relation_type]) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) + + # Request sync. + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_annotation(self) -> None: + """ + Test that annotations get correctly bundled. + """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_reference(self) -> None: + """ + Test that references get correctly bundled. + """ + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_thread(self) -> None: + """ + Test that threads get correctly bundled. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9) + + def test_aggregation_get_event_for_annotation(self) -> None: + """Test that annotations do not get bundled aggregations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self) -> None: + """Test that threads get bundled aggregations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEqual( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + + class RelationRedactionTestCase(BaseRelationsTestCase): """ Test the behaviour of relations when the parent or child event is redacted. diff --git a/tests/server.py b/tests/server.py index 82990c2eb9df..6ce2a17bf442 100644 --- a/tests/server.py +++ b/tests/server.py @@ -54,13 +54,18 @@ ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import ( + AccumulatingProtocol, + MemoryReactor, + MemoryReactorClock, +) from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -88,18 +93,19 @@ class TimedOutException(Exception): """ -@attr.s +@attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). """ - site = attr.ib(type=Union[Site, "FakeSite"]) - _reactor = attr.ib() - result = attr.ib(type=dict, default=attr.Factory(dict)) - _ip = attr.ib(type=str, default="127.0.0.1") + site: Union[Site, "FakeSite"] + _reactor: MemoryReactor + result: dict = attr.Factory(dict) + _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None + resource_usage: Optional[ContextResourceUsage] = None @property def json_body(self): @@ -168,6 +174,8 @@ def unregisterProducer(self): def requestDone(self, _self): self.result["done"] = True + if isinstance(_self, SynapseRequest): + self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): # We give an address so that getClientIP returns a non null entry, From 8fe930c215f69913fbcd96d609ec6950644e4ec4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:49:32 -0400 Subject: [PATCH 082/230] Move get_bundled_aggregations to relations handler. (#12237) The get_bundled_aggregations code is fairly high-level and uses a lot of store methods, we move it into the handler as that seems like a better fit. --- changelog.d/12237.misc | 1 + synapse/events/utils.py | 2 +- synapse/handlers/pagination.py | 5 +- synapse/handlers/relations.py | 151 +++++++++++++++++++- synapse/handlers/room.py | 5 +- synapse/handlers/search.py | 3 +- synapse/handlers/sync.py | 9 +- synapse/rest/client/room.py | 3 +- synapse/storage/databases/main/relations.py | 151 +------------------- 9 files changed, 173 insertions(+), 157 deletions(-) create mode 100644 changelog.d/12237.misc diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc new file mode 100644 index 000000000000..41c9dcbd37f6 --- /dev/null +++ b/changelog.d/12237.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0520068e0f9..71200621277e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,8 +38,8 @@ from . import EventBase if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.server import HomeServer - from synapse.storage.databases.main.relations import BundledAggregations # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 41679f7f866b..876b879483e7 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -134,6 +134,7 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() + self._relations_handler = hs.get_relations_handler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. @@ -539,7 +540,9 @@ async def get_messages( state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() - aggregations = await self.store.get_bundled_aggregations(events, user_id) + aggregations = await self._relations_handler.get_bundled_aggregations( + events, user_id + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e475475ad02..57135d45197b 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,18 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +import attr +from frozendict import frozendict + +from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.types import JsonDict, Requester, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + # The latest event in the thread. + latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. + count: int + # True if the current user has sent an event to the thread. + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main @@ -103,7 +138,7 @@ async def get_relations( ) # The relations returned for the requested event do include their # bundled aggregations. - aggregations = await self._main_store.get_bundled_aggregations( + aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) serialized_events = self._event_serializer.serialize_events( @@ -115,3 +150,115 @@ async def get_relations( return_value["original_event"] = original_event return return_value + + async def _get_bundled_aggregation_for_event( + self, event: EventBase, user_id: str + ) -> Optional[BundledAggregations]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations = BundledAggregations() + + annotations = await self._main_store.get_aggregation_groups_for_event( + event_id, room_id + ) + if annotations.chunk: + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) + + references = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations.references = await references.to_dict(cast("DataStore", self)) + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase], user_id: str + ) -> Dict[str, BundledAggregations]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. + for event in events_by_id.values(): + event_result = await self._get_bundled_aggregation_for_event(event, user_id) + if event_result: + results[event.event_id] = event_result + + # Fetch any edits (but not for redacted events). + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + summaries = await self._main_store.get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._main_store.get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + + return results diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9735631fcd3..092e185c9950 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,8 +60,8 @@ from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state +from synapse.handlers.relations import BundledAggregations from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -1118,6 +1118,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state + self._relations_handler = hs.get_relations_handler() async def get_event_context( self, @@ -1190,7 +1191,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: event = filtered[0] # Fetch the aggregations. - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( itertools.chain(events_before, (event,), events_after), user.to_string(), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index aa16e417eb30..30eddda65fc2 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,6 +54,7 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() self.state_store = self.storage.state self.auth = hs.get_auth() @@ -354,7 +355,7 @@ async def _search( aggregations = None if self._msc3666_enabled: - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( # Generate an iterable of EventBase for all the events that will be # returned, including contextual events. itertools.chain( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c9d6a18bd700..6c569cfb1c88 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -33,11 +33,11 @@ from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -269,6 +269,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() + self._relations_handler = hs.get_relations_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -638,8 +639,10 @@ async def _load_filtered_recents( # as clients will have all the necessary information. bundled_aggregations = None if limited or newly_joined_room: - bundled_aggregations = await self.store.get_bundled_aggregations( - recents, sync_config.user.to_string() + bundled_aggregations = ( + await self._relations_handler.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) ) return TimelineBatch( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8a06ab8c5f05..47e152c8cc7a 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -645,6 +645,7 @@ def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.auth = hs.get_auth() async def on_GET( @@ -663,7 +664,7 @@ async def on_GET( if event: # Ensure there are bundled aggregations available. - aggregations = await self._store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( [event], requester.user.to_string() ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index af2334a65ea0..b2295fd51f60 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -27,7 +27,6 @@ ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.events import EventBase @@ -41,45 +40,15 @@ from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - annotations: Optional[JsonDict] = None - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None - - def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) - - class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -384,7 +353,7 @@ def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") - async def _get_applicable_edits( + async def get_applicable_edits( self, event_ids: Collection[str] ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given @@ -473,7 +442,7 @@ def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: raise NotImplementedError() @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") - async def _get_thread_summaries( + async def get_thread_summaries( self, event_ids: Collection[str] ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. @@ -587,7 +556,7 @@ def _get_thread_summaries_txn( latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] # Check to see if any of those events are edited. - latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + latest_edits = await self.get_applicable_edits(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -610,7 +579,7 @@ def get_thread_participated(self, event_id: str, user_id: str) -> bool: raise NotImplementedError() @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") - async def _get_threads_participated( + async def get_threads_participated( self, event_ids: Collection[str], user_id: str ) -> Dict[str, bool]: """Get whether the requesting user participated in the given threads. @@ -766,116 +735,6 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_aggregation_groups_for_event(event_id, room_id) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) - - references = await self.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - - async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str - ) -> Dict[str, BundledAggregations]: - """Generate bundled aggregations for events. - - Args: - events: The iterable of events to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. - """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } - - # event ID -> bundled aggregation in non-serialized form. - results: Dict[str, BundledAggregations] = {} - - # Fetch other relations per event. - for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result: - results[event.event_id] = event_result - - # Fetch any edits (but not for redacted events). - edits = await self._get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] - ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit - - # Fetch thread summaries. - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - [event_id for event_id, summary in summaries.items() if summary], user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) - - return results - class RelationsStore(RelationsWorkerStore): pass From bf9d549e3ad944e1e53a2ecc898640d690bf1eac Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 18 Mar 2022 19:03:46 +0000 Subject: [PATCH 083/230] Try to detect borked package installations. (#12244) * Try to detect borked package installations. Fixes #12223. Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12244.misc | 1 + synapse/util/check_dependencies.py | 24 +++++++++++++++++++++++- tests/util/test_check_dependencies.py | 15 ++++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12244.misc diff --git a/changelog.d/12244.misc b/changelog.d/12244.misc new file mode 100644 index 000000000000..950d48e4c68e --- /dev/null +++ b/changelog.d/12244.misc @@ -0,0 +1 @@ +Improve error message when dependencies check finds a broken installation. \ No newline at end of file diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 12cd8049392f..66f1da750289 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -128,6 +128,19 @@ def _incorrect_version( ) +def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str: + if extra: + return ( + f"Synapse {VERSION} needs {requirement} for {extra}, " + f"but can't determine {requirement.name}'s version" + ) + else: + return ( + f"Synapse {VERSION} needs {requirement}, " + f"but can't determine {requirement.name}'s version" + ) + + def check_requirements(extra: Optional[str] = None) -> None: """Check Synapse's dependencies are present and correctly versioned. @@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled.append(requirement.name) errors.append(_not_installed(requirement, extra)) else: + if dist.version is None: + # This shouldn't happen---it suggests a borked virtualenv. (See #12223) + # Try to give a vaguely helpful error message anyway. + # Type-ignore: the annotations don't reflect reality: see + # https://github.com/python/typeshed/issues/7513 + # https://bugs.python.org/issue47060 + deps_unfulfilled.append(requirement.name) # type: ignore[unreachable] + errors.append(_no_reported_version(requirement, extra)) + # We specify prereleases=True to allow prereleases such as RCs. - if not requirement.specifier.contains(dist.version, prereleases=True): + elif not requirement.specifier.contains(dist.version, prereleases=True): deps_unfulfilled.append(requirement.name) errors.append(_incorrect_version(requirement, dist.version, extra)) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 38e9f58ac6ca..5d1aa025d127 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -12,7 +12,7 @@ class DummyDistribution(metadata.Distribution): - def __init__(self, version: str): + def __init__(self, version: object): self._version = version @property @@ -30,6 +30,7 @@ def read_text(self, filename): old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") new_release_candidate = DummyDistribution("1.2.3rc4") +distribution_with_no_version = DummyDistribution(None) # could probably use stdlib TestCase --- no need for twisted here @@ -67,6 +68,18 @@ def test_mandatory_dependency(self) -> None: # should not raise check_requirements() + def test_version_reported_as_none(self) -> None: + """Complain if importlib.metadata.version() returns None. + + This shouldn't normally happen, but it was seen in the wild (#12223). + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(distribution_with_no_version): + self.assertRaises(DependencyException, check_requirements) + def test_checks_ignore_dev_dependencies(self) -> None: """Bot generic and per-extra checks should ignore dev dependencies.""" with patch( From afa17f0eabf06087d53697eafc748f7c935fb13f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 21 Mar 2022 11:23:32 +0000 Subject: [PATCH 084/230] Return a 404 from `/state` for an outlier (#12087) * Replace `get_state_for_pdu` with `get_state_ids_for_pdu` and `get_events_as_list`. * Return a 404 from `/state` and `/state_ids` for an outlier --- changelog.d/12087.bugfix | 1 + synapse/federation/federation_server.py | 7 ++- synapse/handlers/federation.py | 61 +++++++++---------------- 3 files changed, 25 insertions(+), 44 deletions(-) create mode 100644 changelog.d/12087.bugfix diff --git a/changelog.d/12087.bugfix b/changelog.d/12087.bugfix new file mode 100644 index 000000000000..6dacdddd0dcf --- /dev/null +++ b/changelog.d/12087.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which caused the `/_matrix/federation/v1/state` and `.../state_ids` endpoints to return incorrect or invalid data when called for an event which we have stored as an "outlier". diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 482bbdd86744..af2d0f7d7932 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -22,7 +22,6 @@ Callable, Collection, Dict, - Iterable, List, Optional, Tuple, @@ -577,10 +576,10 @@ async def _on_state_ids_request_compute( async def _on_context_state_request_compute( self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: + pdus: Collection[EventBase] if event_id: - pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu( - room_id, event_id - ) + event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + pdus = await self.store.get_events_as_list(event_ids) else: pdus = (await self.state.get_current_state(room_id)).values() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index db39aeabded6..350ec9c03af1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -950,54 +950,35 @@ async def on_make_knock_request( return event - async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: - """Returns the state at the event. i.e. not including said event.""" - - event = await self.store.get_event(event_id, check_room_id=room_id) - - state_groups = await self.state_store.get_state_groups(room_id, [event_id]) - - if state_groups: - _, state = list(state_groups.items()).pop() - results = {(e.type, e.state_key): e for e in state} - - if event.is_state(): - # Get previous state - if "replaces_state" in event.unsigned: - prev_id = event.unsigned["replaces_state"] - if prev_id != event.event_id: - prev_event = await self.store.get_event(prev_id) - results[(event.type, event.state_key)] = prev_event - else: - del results[(event.type, event.state_key)] - - res = list(results.values()) - return res - else: - return [] - async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) + if event.internal_metadata.outlier: + raise NotFoundError("State not known at event %s" % (event_id,)) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) - if state_groups: - _, state = list(state_groups.items()).pop() - results = state + # get_state_groups_ids should return exactly one result + assert len(state_groups) == 1 - if event.is_state(): - # Get previous state - if "replaces_state" in event.unsigned: - prev_id = event.unsigned["replaces_state"] - if prev_id != event.event_id: - results[(event.type, event.state_key)] = prev_id - else: - results.pop((event.type, event.state_key), None) + state_map = next(iter(state_groups.values())) - return list(results.values()) - else: - return [] + state_key = event.get_state_key() + if state_key is not None: + # the event was not rejected (get_event raises a NotFoundError for rejected + # events) so the state at the event should include the event itself. + assert ( + state_map.get((event.type, state_key)) == event.event_id + ), "State at event did not include event itself" + + # ... but we need the state *before* that event + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + state_map[(event.type, state_key)] = prev_id + else: + del state_map[(event.type, state_key)] + + return list(state_map.values()) async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int From 1530cef19244e21d8b160bee2d925dcabbc0c4be Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 21 Mar 2022 11:52:10 +0000 Subject: [PATCH 085/230] Make it possible to enable compression for the metrics HTTP resource (#12258) * Make it possible to enable compression for the metrics HTTP resource This can provide significant bandwidth savings pulling metrics from synapse instances. * Add changelog file. * Fix type hint --- changelog.d/12258.misc | 1 + synapse/app/homeserver.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12258.misc diff --git a/changelog.d/12258.misc b/changelog.d/12258.misc new file mode 100644 index 000000000000..80024c8e91ee --- /dev/null +++ b/changelog.d/12258.misc @@ -0,0 +1 @@ +Compress metrics HTTP resource when enabled. Contributed by Nick @ Beeper. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e4dc04c0b40f..ad2b7c9515d8 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -261,7 +261,10 @@ def _configure_named_resource( resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "metrics" and self.config.metrics.enable_metrics: - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + metrics_resource: Resource = MetricsResource(RegistryProxy) + if compress: + metrics_resource = gz_wrap(metrics_resource) + resources[METRICS_PREFIX] = metrics_resource if name == "replication": resources[REPLICATION_PREFIX] = ReplicationRestResource(self) From 6134b3079e954e15c5a92ad3b89050085197b851 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 21 Mar 2022 12:16:46 +0000 Subject: [PATCH 086/230] Reword 'Choose your user name' as 'Choose your account name' in the SSO registration template, in order to comply with SIWA guidelines. (#12260) * Reword as 'Choose your account name' * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/12260.misc | 1 + synapse/res/templates/sso_auth_account_details.html | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12260.misc diff --git a/changelog.d/12260.misc b/changelog.d/12260.misc new file mode 100644 index 000000000000..deacf034deae --- /dev/null +++ b/changelog.d/12260.misc @@ -0,0 +1 @@ +Reword 'Choose your user name' as 'Choose your account name' in the SSO registration template, in order to comply with SIWA guidelines. \ No newline at end of file diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index 41315e4fd4da..b231aace01e6 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -130,7 +130,7 @@
-

Choose your user name

+

Choose your account name

This is required to create your account on {{ server_name }}, and you can't change this later.

From 9d21ecf7ceab55bc19c4457b8b07401b0b1623a7 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 21 Mar 2022 14:43:16 +0100 Subject: [PATCH 087/230] Add type hints to tests files. (#12256) --- changelog.d/12256.misc | 1 + mypy.ini | 2 - tests/handlers/test_typing.py | 35 ++++++----- tests/push/test_push_rule_evaluator.py | 23 +++---- tests/storage/test_background_update.py | 48 ++++++++------- tests/storage/test_id_generators.py | 80 +++++++++++++------------ 6 files changed, 101 insertions(+), 88 deletions(-) create mode 100644 changelog.d/12256.misc diff --git a/changelog.d/12256.misc b/changelog.d/12256.misc new file mode 100644 index 000000000000..c5b635679931 --- /dev/null +++ b/changelog.d/12256.misc @@ -0,0 +1 @@ +Add type hints to tests files. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index d8b3b3f9e588..24d4ba15d452 100644 --- a/mypy.ini +++ b/mypy.ini @@ -82,9 +82,7 @@ exclude = (?x) |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py |tests/state/test_v2.py - |tests/storage/test_background_update.py |tests/storage/test_base.py - |tests/storage/test_id_generators.py |tests/storage/test_roommember.py |tests/test_metrics.py |tests/test_phone_home.py diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f91a80b9fa57..ffd5c4cb938c 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -18,11 +18,14 @@ from unittest.mock import ANY, Mock, call from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer -from synapse.types import UserID, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -42,7 +45,9 @@ OTHER_ROOM_ID = "another-room" -def _expect_edu_transaction(edu_type, content, origin="test"): +def _expect_edu_transaction( + edu_type: str, content: JsonDict, origin: str = "test" +) -> JsonDict: return { "origin": origin, "origin_server_ts": 1000000, @@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"): } -def _make_edu_transaction_json(edu_type, content): +def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes: return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") class TypingNotificationsTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) @@ -83,7 +88,7 @@ def create_resource_dict(self) -> Dict[str, Resource]: d["/_matrix/federation"] = TransportLayerServer(self.hs) return d - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -111,24 +116,24 @@ def get_received_txn_response(*args): self.room_members = [] - async def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id: str, user_id: str) -> None: if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") return None hs.get_auth().check_user_in_room = check_user_in_room - async def check_host_in_room(room_id, server_name): + async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID hs.get_event_auth_handler().check_host_in_room = check_host_in_room - def get_joined_hosts_for_room(room_id): + def get_joined_hosts_for_room(room_id: str): return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - async def get_users_in_room(room_id): + async def get_users_in_room(room_id: str): return {str(u) for u in self.room_members} self.datastore.get_users_in_room = get_users_in_room @@ -153,7 +158,7 @@ async def get_users_in_room(room_id): lambda *args, **kwargs: make_awaitable(None) ) - def test_started_typing_local(self): + def test_started_typing_local(self) -> None: self.room_members = [U_APPLE, U_BANANA] self.assertEqual(self.event_source.get_current_key(), 0) @@ -187,7 +192,7 @@ def test_started_typing_local(self): ) @override_config({"send_federation": True}) - def test_started_typing_remote_send(self): + def test_started_typing_remote_send(self) -> None: self.room_members = [U_APPLE, U_ONION] self.get_success( @@ -217,7 +222,7 @@ def test_started_typing_remote_send(self): try_trailing_slash_on_400=True, ) - def test_started_typing_remote_recv(self): + def test_started_typing_remote_recv(self) -> None: self.room_members = [U_APPLE, U_ONION] self.assertEqual(self.event_source.get_current_key(), 0) @@ -256,7 +261,7 @@ def test_started_typing_remote_recv(self): ], ) - def test_started_typing_remote_recv_not_in_room(self): + def test_started_typing_remote_recv_not_in_room(self) -> None: self.room_members = [U_APPLE, U_ONION] self.assertEqual(self.event_source.get_current_key(), 0) @@ -292,7 +297,7 @@ def test_started_typing_remote_recv_not_in_room(self): self.assertEqual(events[1], 0) @override_config({"send_federation": True}) - def test_stopped_typing(self): + def test_stopped_typing(self) -> None: self.room_members = [U_APPLE, U_BANANA, U_ONION] # Gut-wrenching @@ -343,7 +348,7 @@ def test_stopped_typing(self): [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) - def test_typing_timeout(self): + def test_typing_timeout(self) -> None: self.room_members = [U_APPLE, U_BANANA] self.assertEqual(self.event_source.get_current_key(), 0) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 3849beb9d6d6..5dba1870762e 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Dict, Optional, Union import frozendict @@ -20,12 +20,13 @@ from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.types import JsonDict from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content): + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent: event = FrozenEvent( { "event_id": "$event_id", @@ -39,12 +40,12 @@ def _get_evaluator(self, content): ) room_member_count = 0 sender_power_level = 0 - power_levels = {} + power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluatorForEvent( event, room_member_count, sender_power_level, power_levels ) - def test_display_name(self): + def test_display_name(self) -> None: """Check for a matching display name in the body of the event.""" evaluator = self._get_evaluator({"body": "foo bar baz"}) @@ -71,20 +72,20 @@ def test_display_name(self): self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) def _assert_matches( - self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None ) -> None: evaluator = self._get_evaluator(content) self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) def _assert_not_matches( - self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None ) -> None: evaluator = self._get_evaluator(content) self.assertFalse( evaluator.matches(condition, "@user:test", "display_name"), msg ) - def test_event_match_body(self): + def test_event_match_body(self) -> None: """Check that event_match conditions on content.body work as expected""" # if the key is `content.body`, the pattern matches substrings. @@ -165,7 +166,7 @@ def test_event_match_body(self): r"? after \ should match any character", ) - def test_event_match_non_body(self): + def test_event_match_non_body(self) -> None: """Check that event_match conditions on other keys work as expected""" # if the key is anything other than 'content.body', the pattern must match the @@ -241,7 +242,7 @@ def test_event_match_non_body(self): "pattern should not match before a newline", ) - def test_no_body(self): + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) @@ -250,7 +251,7 @@ def test_no_body(self): } self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - def test_invalid_body(self): + def test_invalid_body(self) -> None: """A non-string body should not break the evaluator.""" condition = { "kind": "contains_display_name", @@ -260,7 +261,7 @@ def test_invalid_body(self): evaluator = self._get_evaluator({"body": body}) self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - def test_tweaks_for_actions(self): + def test_tweaks_for_actions(self) -> None: """ This tests the behaviour of tweaks_for_actions. """ diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 5cf18b690e48..fd619b64d4dd 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -17,8 +17,12 @@ import yaml from twisted.internet.defer import Deferred, ensureDeferred +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer from synapse.storage.background_updates import BackgroundUpdater +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock @@ -26,7 +30,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( @@ -39,7 +43,7 @@ def prepare(self, reactor, clock, homeserver): ) self.store = self.hs.get_datastores().main - async def update(self, progress, count): + async def update(self, progress: JsonDict, count: int) -> int: duration_ms = 10 await self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} @@ -51,7 +55,7 @@ async def update(self, progress, count): ) return count - def test_do_background_update(self): + def test_do_background_update(self) -> None: # the time we claim it takes to update one item when running the update duration_ms = 10 @@ -80,7 +84,7 @@ def test_do_background_update(self): # second step: complete the update # we should now get run with a much bigger number of items to update - async def update(progress, count): + async def update(progress: JsonDict, count: int) -> int: self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, @@ -110,7 +114,7 @@ async def update(progress, count): """ ) ) - def test_background_update_default_batch_set_by_config(self): + def test_background_update_default_batch_set_by_config(self) -> None: """ Test that the background update is run with the default_batch_size set by the config """ @@ -133,7 +137,7 @@ def test_background_update_default_batch_set_by_config(self): # on the first call, we should get run with the default background update size specified in the config self.update_handler.assert_called_once_with({"my_key": 1}, 20) - def test_background_update_default_sleep_behavior(self): + def test_background_update_default_sleep_behavior(self) -> None: """ Test default background update behavior, which is to sleep """ @@ -147,7 +151,7 @@ def test_background_update_default_sleep_behavior(self): self.update_handler.side_effect = self.update self.update_handler.reset_mock() - self.updates.start_doing_background_updates(), + self.updates.start_doing_background_updates() # 2: advance the reactor less than the default sleep duration (1000ms) self.reactor.pump([0.5]) @@ -167,7 +171,7 @@ def test_background_update_default_sleep_behavior(self): """ ) ) - def test_background_update_sleep_set_in_config(self): + def test_background_update_sleep_set_in_config(self) -> None: """ Test that changing the sleep time in the config changes how long it sleeps """ @@ -181,7 +185,7 @@ def test_background_update_sleep_set_in_config(self): self.update_handler.side_effect = self.update self.update_handler.reset_mock() - self.updates.start_doing_background_updates(), + self.updates.start_doing_background_updates() # 2: advance the reactor less than the configured sleep duration (500ms) self.reactor.pump([0.45]) @@ -201,7 +205,7 @@ def test_background_update_sleep_set_in_config(self): """ ) ) - def test_disabling_background_update_sleep(self): + def test_disabling_background_update_sleep(self) -> None: """ Test that disabling sleep in the config results in bg update not sleeping """ @@ -215,7 +219,7 @@ def test_disabling_background_update_sleep(self): self.update_handler.side_effect = self.update self.update_handler.reset_mock() - self.updates.start_doing_background_updates(), + self.updates.start_doing_background_updates() # 2: advance the reactor very little self.reactor.pump([0.025]) @@ -230,7 +234,7 @@ def test_disabling_background_update_sleep(self): """ ) ) - def test_background_update_duration_set_in_config(self): + def test_background_update_duration_set_in_config(self) -> None: """ Test that the desired duration set in the config is used in determining batch size """ @@ -254,7 +258,7 @@ def test_background_update_duration_set_in_config(self): # the first update was run with the default batch size, this should be run with 500ms as the # desired duration - async def update(progress, count): + async def update(progress: JsonDict, count: int) -> int: self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, @@ -275,7 +279,7 @@ async def update(progress, count): """ ) ) - def test_background_update_min_batch_set_in_config(self): + def test_background_update_min_batch_set_in_config(self) -> None: """ Test that the minimum batch size set in the config is used """ @@ -290,7 +294,7 @@ def test_background_update_min_batch_set_in_config(self): ) # Run the update with the long-running update item - async def update(progress, count): + async def update_long(progress: JsonDict, count: int) -> int: await self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} await self.store.db_pool.runInteraction( @@ -301,7 +305,7 @@ async def update(progress, count): ) return count - self.update_handler.side_effect = update + self.update_handler.side_effect = update_long self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), @@ -311,25 +315,25 @@ async def update(progress, count): # the first update was run with the default batch size, this should be run with minimum batch size # as the first items took a very long time - async def update(progress, count): + async def update_short(progress: JsonDict, count: int) -> int: self.assertEqual(progress, {"my_key": 2}) self.assertEqual(count, 5) await self.updates._end_background_update("test_update") return count - self.update_handler.side_effect = update + self.update_handler.side_effect = update_short self.get_success(self.updates.do_next_background_update(False)) class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) ) - self.update_deferred = Deferred() + self.update_deferred: Deferred[int] = Deferred() self.update_handler = Mock(return_value=self.update_deferred) self.updates.register_background_update_handler( "test_update", self.update_handler @@ -358,7 +362,7 @@ class MockCM: ), ) - def test_controller(self): + def test_controller(self) -> None: store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( @@ -368,7 +372,7 @@ def test_controller(self): ) # Set the return value for the context manager. - enter_defer = Deferred() + enter_defer: Deferred[int] = Deferred() self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) # Start the background update. diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 6ac4b93f981a..395396340bf2 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -13,9 +13,13 @@ # limitations under the License. from typing import List, Optional -from synapse.storage.database import DatabasePool +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS @@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -59,12 +63,12 @@ def _create(conn): return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) - def _insert_rows(self, instance_name: str, number: int): + def _insert_rows(self, instance_name: str, number: int) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", @@ -80,12 +84,12 @@ def _insert(txn): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def _insert_row_with_id(self, instance_name: str, stream_id: int): + def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id, updating the postgres sequence position to match. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -104,7 +108,7 @@ def _insert(txn): self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) - def test_empty(self): + def test_empty(self) -> None: """Test an ID generator against an empty database gives sensible current positions. """ @@ -114,7 +118,7 @@ def test_empty(self): # The table is empty so we expect an empty map for positions self.assertEqual(id_gen.get_positions(), {}) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ @@ -130,7 +134,7 @@ def test_single_instance(self): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -142,7 +146,7 @@ async def _get_next_async(): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_out_of_order_finish(self): + def test_out_of_order_finish(self) -> None: """Test that IDs persisted out of order are correctly handled""" # Prefill table with 7 rows written by 'master' @@ -191,7 +195,7 @@ def test_out_of_order_finish(self): self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) - def test_multi_instance(self): + def test_multi_instance(self) -> None: """Test that reads and writes from multiple processes are handled correctly. """ @@ -215,7 +219,7 @@ def test_multi_instance(self): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -233,7 +237,7 @@ async def _get_next_async(): # ... but calling `get_next` on the second instance should give a unique # stream ID - async def _get_next_async(): + async def _get_next_async2() -> None: async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) @@ -241,7 +245,7 @@ async def _get_next_async(): second_id_gen.get_positions(), {"first": 3, "second": 7} ) - self.get_success(_get_next_async()) + self.get_success(_get_next_async2()) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) @@ -249,7 +253,7 @@ async def _get_next_async(): second_id_gen.advance("first", 8) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) - def test_get_next_txn(self): + def test_get_next_txn(self) -> None: """Test that the `get_next_txn` function works correctly.""" # Prefill table with 7 rows written by 'master' @@ -263,7 +267,7 @@ def test_get_next_txn(self): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - def _get_next_txn(txn): + def _get_next_txn(txn: LoggingTransaction) -> None: stream_id = id_gen.get_next_txn(txn) self.assertEqual(stream_id, 8) @@ -275,7 +279,7 @@ def _get_next_txn(txn): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_get_persisted_upto_position(self): + def test_get_persisted_upto_position(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions. """ @@ -317,7 +321,7 @@ def test_get_persisted_upto_position(self): id_gen.advance("second", 15) self.assertEqual(id_gen.get_persisted_upto_position(), 11) - def test_get_persisted_upto_position_get_next(self): + def test_get_persisted_upto_position_get_next(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions when `get_next` is called. """ @@ -331,7 +335,7 @@ def test_get_persisted_upto_position_get_next(self): self.assertEqual(id_gen.get_persisted_upto_position(), 5) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 6) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -344,7 +348,7 @@ async def _get_next_async(): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). - def test_restart_during_out_of_order_persistence(self): + def test_restart_during_out_of_order_persistence(self) -> None: """Test that restarting a process while another process is writing out of order updates are handled correctly. """ @@ -388,7 +392,7 @@ def test_restart_during_out_of_order_persistence(self): id_gen_worker.advance("master", 9) self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) - def test_writer_config_change(self): + def test_writer_config_change(self) -> None: """Test that changing the writer config correctly works.""" self._insert_row_with_id("first", 3) @@ -421,7 +425,7 @@ def test_writer_config_change(self): # Check that we get a sane next stream ID with this new config. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_3.get_next() as stream_id: self.assertEqual(stream_id, 6) @@ -435,7 +439,7 @@ async def _get_next_async(): self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) - def test_sequence_consistency(self): + def test_sequence_consistency(self) -> None: """Test that we error out if the table and sequence diverges.""" # Prefill with some rows @@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -493,10 +497,10 @@ def _create(conn): return self.get_success(self.db_pool.runWithConnection(_create)) - def _insert_row(self, instance_name: str, stream_id: int): + def _insert_row(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id.""" - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -514,13 +518,13 @@ def _insert(txn): self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ id_gen = self._create_id_generator() - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self._insert_row("master", stream_id) @@ -530,7 +534,7 @@ async def _get_next_async(): self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen.get_next_mult(3) as stream_ids: for stream_id in stream_ids: self._insert_row("master", stream_id) @@ -548,14 +552,14 @@ async def _get_next_async2(): self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) - def test_multiple_instance(self): + def test_multiple_instance(self) -> None: """Tests that having multiple instances that get advanced over federation works corretly. """ id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_1.get_next() as stream_id: self._insert_row("first", stream_id) id_gen_2.advance("first", stream_id) @@ -567,7 +571,7 @@ async def _get_next_async(): self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen_2.get_next() as stream_id: self._insert_row("second", stream_id) id_gen_1.advance("second", stream_id) @@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -642,7 +646,7 @@ def _insert_rows( from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), @@ -659,7 +663,7 @@ def _insert(txn): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def test_load_existing_stream(self): + def test_load_existing_stream(self) -> None: """Test creating ID gens with multiple tables that have rows from after the position in `stream_positions` table. """ From d9bc65918e406b8ae15c42b2ea3680d2c9fb79c3 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 21 Mar 2022 17:27:59 +0000 Subject: [PATCH 088/230] Call out synctl change --- CHANGES.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 78498c10b551..06396c3f6f8b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,8 @@ Synapse 1.55.0rc1 (2022-03-15) This release removes a workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. **This breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700))**; Mjolnir users should upgrade Mjolnir before upgrading Synapse to this version. +This release also moves the location of the `synctl` script; see the [upgrade notes](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#synctl-script-has-been-moved) for more details. + Features -------- @@ -38,6 +40,7 @@ Deprecations and Removals ------------------------- - **Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700))** +- **`synctl` has been moved into into `synapse._scripts` and is exposed as an entry point; see [upgrade notes](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#synctl-script-has-been-moved). ([\#12140](https://github.com/matrix-org/synapse/issues/12140)) - Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. ([\#12138](https://github.com/matrix-org/synapse/issues/12138)) - The groups/communities feature in Synapse has been deprecated. ([\#12200](https://github.com/matrix-org/synapse/issues/12200)) @@ -56,7 +59,6 @@ Internal Changes - Fix CI not attaching source distributions and wheels to the GitHub releases. ([\#12131](https://github.com/matrix-org/synapse/issues/12131)) - Remove unused mocks from `test_typing`. ([\#12136](https://github.com/matrix-org/synapse/issues/12136)) - Give `scripts-dev` scripts suffixes for neater CI config. ([\#12137](https://github.com/matrix-org/synapse/issues/12137)) -- Move `synctl` into `synapse._scripts` and expose as an entry point. ([\#12140](https://github.com/matrix-org/synapse/issues/12140)) - Move the snapcraft configuration file to `contrib`. ([\#12142](https://github.com/matrix-org/synapse/issues/12142)) - Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI. ([\#12144](https://github.com/matrix-org/synapse/issues/12144)) - Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI. ([\#12145](https://github.com/matrix-org/synapse/issues/12145)) From 01211e0c16758f41883e42f1d3e6306b7a683e96 Mon Sep 17 00:00:00 2001 From: Michael Telatynski <7t3chguy@gmail.com> Date: Tue, 22 Mar 2022 10:22:25 +0000 Subject: [PATCH 089/230] Tweak copy for sso account details template (#12265) * Tweak copy for sso account details template * Update sso footer copyright year * Add newsfragment Signed-off-by: Michael Telatynski <7t3chguy@gmail.com> --- changelog.d/12265.misc | 1 + synapse/res/templates/sso_auth_account_details.html | 8 ++++---- synapse/res/templates/sso_auth_account_details.js | 2 +- synapse/res/templates/sso_footer.html | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12265.misc diff --git a/changelog.d/12265.misc b/changelog.d/12265.misc new file mode 100644 index 000000000000..4213f5855592 --- /dev/null +++ b/changelog.d/12265.misc @@ -0,0 +1 @@ +Tweak copy for default sso account details template to better adhere to mobile app store guidelines. \ No newline at end of file diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index b231aace01e6..1ba850369a32 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -130,13 +130,13 @@
-

Choose your account name

-

This is required to create your account on {{ server_name }}, and you can't change this later.

+

Create your account

+

This is required. Continue to create your account on {{ server_name }}. You can't change this later.

- +
@
:{{ server_name }}
@@ -145,7 +145,7 @@

Choose your account name

{% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %}
-

{% if idp.idp_icon %}{% endif %}Information from {{ idp.idp_name }}

+

{% if idp.idp_icon %}{% endif %}Optional data from {{ idp.idp_name }}

{% if user_attributes.avatar_url %}