Skip to content

Commit

Permalink
Use per-user feature flags for MSC3881
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Jul 3, 2024
1 parent 4c795f8 commit 487f856
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 17 deletions.
29 changes: 18 additions & 11 deletions synapse/rest/client/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client._base import client_patterns
from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
from synapse.types import JsonDict
Expand All @@ -49,20 +50,22 @@ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
self._store = hs.get_datastores().main

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
user_id = requester.user.to_string()

pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(
user.to_string()
msc3881_enabled = await self._store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)

pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(user_id)

pusher_dicts = [p.as_dict() for p in pushers]

for pusher in pusher_dicts:
if self._msc3881_enabled:
if msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
del pusher["enabled"]
Expand All @@ -80,11 +83,15 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
self._store = hs.get_datastores().main

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
user_id = requester.user.to_string()

msc3881_enabled = await self._store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)

content = parse_json_object_from_request(request)

Expand All @@ -95,7 +102,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
and content["kind"] is None
):
await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string()
content["app_id"], content["pushkey"], user_id=user_id
)
return 200, {}

Expand All @@ -120,19 +127,19 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
append = content["append"]

enabled = True
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
if msc3881_enabled and "org.matrix.msc3881.enabled" in content:
enabled = content["org.matrix.msc3881.enabled"]

if not append:
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"],
pushkey=content["pushkey"],
not_user_id=user.to_string(),
not_user_id=user_id,
)

try:
await self.pusher_pool.add_or_update_pusher(
user_id=user.to_string(),
user_id=user_id,
kind=content["kind"],
app_id=content["app_id"],
app_display_name=content["app_display_name"],
Expand Down
17 changes: 12 additions & 5 deletions synapse/rest/client/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import re
from typing import TYPE_CHECKING, Tuple

from twisted.web.server import Request

from synapse.api.constants import RoomCreationPreset
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.types import JsonDict

if TYPE_CHECKING:
Expand All @@ -46,6 +46,7 @@ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastores().main

# Calculate these once since they shouldn't change after start-up.
self.e2ee_forced_public = (
Expand All @@ -61,10 +62,16 @@ def __init__(self, hs: "HomeServer"):
in self.config.room.encryption_enabled_by_default_for_room_presets
)

async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
requester = None
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
msc3881_enabled = self.config.experimental.msc3881_enabled

if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

msc3881_enabled = await self.store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)

return (
200,
Expand Down Expand Up @@ -129,7 +136,7 @@ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
# TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
"org.matrix.msc3882": self.config.auth.login_via_existing_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
"org.matrix.msc3881": msc3881_enabled,
# Adds support for filtering /messages by event relation.
"org.matrix.msc3874": self.config.experimental.msc3874_enabled,
# Adds support for simple HTTP rendezvous as per MSC3886
Expand Down
82 changes: 81 additions & 1 deletion tests/push/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client import login, push_rule, pusher, receipts, room, versions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
Expand All @@ -42,6 +43,7 @@ class HTTPPusherTests(HomeserverTestCase):
receipts.register_servlets,
push_rule.register_servlets,
pusher.register_servlets,
versions.register_servlets,
]
user_id = True
hijack_auth = False
Expand Down Expand Up @@ -969,6 +971,84 @@ def test_device_id(self) -> None:
lookup_result.device_id,
)

def test_device_id_feature_flag(self) -> None:
"""Tests that a pusher created with a given device ID shows that device ID in
GET /pushers requests when feature is enabled for the user
"""
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")

# We create the pusher with an HTTP request rather than with
# _make_user_with_pusher so that we can test the device ID is correctly set when
# creating a pusher via an API call.
self.make_request(
method="POST",
path="/pushers/set",
content={
"kind": "http",
"app_id": "m.http",
"app_display_name": "HTTP Push Notifications",
"device_display_name": "pushy push",
"pushkey": "a@example.com",
"lang": "en",
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
},
access_token=access_token,
)

# Look up the user info for the access token so we can compare the device ID.
store = self.hs.get_datastores().main
lookup_result = self.get_success(store.get_user_by_access_token(access_token))
assert lookup_result is not None

# Check field is not there before we enable the feature flag
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertNotIn(
"org.matrix.msc3881.device_id", channel.json_body["pushers"][0]
)

self.get_success(
store.set_features_for_user(user_id, {ExperimentalFeature.MSC3881: True})
)

# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertEqual(
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)

def test_msc3881_client_versions_flag(self) -> None:
"""Tests that MSC3881 only appears in /versions if user has it enabled."""

user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")

# Check feature is disabled in /versions
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc3881"])

# Enable feature for user
self.get_success(
self.hs.get_datastores().main.set_features_for_user(
user_id, {ExperimentalFeature.MSC3881: True}
)
)

# Check feature is now enabled in /versions for user
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc3881"])

@override_config({"push": {"jitter_delay": "10s"}})
def test_jitter(self) -> None:
"""Tests that enabling jitter actually delays sending push."""
Expand Down

0 comments on commit 487f856

Please sign in to comment.