From 748bc87f1b3971fc6a768e6b4158b313f4857324 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Mar 2021 17:01:58 +0000 Subject: [PATCH 01/10] Make ratelimiting funcs async and add requester. --- synapse/api/ratelimiting.py | 29 +++--- synapse/federation/federation_server.py | 4 +- synapse/handlers/_base.py | 7 +- synapse/handlers/auth.py | 24 ++--- synapse/handlers/devicemessage.py | 4 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 10 ++- synapse/handlers/register.py | 6 +- synapse/handlers/room_member.py | 19 ++-- synapse/replication/http/register.py | 2 +- synapse/rest/client/v1/login.py | 12 +-- synapse/rest/client/v2_alpha/account.py | 10 ++- synapse/rest/client/v2_alpha/register.py | 8 +- tests/api/test_ratelimiting.py | 110 ++++++++++++++--------- 14 files changed, 155 insertions(+), 92 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index c3f07bc1a3e8..d7b94004ebb0 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -46,7 +46,7 @@ def __init__(self, clock: Clock, rate_hz: float, burst_count: int): OrderedDict() ) # type: OrderedDict[Hashable, Tuple[float, int, float]] - def can_requester_do_action( + async def can_requester_do_action( self, requester: Requester, rate_hz: Optional[float] = None, @@ -73,17 +73,18 @@ def can_requester_do_action( * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return True, -1.0 - - return self.can_do_action( - requester.user.to_string(), rate_hz, burst_count, update, _time_now_s + return await self.can_do_action( + requester, + requester.user.to_string(), + rate_hz, + burst_count, + update, + _time_now_s, ) - def can_do_action( + async def can_do_action( self, + requester: Optional[Requester], key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, @@ -93,6 +94,8 @@ def can_do_action( """Can the entity (e.g. user or IP address) perform the action? Args: + requester: The requester that is doing the action, if any. Used to check for + ratelimit overrides. key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. rate_hz: The long term number of actions that can be performed in a second. @@ -175,8 +178,9 @@ def _prune_message_counts(self, time_now_s: int): else: del self.actions[key] - def ratelimit( + async def ratelimit( self, + requester: Optional[Requester], key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, @@ -186,6 +190,8 @@ def ratelimit( """Checks if an action can be performed. If not, raises a LimitExceededError Args: + requester: The requester that is doing the action, if any. Used to check for + ratelimit overrides. key: An arbitrary key used to classify an action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. @@ -201,7 +207,8 @@ def ratelimit( """ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - allowed, time_allowed = self.can_do_action( + allowed, time_allowed = await self.can_do_action( + requester, key, rate_hz=rate_hz, burst_count=burst_count, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d84e362070d9..be1427aa6df9 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -930,7 +930,9 @@ async def on_edu(self, edu_type: str, origin: str, content: dict): # the limit, drop them. if ( edu_type == EduTypes.RoomKeyRequest - and not self._room_key_request_rate_limiter.can_do_action(origin) + and not await self._room_key_request_rate_limiter.can_do_action( + None, origin + ) ): return diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index aade2c4a3ad4..c86a616431f2 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -113,10 +113,13 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False): if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) + await self.admin_redaction_ratelimiter.ratelimit( + requester, user_id, update=update + ) else: # Override rate and burst count per-user - self.request_ratelimiter.ratelimit( + await self.request_ratelimiter.ratelimit( + requester, user_id, rate_hz=messages_per_second, burst_count=burst_count, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d537ea813785..86c04d956ec3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -352,7 +352,9 @@ async def validate_user_via_ui_auth( requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) + await self._failed_uia_attempts_ratelimiter.ratelimit( + requester, requester_user_id, update=False + ) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -373,7 +375,9 @@ def get_new_session_data() -> JsonDict: ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) + await self._failed_uia_attempts_ratelimiter.can_do_action( + requester, requester_user_id + ) raise # find the completed login type @@ -982,8 +986,8 @@ async def validate_login( # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - (medium, address), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, (medium, address), update=False ) # Check for login providers that support 3pid login types @@ -1016,8 +1020,8 @@ async def validate_login( # this code path, which is fine as then the per-user ratelimit # will kick in below. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - (medium, address) + await self._failed_login_attempts_ratelimiter.can_do_action( + None, (medium, address) ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -1039,8 +1043,8 @@ async def validate_login( # Check if we've hit the failed ratelimit (but don't update it) if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, qualified_user_id.lower(), update=False ) try: @@ -1051,8 +1055,8 @@ async def validate_login( # exception and masking the LoginError. The actual ratelimiting # should have happened above. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - qualified_user_id.lower() + await self._failed_login_attempts_ratelimiter.can_do_action( + None, qualified_user_id.lower() ) raise diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index eb547743be9f..9de20cac0140 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -191,8 +191,8 @@ async def send_device_message( if ( message_type == EduTypes.RoomKeyRequest and user_id != sender_user_id - and self._ratelimiter.can_do_action( - (sender_user_id, requester.device_id) + and await self._ratelimiter.can_do_action( + requester, (sender_user_id, requester.device_id) ) ): continue diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 598a66f74cf4..3ebee38ebe13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1711,7 +1711,7 @@ async def on_invite_request( member_handler = self.hs.get_room_member_handler() # We don't rate limit based on room ID, as that should be done by # sending server. - member_handler.ratelimit_invite(None, event.state_key) + await member_handler.ratelimit_invite(None, None, event.state_key) # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 5f346f6d6d28..cf895587bf9b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -71,7 +71,7 @@ def __init__(self, hs): burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) - def ratelimit_request_token_requests( + async def ratelimit_request_token_requests( self, request: SynapseRequest, medium: str, @@ -85,8 +85,12 @@ def ratelimit_request_token_requests( address: The actual threepid ID, e.g. the phone number or email address """ - self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) - self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + await self._3pid_validation_ratelimiter_ip.ratelimit( + None, (medium, request.getClientIP()) + ) + await self._3pid_validation_ratelimiter_address.ratelimit( + None, (medium, address) + ) async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0fc2bf15d520..9701b76d0f91 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -204,7 +204,7 @@ async def register_user( Raises: SynapseError if there was a problem registering. """ - self.check_registration_ratelimit(address) + await self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( threepid, @@ -583,7 +583,7 @@ def check_user_id_not_appservice_exclusive( errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address: Optional[str]) -> None: + async def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -597,7 +597,7 @@ def check_registration_ratelimit(self, address: Optional[str]) -> None: if not address: return - self.ratelimiter.ratelimit(address) + await self.ratelimiter.ratelimit(None, address) async def register_with_store( self, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4d20ed835764..036dfaeabde2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -159,15 +159,20 @@ async def _user_left_room(self, target: UserID, room_id: str) -> None: async def forget(self, user: UserID, room_id: str) -> None: raise NotImplementedError() - def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + async def ratelimit_invite( + self, + requester: Optional[Requester], + room_id: Optional[str], + invitee_user_id: str, + ): """Ratelimit invites by room and by target user. If room ID is missing then we just rate limit by target user. """ if room_id: - self._invites_per_room_limiter.ratelimit(room_id) + await self._invites_per_room_limiter.ratelimit(requester, room_id) - self._invites_per_user_limiter.ratelimit(invitee_user_id) + await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) async def _local_membership_update( self, @@ -237,7 +242,9 @@ async def _local_membership_update( ( allowed, time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester) + ) = await self._join_rate_limiter_local.can_requester_do_action( + requester + ) if not allowed: raise LimitExceededError( @@ -423,7 +430,7 @@ async def update_membership_locked( if ratelimit: # Don't ratelimit application services. if not requester.app_service or requester.app_service.is_rate_limited(): - self.ratelimit_invite(room_id, target_id) + await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: @@ -534,7 +541,7 @@ async def update_membership_locked( ( allowed, time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action( + ) = await self._join_rate_limiter_remote.can_requester_do_action( requester, ) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d005f3876717..73d747785420 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -77,7 +77,7 @@ async def _serialize_payload( async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - self.registration_handler.check_registration_ratelimit(content["address"]) + await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.register_with_store( user_id=user_id, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e4c352f572a2..4020f1190831 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -141,20 +141,22 @@ async def on_POST(self, request: SynapseRequest): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_token_login(login_submission) else: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -258,7 +260,7 @@ async def _complete_login( # too often. This happens here rather than before as we don't # necessarily know the user before now. if ratelimit: - self._account_ratelimiter.ratelimit(user_id.lower()) + await self._account_ratelimiter.ratelimit(None, user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c2ba790babde..411fb57c473d 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -103,7 +103,9 @@ async def on_POST(self, request): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -387,7 +389,9 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) if next_link: # Raise if the provided next_link value isn't valid @@ -468,7 +472,7 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 8f68d8dfc8fa..c212da0cb29c 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,7 +126,9 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email @@ -208,7 +210,7 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) @@ -406,7 +408,7 @@ async def on_POST(self, request): client_addr = request.getClientIP() - self.ratelimiter.ratelimit(client_addr, update=False) + await self.ratelimiter.ratelimit(None, client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483418192c4b..a701db6afbb2 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,38 +5,44 @@ from tests import unittest -class TestRatelimiter(unittest.TestCase): +class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=0) + ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=5) + ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=10) + ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) def test_allowed_user_via_can_requester_do_action(self): user_requester = create_requester("@user:example.com") limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(user_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(user_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(user_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -52,20 +58,20 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -81,20 +87,20 @@ def test_allowed_appservice_via_can_requester_do_action(self): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_requester_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) @@ -103,15 +109,19 @@ def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=0) + self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key="test_id", _time_now_s=5) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=5) + ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=10) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=10) + ) def test_allowed_via_can_do_action_and_overriding_parameters(self): """Test that we can override options of can_do_action that would otherwise fail @@ -121,32 +131,38 @@ def test_allowed_via_can_do_action_and_overriding_parameters(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=0, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=0, + ) ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=1, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=1, + ) ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, rate_hz=10.0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0) ) self.assertTrue(allowed) self.assertEqual(1.1, time_allowed) # Similarly if we allow a burst of 10 actions - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, burst_count=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10) ) self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) @@ -159,26 +175,38 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) # First attempt should be allowed - limiter.ratelimit(key=("test_id",), _time_now_s=0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=0) + ) # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) + ) self.assertEqual(context.exception.retry_after_ms, 9000) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0) + ) # Similarly if we allow a burst of 10 actions - limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) + ) def test_pruning(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - limiter.can_do_action(key="test_id_1", _time_now_s=0) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_1", _time_now_s=0) + ) self.assertIn("test_id_1", limiter.actions) - limiter.can_do_action(key="test_id_2", _time_now_s=10) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_2", _time_now_s=10) + ) self.assertNotIn("test_id_1", limiter.actions) From 5464f00768827f9a8b0ca9ab1c9c9186bf20f246 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Mar 2021 15:14:47 +0100 Subject: [PATCH 02/10] Check requester for ratelimiting config --- synapse/api/ratelimiting.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index d7b94004ebb0..e7b5b0e70a86 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -112,6 +112,12 @@ async def can_do_action( * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ + if requester: + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz From e4e39302d0dda4cddcc6eef816221097cc5cb48e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Mar 2021 16:50:00 +0000 Subject: [PATCH 03/10] Pass DataStore to ratelimiter --- synapse/api/ratelimiting.py | 6 ++++- synapse/federation/federation_server.py | 1 + synapse/handlers/_base.py | 3 ++- synapse/handlers/auth.py | 2 ++ synapse/handlers/devicemessage.py | 1 + synapse/handlers/identity.py | 2 ++ synapse/handlers/room_member.py | 4 ++++ synapse/rest/client/v1/login.py | 2 ++ synapse/server.py | 1 + tests/api/test_ratelimiting.py | 32 ++++++++++++++++++------- 10 files changed, 44 insertions(+), 10 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index e7b5b0e70a86..dd9d85f96c55 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -31,10 +32,13 @@ class Ratelimiter: burst_count: How many actions that can be performed before being limited. """ - def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + def __init__( + self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + ): self.clock = clock self.rate_hz = rate_hz self.burst_count = burst_count + self.store = store # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index be1427aa6df9..71cb120ef76b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -870,6 +870,7 @@ def __init__(self, hs: "HomeServer"): # A rate limiter for incoming room key requests per origin. self._room_key_request_rate_limiter = Ratelimiter( + store=hs.get_datastore(), clock=self.clock, rate_hz=self.config.rc_key_requests.per_second, burst_count=self.config.rc_key_requests.burst_count, diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c86a616431f2..b3351155c56b 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -49,7 +49,7 @@ def __init__(self, hs: "HomeServer"): # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) self._rc_message = self.hs.config.rc_message @@ -57,6 +57,7 @@ def __init__(self, hs: "HomeServer"): # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 86c04d956ec3..40231a62755d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -238,6 +238,7 @@ def __init__(self, hs: "HomeServer"): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -248,6 +249,7 @@ def __init__(self, hs: "HomeServer"): # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 9de20cac0140..5ee48be6ffa7 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -81,6 +81,7 @@ def __init__(self, hs: "HomeServer"): ) self._ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.rc_key_requests.per_second, burst_count=hs.config.rc_key_requests.burst_count, diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index cf895587bf9b..d89fa5fb305d 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -61,11 +61,13 @@ def __init__(self, hs): # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 036dfaeabde2..c26429336119 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -75,22 +75,26 @@ def __init__(self, hs: "HomeServer"): self.allow_per_room_profiles = self.config.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) self._join_rate_limiter_remote = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) self._invites_per_room_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) self._invites_per_user_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 4020f1190831..3151e72d4f19 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -74,11 +74,13 @@ def __init__(self, hs: "HomeServer"): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, diff --git a/synapse/server.py b/synapse/server.py index 5e787e2281a8..c294fa4e21bc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -329,6 +329,7 @@ def get_distributor(self) -> Distributor: @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( + store=self.get_datastore(), clock=self.get_clock(), rate_hz=self.config.rc_registration.per_second, burst_count=self.config.rc_registration.burst_count, diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index a701db6afbb2..09e1afd44070 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -7,7 +7,9 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) ) @@ -28,7 +30,9 @@ def test_allowed_via_can_do_action(self): def test_allowed_user_via_can_requester_do_action(self): user_requester = create_requester("@user:example.com") - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) allowed, time_allowed = self.get_success_or_raise( limiter.can_requester_do_action(user_requester, _time_now_s=0) ) @@ -57,7 +61,9 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) allowed, time_allowed = self.get_success_or_raise( limiter.can_requester_do_action(as_requester, _time_now_s=0) ) @@ -86,7 +92,9 @@ def test_allowed_appservice_via_can_requester_do_action(self): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) allowed, time_allowed = self.get_success_or_raise( limiter.can_requester_do_action(as_requester, _time_now_s=0) ) @@ -106,7 +114,9 @@ def test_allowed_appservice_via_can_requester_do_action(self): self.assertEquals(-1, time_allowed) def test_allowed_via_ratelimit(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # Shouldn't raise self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) @@ -128,7 +138,9 @@ def test_allowed_via_can_do_action_and_overriding_parameters(self): an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed allowed, time_allowed = self.get_success_or_raise( @@ -172,7 +184,9 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed self.get_success_or_raise( @@ -198,7 +212,9 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): ) def test_pruning(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) ) From 56b0208f1cbc134bdc61db483c9ca29130de6b5e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Mar 2021 15:25:07 +0100 Subject: [PATCH 04/10] Check for ratelimiting override --- synapse/api/ratelimiting.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index dd9d85f96c55..edb7318221c2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -122,6 +122,18 @@ async def can_do_action( if requester.app_service and not requester.app_service.is_rate_limited(): return True, -1.0 + # Check if ratelimiting has been disabled for the user. + # + # Note that we don't use the returned rate/burst count, as the table + # is specifically for the event sending ratelimiter. Instead, we + # only use it to (somewhat cheekily) infer whether the user should + # be subject to any rate limiting or not. + override = await self.store.get_ratelimit_for_user( + requester.authenticated_entity + ) + if override and not override.messages_per_second: + return True, -1.0 + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz From 7c5a7d3dcf30d6c438c0b4dc9dc1ace30e602e53 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Mar 2021 15:01:29 +0100 Subject: [PATCH 05/10] Remove redundant appservice ratelimit checks --- synapse/handlers/_base.py | 5 ----- synapse/handlers/room_member.py | 4 +--- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index b3351155c56b..b73b5951e096 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -92,11 +92,6 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False): if app_service is not None: return # do not ratelimit app service senders - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return - messages_per_second = self._rc_message.per_second burst_count = self._rc_message.burst_count diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c26429336119..9cac5d1e8167 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -432,9 +432,7 @@ async def update_membership_locked( if effective_membership_state == Membership.INVITE: target_id = target.to_string() if ratelimit: - # Don't ratelimit application services. - if not requester.app_service or requester.app_service.is_rate_limited(): - await self.ratelimit_invite(requester, room_id, target_id) + await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: From 5a71e501282bbdbab23f803a687abfc0b0fa1369 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Mar 2021 15:30:43 +0100 Subject: [PATCH 06/10] Newsfile --- changelog.d/9711.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/9711.bugfix diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix new file mode 100644 index 000000000000..4ca3438d46b4 --- /dev/null +++ b/changelog.d/9711.bugfix @@ -0,0 +1 @@ +Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. From 0da466b9d98de6a37d8a7ff17c110953c2f2c49a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 09:49:23 +0100 Subject: [PATCH 07/10] Docstrings --- synapse/api/ratelimiting.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index edb7318221c2..6474ee9edede 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -97,9 +97,14 @@ async def can_do_action( ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - requester: The requester that is doing the action, if any. Used to check for - ratelimit overrides. + requester: The requester that is doing the action, if any. Used to check + if the user has ratelimits disabled in the database. key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. rate_hz: The long term number of actions that can be performed in a second. @@ -211,9 +216,14 @@ async def ratelimit( ): """Checks if an action can be performed. If not, raises a LimitExceededError + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: requester: The requester that is doing the action, if any. Used to check for - ratelimit overrides. + if the user has ratelimits disabled. key: An arbitrary key used to classify an action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. From d07a85e87a1c81ff430e4efb62a4397bdecd10e8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 09:54:24 +0100 Subject: [PATCH 08/10] Allow key to be None --- synapse/api/ratelimiting.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 6474ee9edede..16fc499f4a0c 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -89,7 +89,7 @@ async def can_requester_do_action( async def can_do_action( self, requester: Optional[Requester], - key: Hashable, + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -105,8 +105,8 @@ async def can_do_action( Args: requester: The requester that is doing the action, if any. Used to check if the user has ratelimits disabled in the database. - key: The key we should use when rate limiting. Can be a user ID - (when sending events), an IP address, etc. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -121,6 +121,12 @@ async def can_do_action( * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + if requester: # Disable rate limiting of users belonging to any AS that is configured # not to be rate limited in its registration file (rate_limited: true|false). @@ -208,7 +214,7 @@ def _prune_message_counts(self, time_now_s: int): async def ratelimit( self, requester: Optional[Requester], - key: Hashable, + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -224,7 +230,8 @@ async def ratelimit( Args: requester: The requester that is doing the action, if any. Used to check for if the user has ratelimits disabled. - key: An arbitrary key used to classify an action + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. From c00d973b27faa47a3705f466b3e5ebae8c815baa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 10:00:26 +0100 Subject: [PATCH 09/10] Remove can_requester_do_action --- synapse/api/ratelimiting.py | 36 --------------------------------- synapse/handlers/_base.py | 5 +---- synapse/handlers/auth.py | 6 ++---- synapse/handlers/room_member.py | 6 ++---- tests/api/test_ratelimiting.py | 35 ++++++-------------------------- 5 files changed, 11 insertions(+), 77 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 16fc499f4a0c..2244b8a34062 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -50,42 +50,6 @@ def __init__( OrderedDict() ) # type: OrderedDict[Hashable, Tuple[float, int, float]] - async def can_requester_do_action( - self, - requester: Requester, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, - update: bool = True, - _time_now_s: Optional[int] = None, - ) -> Tuple[bool, float]: - """Can the requester perform the action? - - Args: - requester: The requester to key off when rate limiting. The user property - will be used. - rate_hz: The long term number of actions that can be performed in a second. - Overrides the value set during instantiation if set. - burst_count: How many actions that can be performed before being limited. - Overrides the value set during instantiation if set. - update: Whether to count this check as performing the action - _time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Only used by tests. - - Returns: - A tuple containing: - * A bool indicating if they can perform the action now - * The reactor timestamp for when the action can be performed next. - -1 if rate_hz is less than or equal to zero - """ - return await self.can_do_action( - requester, - requester.user.to_string(), - rate_hz, - burst_count, - update, - _time_now_s, - ) - async def can_do_action( self, requester: Optional[Requester], diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index b73b5951e096..fb899aa90d4d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -109,14 +109,11 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False): if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit( - requester, user_id, update=update - ) + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) else: # Override rate and burst count per-user await self.request_ratelimiter.ratelimit( requester, - user_id, rate_hz=messages_per_second, burst_count=burst_count, update=update, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 40231a62755d..08e413bc98e0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -354,9 +354,7 @@ async def validate_user_via_ui_auth( requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - await self._failed_uia_attempts_ratelimiter.ratelimit( - requester, requester_user_id, update=False - ) + await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -378,7 +376,7 @@ def get_new_session_data() -> JsonDict: except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). await self._failed_uia_attempts_ratelimiter.can_do_action( - requester, requester_user_id + requester, ) raise diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 9cac5d1e8167..1cf12f3255bc 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -246,9 +246,7 @@ async def _local_membership_update( ( allowed, time_allowed, - ) = await self._join_rate_limiter_local.can_requester_do_action( - requester - ) + ) = await self._join_rate_limiter_local.can_do_action(requester) if not allowed: raise LimitExceededError( @@ -543,7 +541,7 @@ async def update_membership_locked( ( allowed, time_allowed, - ) = await self._join_rate_limiter_remote.can_requester_do_action( + ) = await self._join_rate_limiter_remote.can_do_action( requester, ) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 09e1afd44070..4ecdfbc34b02 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -28,29 +28,6 @@ def test_allowed_via_can_do_action(self): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) - def test_allowed_user_via_can_requester_do_action(self): - user_requester = create_requester("@user:example.com") - limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 - ) - allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(user_requester, _time_now_s=0) - ) - self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(user_requester, _time_now_s=5) - ) - self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(user_requester, _time_now_s=10) - ) - self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) - def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( None, @@ -65,19 +42,19 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(as_requester, _time_now_s=0) + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(as_requester, _time_now_s=5) + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(as_requester, _time_now_s=10) + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -96,19 +73,19 @@ def test_allowed_appservice_via_can_requester_do_action(self): store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(as_requester, _time_now_s=0) + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(as_requester, _time_now_s=5) + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( - limiter.can_requester_do_action(as_requester, _time_now_s=10) + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) From 28e477233c9232de50f9da8465bec84791327ab0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 10:05:39 +0100 Subject: [PATCH 10/10] Add test for ratelimit override --- tests/api/test_ratelimiting.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 4ecdfbc34b02..fa96ba07a590 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -203,3 +203,30 @@ def test_pruning(self): ) self.assertNotIn("test_id_1", limiter.actions) + + def test_db_user_override(self): + """Test that users that have ratelimiting disabled in the DB aren't + ratelimited. + """ + store = self.hs.get_datastore() + + user_id = "@user:test" + requester = create_requester(user_id) + + self.get_success( + store.db_pool.simple_insert( + table="ratelimit_override", + values={ + "user_id": user_id, + "messages_per_second": None, + "burst_count": None, + }, + desc="test_db_user_override", + ) + ) + + limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + for _ in range(20): + self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))