Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Refactor the CAS handler in prep for using the abstracted SSO code. (#…
Browse files Browse the repository at this point in the history
…8958)

This makes the CAS handler look more like the SAML/OIDC handlers:

* Render errors to users instead of throwing JSON errors.
* Internal reorganization.
  • Loading branch information
clokep authored Dec 18, 2020
1 parent 56e00ca commit 4218473
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 69 deletions.
1 change: 1 addition & 0 deletions changelog.d/8958.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Properly store the mapping of external ID to Matrix ID for CAS users.
6 changes: 3 additions & 3 deletions docs/dev/cas.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
You should now have a Django project configured to serve CAS authentication with
a single user created.

## Configure Synapse (and Riot) to use CAS
## Configure Synapse (and Element) to use CAS

1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server:
Expand All @@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.

## Testing the configuration

Then in Riot:
Then in Element:

1. Visit the login page with a Riot pointing at your homeserver.
1. Visit the login page with a Element pointing at your homeserver.
2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`.
4. You should be logged in.
Expand Down
215 changes: 151 additions & 64 deletions synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET

import attr

from twisted.web.client import PartialDownloadError

from synapse.api.errors import Codes, LoginError
from synapse.api.errors import HttpResponseException
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart

Expand All @@ -29,6 +31,26 @@
logger = logging.getLogger(__name__)


class CasError(Exception):
"""Used to catch errors when validating the CAS ticket.
"""

def __init__(self, error, error_description=None):
self.error = error
self.error_description = error_description

def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return self.error


@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])


class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
Expand All @@ -50,6 +72,8 @@ def __init__(self, hs: "HomeServer"):

self._http_client = hs.get_proxied_http_client()

self._sso_handler = hs.get_sso_handler()

def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
Expand All @@ -69,14 +93,20 @@ def _build_service_param(self, args: Dict[str, str]) -> str:

async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
) -> Tuple[str, Optional[str]]:
) -> CasResponse:
"""
Validate a CAS ticket with the server, parse the response, and return the user and display name.
Validate a CAS ticket with the server, and return the parsed the response.
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`.
Raises:
CasError: If there's an error parsing the CAS response.
Returns:
The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
Expand All @@ -89,66 +119,65 @@ async def _validate_ticket(
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
raise CasError("server_error", description) from e

user, attributes = self._parse_cas_response(body)
displayname = attributes.pop(self._cas_displayname_attribute, None)

for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

return user, displayname
return self._parse_cas_response(body)

def _parse_cas_response(
self, cas_response_body: bytes
) -> Tuple[str, Dict[str, Optional[str]]]:
def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
Raises:
CasError: If there's an error parsing the CAS response.
Returns:
A tuple of the user and a mapping of other attributes.
The parsed CAS response.
"""

# Ensure the response is valid.
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise CasError(
"missing_service_response",
"root of CAS response is not serviceResponse",
)

success = root[0].tag.endswith("authenticationSuccess")
if not success:
raise CasError("unsucessful_response", "Unsuccessful CAS response")

# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text

# Ensure a user was found.
if user is None:
raise CasError("no_user", "CAS response does not contain user")

return CasResponse(user, attributes)

def get_redirect_url(self, service_args: Dict[str, str]) -> str:
"""
Expand Down Expand Up @@ -201,15 +230,76 @@ async def handle_ticket(
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)

try:
cas_response = await self._validate_ticket(ticket, args)
except CasError as e:
logger.exception("Could not validate ticket")
self._sso_handler.render_error(request, e.error, e.error_description, 401)
return

await self._handle_cas_response(
request, cas_response, client_redirect_url, session
)

async def _handle_cas_response(
self,
request: SynapseRequest,
cas_response: CasResponse,
client_redirect_url: Optional[str],
session: Optional[str],
) -> None:
"""Handle a CAS response to a ticket request.
Assumes that the response has been validated. Maps the user onto an MXID,
registering them if necessary, and returns a response to the browser.
Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect
cas_response: The parsed CAS response.
client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
session: The session parameter from the `/cas/ticket` HTTP request, if given.
This should be the UI Auth session id.
"""

# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return

# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return

# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)

# Get the matrix ID from the CAS username.
user_id = await self._map_cas_user_to_matrix_user(
username, user_display_name, user_agent, ip_address
cas_response, user_agent, ip_address
)

if session:
Expand All @@ -225,34 +315,31 @@ async def handle_ticket(
)

async def _map_cas_user_to_matrix_user(
self,
remote_user_id: str,
display_name: Optional[str],
user_agent: str,
ip_address: str,
self, cas_response: CasResponse, user_agent: str, ip_address: str,
) -> str:
"""
Given a CAS username, retrieve the user ID for it and possibly register the user.
Args:
remote_user_id: The username from the CAS response.
display_name: The display name from the CAS response.
cas_response: The parsed CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
"""

localpart = map_username_to_mxid_localpart(remote_user_id)
localpart = map_username_to_mxid_localpart(cas_response.username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)

displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)

# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=display_name,
default_display_name=displayname,
user_agent_ips=[(user_agent, ip_address)],
)

Expand Down
9 changes: 7 additions & 2 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def __init__(self, hs: "HomeServer"):
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]

def render_error(
self, request, error: str, error_description: Optional[str] = None
self,
request: Request,
error: str,
error_description: Optional[str] = None,
code: int = 400,
) -> None:
"""Renders the error template and responds with it.
Expand All @@ -113,11 +117,12 @@ def render_error(
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
code: The integer error code (an HTTP response code)
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
respond_with_html(request, code, html)

async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
Expand Down

0 comments on commit 4218473

Please sign in to comment.