From 10f45d85bb355c66110b9221b0aa09010d2d03ad Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 26 Oct 2020 14:17:31 -0400 Subject: [PATCH] Add type hints for account validity handler (#8620) This also fixes a bug by fixing handling of an account which doesn't expire. --- changelog.d/8620.bugfix | 1 + mypy.ini | 1 + synapse/handlers/account_validity.py | 29 +++++++++++++++---- synapse/handlers/profile.py | 4 +-- synapse/storage/databases/main/profile.py | 4 +-- .../storage/databases/main/registration.py | 4 +-- 6 files changed, 31 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8620.bugfix diff --git a/changelog.d/8620.bugfix b/changelog.d/8620.bugfix new file mode 100644 index 000000000000..c1078a3fb571 --- /dev/null +++ b/changelog.d/8620.bugfix @@ -0,0 +1 @@ +Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error. diff --git a/mypy.ini b/mypy.ini index 59d9074c3b06..1fbd8decf843 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,7 @@ files = synapse/federation, synapse/handlers/_base.py, synapse/handlers/account_data.py, + synapse/handlers/account_validity.py, synapse/handlers/appservice.py, synapse/handlers/auth.py, synapse/handlers/cas_handler.py, diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index fd4f762f333d..664d09da1c28 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,19 +18,22 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import List +from typing import TYPE_CHECKING, List -from synapse.api.errors import StoreError +from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class AccountValidityHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config self.store = self.hs.get_datastore() @@ -67,7 +70,7 @@ def __init__(self, hs): self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) @wrap_as_background_process("send_renewals") - async def _send_renewal_emails(self): + async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they @@ -81,11 +84,25 @@ async def _send_renewal_emails(self): user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - async def send_renewal_email_to_user(self, user_id: str): + async def send_renewal_email_to_user(self, user_id: str) -> None: + """ + Send a renewal email for a specific user. + + Args: + user_id: The user ID to send a renewal email for. + + Raises: + SynapseError if the user is not set to renew. + """ expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + + # If this user isn't set to be expired, raise an error. + if expiration_ts is None: + raise SynapseError(400, "User has no expiration time: %s" % (user_id,)) + await self._send_renewal_email(user_id, expiration_ts) - async def _send_renewal_email(self, user_id: str, expiration_ts: int): + async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None: """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 3875e53c08da..14348faaf3f0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -131,7 +131,7 @@ async def get_profile_from_cache(self, user_id: str) -> JsonDict: profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - async def get_displayname(self, target_user: UserID) -> str: + async def get_displayname(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: displayname = await self.store.get_profile_displayname( @@ -218,7 +218,7 @@ async def set_displayname( await self._update_join_states(requester, target_user) - async def get_avatar_url(self, target_user: UserID) -> str: + async def get_avatar_url(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: avatar_url = await self.store.get_profile_avatar_url( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index a6d1eb908a5f..0e25ca3d7aa8 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -39,7 +39,7 @@ async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - async def get_profile_displayname(self, user_localpart: str) -> str: + async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, @@ -47,7 +47,7 @@ async def get_profile_displayname(self, user_localpart: str) -> str: desc="get_profile_displayname", ) - async def get_profile_avatar_url(self, user_localpart: str) -> str: + async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index b0329e17ec6e..e7b17a738525 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -240,13 +240,13 @@ async def get_renewal_token_for_user(self, user_id: str) -> str: desc="get_renewal_token_for_user", ) - async def get_users_expiring_soon(self) -> List[Dict[str, int]]: + async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - A list of dictionaries mapping user ID to expiration time (in milliseconds). + A list of dictionaries, each with a user ID and expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at):