From 1697a80bde49adcefcd5c2a8e84ec8bcf8f38fcb Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 4 Apr 2024 11:14:09 +0200 Subject: [PATCH 01/16] Setup the basic Python -> Rust call for the rendezvous endpoint --- rust/src/lib.rs | 2 + rust/src/rendezvous/mod.rs | 54 ++++++++++++++++++++++++ synapse/rest/client/rendezvous.py | 65 ++++++++++++++++++++++++++++- synapse/synapse_rust/rendezvous.pyi | 15 +++++++ 4 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 rust/src/rendezvous/mod.rs create mode 100644 synapse/synapse_rust/rendezvous.pyi diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 36a3d645284..9bd1f17ad97 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -7,6 +7,7 @@ pub mod errors; pub mod events; pub mod http; pub mod push; +pub mod rendezvous; lazy_static! { static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init(); @@ -45,6 +46,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { acl::register_module(py, m)?; push::register_module(py, m)?; events::register_module(py, m)?; + rendezvous::register_module(py, m)?; Ok(()) } diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs new file mode 100644 index 00000000000..6cdd27826bc --- /dev/null +++ b/rust/src/rendezvous/mod.rs @@ -0,0 +1,54 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +use log::info; +use pyo3::{pyclass, pymethods, types::PyModule, PyResult, Python}; + +#[pyclass] +struct Rendezvous {} + +#[pymethods] +impl Rendezvous { + #[new] + fn new() -> Self { + Rendezvous {} + } + + fn store_session(&mut self, content_type: String, body: Vec) -> PyResult<()> { + info!( + "Received new rendezvous message: content_type: {}, len(body): {}", + content_type, + body.len() + ); + + Ok(()) + } +} + +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "rendezvous")?; + + child_module.add_class::()?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import rendezvous` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.rendezvous", child_module)?; + + Ok(()) +} diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py index ed06a299870..bb079a60b01 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py @@ -23,10 +23,12 @@ from http.client import TEMPORARY_REDIRECT from typing import TYPE_CHECKING, Optional -from synapse.http.server import HttpServer, respond_with_redirect +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer, respond_with_json, respond_with_redirect from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.synapse_rust.rendezvous import Rendezvous if TYPE_CHECKING: from synapse.server import HomeServer @@ -97,9 +99,70 @@ async def on_POST(self, request: SynapseRequest) -> None: ) +class MSC4108RendezvousServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc4108/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer") -> None: + super().__init__() + + self.max_upload_size = 100_000 + self._store = Rendezvous() + + async def on_POST(self, request: SynapseRequest) -> None: + content_type = request.getHeader("Content-Type") + if content_type is None: + raise SynapseError( + msg="Request must specify a Content-Type", + code=400, + errcode=Codes.MISSING_PARAM, + ) + + raw_content_length = request.getHeader("Content-Length") + if raw_content_length is None: + raise SynapseError( + msg="Request must specify a Content-Length", + code=400, + errcode=Codes.MISSING_PARAM, + ) + try: + content_length = int(raw_content_length) + except ValueError: + raise SynapseError(msg="Content-Length value is invalid", code=400) + if content_length > self.max_upload_size: + raise SynapseError( + msg="Upload request body is too large", + code=413, + errcode=Codes.TOO_LARGE, + ) + + if request.content is None: + raise SynapseError( + msg="Request must have a body", + code=400, + errcode=Codes.MISSING_PARAM, + ) + + body = request.content.read(content_length + 1) + if len(body) != content_length: + raise SynapseError( + msg="Request body does not match Content-Length", + code=400, + errcode=Codes.INVALID_PARAM, + ) + + self._store.store_session(content_type, body) + + respond_with_json(request, 200, {"success": True}) + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc3886_endpoint is not None: MSC3886RendezvousServlet(hs).register(http_server) + # TODO: gate this behind a feature flag + MSC4108RendezvousServlet(hs).register(http_server) + if hs.config.experimental.msc4108_delegation_endpoint is not None: MSC4108DelegationRendezvousServlet(hs).register(http_server) diff --git a/synapse/synapse_rust/rendezvous.pyi b/synapse/synapse_rust/rendezvous.pyi new file mode 100644 index 00000000000..43c7f06c847 --- /dev/null +++ b/synapse/synapse_rust/rendezvous.pyi @@ -0,0 +1,15 @@ +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +class Rendezvous: + def __init__(self) -> None: ... + def store_session(self, content_type: str, body: bytes) -> str: ... From bb862133b0bcb41977f8c98345a6e5abdcff847d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 5 Apr 2024 16:15:14 +0200 Subject: [PATCH 02/16] WIP rendezvous implementation --- Cargo.lock | 164 +++++++++++++++++++++++++++- rust/Cargo.toml | 4 + rust/src/rendezvous/mod.rs | 127 +++++++++++++++++++-- rust/src/rendezvous/session.rs | 79 ++++++++++++++ synapse/rest/client/rendezvous.py | 82 +++++--------- synapse/synapse_rust/rendezvous.pyi | 7 +- 6 files changed, 398 insertions(+), 65 deletions(-) create mode 100644 rust/src/rendezvous/session.rs diff --git a/Cargo.lock b/Cargo.lock index 65f4807c65a..67e0f24c42e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "bytes" version = "1.6.0" @@ -92,9 +98,9 @@ dependencies = [ [[package]] name = "digest" -version = "0.10.5" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adfbc57365a37acbd2ebf2b64d7e69bb766e2fea813521ed536f5d0520dcf86c" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", @@ -117,6 +123,19 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + [[package]] name = "headers" version = "0.4.0" @@ -182,6 +201,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc" +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -266,6 +294,12 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.76" @@ -369,6 +403,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -461,6 +525,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "smallvec" version = "1.10.0" @@ -489,6 +564,7 @@ name = "synapse" version = "0.1.0" dependencies = [ "anyhow", + "base64", "blake2", "bytes", "headers", @@ -496,12 +572,15 @@ dependencies = [ "http", "lazy_static", "log", + "mime", "pyo3", "pyo3-log", "pythonize", "regex", "serde", "serde_json", + "sha2", + "ulid", ] [[package]] @@ -516,6 +595,17 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +[[package]] +name = "ulid" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34778c17965aa2a08913b57e1f34db9b4a63f5de31768b55bf20d2795f921259" +dependencies = [ + "getrandom", + "rand", + "web-time", +] + [[package]] name = "unicode-ident" version = "1.0.5" @@ -534,6 +624,76 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "windows-sys" version = "0.36.1" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9ac766182bc..d41a216d1cc 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -23,11 +23,13 @@ name = "synapse.synapse_rust" [dependencies] anyhow = "1.0.63" +base64 = "0.21.7" bytes = "1.6.0" headers = "0.4.0" http = "1.1.0" lazy_static = "1.4.0" log = "0.4.17" +mime = "0.3.17" pyo3 = { version = "0.20.0", features = [ "macros", "anyhow", @@ -37,8 +39,10 @@ pyo3 = { version = "0.20.0", features = [ pyo3-log = "0.9.0" pythonize = "0.20.0" regex = "1.6.0" +sha2 = "0.10.8" serde = { version = "1.0.144", features = ["derive"] } serde_json = "1.0.85" +ulid = "1.1.2" [features] extension-module = ["pyo3/extension-module"] diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 6cdd27826bc..9e9b938d9d0 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -13,25 +13,132 @@ * */ -use log::info; -use pyo3::{pyclass, pymethods, types::PyModule, PyResult, Python}; +use std::{collections::HashMap, time::Duration}; +use bytes::Bytes; +use headers::{ContentLength, ContentType, HeaderMapExt}; +use http::{Response, StatusCode}; +use pyo3::{pyclass, pymethods, types::PyModule, PyAny, PyResult, Python}; +use ulid::Ulid; + +use crate::{ + errors::{NotFoundError, SynapseError}, + http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt}, +}; + +mod session; + +use self::session::Session; + +// TODO: handle eviction +#[derive(Default)] #[pyclass] -struct Rendezvous {} +struct Rendezvous { + sessions: HashMap, +} #[pymethods] impl Rendezvous { #[new] fn new() -> Self { - Rendezvous {} + Rendezvous::default() + } + + fn handle_post(&mut self, twisted_request: &PyAny) -> PyResult<()> { + let request = http_request_from_twisted(twisted_request)?; + + let ContentLength(content_length) = request.headers().typed_get_required()?; + + if content_length > 1024 * 100 { + return Err(SynapseError::new( + StatusCode::BAD_REQUEST, + "Content-Length too large".to_owned(), + "M_INVALID_PARAM", + None, + None, + )); + } + + let content_type: ContentType = request.headers().typed_get_required()?; + + let id = Ulid::new(); + + // XXX: this is lazy + let source_uri = request.uri(); + let uri = format!("{source_uri}/{id}"); + + let body = request.into_body(); + + let session = Session::new(body, content_type.into(), Duration::from_secs(5 * 60)); + + let response = serde_json::json!({ + "uri": uri, + }) + .to_string(); + + let mut response = Response::new(response.as_bytes()); + *response.status_mut() = StatusCode::CREATED; + response.headers_mut().typed_insert(ContentType::json()); + response.headers_mut().typed_insert(session.etag()); + response.headers_mut().typed_insert(session.expires()); + response.headers_mut().typed_insert(session.last_modified()); + http_response_to_twisted(twisted_request, response)?; + + self.sessions.insert(id, session); + + Ok(()) + } + + fn handle_get(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { + let _request = http_request_from_twisted(twisted_request)?; + + // TODO: handle If-None-Match + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let session = self.sessions.get(&id).ok_or_else(NotFoundError::new)?; + + let mut response = Response::new(session.data()); + *response.status_mut() = StatusCode::OK; + response.headers_mut().typed_insert(session.content_type()); + response.headers_mut().typed_insert(session.etag()); + response.headers_mut().typed_insert(session.expires()); + response.headers_mut().typed_insert(session.last_modified()); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) + } + + fn handle_put(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { + let request = http_request_from_twisted(twisted_request)?; + + // TODO: handle If-Match + + let content_type: ContentType = request.headers().typed_get_required()?; + let data = request.into_body(); + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let session = self.sessions.get_mut(&id).ok_or_else(NotFoundError::new)?; + session.update(data, content_type.into()); + + let mut response = Response::new(Bytes::new()); + *response.status_mut() = StatusCode::ACCEPTED; + response.headers_mut().typed_insert(session.etag()); + response.headers_mut().typed_insert(session.expires()); + response.headers_mut().typed_insert(session.last_modified()); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) } - fn store_session(&mut self, content_type: String, body: Vec) -> PyResult<()> { - info!( - "Received new rendezvous message: content_type: {}, len(body): {}", - content_type, - body.len() - ); + fn handle_delete(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { + let _request = http_request_from_twisted(twisted_request)?; + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?; + + let mut response = Response::new(Bytes::new()); + *response.status_mut() = StatusCode::NO_CONTENT; + http_response_to_twisted(twisted_request, response)?; Ok(()) } diff --git a/rust/src/rendezvous/session.rs b/rust/src/rendezvous/session.rs new file mode 100644 index 00000000000..cef92f2ecfb --- /dev/null +++ b/rust/src/rendezvous/session.rs @@ -0,0 +1,79 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +use std::time::{Duration, SystemTime}; + +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use bytes::Bytes; +use headers::{ContentType, ETag, Expires, LastModified}; +use mime::Mime; +use sha2::{Digest, Sha256}; + +/// A single session, containing data, metadata, and expiry information. +pub struct Session { + hash: [u8; 32], + data: Bytes, + content_type: Mime, + last_modified: SystemTime, + expires: SystemTime, +} + +impl Session { + pub fn new(data: Bytes, content_type: Mime, ttl: Duration) -> Self { + let hash = Sha256::digest(&data).into(); + let now = SystemTime::now(); + Self { + hash, + data, + content_type, + expires: now + ttl, + last_modified: now, + } + } + + pub fn update(&mut self, data: Bytes, content_type: Mime) { + self.hash = Sha256::digest(&data).into(); + self.data = data; + self.content_type = content_type; + self.last_modified = SystemTime::now(); + } + + /// Returns the Content-Type header of the session. + pub fn content_type(&self) -> ContentType { + self.content_type.clone().into() + } + + pub fn etag(&self) -> ETag { + let encoded = URL_SAFE_NO_PAD.encode(self.hash); + // SAFETY: Base64 encoding is URL-safe, so ETag-safe + format!("\"{encoded}\"") + .parse() + .expect("base64-encoded hash should be URL-safe") + } + + /// Returns the Last-Modified header of the session. + pub fn last_modified(&self) -> LastModified { + self.last_modified.into() + } + + /// Returns the Expires header of the session. + pub fn expires(&self) -> Expires { + self.expires.into() + } + + /// Returns the current data stored in the session. + pub fn data(&self) -> Bytes { + self.data.clone() + } +} diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py index bb079a60b01..21bc92cf95f 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py @@ -23,8 +23,7 @@ from http.client import TEMPORARY_REDIRECT from typing import TYPE_CHECKING, Optional -from synapse.api.errors import Codes, SynapseError -from synapse.http.server import HttpServer, respond_with_json, respond_with_redirect +from synapse.http.server import HttpServer, respond_with_redirect from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns @@ -104,65 +103,44 @@ class MSC4108RendezvousServlet(RestServlet): "/org.matrix.msc4108/rendezvous$", releases=[], v1=False, unstable=True ) - def __init__(self, hs: "HomeServer") -> None: + def __init__(self, store: Rendezvous) -> None: super().__init__() + self._store = store - self.max_upload_size = 100_000 - self._store = Rendezvous() + def on_POST(self, request: SynapseRequest) -> None: + self._store.handle_post(request) - async def on_POST(self, request: SynapseRequest) -> None: - content_type = request.getHeader("Content-Type") - if content_type is None: - raise SynapseError( - msg="Request must specify a Content-Type", - code=400, - errcode=Codes.MISSING_PARAM, - ) - - raw_content_length = request.getHeader("Content-Length") - if raw_content_length is None: - raise SynapseError( - msg="Request must specify a Content-Length", - code=400, - errcode=Codes.MISSING_PARAM, - ) - try: - content_length = int(raw_content_length) - except ValueError: - raise SynapseError(msg="Content-Length value is invalid", code=400) - if content_length > self.max_upload_size: - raise SynapseError( - msg="Upload request body is too large", - code=413, - errcode=Codes.TOO_LARGE, - ) - - if request.content is None: - raise SynapseError( - msg="Request must have a body", - code=400, - errcode=Codes.MISSING_PARAM, - ) - - body = request.content.read(content_length + 1) - if len(body) != content_length: - raise SynapseError( - msg="Request body does not match Content-Length", - code=400, - errcode=Codes.INVALID_PARAM, - ) - - self._store.store_session(content_type, body) - - respond_with_json(request, 200, {"success": True}) + +class MSC4108RendezvousSessionServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc4108/rendezvous/(?P[^/]+)$", + releases=[], + v1=False, + unstable=True, + ) + + def __init__(self, store: Rendezvous) -> None: + super().__init__() + self._store = store + + def on_GET(self, request: SynapseRequest, session_id: str) -> None: + self._store.handle_get(request, session_id) + + def on_PUT(self, request: SynapseRequest, session_id: str) -> None: + self._store.handle_put(request, session_id) + + def on_DELETE(self, request: SynapseRequest, session_id: str) -> None: + self._store.handle_delete(request, session_id) def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc3886_endpoint is not None: MSC3886RendezvousServlet(hs).register(http_server) - # TODO: gate this behind a feature flag - MSC4108RendezvousServlet(hs).register(http_server) + # TODO: gate this behind a feature flag and store the rendezvous object in the HS + rendezvous = Rendezvous() + MSC4108RendezvousServlet(rendezvous).register(http_server) + MSC4108RendezvousSessionServlet(rendezvous).register(http_server) if hs.config.experimental.msc4108_delegation_endpoint is not None: MSC4108DelegationRendezvousServlet(hs).register(http_server) diff --git a/synapse/synapse_rust/rendezvous.pyi b/synapse/synapse_rust/rendezvous.pyi index 43c7f06c847..e15cbbfbc6c 100644 --- a/synapse/synapse_rust/rendezvous.pyi +++ b/synapse/synapse_rust/rendezvous.pyi @@ -10,6 +10,11 @@ # See the GNU Affero General Public License for more details: # . +from twisted.web.iweb import IRequest + class Rendezvous: def __init__(self) -> None: ... - def store_session(self, content_type: str, body: bytes) -> str: ... + def handle_post(self, request: IRequest) -> None: ... + def handle_get(self, request: IRequest, session_id: str) -> None: ... + def handle_put(self, request: IRequest, session_id: str) -> None: ... + def handle_delete(self, request: IRequest, session_id: str) -> None: ... From cd9809e932b1b394bbf413992644deeed6940f61 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 5 Apr 2024 18:30:10 +0200 Subject: [PATCH 03/16] Some more rendezvous implementation --- rust/src/rendezvous/mod.rs | 92 ++++++++++++++++++++++------- synapse/rest/client/rendezvous.py | 7 ++- synapse/synapse_rust/rendezvous.pyi | 2 +- 3 files changed, 77 insertions(+), 24 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 9e9b938d9d0..de7c6e5958d 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -16,9 +16,14 @@ use std::{collections::HashMap, time::Duration}; use bytes::Bytes; -use headers::{ContentLength, ContentType, HeaderMapExt}; -use http::{Response, StatusCode}; -use pyo3::{pyclass, pymethods, types::PyModule, PyAny, PyResult, Python}; +use headers::{ + AccessControlAllowOrigin, AccessControlExposeHeaders, ContentLength, ContentType, HeaderMapExt, + IfMatch, IfNoneMatch, +}; +use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; +use pyo3::{ + exceptions::PyValueError, pyclass, pymethods, types::PyModule, PyAny, PyResult, Python, +}; use ulid::Ulid; use crate::{ @@ -30,18 +35,31 @@ mod session; use self::session::Session; +fn prepare_headers(headers: &mut HeaderMap, session: &Session) { + headers.typed_insert(AccessControlAllowOrigin::ANY); + headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG])); + headers.typed_insert(session.etag()); + headers.typed_insert(session.expires()); + headers.typed_insert(session.last_modified()); +} + // TODO: handle eviction -#[derive(Default)] #[pyclass] struct Rendezvous { + base: Uri, sessions: HashMap, } #[pymethods] impl Rendezvous { #[new] - fn new() -> Self { - Rendezvous::default() + fn new(base: &str) -> PyResult { + let base = Uri::try_from(base).map_err(|_| PyValueError::new_err("Invalid base URI"))?; + + Ok(Self { + base, + sessions: HashMap::new(), + }) } fn handle_post(&mut self, twisted_request: &PyAny) -> PyResult<()> { @@ -63,25 +81,21 @@ impl Rendezvous { let id = Ulid::new(); - // XXX: this is lazy - let source_uri = request.uri(); - let uri = format!("{source_uri}/{id}"); + let uri = format!("{base}/{id}", base = self.base); let body = request.into_body(); let session = Session::new(body, content_type.into(), Duration::from_secs(5 * 60)); let response = serde_json::json!({ - "uri": uri, + "url": uri, }) .to_string(); let mut response = Response::new(response.as_bytes()); *response.status_mut() = StatusCode::CREATED; response.headers_mut().typed_insert(ContentType::json()); - response.headers_mut().typed_insert(session.etag()); - response.headers_mut().typed_insert(session.expires()); - response.headers_mut().typed_insert(session.last_modified()); + prepare_headers(response.headers_mut(), &session); http_response_to_twisted(twisted_request, response)?; self.sessions.insert(id, session); @@ -90,19 +104,27 @@ impl Rendezvous { } fn handle_get(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { - let _request = http_request_from_twisted(twisted_request)?; + let request = http_request_from_twisted(twisted_request)?; - // TODO: handle If-None-Match + let if_none_match: Option = request.headers().typed_get(); let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; let session = self.sessions.get(&id).ok_or_else(NotFoundError::new)?; + if let Some(if_none_match) = if_none_match { + if !if_none_match.precondition_passes(&session.etag()) { + let mut response = Response::new(Bytes::new()); + *response.status_mut() = StatusCode::NOT_MODIFIED; + prepare_headers(response.headers_mut(), session); + http_response_to_twisted(twisted_request, response)?; + return Ok(()); + } + } + let mut response = Response::new(session.data()); *response.status_mut() = StatusCode::OK; + prepare_headers(response.headers_mut(), session); response.headers_mut().typed_insert(session.content_type()); - response.headers_mut().typed_insert(session.etag()); - response.headers_mut().typed_insert(session.expires()); - response.headers_mut().typed_insert(session.last_modified()); http_response_to_twisted(twisted_request, response)?; Ok(()) @@ -111,20 +133,43 @@ impl Rendezvous { fn handle_put(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; - // TODO: handle If-Match + let ContentLength(content_length) = request.headers().typed_get_required()?; + + if content_length > 1024 * 100 { + return Err(SynapseError::new( + StatusCode::BAD_REQUEST, + "Content-Length too large".to_owned(), + "M_INVALID_PARAM", + None, + None, + )); + } let content_type: ContentType = request.headers().typed_get_required()?; + let if_match: IfMatch = request.headers().typed_get_required()?; + let data = request.into_body(); let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; let session = self.sessions.get_mut(&id).ok_or_else(NotFoundError::new)?; + + if !if_match.precondition_passes(&session.etag()) { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers, session); + return Err(SynapseError::new( + StatusCode::PRECONDITION_FAILED, + "ETag does not match".to_owned(), + "M_CONCURRENT_WRITE", + None, + Some(headers), + )); + } + session.update(data, content_type.into()); let mut response = Response::new(Bytes::new()); *response.status_mut() = StatusCode::ACCEPTED; - response.headers_mut().typed_insert(session.etag()); - response.headers_mut().typed_insert(session.expires()); - response.headers_mut().typed_insert(session.last_modified()); + prepare_headers(response.headers_mut(), session); http_response_to_twisted(twisted_request, response)?; Ok(()) @@ -138,6 +183,9 @@ impl Rendezvous { let mut response = Response::new(Bytes::new()); *response.status_mut() = StatusCode::NO_CONTENT; + response + .headers_mut() + .typed_insert(AccessControlAllowOrigin::ANY); http_response_to_twisted(twisted_request, response)?; Ok(()) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py index 21bc92cf95f..fd868383f8d 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py @@ -112,6 +112,7 @@ def on_POST(self, request: SynapseRequest) -> None: class MSC4108RendezvousSessionServlet(RestServlet): + # TODO: this should probably be mounted on the _synapse/client namespace PATTERNS = client_patterns( "/org.matrix.msc4108/rendezvous/(?P[^/]+)$", releases=[], @@ -138,7 +139,11 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: MSC3886RendezvousServlet(hs).register(http_server) # TODO: gate this behind a feature flag and store the rendezvous object in the HS - rendezvous = Rendezvous() + base = ( + hs.config.server.public_baseurl + + "_matrix/client/unstable/org.matrix.msc4108/rendezvous" + ) + rendezvous = Rendezvous(base) MSC4108RendezvousServlet(rendezvous).register(http_server) MSC4108RendezvousSessionServlet(rendezvous).register(http_server) diff --git a/synapse/synapse_rust/rendezvous.pyi b/synapse/synapse_rust/rendezvous.pyi index e15cbbfbc6c..cb213add5df 100644 --- a/synapse/synapse_rust/rendezvous.pyi +++ b/synapse/synapse_rust/rendezvous.pyi @@ -13,7 +13,7 @@ from twisted.web.iweb import IRequest class Rendezvous: - def __init__(self) -> None: ... + def __init__(self, base: str) -> None: ... def handle_post(self, request: IRequest) -> None: ... def handle_get(self, request: IRequest, session_id: str) -> None: ... def handle_put(self, request: IRequest, session_id: str) -> None: ... From 08b452584195085767d4c9301b51506bb7d50e0d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 8 Apr 2024 11:43:22 +0200 Subject: [PATCH 04/16] Return better errors --- rust/src/rendezvous/mod.rs | 56 ++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index de7c6e5958d..83fbbb47146 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -21,11 +21,13 @@ use headers::{ IfMatch, IfNoneMatch, }; use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; +use mime::Mime; use pyo3::{ exceptions::PyValueError, pyclass, pymethods, types::PyModule, PyAny, PyResult, Python, }; use ulid::Ulid; +use self::session::Session; use crate::{ errors::{NotFoundError, SynapseError}, http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt}, @@ -33,7 +35,7 @@ use crate::{ mod session; -use self::session::Session; +const MAX_CONTENT_LENGTH: u64 = 1024 * 100; fn prepare_headers(headers: &mut HeaderMap, session: &Session) { headers.typed_insert(AccessControlAllowOrigin::ANY); @@ -43,6 +45,24 @@ fn prepare_headers(headers: &mut HeaderMap, session: &Session) { headers.typed_insert(session.last_modified()); } +fn check_input_headers(headers: &HeaderMap) -> PyResult { + let ContentLength(content_length) = headers.typed_get_required()?; + + if content_length > MAX_CONTENT_LENGTH { + return Err(SynapseError::new( + StatusCode::PAYLOAD_TOO_LARGE, + "Payload too large".to_owned(), + "M_TOO_LARGE", + None, + None, + )); + } + + let content_type: ContentType = headers.typed_get_required()?; + + Ok(content_type.into()) +} + // TODO: handle eviction #[pyclass] struct Rendezvous { @@ -65,19 +85,7 @@ impl Rendezvous { fn handle_post(&mut self, twisted_request: &PyAny) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; - let ContentLength(content_length) = request.headers().typed_get_required()?; - - if content_length > 1024 * 100 { - return Err(SynapseError::new( - StatusCode::BAD_REQUEST, - "Content-Length too large".to_owned(), - "M_INVALID_PARAM", - None, - None, - )); - } - - let content_type: ContentType = request.headers().typed_get_required()?; + let content_type = check_input_headers(request.headers())?; let id = Ulid::new(); @@ -85,7 +93,7 @@ impl Rendezvous { let body = request.into_body(); - let session = Session::new(body, content_type.into(), Duration::from_secs(5 * 60)); + let session = Session::new(body, content_type, Duration::from_secs(5 * 60)); let response = serde_json::json!({ "url": uri, @@ -106,7 +114,7 @@ impl Rendezvous { fn handle_get(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; - let if_none_match: Option = request.headers().typed_get(); + let if_none_match: Option = request.headers().typed_get_optional()?; let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; let session = self.sessions.get(&id).ok_or_else(NotFoundError::new)?; @@ -133,19 +141,8 @@ impl Rendezvous { fn handle_put(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; - let ContentLength(content_length) = request.headers().typed_get_required()?; - - if content_length > 1024 * 100 { - return Err(SynapseError::new( - StatusCode::BAD_REQUEST, - "Content-Length too large".to_owned(), - "M_INVALID_PARAM", - None, - None, - )); - } + let content_type = check_input_headers(request.headers())?; - let content_type: ContentType = request.headers().typed_get_required()?; let if_match: IfMatch = request.headers().typed_get_required()?; let data = request.into_body(); @@ -156,6 +153,7 @@ impl Rendezvous { if !if_match.precondition_passes(&session.etag()) { let mut headers = HeaderMap::new(); prepare_headers(&mut headers, session); + return Err(SynapseError::new( StatusCode::PRECONDITION_FAILED, "ETag does not match".to_owned(), @@ -165,7 +163,7 @@ impl Rendezvous { )); } - session.update(data, content_type.into()); + session.update(data, content_type); let mut response = Response::new(Bytes::new()); *response.status_mut() = StatusCode::ACCEPTED; From 73d3152abcdeb95c81567ad568af015b400e99f8 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 8 Apr 2024 14:55:50 +0200 Subject: [PATCH 05/16] Add config flag, rename to RendezVousHandler & other fixes --- rust/src/rendezvous/mod.rs | 27 +++++++++++++++++++-------- rust/src/rendezvous/session.rs | 8 +++++++- synapse/config/experimental.py | 12 +++++++++++- synapse/rest/client/rendezvous.py | 28 +++++++++++----------------- synapse/rest/client/versions.py | 9 +++++++-- synapse/server.py | 5 +++++ synapse/synapse_rust/rendezvous.pyi | 6 ++++-- 7 files changed, 64 insertions(+), 31 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 83fbbb47146..a1b68e6f0fd 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -18,7 +18,7 @@ use std::{collections::HashMap, time::Duration}; use bytes::Bytes; use headers::{ AccessControlAllowOrigin, AccessControlExposeHeaders, ContentLength, ContentType, HeaderMapExt, - IfMatch, IfNoneMatch, + IfMatch, IfNoneMatch, Pragma, }; use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; use mime::Mime; @@ -40,6 +40,7 @@ const MAX_CONTENT_LENGTH: u64 = 1024 * 100; fn prepare_headers(headers: &mut HeaderMap, session: &Session) { headers.typed_insert(AccessControlAllowOrigin::ANY); headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG])); + headers.typed_insert(Pragma::no_cache()); headers.typed_insert(session.etag()); headers.typed_insert(session.expires()); headers.typed_insert(session.last_modified()); @@ -65,16 +66,24 @@ fn check_input_headers(headers: &HeaderMap) -> PyResult { // TODO: handle eviction #[pyclass] -struct Rendezvous { +struct RendezvousHandler { base: Uri, sessions: HashMap, } #[pymethods] -impl Rendezvous { +impl RendezvousHandler { #[new] - fn new(base: &str) -> PyResult { - let base = Uri::try_from(base).map_err(|_| PyValueError::new_err("Invalid base URI"))?; + fn new(homeserver: &PyAny) -> PyResult { + let base: String = homeserver + .getattr("config")? + .getattr("server")? + .getattr("public_baseurl")? + .extract()?; + let base = Uri::try_from(format!( + "{base}_matrix/client/unstable/org.matrix.msc4108/rendezvous" + )) + .map_err(|_| PyValueError::new_err("Invalid base URI"))?; Ok(Self { base, @@ -131,8 +140,10 @@ impl Rendezvous { let mut response = Response::new(session.data()); *response.status_mut() = StatusCode::OK; - prepare_headers(response.headers_mut(), session); - response.headers_mut().typed_insert(session.content_type()); + let headers = response.headers_mut(); + prepare_headers(headers, session); + headers.typed_insert(session.content_type()); + headers.typed_insert(session.content_length()); http_response_to_twisted(twisted_request, response)?; Ok(()) @@ -193,7 +204,7 @@ impl Rendezvous { pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { let child_module = PyModule::new(py, "rendezvous")?; - child_module.add_class::()?; + child_module.add_class::()?; m.add_submodule(child_module)?; diff --git a/rust/src/rendezvous/session.rs b/rust/src/rendezvous/session.rs index cef92f2ecfb..b1e0502aff0 100644 --- a/rust/src/rendezvous/session.rs +++ b/rust/src/rendezvous/session.rs @@ -16,7 +16,7 @@ use std::time::{Duration, SystemTime}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; use bytes::Bytes; -use headers::{ContentType, ETag, Expires, LastModified}; +use headers::{ContentLength, ContentType, ETag, Expires, LastModified}; use mime::Mime; use sha2::{Digest, Sha256}; @@ -54,6 +54,12 @@ impl Session { self.content_type.clone().into() } + /// Returns the Content-Length header of the session. + pub fn content_length(&self) -> ContentLength { + ContentLength(self.data.len() as _) + } + + /// Returns the ETag header of the session. pub fn etag(&self) -> ETag { let encoded = URL_SAFE_NO_PAD.encode(self.hash); // SAFETY: Base64 encoding is URL-safe, so ETag-safe diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 353ae23f910..baa3580f293 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -413,12 +413,22 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: ) # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code + self.msc4108_enabled = experimental.get("msc4108_enabled", False) + self.msc4108_delegation_endpoint: Optional[str] = experimental.get( "msc4108_delegation_endpoint", None ) - if self.msc4108_delegation_endpoint is not None and not self.msc3861.enabled: + if ( + self.msc4108_enabled or self.msc4108_delegation_endpoint is not None + ) and not self.msc3861.enabled: raise ConfigError( "MSC4108 requires MSC3861 to be enabled", ("experimental", "msc4108_delegation_endpoint"), ) + + if self.msc4108_delegation_endpoint is not None and self.msc4108_enabled: + raise ConfigError( + "You cannot have MSC4108 both enabled and delegated at the same time", + ("experimental", "msc4108_delegation_endpoint"), + ) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py index fd868383f8d..443256cc264 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py @@ -27,7 +27,6 @@ from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.synapse_rust.rendezvous import Rendezvous if TYPE_CHECKING: from synapse.server import HomeServer @@ -103,12 +102,12 @@ class MSC4108RendezvousServlet(RestServlet): "/org.matrix.msc4108/rendezvous$", releases=[], v1=False, unstable=True ) - def __init__(self, store: Rendezvous) -> None: + def __init__(self, hs: "HomeServer") -> None: super().__init__() - self._store = store + self._handler = hs.get_rendezvous_handler() def on_POST(self, request: SynapseRequest) -> None: - self._store.handle_post(request) + self._handler.handle_post(request) class MSC4108RendezvousSessionServlet(RestServlet): @@ -120,32 +119,27 @@ class MSC4108RendezvousSessionServlet(RestServlet): unstable=True, ) - def __init__(self, store: Rendezvous) -> None: + def __init__(self, hs: "HomeServer") -> None: super().__init__() - self._store = store + self._handler = hs.get_rendezvous_handler() def on_GET(self, request: SynapseRequest, session_id: str) -> None: - self._store.handle_get(request, session_id) + self._handler.handle_get(request, session_id) def on_PUT(self, request: SynapseRequest, session_id: str) -> None: - self._store.handle_put(request, session_id) + self._handler.handle_put(request, session_id) def on_DELETE(self, request: SynapseRequest, session_id: str) -> None: - self._store.handle_delete(request, session_id) + self._handler.handle_delete(request, session_id) def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc3886_endpoint is not None: MSC3886RendezvousServlet(hs).register(http_server) - # TODO: gate this behind a feature flag and store the rendezvous object in the HS - base = ( - hs.config.server.public_baseurl - + "_matrix/client/unstable/org.matrix.msc4108/rendezvous" - ) - rendezvous = Rendezvous(base) - MSC4108RendezvousServlet(rendezvous).register(http_server) - MSC4108RendezvousSessionServlet(rendezvous).register(http_server) + if hs.config.experimental.msc4108_enabled: + MSC4108RendezvousServlet(hs).register(http_server) + MSC4108RendezvousSessionServlet(hs).register(http_server) if hs.config.experimental.msc4108_delegation_endpoint is not None: MSC4108DelegationRendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 638d4c45ae6..fa453a3b027 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -141,8 +141,13 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]: # Allows clients to handle push for encrypted events. "org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events, # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code - "org.matrix.msc4108": self.config.experimental.msc4108_delegation_endpoint - is not None, + "org.matrix.msc4108": ( + self.config.experimental.msc4108_enabled + or ( + self.config.experimental.msc4108_delegation_endpoint + is not None + ) + ), }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 6d5a18fb1de..95e319d2e66 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -143,6 +143,7 @@ from synapse.storage import Databases from synapse.storage.controllers import StorageControllers from synapse.streams.events import EventSources +from synapse.synapse_rust.rendezvous import RendezvousHandler from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock from synapse.util.distributor import Distributor @@ -859,6 +860,10 @@ def get_push_rules_handler(self) -> PushRulesHandler: def get_room_forgetter_handler(self) -> RoomForgetterHandler: return RoomForgetterHandler(self) + @cache_in_self + def get_rendezvous_handler(self) -> RendezvousHandler: + return RendezvousHandler(self) + @cache_in_self def get_outbound_redis_connection(self) -> "ConnectionHandler": """ diff --git a/synapse/synapse_rust/rendezvous.pyi b/synapse/synapse_rust/rendezvous.pyi index cb213add5df..b09d88c6574 100644 --- a/synapse/synapse_rust/rendezvous.pyi +++ b/synapse/synapse_rust/rendezvous.pyi @@ -12,8 +12,10 @@ from twisted.web.iweb import IRequest -class Rendezvous: - def __init__(self, base: str) -> None: ... +from synapse.server import HomeServer + +class RendezvousHandler: + def __init__(self, homeserver: HomeServer) -> None: ... def handle_post(self, request: IRequest) -> None: ... def handle_get(self, request: IRequest, session_id: str) -> None: ... def handle_put(self, request: IRequest, session_id: str) -> None: ... From a34c8f7bdb72f3fcdc1de533f52427bb9be4994f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 8 Apr 2024 15:53:55 +0200 Subject: [PATCH 06/16] Prepare for looping eviction calls --- rust/src/rendezvous/mod.rs | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index a1b68e6f0fd..345434ebdce 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -21,9 +21,10 @@ use headers::{ IfMatch, IfNoneMatch, Pragma, }; use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; +use log::info; use mime::Mime; use pyo3::{ - exceptions::PyValueError, pyclass, pymethods, types::PyModule, PyAny, PyResult, Python, + exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyResult, Python, }; use ulid::Ulid; @@ -74,7 +75,7 @@ struct RendezvousHandler { #[pymethods] impl RendezvousHandler { #[new] - fn new(homeserver: &PyAny) -> PyResult { + fn new(py: Python<'_>, homeserver: &PyAny) -> PyResult> { let base: String = homeserver .getattr("config")? .getattr("server")? @@ -85,10 +86,26 @@ impl RendezvousHandler { )) .map_err(|_| PyValueError::new_err("Invalid base URI"))?; - Ok(Self { - base, - sessions: HashMap::new(), - }) + // Construct a Python object so that we can get a reference to the + // evict method and schedule it to run. + let self_ = Py::new( + py, + Self { + base, + sessions: HashMap::new(), + }, + )?; + + let evict = self_.getattr(py, "_evict")?; + homeserver + .call_method0("get_clock")? + .call_method("looping_call", (evict, 500), None)?; + + Ok(self_) + } + + fn _evict(&mut self) { + info!("Evicting sessions"); } fn handle_post(&mut self, twisted_request: &PyAny) -> PyResult<()> { From 317528d42b7f6d93dfb8ddd32551c765c3ccf77d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 10 Apr 2024 09:14:43 +0200 Subject: [PATCH 07/16] Use Synapse's clock to generate timestamps --- rust/src/rendezvous/mod.rs | 29 ++++++++++++++++++++++------- rust/src/rendezvous/session.rs | 8 ++++---- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 345434ebdce..f02fddb299d 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -13,7 +13,10 @@ * */ -use std::{collections::HashMap, time::Duration}; +use std::{ + collections::HashMap, + time::{Duration, SystemTime}, +}; use bytes::Bytes; use headers::{ @@ -24,7 +27,8 @@ use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; use log::info; use mime::Mime; use pyo3::{ - exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyResult, Python, + exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyObject, PyResult, + Python, ToPyObject, }; use ulid::Ulid; @@ -69,6 +73,7 @@ fn check_input_headers(headers: &HeaderMap) -> PyResult { #[pyclass] struct RendezvousHandler { base: Uri, + clock: PyObject, sessions: HashMap, } @@ -86,12 +91,15 @@ impl RendezvousHandler { )) .map_err(|_| PyValueError::new_err("Invalid base URI"))?; + let clock = homeserver.call_method0("get_clock")?.to_object(py); + // Construct a Python object so that we can get a reference to the // evict method and schedule it to run. let self_ = Py::new( py, Self { base, + clock, sessions: HashMap::new(), }, )?; @@ -108,18 +116,22 @@ impl RendezvousHandler { info!("Evicting sessions"); } - fn handle_post(&mut self, twisted_request: &PyAny) -> PyResult<()> { + fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; let content_type = check_input_headers(request.headers())?; - let id = Ulid::new(); + // Generate a new ULID for the session from the current time. + let clock = self.clock.as_ref(py); + let now: u64 = clock.call_method0("time_msec")?.extract()?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + let id = Ulid::from_datetime(now); let uri = format!("{base}/{id}", base = self.base); let body = request.into_body(); - let session = Session::new(body, content_type, Duration::from_secs(5 * 60)); + let session = Session::new(body, content_type, now, Duration::from_secs(5 * 60)); let response = serde_json::json!({ "url": uri, @@ -166,7 +178,7 @@ impl RendezvousHandler { Ok(()) } - fn handle_put(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { + fn handle_put(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; let content_type = check_input_headers(request.headers())?; @@ -175,6 +187,9 @@ impl RendezvousHandler { let data = request.into_body(); + let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; let session = self.sessions.get_mut(&id).ok_or_else(NotFoundError::new)?; @@ -191,7 +206,7 @@ impl RendezvousHandler { )); } - session.update(data, content_type); + session.update(data, content_type, now); let mut response = Response::new(Bytes::new()); *response.status_mut() = StatusCode::ACCEPTED; diff --git a/rust/src/rendezvous/session.rs b/rust/src/rendezvous/session.rs index b1e0502aff0..50f76628340 100644 --- a/rust/src/rendezvous/session.rs +++ b/rust/src/rendezvous/session.rs @@ -30,9 +30,9 @@ pub struct Session { } impl Session { - pub fn new(data: Bytes, content_type: Mime, ttl: Duration) -> Self { + /// Create a new session with the given data, content type, and time-to-live. + pub fn new(data: Bytes, content_type: Mime, now: SystemTime, ttl: Duration) -> Self { let hash = Sha256::digest(&data).into(); - let now = SystemTime::now(); Self { hash, data, @@ -42,11 +42,11 @@ impl Session { } } - pub fn update(&mut self, data: Bytes, content_type: Mime) { + pub fn update(&mut self, data: Bytes, content_type: Mime, now: SystemTime) { self.hash = Sha256::digest(&data).into(); self.data = data; self.content_type = content_type; - self.last_modified = SystemTime::now(); + self.last_modified = now; } /// Returns the Content-Type header of the session. From 4bf190ff42c735dddf449b1e91eb3c198fbae4a2 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 9 Apr 2024 18:07:36 +0100 Subject: [PATCH 08/16] Set rest of required CORS headers --- rust/src/rendezvous/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index f02fddb299d..fbe87ba0e57 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -20,8 +20,8 @@ use std::{ use bytes::Bytes; use headers::{ - AccessControlAllowOrigin, AccessControlExposeHeaders, ContentLength, ContentType, HeaderMapExt, - IfMatch, IfNoneMatch, Pragma, + AccessControlAllowOrigin, AccessControlExposeHeaders, CacheControl, ContentLength, ContentType, + HeaderMapExt, IfMatch, IfNoneMatch, Pragma, }; use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; use log::info; @@ -42,10 +42,12 @@ mod session; const MAX_CONTENT_LENGTH: u64 = 1024 * 100; +// n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers. fn prepare_headers(headers: &mut HeaderMap, session: &Session) { headers.typed_insert(AccessControlAllowOrigin::ANY); headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG])); headers.typed_insert(Pragma::no_cache()); + headers.typed_insert(CacheControl::new().with_no_store()); headers.typed_insert(session.etag()); headers.typed_insert(session.expires()); headers.typed_insert(session.last_modified()); From 60e0a48791004064d764d7711739e25fa737ba5f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 10 Apr 2024 17:37:29 +0200 Subject: [PATCH 09/16] Implement session eviction --- rust/src/rendezvous/mod.rs | 38 +++++++++++++++++++++++++++------- rust/src/rendezvous/session.rs | 6 ++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index fbe87ba0e57..9b04d79583a 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -14,7 +14,7 @@ */ use std::{ - collections::HashMap, + collections::BTreeMap, time::{Duration, SystemTime}, }; @@ -24,7 +24,6 @@ use headers::{ HeaderMapExt, IfMatch, IfNoneMatch, Pragma, }; use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; -use log::info; use mime::Mime; use pyo3::{ exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyObject, PyResult, @@ -41,6 +40,7 @@ use crate::{ mod session; const MAX_CONTENT_LENGTH: u64 = 1024 * 100; +const CAPACITY: usize = 100; // n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers. fn prepare_headers(headers: &mut HeaderMap, session: &Session) { @@ -71,12 +71,23 @@ fn check_input_headers(headers: &HeaderMap) -> PyResult { Ok(content_type.into()) } -// TODO: handle eviction #[pyclass] struct RendezvousHandler { base: Uri, clock: PyObject, - sessions: HashMap, + sessions: BTreeMap, +} + +impl RendezvousHandler { + fn evict(&mut self, now: SystemTime, max_entries: usize) { + // First remove all the entries which expired + self.sessions.retain(|_, session| !session.expired(now)); + + // Then we remove the oldest entires until we're under the limit + while self.sessions.len() > max_entries { + self.sessions.pop_first(); + } + } } #[pymethods] @@ -102,7 +113,7 @@ impl RendezvousHandler { Self { base, clock, - sessions: HashMap::new(), + sessions: BTreeMap::new(), }, )?; @@ -114,8 +125,13 @@ impl RendezvousHandler { Ok(self_) } - fn _evict(&mut self) { - info!("Evicting sessions"); + fn _evict(&mut self, py: Python<'_>) -> PyResult<()> { + let clock = self.clock.as_ref(py); + let now: u64 = clock.call_method0("time_msec")?.extract()?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + self.evict(now, CAPACITY); + + Ok(()) } fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> { @@ -123,10 +139,16 @@ impl RendezvousHandler { let content_type = check_input_headers(request.headers())?; - // Generate a new ULID for the session from the current time. let clock = self.clock.as_ref(py); let now: u64 = clock.call_method0("time_msec")?.extract()?; let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + + // We trigger an immediate eviction if we're at 2x the capacity + if self.sessions.len() >= CAPACITY * 2 { + self.evict(now, CAPACITY); + } + + // Generate a new ULID for the session from the current time. let id = Ulid::from_datetime(now); let uri = format!("{base}/{id}", base = self.base); diff --git a/rust/src/rendezvous/session.rs b/rust/src/rendezvous/session.rs index 50f76628340..179304edfe6 100644 --- a/rust/src/rendezvous/session.rs +++ b/rust/src/rendezvous/session.rs @@ -42,6 +42,12 @@ impl Session { } } + /// Returns true if the session has expired at the given time. + pub fn expired(&self, now: SystemTime) -> bool { + self.expires <= now + } + + /// Update the session with new data, content type, and last modified time. pub fn update(&mut self, data: Bytes, content_type: Mime, now: SystemTime) { self.hash = Sha256::digest(&data).into(); self.data = data; From 5f9fe56113160b0d8535d5edebab33865e9f747d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 10 Apr 2024 17:39:18 +0200 Subject: [PATCH 10/16] Do not allow GETting/PUTting expired sessions --- rust/src/rendezvous/mod.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 9b04d79583a..168c5866bb2 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -173,13 +173,20 @@ impl RendezvousHandler { Ok(()) } - fn handle_get(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { + fn handle_get(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; let if_none_match: Option = request.headers().typed_get_optional()?; + let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; - let session = self.sessions.get(&id).ok_or_else(NotFoundError::new)?; + let session = self + .sessions + .get(&id) + .filter(|s| !s.expired(now)) + .ok_or_else(NotFoundError::new)?; if let Some(if_none_match) = if_none_match { if !if_none_match.precondition_passes(&session.etag()) { @@ -215,7 +222,11 @@ impl RendezvousHandler { let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; - let session = self.sessions.get_mut(&id).ok_or_else(NotFoundError::new)?; + let session = self + .sessions + .get_mut(&id) + .filter(|s| !s.expired(now)) + .ok_or_else(NotFoundError::new)?; if !if_match.precondition_passes(&session.etag()) { let mut headers = HeaderMap::new(); From 415715174117f7d89bb309ee695b81b2dbce756d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 10 Apr 2024 18:14:14 +0200 Subject: [PATCH 11/16] Move the rendezvous session handler to the _synapse/ namespace --- rust/src/rendezvous/mod.rs | 6 +-- synapse/http/server.py | 5 +- synapse/rest/client/rendezvous.py | 24 ---------- synapse/rest/synapse/client/__init__.py | 4 ++ synapse/rest/synapse/client/rendezvous.py | 58 +++++++++++++++++++++++ 5 files changed, 67 insertions(+), 30 deletions(-) create mode 100644 synapse/rest/synapse/client/rendezvous.py diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 168c5866bb2..5207825771b 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -99,10 +99,8 @@ impl RendezvousHandler { .getattr("server")? .getattr("public_baseurl")? .extract()?; - let base = Uri::try_from(format!( - "{base}_matrix/client/unstable/org.matrix.msc4108/rendezvous" - )) - .map_err(|_| PyValueError::new_err("Invalid base URI"))?; + let base = Uri::try_from(format!("{base}_synapse/client/rendezvous")) + .map_err(|_| PyValueError::new_err("Invalid base URI"))?; let clock = homeserver.call_method0("get_clock")?.to_object(py); diff --git a/synapse/http/server.py b/synapse/http/server.py index 45b2cbffcdb..211795dc396 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -909,8 +909,9 @@ def set_cors_headers(request: "SynapseRequest") -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - if request.path is not None and request.path.startswith( - b"/_matrix/client/unstable/org.matrix.msc4108/rendezvous" + if request.path is not None and ( + request.path == b"/_matrix/client/unstable/org.matrix.msc4108/rendezvous" + or request.path.startswith(b"/_synapse/client/rendezvous") ): request.setHeader( b"Access-Control-Allow-Headers", diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py index 443256cc264..143f0576516 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py @@ -110,36 +110,12 @@ def on_POST(self, request: SynapseRequest) -> None: self._handler.handle_post(request) -class MSC4108RendezvousSessionServlet(RestServlet): - # TODO: this should probably be mounted on the _synapse/client namespace - PATTERNS = client_patterns( - "/org.matrix.msc4108/rendezvous/(?P[^/]+)$", - releases=[], - v1=False, - unstable=True, - ) - - def __init__(self, hs: "HomeServer") -> None: - super().__init__() - self._handler = hs.get_rendezvous_handler() - - def on_GET(self, request: SynapseRequest, session_id: str) -> None: - self._handler.handle_get(request, session_id) - - def on_PUT(self, request: SynapseRequest, session_id: str) -> None: - self._handler.handle_put(request, session_id) - - def on_DELETE(self, request: SynapseRequest, session_id: str) -> None: - self._handler.handle_delete(request, session_id) - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc3886_endpoint is not None: MSC3886RendezvousServlet(hs).register(http_server) if hs.config.experimental.msc4108_enabled: MSC4108RendezvousServlet(hs).register(http_server) - MSC4108RendezvousSessionServlet(hs).register(http_server) if hs.config.experimental.msc4108_delegation_endpoint is not None: MSC4108DelegationRendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 31544867d4a..ba6576d4db5 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -26,6 +26,7 @@ from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource @@ -76,6 +77,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # To be removed in Synapse v1.32.0. resources["/_matrix/saml2"] = res + if hs.config.experimental.msc4108_enabled: + resources["/_synapse/client/rendezvous"] = MSC4108RendezvousSessionResource(hs) + return resources diff --git a/synapse/rest/synapse/client/rendezvous.py b/synapse/rest/synapse/client/rendezvous.py new file mode 100644 index 00000000000..5216d30d1f4 --- /dev/null +++ b/synapse/rest/synapse/client/rendezvous.py @@ -0,0 +1,58 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + +import logging +from typing import TYPE_CHECKING, List + +from synapse.api.errors import UnrecognizedRequestError +from synapse.http.server import DirectServeJsonResource +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MSC4108RendezvousSessionResource(DirectServeJsonResource): + isLeaf = True + + def __init__(self, hs: "HomeServer") -> None: + super().__init__() + self._handler = hs.get_rendezvous_handler() + + async def _async_render_GET(self, request: SynapseRequest) -> None: + postpath: List[bytes] = request.postpath # type: ignore + if len(postpath) != 1: + raise UnrecognizedRequestError() + session_id = postpath[0].decode("ascii") + + self._handler.handle_get(request, session_id) + + def _async_render_PUT(self, request: SynapseRequest) -> None: + postpath: List[bytes] = request.postpath # type: ignore + if len(postpath) != 1: + raise UnrecognizedRequestError() + session_id = postpath[0].decode("ascii") + + self._handler.handle_put(request, session_id) + + def _async_render_DELETE(self, request: SynapseRequest) -> None: + postpath: List[bytes] = request.postpath # type: ignore + if len(postpath) != 1: + raise UnrecognizedRequestError() + session_id = postpath[0].decode("ascii") + + self._handler.handle_delete(request, session_id) From 37af4f4d7e07577b44caf054996b5bcd06555b72 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 11 Apr 2024 10:41:16 +0200 Subject: [PATCH 12/16] Add tests for the rendezvous implementation --- tests/rest/client/test_rendezvous.py | 341 ++++++++++++++++++++++++++- 1 file changed, 340 insertions(+), 1 deletion(-) diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index c84704c0905..2254cc63ea1 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2022 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -19,9 +19,14 @@ # # +from typing import Dict +from urllib.parse import urlparse + from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource from synapse.rest.client import rendezvous +from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource from synapse.server import HomeServer from synapse.util import Clock @@ -42,6 +47,12 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs + def create_resource_dict(self) -> Dict[str, Resource]: + return { + **super().create_resource_dict(), + "/_synapse/client/rendezvous": MSC4108RendezvousSessionResource(self.hs), + } + def test_disabled(self) -> None: channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None) self.assertEqual(channel.code, 404) @@ -75,3 +86,331 @@ def test_msc4108_delegation(self) -> None: channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None) self.assertEqual(channel.code, 307) self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"]) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108(self) -> None: + """ + Test the MSC4108 rendezvous endpoint, including: + - Creating a session + - Getting the data back + - Updating the data + - Deleting the data + - ETag handling + """ + # We can post arbitrary data to the endpoint + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + # This sets the content type to application/x-www-form-urlencoded + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 201) + self.assertSubstring("/_synapse/client/rendezvous/", channel.json_body["url"]) + headers = dict(channel.headers.getAllRawHeaders()) + self.assertIn(b"ETag", headers) + self.assertIn(b"Expires", headers) + self.assertEqual(headers[b"Content-Type"], [b"application/json"]) + self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"]) + self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"]) + self.assertEqual(headers[b"Cache-Control"], [b"no-store"]) + self.assertEqual(headers[b"Pragma"], [b"no-cache"]) + self.assertIn("url", channel.json_body) + self.assertTrue(channel.json_body["url"].startswith("https://")) + + url = urlparse(channel.json_body["url"]) + session_endpoint = url.path + etag = headers[b"ETag"][0] + + # We can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + headers = dict(channel.headers.getAllRawHeaders()) + self.assertEqual(headers[b"ETag"], [etag]) + self.assertIn(b"Expires", headers) + self.assertEqual( + headers[b"Content-Type"], [b"application/x-www-form-urlencoded"] + ) + self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"]) + self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"]) + self.assertEqual(headers[b"Cache-Control"], [b"no-store"]) + self.assertEqual(headers[b"Pragma"], [b"no-cache"]) + self.assertEqual(channel.text_body, "foo=bar") + + # We can make sure the data hasn't changed + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + custom_headers=[("If-None-Match", etag)], + ) + + self.assertEqual(channel.code, 304) + + # We can update the data + channel = self.make_request( + "PUT", + session_endpoint, + "foo=baz", + content_is_form=True, + access_token=None, + custom_headers=[("If-Match", etag)], + ) + + self.assertEqual(channel.code, 202) + headers = dict(channel.headers.getAllRawHeaders()) + old_etag = etag + new_etag = headers[b"ETag"][0] + + # If we try to update it again with the old etag, it should fail + channel = self.make_request( + "PUT", + session_endpoint, + "bar=baz", + content_is_form=True, + access_token=None, + custom_headers=[("If-Match", old_etag)], + ) + + self.assertEqual(channel.code, 412) + self.assertEqual(channel.json_body["errcode"], "M_CONCURRENT_WRITE") + + # If we try to get with the old etag, we should get the updated data + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + custom_headers=[("If-None-Match", old_etag)], + ) + + self.assertEqual(channel.code, 200) + headers = dict(channel.headers.getAllRawHeaders()) + self.assertEqual(headers[b"ETag"], [new_etag]) + self.assertEqual(channel.text_body, "foo=baz") + + # We can delete the data + channel = self.make_request( + "DELETE", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 204) + + # If we try to get the data again, it should fail + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_expiration(self) -> None: + """ + Test that entries are evicted after a TTL. + """ + # Start a new session + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 201) + session_endpoint = urlparse(channel.json_body["url"]).path + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.text_body, "foo=bar") + + # Advance the clock, TTL of entries is 5 minutes + self.reactor.advance(300) + + # Get the data back, it should be gone + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_capacity(self) -> None: + """ + Test that a capacity limit is enforced on the rendezvous sessions, as old + entries are evicted at an interval when the limit is reached. + """ + # Start a new session + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 201) + session_endpoint = urlparse(channel.json_body["url"]).path + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.text_body, "foo=bar") + + # Start a lot of new sessions + for _ in range(100): + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 201) + + # Get the data back, it should still be there, as the eviction hasn't run yet + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + + # Advance the clock, as it will trigger the eviction + self.reactor.advance(1) + + # Get the data back, it should be gone + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_hard_capacity(self) -> None: + """ + Test that a hard capacity limit is enforced on the rendezvous sessions, as old + entries are evicted immediately when the limit is reached. + """ + # Start a new session + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 201) + session_endpoint = urlparse(channel.json_body["url"]).path + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.text_body, "foo=bar") + + # Start a lot of new sessions + for _ in range(200): + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 201) + + # Get the data back, it should already be gone as we hit the hard limit + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) From bc5d2d07df623b40bf23544788bcfe3aa2cbf202 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 12 Apr 2024 14:20:10 +0200 Subject: [PATCH 13/16] Newsfile --- changelog.d/17056.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/17056.feature diff --git a/changelog.d/17056.feature b/changelog.d/17056.feature new file mode 100644 index 00000000000..b4cbe849e4f --- /dev/null +++ b/changelog.d/17056.feature @@ -0,0 +1 @@ +Implement the rendezvous mechanism described by MSC4108. From b04f0437ff3686461b69308e8e9ff85eef75cd80 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 15 Apr 2024 10:57:07 +0200 Subject: [PATCH 14/16] Make the capacity, max payload size and eviction interval configurable This does not add the config bits, but does add the plumbing to set it from the Python size --- rust/src/rendezvous/mod.rs | 76 +++++++++++++++++------------ synapse/synapse_rust/rendezvous.pyi | 9 +++- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 5207825771b..755aa46759d 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -39,9 +39,6 @@ use crate::{ mod session; -const MAX_CONTENT_LENGTH: u64 = 1024 * 100; -const CAPACITY: usize = 100; - // n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers. fn prepare_headers(headers: &mut HeaderMap, session: &Session) { headers.typed_insert(AccessControlAllowOrigin::ANY); @@ -53,38 +50,42 @@ fn prepare_headers(headers: &mut HeaderMap, session: &Session) { headers.typed_insert(session.last_modified()); } -fn check_input_headers(headers: &HeaderMap) -> PyResult { - let ContentLength(content_length) = headers.typed_get_required()?; - - if content_length > MAX_CONTENT_LENGTH { - return Err(SynapseError::new( - StatusCode::PAYLOAD_TOO_LARGE, - "Payload too large".to_owned(), - "M_TOO_LARGE", - None, - None, - )); - } - - let content_type: ContentType = headers.typed_get_required()?; - - Ok(content_type.into()) -} - #[pyclass] struct RendezvousHandler { base: Uri, clock: PyObject, sessions: BTreeMap, + capacity: usize, + max_content_length: u64, } impl RendezvousHandler { - fn evict(&mut self, now: SystemTime, max_entries: usize) { + /// Check the input headers of a request which sets data for a session, and return the content type. + fn check_input_headers(&self, headers: &HeaderMap) -> PyResult { + let ContentLength(content_length) = headers.typed_get_required()?; + + if content_length > self.max_content_length { + return Err(SynapseError::new( + StatusCode::PAYLOAD_TOO_LARGE, + "Payload too large".to_owned(), + "M_TOO_LARGE", + None, + None, + )); + } + + let content_type: ContentType = headers.typed_get_required()?; + + Ok(content_type.into()) + } + + /// Evict expired sessions and remove the oldest sessions until we're under the capacity. + fn evict(&mut self, now: SystemTime) { // First remove all the entries which expired self.sessions.retain(|_, session| !session.expired(now)); // Then we remove the oldest entires until we're under the limit - while self.sessions.len() > max_entries { + while self.sessions.len() > self.capacity { self.sessions.pop_first(); } } @@ -93,7 +94,14 @@ impl RendezvousHandler { #[pymethods] impl RendezvousHandler { #[new] - fn new(py: Python<'_>, homeserver: &PyAny) -> PyResult> { + #[pyo3(signature = (homeserver, /, capacity=100, max_content_length=1024*1024, eviction_interval=60*1000))] + fn new( + py: Python<'_>, + homeserver: &PyAny, + capacity: usize, + max_content_length: u64, + eviction_interval: u64, + ) -> PyResult> { let base: String = homeserver .getattr("config")? .getattr("server")? @@ -112,13 +120,17 @@ impl RendezvousHandler { base, clock, sessions: BTreeMap::new(), + capacity, + max_content_length, }, )?; let evict = self_.getattr(py, "_evict")?; - homeserver - .call_method0("get_clock")? - .call_method("looping_call", (evict, 500), None)?; + homeserver.call_method0("get_clock")?.call_method( + "looping_call", + (evict, eviction_interval), + None, + )?; Ok(self_) } @@ -127,7 +139,7 @@ impl RendezvousHandler { let clock = self.clock.as_ref(py); let now: u64 = clock.call_method0("time_msec")?.extract()?; let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); - self.evict(now, CAPACITY); + self.evict(now); Ok(()) } @@ -135,15 +147,15 @@ impl RendezvousHandler { fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; - let content_type = check_input_headers(request.headers())?; + let content_type = self.check_input_headers(request.headers())?; let clock = self.clock.as_ref(py); let now: u64 = clock.call_method0("time_msec")?.extract()?; let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); // We trigger an immediate eviction if we're at 2x the capacity - if self.sessions.len() >= CAPACITY * 2 { - self.evict(now, CAPACITY); + if self.sessions.len() >= self.capacity * 2 { + self.evict(now); } // Generate a new ULID for the session from the current time. @@ -210,7 +222,7 @@ impl RendezvousHandler { fn handle_put(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> { let request = http_request_from_twisted(twisted_request)?; - let content_type = check_input_headers(request.headers())?; + let content_type = self.check_input_headers(request.headers())?; let if_match: IfMatch = request.headers().typed_get_required()?; diff --git a/synapse/synapse_rust/rendezvous.pyi b/synapse/synapse_rust/rendezvous.pyi index b09d88c6574..49b784e63d1 100644 --- a/synapse/synapse_rust/rendezvous.pyi +++ b/synapse/synapse_rust/rendezvous.pyi @@ -15,7 +15,14 @@ from twisted.web.iweb import IRequest from synapse.server import HomeServer class RendezvousHandler: - def __init__(self, homeserver: HomeServer) -> None: ... + def __init__( + self, + homeserver: HomeServer, + /, + capacity: int = 100, + max_content_length: int = 1024 * 1024, + eviction_interval: int = 60 * 1000, + ) -> None: ... def handle_post(self, request: IRequest) -> None: ... def handle_get(self, request: IRequest, session_id: str) -> None: ... def handle_put(self, request: IRequest, session_id: str) -> None: ... From de01983f32e40064be9c86f8c3828b4fee586693 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 23 Apr 2024 09:55:57 +0100 Subject: [PATCH 15/16] MSC4108 4KB maximum + 60 second TTL (#17113) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- rust/src/rendezvous/mod.rs | 19 ++++++++++++++----- synapse/synapse_rust/rendezvous.pyi | 3 ++- tests/rest/client/test_rendezvous.py | 11 ++++++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 755aa46759d..82721d51b70 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -14,7 +14,7 @@ */ use std::{ - collections::BTreeMap, + collections::{BTreeMap, HashMap}, time::{Duration, SystemTime}, }; @@ -57,6 +57,7 @@ struct RendezvousHandler { sessions: BTreeMap, capacity: usize, max_content_length: u64, + ttl: Duration, } impl RendezvousHandler { @@ -94,13 +95,14 @@ impl RendezvousHandler { #[pymethods] impl RendezvousHandler { #[new] - #[pyo3(signature = (homeserver, /, capacity=100, max_content_length=1024*1024, eviction_interval=60*1000))] + #[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=60*1000))] fn new( py: Python<'_>, homeserver: &PyAny, capacity: usize, max_content_length: u64, eviction_interval: u64, + ttl: u64, ) -> PyResult> { let base: String = homeserver .getattr("config")? @@ -122,6 +124,7 @@ impl RendezvousHandler { sessions: BTreeMap::new(), capacity, max_content_length, + ttl: Duration::from_millis(ttl), }, )?; @@ -165,7 +168,7 @@ impl RendezvousHandler { let body = request.into_body(); - let session = Session::new(body, content_type, now, Duration::from_secs(5 * 60)); + let session = Session::new(body, content_type, now, self.ttl); let response = serde_json::json!({ "url": uri, @@ -242,11 +245,17 @@ impl RendezvousHandler { let mut headers = HeaderMap::new(); prepare_headers(&mut headers, session); + let mut additional_fields = HashMap::with_capacity(1); + additional_fields.insert( + String::from("org.matrix.msc4108.errcode"), + String::from("M_CONCURRENT_WRITE"), + ); + return Err(SynapseError::new( StatusCode::PRECONDITION_FAILED, "ETag does not match".to_owned(), - "M_CONCURRENT_WRITE", - None, + "M_UNKNOWN", // Would be M_CONCURRENT_WRITE + Some(additional_fields), Some(headers), )); } diff --git a/synapse/synapse_rust/rendezvous.pyi b/synapse/synapse_rust/rendezvous.pyi index 49b784e63d1..03eae3a1964 100644 --- a/synapse/synapse_rust/rendezvous.pyi +++ b/synapse/synapse_rust/rendezvous.pyi @@ -20,8 +20,9 @@ class RendezvousHandler: homeserver: HomeServer, /, capacity: int = 100, - max_content_length: int = 1024 * 1024, + max_content_length: int = 4 * 1024, # MSC4108 specifies 4KB eviction_interval: int = 60 * 1000, + ttl: int = 60 * 1000, ) -> None: ... def handle_post(self, request: IRequest) -> None: ... def handle_get(self, request: IRequest, session_id: str) -> None: ... diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index 2254cc63ea1..1ad56a22859 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -195,7 +195,10 @@ def test_msc4108(self) -> None: ) self.assertEqual(channel.code, 412) - self.assertEqual(channel.json_body["errcode"], "M_CONCURRENT_WRITE") + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN") + self.assertEqual( + channel.json_body["org.matrix.msc4108.errcode"], "M_CONCURRENT_WRITE" + ) # If we try to get with the old etag, we should get the updated data channel = self.make_request( @@ -270,8 +273,8 @@ def test_msc4108_expiration(self) -> None: self.assertEqual(channel.code, 200) self.assertEqual(channel.text_body, "foo=bar") - # Advance the clock, TTL of entries is 5 minutes - self.reactor.advance(300) + # Advance the clock, TTL of entries is 1 minute + self.reactor.advance(60) # Get the data back, it should be gone channel = self.make_request( @@ -385,6 +388,8 @@ def test_msc4108_hard_capacity(self) -> None: ) self.assertEqual(channel.code, 201) session_endpoint = urlparse(channel.json_body["url"]).path + # We advance the clock to make sure that this entry is the "lowest" in the session list + self.reactor.advance(1) # Sanity check that we can get the data back channel = self.make_request( From 86c77976920724efbbe66e0130fe6595faa1a999 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Thu, 25 Apr 2024 12:33:02 +0100 Subject: [PATCH 16/16] Restrict MSC4108 content-type to text/plain (#17122) --- rust/src/rendezvous/mod.rs | 11 ++++ tests/rest/client/test_rendezvous.py | 79 +++++++++++++++++++++++----- tests/server.py | 7 ++- tests/unittest.py | 5 ++ 4 files changed, 89 insertions(+), 13 deletions(-) diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs index 82721d51b70..c0f5d8b6000 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs @@ -77,6 +77,17 @@ impl RendezvousHandler { let content_type: ContentType = headers.typed_get_required()?; + // Content-Type must be text/plain + if content_type != ContentType::text() { + return Err(SynapseError::new( + StatusCode::BAD_REQUEST, + "Content-Type must be text/plain".to_owned(), + "M_INVALID_PARAM", + None, + None, + )); + } + Ok(content_type.into()) } diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index 1ad56a22859..0ab754a11aa 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -118,8 +118,7 @@ def test_msc4108(self) -> None: "POST", msc4108_endpoint, "foo=bar", - # This sets the content type to application/x-www-form-urlencoded - content_is_form=True, + content_type=b"text/plain", access_token=None, ) self.assertEqual(channel.code, 201) @@ -150,9 +149,7 @@ def test_msc4108(self) -> None: headers = dict(channel.headers.getAllRawHeaders()) self.assertEqual(headers[b"ETag"], [etag]) self.assertIn(b"Expires", headers) - self.assertEqual( - headers[b"Content-Type"], [b"application/x-www-form-urlencoded"] - ) + self.assertEqual(headers[b"Content-Type"], [b"text/plain"]) self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"]) self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"]) self.assertEqual(headers[b"Cache-Control"], [b"no-store"]) @@ -174,7 +171,7 @@ def test_msc4108(self) -> None: "PUT", session_endpoint, "foo=baz", - content_is_form=True, + content_type=b"text/plain", access_token=None, custom_headers=[("If-Match", etag)], ) @@ -189,7 +186,7 @@ def test_msc4108(self) -> None: "PUT", session_endpoint, "bar=baz", - content_is_form=True, + content_type=b"text/plain", access_token=None, custom_headers=[("If-Match", old_etag)], ) @@ -258,7 +255,7 @@ def test_msc4108_expiration(self) -> None: "POST", msc4108_endpoint, "foo=bar", - content_is_form=True, + content_type=b"text/plain", access_token=None, ) self.assertEqual(channel.code, 201) @@ -311,7 +308,7 @@ def test_msc4108_capacity(self) -> None: "POST", msc4108_endpoint, "foo=bar", - content_is_form=True, + content_type=b"text/plain", access_token=None, ) self.assertEqual(channel.code, 201) @@ -332,7 +329,7 @@ def test_msc4108_capacity(self) -> None: "POST", msc4108_endpoint, "foo=bar", - content_is_form=True, + content_type=b"text/plain", access_token=None, ) self.assertEqual(channel.code, 201) @@ -383,7 +380,7 @@ def test_msc4108_hard_capacity(self) -> None: "POST", msc4108_endpoint, "foo=bar", - content_is_form=True, + content_type=b"text/plain", access_token=None, ) self.assertEqual(channel.code, 201) @@ -406,7 +403,7 @@ def test_msc4108_hard_capacity(self) -> None: "POST", msc4108_endpoint, "foo=bar", - content_is_form=True, + content_type=b"text/plain", access_token=None, ) self.assertEqual(channel.code, 201) @@ -419,3 +416,61 @@ def test_msc4108_hard_capacity(self) -> None: ) self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_content_type(self) -> None: + """ + Test that the content-type is restricted to text/plain. + """ + # We cannot post invalid content-type arbitrary data to the endpoint + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + + # Make a valid request + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + url = urlparse(channel.json_body["url"]) + session_endpoint = url.path + headers = dict(channel.headers.getAllRawHeaders()) + etag = headers[b"ETag"][0] + + # We can't update the data with invalid content-type + channel = self.make_request( + "PUT", + session_endpoint, + "foo=baz", + content_is_form=True, + access_token=None, + custom_headers=[("If-Match", etag)], + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") diff --git a/tests/server.py b/tests/server.py index 4aaa91e956a..434be3d22c6 100644 --- a/tests/server.py +++ b/tests/server.py @@ -351,6 +351,7 @@ def make_request( request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, + content_type: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, @@ -373,6 +374,8 @@ def make_request( with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + content_type: The content-type to use for the request. If not set then will default to + application/json unless content_is_form is true. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. await_result: whether to wait for the request to complete rendering. If true, @@ -436,7 +439,9 @@ def make_request( ) if content: - if content_is_form: + if content_type is not None: + req.requestHeaders.addRawHeader(b"Content-Type", content_type) + elif content_is_form: req.requestHeaders.addRawHeader( b"Content-Type", b"application/x-www-form-urlencoded" ) diff --git a/tests/unittest.py b/tests/unittest.py index 6fe0cd4a2dc..e6aad9ed40b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -523,6 +523,7 @@ def make_request( request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, + content_type: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, @@ -541,6 +542,9 @@ def make_request( with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + + content_type: The content-type to use for the request. If not set then will default to + application/json unless content_is_form is true. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. @@ -566,6 +570,7 @@ def make_request( request, shorthand, federation_auth_origin, + content_type, content_is_form, await_result, custom_headers,