Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use Pydantic to validate /devices endpoints #14054

Merged
merged 4 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14054.feature
Original file line number Original file line Diff line number Diff line change
@@ -0,0 +1 @@
Improve validation of request bodies for the [Device Management](https://spec.matrix.org/v1.4/client-server-api/#device-management) and [MSC2697 Device Dehyrdation](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) client-server API endpoints.
98 changes: 52 additions & 46 deletions synapse/rest/client/devices.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@
# limitations under the License. # limitations under the License.


import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple

from pydantic import Extra, StrictStr


from synapse.api import errors from synapse.api import errors
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, parse_and_validate_json_object_from_request,
parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict from synapse.types import JsonDict


if TYPE_CHECKING: if TYPE_CHECKING:
Expand Down Expand Up @@ -80,35 +83,37 @@ def __init__(self, hs: "HomeServer"):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()


class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData]
devices: List[StrictStr]

@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)


try: try:
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.PostBody)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# DELETE # TODO: Can/should we remove this fallback now?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if the haproxy logs include the count of bytes in the request body. If so, might be able to see if there are still any clients relying on this ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The httplog format (which we use) doesn't include this. But there is a field we could opt into:

| | %U | bytes_uploaded (from client to server) | numeric |

I would guess this includes the HTTP request line and headers: would have to check.

# deal with older clients which didn't pass a JSON dict # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = self.PostBody.parse_obj({})
else: else:
raise e raise e


assert_params_in_dict(body, ["devices"])

await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
body, body.dict(exclude_unset=True),
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"remove device(s) from your account", "remove device(s) from your account",
# Users might call this multiple times in a row while cleaning up # Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used. # devices, allow a single UI auth session to be re-used.
can_skip_ui_auth=True, can_skip_ui_auth=True,
) )


await self.device_handler.delete_devices( await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"] requester.user.to_string(), body.devices
) )
return 200, {} return 200, {}


Expand Down Expand Up @@ -147,27 +152,31 @@ async def on_GET(


return 200, device return 200, device


class DeleteBody(RequestBodyModel):
auth: Optional[AuthenticationData]

@interactive_auth_handler @interactive_auth_handler
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, device_id: str self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)


try: try:
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.DeleteBody)


except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# TODO: can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = self.DeleteBody.parse_obj({})
else: else:
raise raise


await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
body, body.dict(exclude_unset=True),
"remove a device from your account", "remove a device from your account",
# Users might call this multiple times in a row while cleaning up # Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used. # devices, allow a single UI auth session to be re-used.
Expand All @@ -179,18 +188,33 @@ async def on_DELETE(
) )
return 200, {} return 200, {}


class PutBody(RequestBodyModel):
display_name: Optional[StrictStr]

async def on_PUT( async def on_PUT(
self, request: SynapseRequest, device_id: str self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)


body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.device_handler.update_device( await self.device_handler.update_device(
requester.user.to_string(), device_id, body requester.user.to_string(), device_id, body.dict()
) )
return 200, {} return 200, {}




class DehydratedDeviceDataModel(RequestBodyModel):
"""JSON blob describing a dehydrated device to be stored.

Expects other freeform fields. Use .dict() to access them.
"""

class Config:
extra = Extra.allow

algorithm: StrictStr


class DehydratedDeviceServlet(RestServlet): class DehydratedDeviceServlet(RestServlet):
"""Retrieve or store a dehydrated device. """Retrieve or store a dehydrated device.


Expand Down Expand Up @@ -246,27 +270,19 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
else: else:
raise errors.NotFoundError("No dehydrated device available") raise errors.NotFoundError("No dehydrated device available")


class PutBody(RequestBodyModel):
device_id: StrictStr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_data: DehydratedDeviceDataModel
initial_device_display_name: Optional[StrictStr]

async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request) submission = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)


if "device_data" not in submission:
raise errors.SynapseError(
400,
"device_data missing",
errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_data"], dict):
raise errors.SynapseError(
400,
"device_data must be an object",
errcode=errors.Codes.INVALID_PARAM,
)

device_id = await self.device_handler.store_dehydrated_device( device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(), requester.user.to_string(),
submission["device_data"], submission.device_data,
submission.get("initial_device_display_name", None), submission.initial_device_display_name,
) )
return 200, {"device_id": device_id} return 200, {"device_id": device_id}


Expand Down Expand Up @@ -300,28 +316,18 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()


class PostBody(RequestBodyModel):
device_id: StrictStr

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)


submission = parse_json_object_from_request(request) submission = parse_and_validate_json_object_from_request(request, self.PostBody)

if "device_id" not in submission:
raise errors.SynapseError(
400,
"device_id missing",
errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_id"], str):
raise errors.SynapseError(
400,
"device_id must be a string",
errcode=errors.Codes.INVALID_PARAM,
)


result = await self.device_handler.rehydrate_device( result = await self.device_handler.rehydrate_device(
requester.user.to_string(), requester.user.to_string(),
self.auth.get_access_token_from_request(request), self.auth.get_access_token_from_request(request),
submission["device_id"], submission.device_id,
) )


return 200, result return 200, result
Expand Down