Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Include whether the requesting user has participated in a thread. (#1…
Browse files Browse the repository at this point in the history
…1577)

Per updates to MSC3440.

This is implement as a separate method since it needs to be cached
on a per-user basis, instead of a per-thread basis.
  • Loading branch information
clokep authored Jan 18, 2022
1 parent 251b556 commit 68acb0a
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 18 deletions.
1 change: 1 addition & 0 deletions changelog.d/11577.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
2 changes: 1 addition & 1 deletion synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ async def get_messages(
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()

aggregations = await self.store.get_bundled_aggregations(events)
aggregations = await self.store.get_bundled_aggregations(events, user_id)

time_now = self.clock.time_msec()

Expand Down
12 changes: 9 additions & 3 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,12 +1182,18 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
results["event"] = filtered[0]

# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations([results["event"]])
aggregations = await self.store.get_bundled_aggregations(
[results["event"]], user.to_string()
)
aggregations.update(
await self.store.get_bundled_aggregations(results["events_before"])
await self.store.get_bundled_aggregations(
results["events_before"], user.to_string()
)
)
aggregations.update(
await self.store.get_bundled_aggregations(results["events_after"])
await self.store.get_bundled_aggregations(
results["events_after"], user.to_string()
)
)
results["aggregations"] = aggregations

Expand Down
4 changes: 3 additions & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,9 @@ async def _load_filtered_recents(
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations(recents)
bundled_aggregations = await self.store.get_bundled_aggregations(
recents, sync_config.user.to_string()
)

return TimelineBatch(
events=recents,
Expand Down
4 changes: 3 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ async def on_GET(
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.store.get_bundled_aggregations(events)
aggregations = await self.store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
Expand Down
4 changes: 3 additions & 1 deletion synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,9 @@ async def on_GET(

if event:
# Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations([event])
aggregations = await self._store.get_bundled_aggregations(
[event], requester.user.to_string()
)

time_now = self.clock.time_msec()
event_dict = self._event_serializer.serialize_event(
Expand Down
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,6 +1793,13 @@ def _handle_event_relations(
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.room_id, event.sender),
)

def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Expand Down
66 changes: 55 additions & 11 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
"""Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
Expand All @@ -398,7 +397,7 @@ async def get_thread_summary(
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
Expand All @@ -419,6 +418,7 @@ def _get_thread_summary_txn(

latest_event_id = row[0]

# Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
Expand All @@ -443,6 +443,44 @@ def _get_thread_summary_txn(

return count, latest_event

@cached()
async def get_thread_participated(
self, event_id: str, room_id: str, user_id: str
) -> bool:
"""Get whether the requesting user participated in a thread.
This is separate from get_thread_summary since that can be cached across
all users while this value is specific to the requeser.
Args:
event_id: The thread related to this event ID.
room_id: The room the event belongs to.
user_id: The user requesting the summary.
Returns:
True if the requesting user participated in the thread, otherwise false.
"""

def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
# Fetch whether the requester has participated or not.
sql = """
SELECT 1
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
AND sender = ?
"""

txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
return bool(txn.fetchone())

return await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)

async def events_have_relations(
self,
parent_ids: List[str],
Expand Down Expand Up @@ -546,14 +584,15 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
)

async def _get_bundled_aggregation_for_event(
self, event: EventBase
self, event: EventBase, user_id: str
) -> Optional[Dict[str, Any]]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
Expand Down Expand Up @@ -598,27 +637,32 @@ async def _get_bundled_aggregation_for_event(

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
(
thread_count,
latest_thread_event,
) = await self.get_thread_summary(event_id, room_id)
thread_count, latest_thread_event = await self.get_thread_summary(
event_id, room_id
)
participated = await self.get_thread_participated(
event_id, room_id, user_id
)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": latest_thread_event,
"count": thread_count,
"current_user_participated": participated,
}

# Store the bundled aggregations in the event metadata for later use.
return aggregations

async def get_bundled_aggregations(
self, events: Iterable[EventBase]
self,
events: Iterable[EventBase],
user_id: str,
) -> Dict[str, Dict[str, Any]]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
Expand All @@ -631,7 +675,7 @@ async def get_bundled_aggregations(
# TODO Parallelize.
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event)
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result is not None:
results[event.event_id] = event_result

Expand Down
3 changes: 3 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,9 @@ def assert_bundle(actual):
2,
actual[RelationTypes.THREAD].get("count"),
)
self.assertTrue(
actual[RelationTypes.THREAD].get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
Expand Down

0 comments on commit 68acb0a

Please sign in to comment.