Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish up work to allow per-user feature flags #17392

Merged
merged 6 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17392.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Finish up work to allow per-user feature flags.
11 changes: 9 additions & 2 deletions synapse/rest/admin/experimental_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,24 @@
from synapse.types import JsonDict, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer
from typing_extensions import assert_never

from synapse.server import HomeServer, HomeServerConfig


class ExperimentalFeature(str, Enum):
"""
Currently supported per-user features
"""

MSC3026 = "msc3026"
MSC3881 = "msc3881"

def is_globally_enabled(self, config: "HomeServerConfig") -> bool:
if self is ExperimentalFeature.MSC3881:
return config.experimental.msc3881_enabled

assert_never(self)


class ExperimentalFeaturesRestServlet(RestServlet):
"""
Expand Down
29 changes: 18 additions & 11 deletions synapse/rest/client/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client._base import client_patterns
from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
from synapse.types import JsonDict
Expand All @@ -49,20 +50,22 @@ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
self._store = hs.get_datastores().main

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
user_id = requester.user.to_string()

pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(
user.to_string()
msc3881_enabled = await self._store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)

pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(user_id)

pusher_dicts = [p.as_dict() for p in pushers]

for pusher in pusher_dicts:
if self._msc3881_enabled:
if msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
del pusher["enabled"]
Expand All @@ -80,11 +83,15 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
self._store = hs.get_datastores().main

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
user_id = requester.user.to_string()

msc3881_enabled = await self._store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)

content = parse_json_object_from_request(request)

Expand All @@ -95,7 +102,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
and content["kind"] is None
):
await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string()
content["app_id"], content["pushkey"], user_id=user_id
)
return 200, {}

Expand All @@ -120,19 +127,19 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
append = content["append"]

enabled = True
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
if msc3881_enabled and "org.matrix.msc3881.enabled" in content:
enabled = content["org.matrix.msc3881.enabled"]

if not append:
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"],
pushkey=content["pushkey"],
not_user_id=user.to_string(),
not_user_id=user_id,
)

try:
await self.pusher_pool.add_or_update_pusher(
user_id=user.to_string(),
user_id=user_id,
kind=content["kind"],
app_id=content["app_id"],
app_display_name=content["app_display_name"],
Expand Down
20 changes: 16 additions & 4 deletions synapse/rest/client/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import re
from typing import TYPE_CHECKING, Tuple

from twisted.web.server import Request

from synapse.api.constants import RoomCreationPreset
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.types import JsonDict

if TYPE_CHECKING:
Expand All @@ -45,6 +45,8 @@ class VersionsRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastores().main

# Calculate these once since they shouldn't change after start-up.
self.e2ee_forced_public = (
Expand All @@ -60,7 +62,17 @@ def __init__(self, hs: "HomeServer"):
in self.config.room.encryption_enabled_by_default_for_room_presets
)

def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
msc3881_enabled = self.config.experimental.msc3881_enabled

if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

msc3881_enabled = await self.store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)

return (
200,
{
Expand Down Expand Up @@ -124,7 +136,7 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
# TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
"org.matrix.msc3882": self.config.auth.login_via_existing_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
"org.matrix.msc3881": msc3881_enabled,
# Adds support for filtering /messages by event relation.
"org.matrix.msc3874": self.config.experimental.msc3874_enabled,
# Adds support for simple HTTP rendezvous as per MSC3886
Expand Down
64 changes: 55 additions & 9 deletions synapse/storage/databases/main/experimental_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast

from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.util.caches.descriptors import cached

Expand Down Expand Up @@ -73,12 +77,54 @@ async def set_features_for_user(
features:
pairs of features and True/False for whether the feature should be enabled
"""
for feature, enabled in features.items():
await self.db_pool.simple_upsert(
table="per_user_experimental_features",
keyvalues={"feature": feature, "user_id": user},
values={"enabled": enabled},
insertion_values={"user_id": user, "feature": feature},
)

await self.invalidate_cache_and_stream("list_enabled_features", (user,))
def set_features_for_user_txn(txn: LoggingTransaction) -> None:
for feature, enabled in features.items():
self.db_pool.simple_upsert_txn(
txn,
table="per_user_experimental_features",
keyvalues={"feature": feature, "user_id": user},
values={"enabled": enabled},
insertion_values={"user_id": user, "feature": feature},
)

self._invalidate_cache_and_stream(
txn, self.is_feature_enabled, (user, feature)
)

self._invalidate_cache_and_stream(txn, self.list_enabled_features, (user,))

return await self.db_pool.runInteraction(
"set_features_for_user", set_features_for_user_txn
)

@cached()
async def is_feature_enabled(
self, user_id: str, feature: "ExperimentalFeature"
) -> bool:
"""
Checks to see if a given feature is enabled for the user
Args:
user_id: the user to be queried on
feature: the feature in question
Returns:
True if the feature is enabled, False if it is not or if the feature was
not found.
"""

if feature.is_globally_enabled(self.hs.config):
return True

# if it's not enabled globally, check if it is enabled per-user
res = await self.db_pool.simple_select_one_onecol(
table="per_user_experimental_features",
keyvalues={"user_id": user_id, "feature": feature},
retcol="enabled",
allow_none=True,
desc="get_feature_enabled",
)

# None and false are treated the same
db_enabled = bool(res)

return db_enabled
82 changes: 81 additions & 1 deletion tests/push/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client import login, push_rule, pusher, receipts, room, versions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
Expand All @@ -42,6 +43,7 @@ class HTTPPusherTests(HomeserverTestCase):
receipts.register_servlets,
push_rule.register_servlets,
pusher.register_servlets,
versions.register_servlets,
]
user_id = True
hijack_auth = False
Expand Down Expand Up @@ -969,6 +971,84 @@ def test_device_id(self) -> None:
lookup_result.device_id,
)

def test_device_id_feature_flag(self) -> None:
"""Tests that a pusher created with a given device ID shows that device ID in
GET /pushers requests when feature is enabled for the user
"""
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")

# We create the pusher with an HTTP request rather than with
# _make_user_with_pusher so that we can test the device ID is correctly set when
# creating a pusher via an API call.
self.make_request(
method="POST",
path="/pushers/set",
content={
"kind": "http",
"app_id": "m.http",
"app_display_name": "HTTP Push Notifications",
"device_display_name": "pushy push",
"pushkey": "a@example.com",
"lang": "en",
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
},
access_token=access_token,
)

# Look up the user info for the access token so we can compare the device ID.
store = self.hs.get_datastores().main
lookup_result = self.get_success(store.get_user_by_access_token(access_token))
assert lookup_result is not None

# Check field is not there before we enable the feature flag
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertNotIn(
"org.matrix.msc3881.device_id", channel.json_body["pushers"][0]
)

self.get_success(
store.set_features_for_user(user_id, {ExperimentalFeature.MSC3881: True})
)

# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertEqual(
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)

def test_msc3881_client_versions_flag(self) -> None:
"""Tests that MSC3881 only appears in /versions if user has it enabled."""

user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")

# Check feature is disabled in /versions
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc3881"])

# Enable feature for user
self.get_success(
self.hs.get_datastores().main.set_features_for_user(
user_id, {ExperimentalFeature.MSC3881: True}
)
)

# Check feature is now enabled in /versions for user
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc3881"])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also have a test for the feature being globally enable - yet disabled for a specific user - if we want to preserve that behaviour.

@override_config({"push": {"jitter_delay": "10s"}})
def test_jitter(self) -> None:
"""Tests that enabling jitter actually delays sending push."""
Expand Down
14 changes: 3 additions & 11 deletions tests/rest/admin/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_enable_and_disable(self) -> None:
"PUT",
url,
content={
"features": {"msc3026": True, "msc3881": True},
"features": {"msc3881": True},
},
access_token=self.admin_user_tok,
)
Expand All @@ -399,10 +399,6 @@ def test_enable_and_disable(self) -> None:
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)
self.assertEqual(
True,
channel.json_body["features"]["msc3026"],
)
self.assertEqual(
True,
channel.json_body["features"]["msc3881"],
Expand All @@ -413,7 +409,7 @@ def test_enable_and_disable(self) -> None:
channel = self.make_request(
"PUT",
url,
content={"features": {"msc3026": False}},
content={"features": {"msc3881": False}},
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)
Expand All @@ -429,10 +425,6 @@ def test_enable_and_disable(self) -> None:
self.assertEqual(channel.code, 200)
self.assertEqual(
False,
channel.json_body["features"]["msc3026"],
)
self.assertEqual(
True,
channel.json_body["features"]["msc3881"],
)

Expand All @@ -441,7 +433,7 @@ def test_enable_and_disable(self) -> None:
channel = self.make_request(
"PUT",
url,
content={"features": {"msc3026": False}},
content={"features": {"msc3881": False}},
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)
Expand Down
Loading