Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert tags and metrics databases to async/await (#8062)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Aug 11, 2020
1 parent a0acdfa commit 04faa0b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 65 deletions.
1 change: 1 addition & 0 deletions changelog.d/8062.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
20 changes: 6 additions & 14 deletions synapse/storage/databases/main/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import typing
from collections import Counter

from twisted.internet import defer

from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -69,8 +67,7 @@ def fetch(txn):
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res])

@defer.inlineCallbacks
def count_daily_messages(self):
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
Expand All @@ -88,11 +85,9 @@ def _count_messages(txn):
(count,) = txn.fetchone()
return count

ret = yield self.db_pool.runInteraction("count_messages", _count_messages)
return ret
return await self.db_pool.runInteraction("count_messages", _count_messages)

@defer.inlineCallbacks
def count_daily_sent_messages(self):
async def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then thats your own fault.
Expand All @@ -109,13 +104,11 @@ def _count_messages(txn):
(count,) = txn.fetchone()
return count

ret = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)
return ret

@defer.inlineCallbacks
def count_daily_active_rooms(self):
async def count_daily_active_rooms(self):
def _count(txn):
sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
Expand All @@ -126,5 +119,4 @@ def _count(txn):
(count,) = txn.fetchone()
return count

ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count)
return ret
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
103 changes: 53 additions & 50 deletions synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,40 @@
# limitations under the License.

import logging
from typing import List, Tuple
from typing import Dict, List, Tuple

from canonicaljson import json

from twisted.internet import defer

from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached

logger = logging.getLogger(__name__)


class TagsWorkerStore(AccountDataWorkerStore):
@cached()
def get_tags_for_user(self, user_id):
async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for a user.
Args:
user_id(str): The user to get the tags for.
user_id: The user to get the tags for.
Returns:
A deferred dict mapping from room_id strings to dicts mapping from
tag strings to tag content.
A mapping from room_id strings to dicts mapping from tag strings to
tag content.
"""

deferred = self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)

@deferred.addCallback
def tags_by_room(rows):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room

return deferred
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room

async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
Expand Down Expand Up @@ -127,17 +122,19 @@ def get_tag_content(txn, tag_ids):

return results, upto_token, limited

@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
async def get_updated_tags(
self, user_id: str, stream_id: int
) -> Dict[str, List[str]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
A mapping from room_id strings to lists of tag strings for all the
rooms that changed since the stream_id token.
"""

def get_updated_tags_txn(txn):
Expand All @@ -155,47 +152,53 @@ def get_updated_tags_txn(txn):
if not changed:
return {}

room_ids = yield self.db_pool.runInteraction(
room_ids = await self.db_pool.runInteraction(
"get_updated_tags", get_updated_tags_txn
)

results = {}
if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id)
tags_by_room = await self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})

return results

def get_tags_for_room(self, user_id, room_id):
async def get_tags_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the tags for the given room
Args:
user_id(str): The user to get tags for
room_id(str): The room to get tags for
user_id: The user to get tags for
room_id: The room to get tags for
Returns:
A deferred list of string tags.
A mapping of tags to tag content.
"""
return self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(
lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
)
return {row["tag"]: db_to_json(row["content"]) for row in rows}


class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
"""Add a tag to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
tag(str): The tag name to add.
content(dict): A json object to associate with the tag.
user_id: The user to add a tag for.
room_id: The room to add a tag for.
tag: The tag name to add.
content: A json object to associate with the tag.
Returns:
A deferred that completes once the tag has been added.
The next account data ID.
"""
content_json = json.dumps(content)

Expand All @@ -209,18 +212,17 @@ def add_tag_txn(txn, next_id):
self._update_revision_txn(txn, user_id, room_id, next_id)

with self._account_data_id_gen.get_next() as next_id:
yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)

self.get_tags_for_user.invalidate((user_id,))

result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()

@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
Returns:
A deferred that completes once the tag has been removed
The next account data ID.
"""

def remove_tag_txn(txn, next_id):
Expand All @@ -232,21 +234,22 @@ def remove_tag_txn(txn, next_id):
self._update_revision_txn(txn, user_id, room_id, next_id)

with self._account_data_id_gen.get_next() as next_id:
yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)

self.get_tags_for_user.invalidate((user_id,))

result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()

def _update_revision_txn(self, txn, user_id, room_id, next_id):
def _update_revision_txn(
self, txn, user_id: str, room_id: str, next_id: int
) -> None:
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
user_id(str): The ID of the user.
room_id(str): The ID of the room.
next_id(int): The the revision to advance to.
user_id: The ID of the user.
room_id: The ID of the room.
next_id: The the revision to advance to.
"""

txn.call_after(
Expand Down
5 changes: 4 additions & 1 deletion tests/server_notices/test_resource_limits_server_notices.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import default_config

Expand Down Expand Up @@ -79,7 +80,9 @@ def prepare(self, reactor, clock, hs):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
self._rlsn._store.get_tags_for_room = Mock(
side_effect=lambda user_id, room_id: make_awaitable({})
)

@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
Expand Down

0 comments on commit 04faa0b

Please sign in to comment.