Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement MSC2858 support #9183

Merged
merged 8 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9183.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858).
richvdh marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions synapse/config/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def read_config(self, config, **kwargs):
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
self.sso_client_whitelist.append(login_fallback_url)

# experimental support for MSC2858 (multiple SSO identity providers)
self.experimental_msc2858_support_enabled = config.get(
"experimental_msc2858_support_enabled", False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we've reached any conclusions on what name this sort of setting should have. Hopefully it's temporary so doesn't matter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it should be put under a experimental_features section, or something, so that if we want to remove the options we can have a generic "Unknown experimental flag" warning in the logs? Or maybe we just remember that experimental_msc2858_support_enabled was an option after we remove support and log warnings if we see it.

I don't care much either way, twas just a thought.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 let's go with experimental_features.

)

def generate_config_section(self, **kwargs):
return """\
# Additional settings to use with single-sign on systems such as OpenID Connect,
Expand Down
23 changes: 18 additions & 5 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from twisted.web.http import Request

from synapse.api.constants import LoginType
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
Expand Down Expand Up @@ -235,14 +235,18 @@ def render_error(
respond_with_html(request, code, html)

async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes,
self,
request: SynapseRequest,
client_redirect_url: bytes,
idp_id: Optional[str],
) -> str:
"""Handle a request to /login/sso/redirect

Args:
request: incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login.
idp_id: optional identity provider chosen by the client

Returns:
the URI to redirect to
Expand All @@ -252,10 +256,19 @@ async def handle_redirect_request(
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
)

# if the client chose an IdP, use that
idp = None # type: Optional[SsoIdentityProvider]
if idp_id:
idp = self._identity_providers.get(idp_id)
if not idp:
raise NotFoundError("Unknown identity provider")

# if we only have one auth provider, redirect to it directly
if len(self._identity_providers) == 1:
ap = next(iter(self._identity_providers.values()))
return await ap.handle_redirect_request(request, client_redirect_url)
elif len(self._identity_providers) == 1:
idp = next(iter(self._identity_providers.values()))

if idp:
return await idp.handle_redirect_request(request, client_redirect_url)

# otherwise, redirect to the IDP picker
return "/_synapse/client/pick_idp?" + urlencode(
Expand Down
44 changes: 36 additions & 8 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,22 @@
import urllib
from http import HTTPStatus
from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Pattern,
Tuple,
Union,
)

import jinja2
from canonicaljson import iterencode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer

from twisted.internet import defer, interfaces
Expand Down Expand Up @@ -168,24 +180,40 @@ async def wrapped_async_request_handler(self, request):
return preserve_fn(wrapped_async_request_handler)


class HttpServer:
# Type of a callback method for processing requests
# it is actually called with a SynapseRequest and a kwargs dict for the params,
# but I can't figure out how to represent that.
ServletCallback = Callable[
..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
]


class HttpServer(Protocol):
""" Interface for registering callbacks on a HTTP server
"""

def register_paths(self, method, path_patterns, callback):
def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
""" Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.

If the regex contains groups these gets passed to the callback via
an unpacked tuple.

Args:
method (str): The method to listen to.
path_patterns (list<SRE_Pattern>): The regex used to match requests.
callback (function): The function to fire if we receive a matched
method: The HTTP method to listen to.
path_patterns: The regex used to match requests.
callback: The function to fire if we receive a matched
request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex.
This should return a tuple of (code, response).
This should return either tuple of (code, response), or None.
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
"""
pass

Expand Down Expand Up @@ -354,7 +382,7 @@ def register_paths(self, method, path_patterns, callback, servlet_classname):

def _get_handler_for_request(
self, request: SynapseRequest
) -> Tuple[Callable, str, Dict[str, str]]:
) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.

Returns:
Expand Down
55 changes: 49 additions & 6 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
from synapse.http.server import finish_request
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
Expand Down Expand Up @@ -60,11 +61,14 @@ def __init__(self, hs: "HomeServer"):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
self._msc2858_enabled = hs.config.sso.experimental_msc2858_support_enabled

self.auth = hs.get_auth()

self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()

self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
Expand All @@ -89,8 +93,17 @@ def on_GET(self, request: SynapseRequest):
flows.append({"type": LoginRestServlet.CAS_TYPE})

if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
# While its valid for us to advertise this login type generally,
sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict

if self._msc2858_enabled:
sso_flow["org.matrix.msc2858.identity_providers"] = [
_get_auth_flow_dict_for_idp(idp)
for idp in self._sso_handler.get_identity_providers().values()
]

flows.append(sso_flow)

# While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# SSO login flow.
# Generally we don't want to advertise login flows that clients
Expand Down Expand Up @@ -311,8 +324,20 @@ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
return result


def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
"""Return an entry for the login flow dict

Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
return e


class SsoRedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the absence of the $ was a minor bug.


def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
Expand All @@ -324,13 +349,31 @@ def __init__(self, hs: "HomeServer"):
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.sso.experimental_msc2858_support_enabled

def register(self, http_server: HttpServer) -> None:
super().register(http_server)
if self._msc2858_enabled:
# expose additional endpoint for MSC2858 support
http_server.register_paths(
"GET",
client_patterns(
"/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
releases=(),
unstable=True,
),
self.on_GET,
self.__class__.__name__,
)

async def on_GET(self, request: SynapseRequest):
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None:
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
sso_url = await self._sso_handler.handle_redirect_request(
request, client_redirect_url
request, client_redirect_url, idp_id,
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
Expand Down
92 changes: 92 additions & 0 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
# the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]

# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]


class LoginRestServletTestCase(unittest.HomeserverTestCase):

Expand Down Expand Up @@ -426,6 +430,57 @@ def create_resource_dict(self) -> Dict[str, Resource]:
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d

def test_get_login_flows(self):
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)

expected_flows = [
{"type": "m.login.cas"},
{"type": "m.login.sso"},
{"type": "m.login.token"},
{"type": "m.login.password"},
] + ADDITIONAL_LOGIN_FLOWS

self.assertCountEqual(channel.json_body["flows"], expected_flows)

@override_config({"experimental_msc2858_support_enabled": True})
def test_get_msc2858_login_flows(self):
"""The SSO flow should include IdP info if MSC2858 is enabled"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)

# stick the flows results in a dict by type
flow_results = {}
for f in channel.json_body["flows"]:
flow_type = f["type"]
self.assertNotIn(
flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
)
flow_results[flow_type] = f

self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
sso_flow = flow_results.pop("m.login.sso")
# we should have a set of IdPs
self.assertCountEqual(
sso_flow["org.matrix.msc2858.identity_providers"],
[
{"id": "cas", "name": "CAS"},
{"id": "saml", "name": "SAML"},
{"id": "oidc-idp1", "name": "IDP1"},
{"id": "oidc", "name": "OIDC"},
],
)

# the rest of the flows are simple
expected_flows = [
{"type": "m.login.cas"},
{"type": "m.login.token"},
{"type": "m.login.password"},
] + ADDITIONAL_LOGIN_FLOWS

self.assertCountEqual(flow_results.values(), expected_flows)

def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
Expand Down Expand Up @@ -564,6 +619,43 @@ def test_multi_sso_redirect_to_unknown(self):
)
self.assertEqual(channel.code, 400, channel.result)

def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")

@override_config({"experimental_msc2858_support_enabled": True})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")

@override_config({"experimental_msc2858_support_enabled": True})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)

self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)

# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)

@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = "
Expand Down
3 changes: 1 addition & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
Expand Down Expand Up @@ -351,7 +350,7 @@ def getRawHeaders(name, default=None):


# This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer):
class MockHttpResource:
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
Expand Down