Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Consolidate logic for parsing relations. #12693

Merged
merged 4 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12693.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Consolidate parsing of relation information from events.
44 changes: 44 additions & 0 deletions synapse/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import abc
import collections.abc
import os
from typing import (
TYPE_CHECKING,
Expand All @@ -32,9 +33,11 @@
overload,
)

import attr
from typing_extensions import Literal
from unpaddedbase64 import encode_base64

from synapse.api.constants import RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
Expand Down Expand Up @@ -287,6 +290,17 @@ def is_historical(self) -> bool:
return self._dict.get("historical", False)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventRelation:
# The target event of the relation.
parent_id: str
# The relation type.
rel_type: str
# The aggregation key. Will be None if the rel_type is not m.annotation or is
# not a string.
aggregation_key: Optional[str]


class EventBase(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
Expand Down Expand Up @@ -415,6 +429,36 @@ def auth_event_ids(self) -> Sequence[str]:
"""
return [e for e, _ in self._dict["auth_events"]]

def relation(self) -> Optional[_EventRelation]:
"""
Parse the event's relation information.

Returns:
The event relation information, if it is valid. None, otherwise.
"""
relation = self.content.get("m.relates_to")
if not relation or not isinstance(relation, collections.abc.Mapping):
# No relation information.
return None

# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if not isinstance(rel_type, str):
return None

parent_id = relation.get("event_id")
if not isinstance(parent_id, str):
return None

# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
if not isinstance(aggregation_key, str):
aggregation_key = None

return _EventRelation(parent_id, rel_type, aggregation_key)
clokep marked this conversation as resolved.
Show resolved Hide resolved

def freeze(self) -> None:
"""'Freeze' the event dict, so it cannot be modified by accident"""

Expand Down
28 changes: 11 additions & 17 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,20 +1060,11 @@ async def _validate_event_relation(self, event: EventBase) -> None:
SynapseError if the event is invalid.
"""

relation = event.content.get("m.relates_to")
relation = event.relation()
if not relation:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't properly handle non-dict m.relates_to.

return

relation_type = relation.get("rel_type")
if not relation_type:
return

# Ensure the parent is real.
relates_to = relation.get("event_id")
if not relates_to:
return
Comment on lines -1067 to -1074
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't properly ignore non-string fields.


parent_event = await self.store.get_event(relates_to, allow_none=True)
parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
if parent_event:
# And in the same room.
if parent_event.room_id != event.room_id:
Expand All @@ -1082,28 +1073,31 @@ async def _validate_event_relation(self, event: EventBase) -> None:
else:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
if not await self.store.event_is_target_of_relation(relates_to):
if not await self.store.event_is_target_of_relation(relation.parent_id):
# Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event")

# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
if relation_type == RelationTypes.ANNOTATION:
aggregation_key = relation["key"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would raise a 500 on a missing key, now it will 400.

if relation.rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.aggregation_key

if aggregation_key is None:
raise SynapseError(400, "Missing aggregation key")

if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long")

already_exists = await self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender
relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")

# Don't attempt to start a thread if the parent event is a relation.
elif relation_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relates_to):
elif relation.rel_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relation.parent_id):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)
Expand Down
18 changes: 9 additions & 9 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import logging
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -373,20 +372,21 @@ async def get_bundled_aggregations(
if event.is_state():
continue

relates_to = event.content.get("m.relates_to")
relation_type = None
if isinstance(relates_to, collections.abc.Mapping):
relation_type = relates_to.get("rel_type")
relates_to = event.relation()
if relates_to:
# An event which is a replacement (ie edit) or annotation (ie,
# reaction) may not have any other event related to it.
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
if relates_to.rel_type in (
RelationTypes.ANNOTATION,
RelationTypes.REPLACE,
):
continue

# Track the event's relation information for later.
relations_by_id[event.event_id] = relates_to.rel_type

# The event should get bundled aggregations.
events_by_id[event.event_id] = event
# Track the event's relation information for later.
if isinstance(relation_type, str):
relations_by_id[event.event_id] = relation_type

# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
Expand Down
4 changes: 2 additions & 2 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
return False

# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
Comment on lines -81 to -82
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't properly handle a non-dict m.relates_to.

relates_to = event.relation()
if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
return False

# Mark events that have a non-empty string body as unread.
Expand Down
45 changes: 19 additions & 26 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,52 +1807,45 @@ def _handle_event_relations(
txn: The current database transaction.
event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
relation = event.relation()
if not relation:
# No relations
# No relation, nothing to do.
return

# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if not isinstance(rel_type, str):
return

parent_id = relation.get("event_id")
if not isinstance(parent_id, str):
return

# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
Comment on lines -1824 to -1827
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't properly handle non-string keys.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, even after this change it would allow for invalid keys (but just set them to null). I'm unsure if that's "OK" or not. (It seems better than trying to persist a dictionary or other odd data types though!)


self.db_pool.simple_insert_txn(
txn,
table="event_relations",
values={
"event_id": event.event_id,
"relates_to_id": parent_id,
"relation_type": rel_type,
"aggregation_key": aggregation_key,
"relates_to_id": relation.parent_id,
"relation_type": relation.rel_type,
"aggregation_key": relation.aggregation_key,
},
)

txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
txn.call_after(
self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
self.store.get_relations_for_event.invalidate, (relation.parent_id,)
)
txn.call_after(
self.store.get_aggregation_groups_for_event.invalidate,
(relation.parent_id,),
)

if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
if relation.rel_type == RelationTypes.REPLACE:
txn.call_after(
self.store.get_applicable_edit.invalidate, (relation.parent_id,)
)

if rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
if relation.rel_type == RelationTypes.THREAD:
txn.call_after(
self.store.get_thread_summary.invalidate, (relation.parent_id,)
)
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.sender),
(relation.parent_id, event.sender),
)

def _handle_insertion_event(
Expand Down
8 changes: 6 additions & 2 deletions tests/rest/client/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,13 @@ def test_unread_counts(self) -> None:
self._check_unread_count(3)

# Check that custom events with a body increase the unread counter.
self.helper.send_event(
result = self.helper.send_event(
self.room_id,
"org.matrix.custom_type",
{"body": "hello"},
tok=self.tok2,
)
event_id = result["event_id"]
self._check_unread_count(4)

# Check that edits don't increase the unread counter.
Expand All @@ -693,7 +694,10 @@ def test_unread_counts(self) -> None:
content={
"body": "hello",
"msgtype": "m.text",
"m.relates_to": {"rel_type": RelationTypes.REPLACE},
"m.relates_to": {
"rel_type": RelationTypes.REPLACE,
"event_id": event_id,
},
},
tok=self.tok2,
)
Expand Down