diff --git a/changelog.d/14856.misc b/changelog.d/14856.misc new file mode 100644 index 000000000000..3731d6cbf184 --- /dev/null +++ b/changelog.d/14856.misc @@ -0,0 +1 @@ +Fix `wait_for_stream_position` to correctly wait for the right instance to advance its token. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7263bb2796da..74cf2d036f4b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -325,7 +325,7 @@ async def wait_for_stream_position( # anyway in that case we don't need to wait. return - current_position = self._streams[stream_name].current_token(self._instance_name) + current_position = self._streams[stream_name].current_token(instance_name) if position <= current_position: # We're already past the position return diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index 1e299d2d67ea..555922409d13 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.replication.tcp.commands import PositionCommand, RdataCommand + from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -71,3 +75,77 @@ def test_non_background_worker_not_subscribed_to_user_ip(self) -> None: self.assertEqual( len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1 ) + + def test_wait_for_stream_position(self) -> None: + """Check that wait for stream position correctly waits for an update from the + correct instance. + """ + store = self.hs.get_datastores().main + cmd_handler = self.hs.get_replication_command_handler() + data_handler = self.hs.get_replication_data_handler() + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + + cache_id_gen = worker1.get_datastores().main._cache_id_gen + assert cache_id_gen is not None + + self.replicate() + + # First, make sure the master knows that `worker1` exists. + initial_token = cache_id_gen.get_current_token() + cmd_handler.send_command( + PositionCommand("caches", "worker1", initial_token, initial_token) + ) + self.replicate() + + # Next send out a normal RDATA, and check that waiting for that stream + # ID returns immediately. + ctx = cache_id_gen.get_next() + next_token = self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + cmd_handler.send_command( + RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) + ) + self.replicate() + + self.get_success( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + + # `wait_for_stream_position` should only return once master receives an + # RDATA from the worker + ctx = cache_id_gen.get_next() + next_token = self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... updating the cache ID gen on the master still shouldn't cause the + # deferred to wake up. + ctx = store._cache_id_gen.get_next() + self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... but receiving the RDATA should + cmd_handler.send_command( + RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) + ) + self.replicate() + + self.assertTrue(d.called)