Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Broadcast Handler refactor #1293

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lib/realtime/rate_counter/rate_counter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ defmodule Realtime.RateCounter do
)
end

@spec stop(term()) :: :ok
def stop(tenant_id) do
keys =
Registry.select(Realtime.Registry.Unique, [
{{{:"$1", :_, {:_, :_, :"$2"}}, :"$3", :_}, [{:==, :"$1", __MODULE__}, {:==, :"$2", tenant_id}], [:"$_"]}
])

Enum.each(keys, fn {{_, _, key}, {pid, _}} ->
GenServer.stop(pid)
Cachex.del!(@cache, key)
end)

:ok
end

@doc """
Starts a new RateCounter under a DynamicSupervisor
"""
Expand Down
28 changes: 13 additions & 15 deletions lib/realtime_web/channels/realtime_channel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ defmodule RealtimeWeb.RealtimeChannel do
|> assign_access_token(params)
|> assign_counter()
|> assign(:using_broadcast?, !!params["config"]["broadcast"])
|> assign(:check_authorization?, !!params["config"]["private"])
|> assign(:private?, !!params["config"]["private"])
|> assign(:policies, nil)

start_db_rate_counter(tenant_id)
Expand All @@ -63,13 +63,12 @@ defmodule RealtimeWeb.RealtimeChannel do
{:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id),
socket = assign_authorization_context(socket, sub_topic, access_token, claims),
{:ok, socket} <- maybe_assign_policies(sub_topic, db_conn, socket) do
public? = !socket.assigns.check_authorization?
is_new_api = new_api?(params)
tenant_topic = Tenants.tenant_topic(tenant_id, sub_topic, public?)
tenant_topic = Tenants.tenant_topic(tenant_id, sub_topic, !socket.assigns.private?)
Realtime.UsersCounter.add(transport_pid, tenant_id)
RealtimeWeb.Endpoint.subscribe(tenant_topic)
Phoenix.PubSub.subscribe(Realtime.PubSub, "realtime:operations:" <> tenant_id)

is_new_api = new_api?(params)
pg_change_params = pg_change_params(is_new_api, params, channel_pid, claims, sub_topic)

opts = %{
Expand Down Expand Up @@ -314,12 +313,7 @@ defmodule RealtimeWeb.RealtimeChannel do
case confirm_token(socket) do
{:ok, claims, confirm_token_ref, _, _} ->
pg_change_params = Enum.map(pg_change_params, &Map.put(&1, :claims, claims))

{:noreply,
assign(socket, %{
confirm_token_ref: confirm_token_ref,
pg_change_params: pg_change_params
})}
{:noreply, assign(socket, %{confirm_token_ref: confirm_token_ref, pg_change_params: pg_change_params})}

{:error, :missing_claims} ->
shutdown_response(socket, "Fields `role` and `exp` are required in JWT")
Expand Down Expand Up @@ -570,12 +564,16 @@ defmodule RealtimeWeb.RealtimeChannel do
access_token: access_token
} = assigns

topic = Map.get(assigns, :topic)
db_conn = Map.get(assigns, :db_conn)
socket = Map.put(socket, :policies, nil)
jwt_jwks = Map.get(assigns, :jwt_jwks)

with jwt_secret_dec <- Crypto.decrypt!(jwt_secret),
{:ok, %{"exp" => exp} = claims} when is_integer(exp) <-
ChannelsAuthorization.authorize_conn(access_token, jwt_secret_dec, jwt_jwks),
exp_diff when exp_diff > 0 <- exp - Joken.current_time() do
exp_diff when exp_diff > 0 <- exp - Joken.current_time(),
{:ok, socket} <- maybe_assign_policies(topic, db_conn, socket) do
if ref = assigns[:confirm_token_ref], do: Helpers.cancel_timer(ref)

interval = min(@confirm_token_ms_interval, exp_diff * 1_000)
Expand Down Expand Up @@ -733,9 +731,9 @@ defmodule RealtimeWeb.RealtimeChannel do
defp maybe_assign_policies(
topic,
db_conn,
%{assigns: %{check_authorization?: true}} = socket
%{assigns: %{private?: true}} = socket
)
when not is_nil(topic) do
when not is_nil(topic) and not is_nil(db_conn) do
%{using_broadcast?: using_broadcast?} = socket.assigns

authorization_context = socket.assigns.authorization_context
Expand Down Expand Up @@ -771,10 +769,10 @@ defmodule RealtimeWeb.RealtimeChannel do
{:ok, assign(socket, policies: nil)}
end

defp only_private?(tenant_id, %{assigns: %{check_authorization?: check_authorization?}}) do
defp only_private?(tenant_id, %{assigns: %{private?: private?}}) do
tenant = Tenants.Cache.get_tenant_by_external_id(tenant_id)

if tenant.private_only and !check_authorization?,
if tenant.private_only and !private?,
do: {:error, :private_only},
else: :ok
end
Expand Down
61 changes: 33 additions & 28 deletions lib/realtime_web/channels/realtime_channel/broadcast_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,54 @@ defmodule RealtimeWeb.RealtimeChannel.BroadcastHandler do
@event_type "broadcast"
@spec call(map(), Phoenix.Socket.t()) ::
{:reply, :ok, Phoenix.Socket.t()} | {:noreply, Phoenix.Socket.t()}
def call(
payload,
%{
assigns: %{
is_new_api: true,
ack_broadcast: ack_broadcast,
self_broadcast: self_broadcast,
tenant_topic: tenant_topic,
authorization_context: authorization_context,
db_conn: db_conn
}
} = socket
) do
with {:ok, %{assigns: %{policies: policies}}} <-
run_authorization_check(socket, db_conn, authorization_context) do
case policies do
%Policies{broadcast: %BroadcastPolicies{write: false}} ->
Logger.info("Broadcast message ignored on #{tenant_topic}")
def call(payload, %{assigns: %{private?: true}} = socket) do
%{
assigns: %{
self_broadcast: self_broadcast,
tenant_topic: tenant_topic,
authorization_context: authorization_context,
db_conn: db_conn
}
} = socket

case run_authorization_check(socket, db_conn, authorization_context) do
{:ok,
%{assigns: %{ack_broadcast: ack_broadcast, policies: %Policies{broadcast: %BroadcastPolicies{write: true}}}} =
socket} ->
socket = increment_rate_counter(socket)
send_message(self_broadcast, tenant_topic, payload)
if ack_broadcast, do: {:reply, :ok, socket}, else: {:noreply, socket}

{:ok, socket} ->
{:noreply, socket}

_ ->
send_message(self_broadcast, tenant_topic, payload)
end
else
{:error, :increase_connection_pool} ->
log_error("IncreaseConnectionPool", "Please increase your connection pool size")
{:error, :unable_to_set_policies}
{:noreply, socket}

{:error, error} ->
log_error("UnableToSetPolicies", error)
{:error, :unable_to_set_policies}
{:noreply, socket}
end
end

def call(payload, %{assigns: %{private?: false}} = socket) do
%{
assigns: %{
tenant_topic: tenant_topic,
self_broadcast: self_broadcast,
ack_broadcast: ack_broadcast
}
} = socket

socket = increment_rate_counter(socket)
send_message(self_broadcast, tenant_topic, payload)

if ack_broadcast,
do: {:reply, :ok, socket},
else: {:noreply, socket}
end

def call(_payload, socket) do
{:noreply, socket}
end

defp send_message(self_broadcast, tenant_topic, payload) do
if self_broadcast,
do: Endpoint.broadcast(tenant_topic, @event_type, payload),
Expand Down
42 changes: 13 additions & 29 deletions test/integration/rt_channel_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ defmodule Realtime.Integration.RtChannelTest do

require Logger

alias __MODULE__.Endpoint
alias Extensions.PostgresCdcRls, as: Rls
alias Extensions.PostgresCdcRls
alias Phoenix.Socket.Message
alias Phoenix.Socket.V1
alias Postgrex, as: P
alias Postgrex
alias Realtime.Api.Tenant
alias Realtime.Database
alias Realtime.Integration.RtChannelTest.Endpoint
alias Realtime.Integration.WebsocketClient
alias Realtime.RateCounter
alias Realtime.Repo
alias Realtime.Tenants
alias Realtime.Tenants.Cache
alias Realtime.Tenants.Authorization
alias Realtime.Tenants.Cache
alias Realtime.Tenants.Migrations

@moduletag :capture_log
@port 4002
@serializer V1.JSONSerializer
Expand Down Expand Up @@ -75,6 +75,7 @@ defmodule Realtime.Integration.RtChannelTest do
end

setup do
RateCounter.stop(@external_id)
Cache.invalidate_tenant_cache(@external_id)
Process.sleep(500)
[tenant] = Tenant |> Repo.all() |> Repo.preload(:extensions)
Expand Down Expand Up @@ -127,8 +128,8 @@ defmodule Realtime.Integration.RtChannelTest do
},
8000

{:ok, _, conn} = Rls.get_manager_conn(@external_id)
P.query!(conn, "insert into test (details) values ('test')", [])
{:ok, _, conn} = PostgresCdcRls.get_manager_conn(@external_id)
Postgrex.query!(conn, "insert into test (details) values ('test')", [])

assert_receive %Message{
event: "postgres_changes",
Expand All @@ -152,7 +153,7 @@ defmodule Realtime.Integration.RtChannelTest do
},
500

P.query!(conn, "update test set details = 'test' where id = #{id}", [])
Postgrex.query!(conn, "update test set details = 'test' where id = #{id}", [])

assert_receive %Message{
event: "postgres_changes",
Expand All @@ -177,7 +178,7 @@ defmodule Realtime.Integration.RtChannelTest do
},
500

P.query!(conn, "delete from test where id = #{id}", [])
Postgrex.query!(conn, "delete from test where id = #{id}", [])

assert_receive %Message{
event: "postgres_changes",
Expand Down Expand Up @@ -216,30 +217,13 @@ defmodule Realtime.Integration.RtChannelTest do
topic = "realtime:any"
WebsocketClient.join(socket, topic, %{config: config})

assert_receive %Message{
event: "phx_reply",
payload: %{
"response" => %{
"postgres_changes" => []
},
"status" => "ok"
},
ref: "1",
topic: ^topic
},
500

assert_receive %Message{event: "phx_reply", topic: ^topic}, 500
assert_receive %Message{}

payload = %{"event" => "TEST", "payload" => %{"msg" => 1}, "type" => "broadcast"}
WebsocketClient.send_event(socket, topic, "broadcast", payload)

assert_receive %Message{
event: "broadcast",
payload: ^payload,
topic: ^topic
},
500
assert_receive %Message{event: "broadcast", payload: ^payload, topic: ^topic}, 500
end

@tag policies: [
Expand Down Expand Up @@ -1449,7 +1433,7 @@ defmodule Realtime.Integration.RtChannelTest do

for _ <- 1..1000 do
WebsocketClient.join(socket, realtime_topic, %{config: config})
1..10 |> Enum.random() |> Process.sleep()
1..5 |> Enum.random() |> Process.sleep()
end

assert_receive %Message{
Expand Down
28 changes: 28 additions & 0 deletions test/realtime/rate_counter/rate_counter_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ defmodule Realtime.RateCounterTest do
end
end

describe "stop/1" do
test "stops rate counters for a given entity" do
entity_id = Ecto.UUID.generate()
fake_terms = Enum.map(1..10, fn _ -> {:domain, :"metric_#{random_string()}", Ecto.UUID.generate()} end)
terms = Enum.map(1..10, fn _ -> {:domain, :"metric_#{random_string()}", entity_id} end)

for term <- terms do
{:ok, _} = RateCounter.new(term)
assert {:ok, %RateCounter{}} = RateCounter.get(term)
end

for term <- fake_terms do
{:ok, _} = RateCounter.new(term)
assert {:ok, %RateCounter{}} = RateCounter.get(term)
end

assert :ok = RateCounter.stop(entity_id)

for term <- terms do
assert {:error, _} = RateCounter.get(term)
end

for term <- fake_terms do
assert {:ok, %RateCounter{}} = RateCounter.get(term)
end
end
end

test "handle handles counter shutdown and dies" do
term = {:domain, :metric, Ecto.UUID.generate()}
{:ok, pid} = RateCounter.new(term)
Expand Down
1 change: 1 addition & 0 deletions test/realtime/tenants/replication_connection_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ defmodule Realtime.Tenants.ReplicationConnectionTest do
alias Realtime.Tenants.Migrations

setup do
Cleanup.ensure_no_replication_slot()
slot = Application.get_env(:realtime, :slot_name_suffix)
Application.put_env(:realtime, :slot_name_suffix, "test")
start_supervised(Realtime.Tenants.CacheSupervisor)
Expand Down
Loading
Loading