From 60725ad3a163167f0f26b276de4709edfd5079e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filipe=20Caba=C3=A7o?= Date: Mon, 10 Feb 2025 16:44:23 +0000 Subject: [PATCH] fix: Broadcast Handler refactor * Skips message counting in case of bad auth * Simplifies approach and code in both Broadcast Handler and Realtime Channel * Adds more tests to Broadcast Handler use cases apart from rt_channel integration tests * Adds RateCounter stop to reduce test flakiness --- lib/realtime/rate_counter/rate_counter.ex | 15 ++ lib/realtime_web/channels/realtime_channel.ex | 28 +- .../realtime_channel/broadcast_handler.ex | 61 +++-- test/integration/rt_channel_test.exs | 42 +-- .../rate_counter/rate_counter_test.exs | 28 ++ .../tenants/replication_connection_test.exs | 1 + .../broadcast_handler_test.exs | 247 ++++++++++++++++++ 7 files changed, 350 insertions(+), 72 deletions(-) create mode 100644 test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs diff --git a/lib/realtime/rate_counter/rate_counter.ex b/lib/realtime/rate_counter/rate_counter.ex index 958b8b516..0148feeb8 100644 --- a/lib/realtime/rate_counter/rate_counter.ex +++ b/lib/realtime/rate_counter/rate_counter.ex @@ -66,6 +66,21 @@ defmodule Realtime.RateCounter do ) end + @spec stop(term()) :: :ok + def stop(tenant_id) do + keys = + Registry.select(Realtime.Registry.Unique, [ + {{{:"$1", :_, {:_, :_, :"$2"}}, :"$3", :_}, [{:==, :"$1", __MODULE__}, {:==, :"$2", tenant_id}], [:"$_"]} + ]) + + Enum.each(keys, fn {{_, _, key}, {pid, _}} -> + GenServer.stop(pid) + Cachex.del!(@cache, key) + end) + + :ok + end + @doc """ Starts a new RateCounter under a DynamicSupervisor """ diff --git a/lib/realtime_web/channels/realtime_channel.ex b/lib/realtime_web/channels/realtime_channel.ex index 25adf3290..69aa1242b 100644 --- a/lib/realtime_web/channels/realtime_channel.ex +++ b/lib/realtime_web/channels/realtime_channel.ex @@ -49,7 +49,7 @@ defmodule RealtimeWeb.RealtimeChannel do |> assign_access_token(params) |> assign_counter() |> assign(:using_broadcast?, !!params["config"]["broadcast"]) - |> assign(:check_authorization?, !!params["config"]["private"]) + |> assign(:private?, !!params["config"]["private"]) |> assign(:policies, nil) start_db_rate_counter(tenant_id) @@ -63,13 +63,12 @@ defmodule RealtimeWeb.RealtimeChannel do {:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id), socket = assign_authorization_context(socket, sub_topic, access_token, claims), {:ok, socket} <- maybe_assign_policies(sub_topic, db_conn, socket) do - public? = !socket.assigns.check_authorization? - is_new_api = new_api?(params) - tenant_topic = Tenants.tenant_topic(tenant_id, sub_topic, public?) + tenant_topic = Tenants.tenant_topic(tenant_id, sub_topic, !socket.assigns.private?) Realtime.UsersCounter.add(transport_pid, tenant_id) RealtimeWeb.Endpoint.subscribe(tenant_topic) Phoenix.PubSub.subscribe(Realtime.PubSub, "realtime:operations:" <> tenant_id) + is_new_api = new_api?(params) pg_change_params = pg_change_params(is_new_api, params, channel_pid, claims, sub_topic) opts = %{ @@ -314,12 +313,7 @@ defmodule RealtimeWeb.RealtimeChannel do case confirm_token(socket) do {:ok, claims, confirm_token_ref, _, _} -> pg_change_params = Enum.map(pg_change_params, &Map.put(&1, :claims, claims)) - - {:noreply, - assign(socket, %{ - confirm_token_ref: confirm_token_ref, - pg_change_params: pg_change_params - })} + {:noreply, assign(socket, %{confirm_token_ref: confirm_token_ref, pg_change_params: pg_change_params})} {:error, :missing_claims} -> shutdown_response(socket, "Fields `role` and `exp` are required in JWT") @@ -570,12 +564,16 @@ defmodule RealtimeWeb.RealtimeChannel do access_token: access_token } = assigns + topic = Map.get(assigns, :topic) + db_conn = Map.get(assigns, :db_conn) + socket = Map.put(socket, :policies, nil) jwt_jwks = Map.get(assigns, :jwt_jwks) with jwt_secret_dec <- Crypto.decrypt!(jwt_secret), {:ok, %{"exp" => exp} = claims} when is_integer(exp) <- ChannelsAuthorization.authorize_conn(access_token, jwt_secret_dec, jwt_jwks), - exp_diff when exp_diff > 0 <- exp - Joken.current_time() do + exp_diff when exp_diff > 0 <- exp - Joken.current_time(), + {:ok, socket} <- maybe_assign_policies(topic, db_conn, socket) do if ref = assigns[:confirm_token_ref], do: Helpers.cancel_timer(ref) interval = min(@confirm_token_ms_interval, exp_diff * 1_000) @@ -733,9 +731,9 @@ defmodule RealtimeWeb.RealtimeChannel do defp maybe_assign_policies( topic, db_conn, - %{assigns: %{check_authorization?: true}} = socket + %{assigns: %{private?: true}} = socket ) - when not is_nil(topic) do + when not is_nil(topic) and not is_nil(db_conn) do %{using_broadcast?: using_broadcast?} = socket.assigns authorization_context = socket.assigns.authorization_context @@ -771,10 +769,10 @@ defmodule RealtimeWeb.RealtimeChannel do {:ok, assign(socket, policies: nil)} end - defp only_private?(tenant_id, %{assigns: %{check_authorization?: check_authorization?}}) do + defp only_private?(tenant_id, %{assigns: %{private?: private?}}) do tenant = Tenants.Cache.get_tenant_by_external_id(tenant_id) - if tenant.private_only and !check_authorization?, + if tenant.private_only and !private?, do: {:error, :private_only}, else: :ok end diff --git a/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex b/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex index 0b6653e4b..bf1facfc7 100644 --- a/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex +++ b/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex @@ -17,49 +17,54 @@ defmodule RealtimeWeb.RealtimeChannel.BroadcastHandler do @event_type "broadcast" @spec call(map(), Phoenix.Socket.t()) :: {:reply, :ok, Phoenix.Socket.t()} | {:noreply, Phoenix.Socket.t()} - def call( - payload, - %{ - assigns: %{ - is_new_api: true, - ack_broadcast: ack_broadcast, - self_broadcast: self_broadcast, - tenant_topic: tenant_topic, - authorization_context: authorization_context, - db_conn: db_conn - } - } = socket - ) do - with {:ok, %{assigns: %{policies: policies}}} <- - run_authorization_check(socket, db_conn, authorization_context) do - case policies do - %Policies{broadcast: %BroadcastPolicies{write: false}} -> - Logger.info("Broadcast message ignored on #{tenant_topic}") + def call(payload, %{assigns: %{private?: true}} = socket) do + %{ + assigns: %{ + self_broadcast: self_broadcast, + tenant_topic: tenant_topic, + authorization_context: authorization_context, + db_conn: db_conn + } + } = socket + + case run_authorization_check(socket, db_conn, authorization_context) do + {:ok, + %{assigns: %{ack_broadcast: ack_broadcast, policies: %Policies{broadcast: %BroadcastPolicies{write: true}}}} = + socket} -> + socket = increment_rate_counter(socket) + send_message(self_broadcast, tenant_topic, payload) + if ack_broadcast, do: {:reply, :ok, socket}, else: {:noreply, socket} + + {:ok, socket} -> + {:noreply, socket} - _ -> - send_message(self_broadcast, tenant_topic, payload) - end - else {:error, :increase_connection_pool} -> log_error("IncreaseConnectionPool", "Please increase your connection pool size") - {:error, :unable_to_set_policies} + {:noreply, socket} {:error, error} -> log_error("UnableToSetPolicies", error) - {:error, :unable_to_set_policies} + {:noreply, socket} end + end + + def call(payload, %{assigns: %{private?: false}} = socket) do + %{ + assigns: %{ + tenant_topic: tenant_topic, + self_broadcast: self_broadcast, + ack_broadcast: ack_broadcast + } + } = socket socket = increment_rate_counter(socket) + send_message(self_broadcast, tenant_topic, payload) if ack_broadcast, do: {:reply, :ok, socket}, else: {:noreply, socket} end - def call(_payload, socket) do - {:noreply, socket} - end - defp send_message(self_broadcast, tenant_topic, payload) do if self_broadcast, do: Endpoint.broadcast(tenant_topic, @event_type, payload), diff --git a/test/integration/rt_channel_test.exs b/test/integration/rt_channel_test.exs index 1f3270792..623d2f787 100644 --- a/test/integration/rt_channel_test.exs +++ b/test/integration/rt_channel_test.exs @@ -10,20 +10,20 @@ defmodule Realtime.Integration.RtChannelTest do require Logger - alias __MODULE__.Endpoint - alias Extensions.PostgresCdcRls, as: Rls + alias Extensions.PostgresCdcRls alias Phoenix.Socket.Message alias Phoenix.Socket.V1 - alias Postgrex, as: P + alias Postgrex alias Realtime.Api.Tenant alias Realtime.Database + alias Realtime.Integration.RtChannelTest.Endpoint alias Realtime.Integration.WebsocketClient + alias Realtime.RateCounter alias Realtime.Repo alias Realtime.Tenants - alias Realtime.Tenants.Cache alias Realtime.Tenants.Authorization + alias Realtime.Tenants.Cache alias Realtime.Tenants.Migrations - @moduletag :capture_log @port 4002 @serializer V1.JSONSerializer @@ -75,6 +75,7 @@ defmodule Realtime.Integration.RtChannelTest do end setup do + RateCounter.stop(@external_id) Cache.invalidate_tenant_cache(@external_id) Process.sleep(500) [tenant] = Tenant |> Repo.all() |> Repo.preload(:extensions) @@ -127,8 +128,8 @@ defmodule Realtime.Integration.RtChannelTest do }, 8000 - {:ok, _, conn} = Rls.get_manager_conn(@external_id) - P.query!(conn, "insert into test (details) values ('test')", []) + {:ok, _, conn} = PostgresCdcRls.get_manager_conn(@external_id) + Postgrex.query!(conn, "insert into test (details) values ('test')", []) assert_receive %Message{ event: "postgres_changes", @@ -152,7 +153,7 @@ defmodule Realtime.Integration.RtChannelTest do }, 500 - P.query!(conn, "update test set details = 'test' where id = #{id}", []) + Postgrex.query!(conn, "update test set details = 'test' where id = #{id}", []) assert_receive %Message{ event: "postgres_changes", @@ -177,7 +178,7 @@ defmodule Realtime.Integration.RtChannelTest do }, 500 - P.query!(conn, "delete from test where id = #{id}", []) + Postgrex.query!(conn, "delete from test where id = #{id}", []) assert_receive %Message{ event: "postgres_changes", @@ -216,30 +217,13 @@ defmodule Realtime.Integration.RtChannelTest do topic = "realtime:any" WebsocketClient.join(socket, topic, %{config: config}) - assert_receive %Message{ - event: "phx_reply", - payload: %{ - "response" => %{ - "postgres_changes" => [] - }, - "status" => "ok" - }, - ref: "1", - topic: ^topic - }, - 500 - + assert_receive %Message{event: "phx_reply", topic: ^topic}, 500 assert_receive %Message{} payload = %{"event" => "TEST", "payload" => %{"msg" => 1}, "type" => "broadcast"} WebsocketClient.send_event(socket, topic, "broadcast", payload) - assert_receive %Message{ - event: "broadcast", - payload: ^payload, - topic: ^topic - }, - 500 + assert_receive %Message{event: "broadcast", payload: ^payload, topic: ^topic}, 500 end @tag policies: [ @@ -1449,7 +1433,7 @@ defmodule Realtime.Integration.RtChannelTest do for _ <- 1..1000 do WebsocketClient.join(socket, realtime_topic, %{config: config}) - 1..10 |> Enum.random() |> Process.sleep() + 1..5 |> Enum.random() |> Process.sleep() end assert_receive %Message{ diff --git a/test/realtime/rate_counter/rate_counter_test.exs b/test/realtime/rate_counter/rate_counter_test.exs index 514f0c07b..0bad7c254 100644 --- a/test/realtime/rate_counter/rate_counter_test.exs +++ b/test/realtime/rate_counter/rate_counter_test.exs @@ -52,6 +52,34 @@ defmodule Realtime.RateCounterTest do end end + describe "stop/1" do + test "stops rate counters for a given entity" do + entity_id = Ecto.UUID.generate() + fake_terms = Enum.map(1..10, fn _ -> {:domain, :"metric_#{random_string()}", Ecto.UUID.generate()} end) + terms = Enum.map(1..10, fn _ -> {:domain, :"metric_#{random_string()}", entity_id} end) + + for term <- terms do + {:ok, _} = RateCounter.new(term) + assert {:ok, %RateCounter{}} = RateCounter.get(term) + end + + for term <- fake_terms do + {:ok, _} = RateCounter.new(term) + assert {:ok, %RateCounter{}} = RateCounter.get(term) + end + + assert :ok = RateCounter.stop(entity_id) + + for term <- terms do + assert {:error, _} = RateCounter.get(term) + end + + for term <- fake_terms do + assert {:ok, %RateCounter{}} = RateCounter.get(term) + end + end + end + test "handle handles counter shutdown and dies" do term = {:domain, :metric, Ecto.UUID.generate()} {:ok, pid} = RateCounter.new(term) diff --git a/test/realtime/tenants/replication_connection_test.exs b/test/realtime/tenants/replication_connection_test.exs index 79ccd2bf3..7ea887bff 100644 --- a/test/realtime/tenants/replication_connection_test.exs +++ b/test/realtime/tenants/replication_connection_test.exs @@ -12,6 +12,7 @@ defmodule Realtime.Tenants.ReplicationConnectionTest do alias Realtime.Tenants.Migrations setup do + Cleanup.ensure_no_replication_slot() slot = Application.get_env(:realtime, :slot_name_suffix) Application.put_env(:realtime, :slot_name_suffix, "test") start_supervised(Realtime.Tenants.CacheSupervisor) diff --git a/test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs b/test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs new file mode 100644 index 000000000..bd4b612d0 --- /dev/null +++ b/test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs @@ -0,0 +1,247 @@ +defmodule RealtimeWeb.RealtimeChannel.BroadcastHandlerTest do + # async: false as we are using the database to test RLS policies + use Realtime.DataCase, async: false + import Generators + import Mock + + alias Realtime.GenCounter + alias Realtime.RateCounter + alias Realtime.RateCounter + alias Realtime.Tenants + alias Realtime.Tenants.Authorization + alias Realtime.Tenants.Authorization.Policies + alias Realtime.Tenants.Authorization.Policies.BroadcastPolicies + alias Realtime.Tenants.Connect + alias RealtimeWeb.Endpoint + alias RealtimeWeb.Joken.CurrentTime + alias RealtimeWeb.RealtimeChannel.BroadcastHandler + + setup [:initiate_tenant] + setup %{topic: topic}, do: Endpoint.subscribe("realtime:#{topic}") + + describe "call/2" do + test "with write true policy, user is able to send message", %{topic: topic, tenant: tenant, db_conn: db_conn} do + socket = socket_fixture(tenant, topic, db_conn, %Policies{broadcast: %BroadcastPolicies{write: true}}) + + for _ <- 1..100, reduce: socket do + socket -> + {:reply, :ok, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + assert_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + Process.sleep(1200) + {:ok, %{avg: avg}} = RateCounter.get(Tenants.events_per_second_key(tenant)) + assert avg > 0 + end + + test "with write false policy, user is not able to send message", %{topic: topic, tenant: tenant, db_conn: db_conn} do + socket = socket_fixture(tenant, topic, db_conn, %Policies{broadcast: %BroadcastPolicies{write: false}}) + + for _ <- 1..100, reduce: socket do + socket -> + {:noreply, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + refute_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + Process.sleep(1200) + {:ok, %{avg: avg}} = RateCounter.get(Tenants.events_per_second_key(tenant)) + assert avg == 0.0 + end + + @tag policies: [:authenticated_read_broadcast, :authenticated_write_broadcast] + test "with nil policy but valid user, is able to send message", %{ + topic: topic, + tenant: tenant, + db_conn: db_conn + } do + socket = socket_fixture(tenant, topic, db_conn) + + for _ <- 1..100, reduce: socket do + socket -> + {:reply, :ok, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + assert_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + Process.sleep(1000) + {:ok, %{avg: avg}} = RateCounter.get(Tenants.events_per_second_key(tenant)) + assert avg > 0.0 + end + + test "with nil policy and invalid user, is not able to send message", %{ + topic: topic, + tenant: tenant, + db_conn: db_conn + } do + socket = socket_fixture(tenant, topic, db_conn) + + for _ <- 1..100, reduce: socket do + socket -> + {:noreply, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + refute_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + Process.sleep(1200) + {:ok, %{avg: avg}} = RateCounter.get(Tenants.events_per_second_key(tenant)) + assert avg == 0.0 + end + + @tag policies: [:authenticated_read_broadcast, :authenticated_write_broadcast] + + test "validation only runs once on nil and valid policies", %{ + topic: topic, + tenant: tenant, + db_conn: db_conn + } do + socket = socket_fixture(tenant, topic, db_conn) + + with_mock Authorization, [:passthrough], [] do + for _ <- 1..100, reduce: socket do + socket -> + {:reply, :ok, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + assert_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + assert_called_exactly(Authorization.get_write_authorizations(:_, :_, :_), 1) + end + end + + test "validation only runs once on nil and blocking policies", %{ + topic: topic, + tenant: tenant, + db_conn: db_conn + } do + socket = socket_fixture(tenant, topic, db_conn) + + with_mock Authorization, [:passthrough], [] do + for _ <- 1..100, reduce: socket do + socket -> + {:noreply, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + refute_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + assert_called_exactly(Authorization.get_write_authorizations(:_, :_, :_), 1) + end + end + + test "no ack still sends message", %{ + topic: topic, + tenant: tenant, + db_conn: db_conn + } do + socket = socket_fixture(tenant, topic, db_conn, %Policies{broadcast: %BroadcastPolicies{write: true}}, false) + + for _ <- 1..100, reduce: socket do + socket -> + {:noreply, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + assert_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + end + + test "public channels are able to send messages", %{topic: topic, tenant: tenant, db_conn: db_conn} do + socket = socket_fixture(tenant, topic, db_conn, nil, false, false) + + for _ <- 1..100, reduce: socket do + socket -> + {:noreply, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + assert_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + Process.sleep(1000) + {:ok, %{avg: avg}} = RateCounter.get(Tenants.events_per_second_key(tenant)) + assert avg > 0.0 + end + + test "public channels are able to send messages and ack", %{topic: topic, tenant: tenant, db_conn: db_conn} do + socket = socket_fixture(tenant, topic, db_conn, nil, true, false) + + for _ <- 1..100, reduce: socket do + socket -> + {:reply, :ok, socket} = BroadcastHandler.call(%{}, socket) + topic = "realtime:#{topic}" + assert_receive %Phoenix.Socket.Broadcast{topic: ^topic, event: "broadcast", payload: %{}} + socket + end + + Process.sleep(1000) + {:ok, %{avg: avg}} = RateCounter.get(Tenants.events_per_second_key(tenant)) + assert avg > 0.0 + end + end + + defp initiate_tenant(context) do + start_supervised(Realtime.GenCounter) + start_supervised(Realtime.RateCounter) + start_supervised(CurrentTime.Mock) + + tenant = tenant_fixture() + {:ok, db_conn} = Connect.lookup_or_start_connection(tenant.external_id) + topic = random_string() + + if policies = context[:policies] do + create_rls_policies(db_conn, policies, %{topic: topic}) + end + + on_exit(fn -> Connect.shutdown(tenant.external_id) end) + {:ok, tenant: tenant, db_conn: db_conn, topic: topic} + end + + defp socket_fixture( + tenant, + topic, + db_conn, + policies \\ %Policies{broadcast: %BroadcastPolicies{write: nil, read: true}}, + ack_broadcast \\ true, + private? \\ true + ) do + claims = %{sub: random_string(), role: "authenticated", exp: Joken.current_time() + 1_000} + signer = Joken.Signer.create("HS256", "secret") + + jwt = Joken.generate_and_sign!(%{}, claims, signer) + + authorization_context = + Authorization.build_authorization_params(%{ + topic: topic, + jwt: jwt, + claims: claims, + headers: [{"header-1", "value-1"}], + role: claims.role + }) + + key = Tenants.events_per_second_key(tenant) + GenCounter.new(key) + RateCounter.new(key) + {:ok, rate_counter} = RateCounter.get(key) + + tenant_topic = "realtime:#{topic}" + self_broadcast = true + + %Phoenix.Socket{ + assigns: %{ + tenant_topic: tenant_topic, + ack_broadcast: ack_broadcast, + self_broadcast: self_broadcast, + policies: policies, + authorization_context: authorization_context, + rate_counter: rate_counter, + private?: private?, + db_conn: db_conn + } + } + end +end