Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add a cache around server ACL checking #16360

Merged
merged 13 commits into from
Sep 26, 2023
1 change: 1 addition & 0 deletions changelog.d/16360.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache server ACL checking.
100 changes: 100 additions & 0 deletions rust/src/acl/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! An implementation of Matrix server ACL rules.

use crate::push::utils::{glob_to_regex, GlobMatchType};
use anyhow::Error;
use pyo3::prelude::*;
use regex::Regex;
use std::net::Ipv4Addr;
use std::str::FromStr;

/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "acl")?;
child_module.add_class::<ServerAclEvaluator>()?;

m.add_submodule(child_module)?;

// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import acl` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.acl", child_module)?;

Ok(())
}

#[derive(Debug, Clone)]
#[pyclass(frozen)]
pub struct ServerAclEvaluator {
allow_ip_literals: bool,
allow: Vec<Regex>,
deny: Vec<Regex>,
}

#[pymethods]
impl ServerAclEvaluator {
#[new]
pub fn py_new(
allow_ip_literals: bool,
allow: Vec<String>,
deny: Vec<String>,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
) -> Result<Self, Error> {
let allow = allow
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole).unwrap())
clokep marked this conversation as resolved.
Show resolved Hide resolved
.collect();
let deny = deny
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole).unwrap())
.collect();

Ok(ServerAclEvaluator {
allow_ip_literals,
allow,
deny,
})
}

pub fn server_matches_acl_event(&self, server_name: &str) -> bool {
// first of all, check if literal IPs are blocked, and if so, whether the
// server name is a literal IP
if !self.allow_ip_literals {
// check for ipv6 literals. These start with '['.
if server_name.starts_with("[") {
return false;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attempts to stay close to the Python code, maybe it would make more sense to do Ipv6Addr. 🤷


// check for ipv4 literals. We can just lift the routine from std::net.
if let Ok(_) = Ipv4Addr::from_str(server_name) {
return false;
}
}

// next, check the deny list
if self.deny.iter().any(|e| e.is_match(server_name)) {
return false;
}

// then the allow list.
if self.allow.iter().any(|e| e.is_match(server_name)) {
return true;
}

// everything else should be rejected.
false
}
}
2 changes: 2 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use lazy_static::lazy_static;
use pyo3::prelude::*;
use pyo3_log::ResetHandle;

pub mod acl;
pub mod push;

lazy_static! {
Expand Down Expand Up @@ -38,6 +39,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?;

acl::register_module(py, m)?;
push::register_module(py, m)?;

Ok(())
Expand Down
21 changes: 21 additions & 0 deletions stubs/synapse/synapse_rust/acl.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

class ServerAclEvaluator:
def __init__(
self, allow_ip_literals: bool, allow: List[str], deny: List[str]
) -> None: ...
def server_matches_acl_event(self, server_name: str) -> bool: ...
7 changes: 5 additions & 2 deletions synapse/events/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
CANONICALJSON_MIN_INT,
validate_canonicaljson,
)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.federation.federation_server import server_acl_evaluator_from_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
Expand Down Expand Up @@ -100,7 +100,10 @@ def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
self._validate_retention(event)

elif event.type == EventTypes.ServerACL:
if not server_matches_acl_event(config.server.server_name, event):
server_acl_evaluator = server_acl_evaluator_from_event(event)
if not server_acl_evaluator.server_matches_acl_event(
config.server.server_name
):
raise SynapseError(
400, "Can't create an ACL event that denies the local server"
)
Expand Down
67 changes: 24 additions & 43 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@
Union,
)

from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram

from twisted.internet.abstract import isIPAddress
from twisted.python import failure

from synapse.api.constants import (
Expand Down Expand Up @@ -86,9 +84,11 @@
from synapse.storage.databases.main.lock import Lock
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
from synapse.synapse_rust.acl import ServerAclEvaluator
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name

Expand Down Expand Up @@ -1327,72 +1327,53 @@ async def check_server_matches_acl(self, server_name: str, room_id: str) -> None
acl_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return

raise AuthError(code=403, msg="Server is banned from room")
if acl_event:
server_acl_evaluator = await self._get_server_acl_evaluator(
acl_event.event_id, acl_event
)
if not server_acl_evaluator.server_matches_acl_event(server_name):
raise AuthError(code=403, msg="Server is banned from room")

@cached(uncached_args=("acl_event",))
def _get_server_acl_evaluator(
self, event_id: str, acl_event: EventBase
) -> ServerAclEvaluator:
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""Create a ServerAclEvaluator from an event, but cached on the event ID."""
return server_acl_evaluator_from_event(acl_event)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event

Args:
server_name: name of server, without any port part
acl_event: m.room.server_acl event
def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator":
"""
Create a ServerAclEvaluator from a m.room.server_acl event's content.

Returns:
True if this server is allowed by the ACLs
This does up-front parsing of the content to ignore bad data and pre-compile
regular expressions.
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)

# first of all, check if literal IPs are blocked, and if so, whether the
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
if server_name[0] == "[":
return False

# check for ipv4 literals. We can just lift the routine from twisted.
if isIPAddress(server_name):
return False

# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched deny rule %s", server_name, e)
return False
else:
deny = [s for s in deny if isinstance(s, str)]

# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched allow rule %s", server_name, e)
return True

# everything else should be rejected.
# logger.info("%s fell through", server_name)
return False
else:
allow = [s for s in allow if isinstance(s, str)]


def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
return bool(regex.match(server_name))
return ServerAclEvaluator(allow_ip_literals, allow, deny)


class FederationHandlerRegistry:
Expand Down
35 changes: 22 additions & 13 deletions tests/federation/test_federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.federation.federation_server import server_acl_evaluator_from_event
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
Expand Down Expand Up @@ -67,37 +67,46 @@ def test_blocked_server(self) -> None:
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
logging.info("ACL event: %s", e.content)

self.assertFalse(server_matches_acl_event("evil.com", e))
self.assertFalse(server_matches_acl_event("EVIL.COM", e))
server_acl_evalutor = server_acl_evaluator_from_event(e)

self.assertTrue(server_matches_acl_event("evil.com.au", e))
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
self.assertFalse(server_acl_evalutor.server_matches_acl_event("evil.com"))
self.assertFalse(server_acl_evalutor.server_matches_acl_event("EVIL.COM"))

self.assertTrue(server_acl_evalutor.server_matches_acl_event("evil.com.au"))
self.assertTrue(
server_acl_evalutor.server_matches_acl_event("honestly.not.evil.com")
)

def test_block_ip_literals(self) -> None:
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
logging.info("ACL event: %s", e.content)

self.assertFalse(server_matches_acl_event("1.2.3.4", e))
self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
self.assertFalse(server_matches_acl_event("[1:2::]", e))
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
server_acl_evalutor = server_acl_evaluator_from_event(e)

self.assertFalse(server_acl_evalutor.server_matches_acl_event("1.2.3.4"))
self.assertTrue(server_acl_evalutor.server_matches_acl_event("1a.2.3.4"))
self.assertFalse(server_acl_evalutor.server_matches_acl_event("[1:2::]"))
self.assertTrue(server_acl_evalutor.server_matches_acl_event("1:2:3:4"))

def test_wildcard_matching(self) -> None:
e = _create_acl_event({"allow": ["good*.com"]})

server_acl_evalutor = server_acl_evaluator_from_event(e)

self.assertTrue(
server_matches_acl_event("good.com", e),
server_acl_evalutor.server_matches_acl_event("good.com"),
"* matches 0 characters",
)
self.assertTrue(
server_matches_acl_event("GOOD.COM", e),
server_acl_evalutor.server_matches_acl_event("GOOD.COM"),
"pattern is case-insensitive",
)
self.assertTrue(
server_matches_acl_event("good.aa.com", e),
server_acl_evalutor.server_matches_acl_event("good.aa.com"),
"* matches several characters, including '.'",
)
self.assertFalse(
server_matches_acl_event("ishgood.com", e),
server_acl_evalutor.server_matches_acl_event("ishgood.com"),
"pattern does not allow prefixes",
)

Expand Down