Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert identity handler to async/await. #7561

Merged
merged 1 commit into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7561.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert the identity handler to async/await.
94 changes: 39 additions & 55 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64

from twisted.internet import defer
from twisted.internet.error import TimeoutError

from synapse.api.errors import (
Expand Down Expand Up @@ -60,8 +59,7 @@ def __init__(self, hs):
self.federation_http_client = hs.get_http_client()
self.hs = hs

@defer.inlineCallbacks
def threepid_from_creds(self, id_server, creds):
async def threepid_from_creds(self, id_server, creds):
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Expand Down Expand Up @@ -97,7 +95,7 @@ def threepid_from_creds(self, id_server, creds):
url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"

try:
data = yield self.http_client.get_json(url, query_params)
data = await self.http_client.get_json(url, query_params)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
Expand All @@ -120,8 +118,7 @@ def threepid_from_creds(self, id_server, creds):
logger.info("%s reported non-validated threepid: %s", id_server, creds)
return None

@defer.inlineCallbacks
def bind_threepid(
async def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
):
"""Bind a 3PID to an identity server
Expand Down Expand Up @@ -161,12 +158,12 @@ def bind_threepid(
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
data = yield self.blacklisting_http_client.post_json_get_json(
data = await self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)

# Remember where we bound the threepid
yield self.store.add_user_bound_threepid(
await self.store.add_user_bound_threepid(
user_id=mxid,
medium=data["medium"],
address=data["address"],
Expand All @@ -185,13 +182,12 @@ def bind_threepid(
return data

logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
res = yield self.bind_threepid(
res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False
)
return res

@defer.inlineCallbacks
def try_unbind_threepid(self, mxid, threepid):
async def try_unbind_threepid(self, mxid, threepid):
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on

Expand All @@ -211,7 +207,7 @@ def try_unbind_threepid(self, mxid, threepid):
if threepid.get("id_server"):
id_servers = [threepid["id_server"]]
else:
id_servers = yield self.store.get_id_servers_user_bound(
id_servers = await self.store.get_id_servers_user_bound(
user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)

Expand All @@ -221,14 +217,13 @@ def try_unbind_threepid(self, mxid, threepid):

changed = True
for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server(
changed &= await self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server
)

return changed

@defer.inlineCallbacks
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server

Args:
Expand Down Expand Up @@ -266,7 +261,7 @@ def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
yield self.blacklisting_http_client.post_json_get_json(
await self.blacklisting_http_client.post_json_get_json(
url, content, headers
)
changed = True
Expand All @@ -281,7 +276,7 @@ def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")

yield self.store.remove_user_bound_threepid(
await self.store.remove_user_bound_threepid(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
Expand Down Expand Up @@ -376,8 +371,7 @@ async def send_threepid_validation(

return session_id

@defer.inlineCallbacks
def requestEmailToken(
async def requestEmailToken(
self, id_server, email, client_secret, send_attempt, next_link=None
):
"""
Expand Down Expand Up @@ -412,7 +406,7 @@ def requestEmailToken(
)

try:
data = yield self.http_client.post_json_get_json(
data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
params,
)
Expand All @@ -423,8 +417,7 @@ def requestEmailToken(
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")

@defer.inlineCallbacks
def requestMsisdnToken(
async def requestMsisdnToken(
self,
id_server,
country,
Expand Down Expand Up @@ -466,7 +459,7 @@ def requestMsisdnToken(
)

try:
data = yield self.http_client.post_json_get_json(
data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params,
)
Expand All @@ -487,8 +480,7 @@ def requestMsisdnToken(
)
return data

@defer.inlineCallbacks
def validate_threepid_session(self, client_secret, sid):
async def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.

Expand All @@ -510,12 +502,12 @@ def validate_threepid_session(self, client_secret, sid):
# Try to validate as email
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server
validation_session = yield self.threepid_from_creds(
validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
validation_session = yield self.store.get_threepid_validation_session(
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)

Expand All @@ -525,14 +517,13 @@ def validate_threepid_session(self, client_secret, sid):
# Try to validate as msisdn
if self.hs.config.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = yield self.threepid_from_creds(
validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)

return validation_session

@defer.inlineCallbacks
def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
"""Proxy a POST submitToken request to an identity server for verification purposes

Args:
Expand All @@ -553,20 +544,17 @@ def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
body = {"client_secret": client_secret, "sid": sid, "token": token}

try:
return (
yield self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
body,
)
return await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
body,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")

@defer.inlineCallbacks
def lookup_3pid(self, id_server, medium, address, id_access_token=None):
async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.

Args:
Expand All @@ -582,7 +570,7 @@ def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""
if id_access_token is not None:
try:
results = yield self._lookup_3pid_v2(
results = await self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
Expand All @@ -601,10 +589,9 @@ def lookup_3pid(self, id_server, medium, address, id_access_token=None):
logger.warning("Error when looking up hashing details: %s", e)
return None

return (yield self._lookup_3pid_v1(id_server, medium, address))
return await self._lookup_3pid_v1(id_server, medium, address)

@defer.inlineCallbacks
def _lookup_3pid_v1(self, id_server, medium, address):
async def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.

Args:
Expand All @@ -617,15 +604,15 @@ def _lookup_3pid_v1(self, id_server, medium, address):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.blacklisting_http_client.get_json(
data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)

if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server)
await self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
Expand All @@ -634,8 +621,7 @@ def _lookup_3pid_v1(self, id_server, medium, address):

return None

@defer.inlineCallbacks
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.

Args:
Expand All @@ -650,7 +636,7 @@ def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""
# Check what hashing details are supported by this identity server
try:
hash_details = yield self.blacklisting_http_client.get_json(
hash_details = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
Expand Down Expand Up @@ -717,7 +703,7 @@ def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
headers = {"Authorization": create_id_access_token_header(id_access_token)}

try:
lookup_results = yield self.blacklisting_http_client.post_json_get_json(
lookup_results = await self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
Expand Down Expand Up @@ -745,13 +731,12 @@ def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid

@defer.inlineCallbacks
def _verify_any_signature(self, data, server_hostname):
async def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
key_data = yield self.blacklisting_http_client.get_json(
key_data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
Expand All @@ -770,8 +755,7 @@ def _verify_any_signature(self, data, server_hostname):
)
return

@defer.inlineCallbacks
def ask_id_server_for_third_party_invite(
async def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
Expand Down Expand Up @@ -844,7 +828,7 @@ def ask_id_server_for_third_party_invite(
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
data = yield self.blacklisting_http_client.post_json_get_json(
data = await self.blacklisting_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
Expand All @@ -864,7 +848,7 @@ def ask_id_server_for_third_party_invite(
url = base_url + "/api/v1/store-invite"

try:
data = yield self.blacklisting_http_client.post_json_get_json(
data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
Expand All @@ -882,7 +866,7 @@ def ask_id_server_for_third_party_invite(
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
data = yield self.blacklisting_http_client.post_urlencoded_get_json(
data = await self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
Expand Down
15 changes: 7 additions & 8 deletions synapse/handlers/ui_auth/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()

@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
async def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)

Expand All @@ -155,18 +154,18 @@ def _check_threepid(self, medium, authdict):
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
threepid = yield identity_handler.threepid_from_creds(
threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
threepid = yield identity_handler.threepid_from_creds(
threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
threepid = None
row = yield self.store.get_threepid_validation_session(
row = await self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
Expand All @@ -181,7 +180,7 @@ def _check_threepid(self, medium, authdict):
}

# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
await self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Email address verification is not enabled on this homeserver"
Expand Down Expand Up @@ -220,7 +219,7 @@ def is_enabled(self):
)

def check_auth(self, authdict, clientip):
return self._check_threepid("email", authdict)
return defer.ensureDeferred(self._check_threepid("email", authdict))


class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
Expand All @@ -234,7 +233,7 @@ def is_enabled(self):
return bool(self.hs.config.account_threepid_delegate_msisdn)

def check_auth(self, authdict, clientip):
return self._check_threepid("msisdn", authdict)
return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left these as Deferreds just to limit the scope of this PR.



INTERACTIVE_AUTH_CHECKERS = [
Expand Down