Skip to content

Commit

Permalink
fix: Set connect context to check RLS policies
Browse files Browse the repository at this point in the history
* Setups for tests closer to production based on what realtime/storage is doing
* Changes fixture to use supabase_admin to keep the 'closer to production' setup
* Creates a migration to properly set the GRANTS required for the roles supported
* Enables RLS policies on channels table
* Changes order of checks so we have access to decrypted JWT token
* If config of channel is public, will assign permissions to socket
* Permission check is done after we set various config values of the current transaction
* Select on channels determines if user has access to it, no rows means no access
  • Loading branch information
filipecabaco committed Dec 8, 2023
1 parent 597cb21 commit 0147ace
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 127 deletions.
14 changes: 0 additions & 14 deletions dev/postgres/00-setup.sql

This file was deleted.

31 changes: 27 additions & 4 deletions lib/realtime_web/channels/realtime_channel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ defmodule RealtimeWeb.RealtimeChannel do
start_db_rate_counter(tenant)

with false <- SignalHandler.shutdown_in_progress?(),
{:ok, _} <- Connect.lookup_or_start_connection(tenant),
{:ok, claims, confirm_token_ref} <- confirm_token(socket),
{:ok, conn} <- Connect.lookup_or_start_connection(tenant),
:ok <- limit_joins(socket),
:ok <- limit_channels(socket),
:ok <- limit_max_users(socket),
{:ok, claims, confirm_token_ref} <- confirm_token(socket),
is_new_api <- is_new_api(params) do
is_new_api <- is_new_api(params),
check_params <- set_check_params(socket, params),
{:ok, socket} <- check_conn_authorization(socket, conn, check_params) do
Realtime.UsersCounter.add(transport_pid, tenant)

tenant_topic = tenant <> ":" <> sub_topic
Expand Down Expand Up @@ -510,7 +512,6 @@ defmodule RealtimeWeb.RealtimeChannel do

interval = min(@confirm_token_ms_interval, exp_diff * 1_000)
ref = Process.send_after(self(), :confirm_token, interval)

{:ok, claims, ref}
else
{:error, e} -> {:error, e}
Expand Down Expand Up @@ -647,4 +648,26 @@ defmodule RealtimeWeb.RealtimeChannel do
Logger.error("Start channel error: " <> error_msg)
{:error, %{reason: error_msg}}
end

defp set_check_params(socket, %{"config" => %{"public" => true}} = params) do
%{
channel_name: params["config"]["channel"],
headers: socket.assigns.headers |> Map.new() |> Jason.encode!(),
jwt: socket.assigns.access_token,
claims: %{
claims: Jason.encode!(socket.assigns.claims),
sub: Map.get(socket.assigns.claims, "sub"),
role: Map.get(socket.assigns.claims, "role")
},
role: Map.get(socket.assigns.claims, "role")
}
end

defp set_check_params(_, _params), do: nil

defp check_conn_authorization(socket, _, nil), do: {:ok, socket}

defp check_conn_authorization(socket, conn, check_params) do
Helpers.check_connection(socket, conn, check_params)
end
end
3 changes: 2 additions & 1 deletion lib/realtime_web/channels/realtime_channel/assign.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ defmodule RealtimeWeb.RealtimeChannel.Assigns do
:tenant_token,
:access_token,
:postgres_cdc_module,
:channel_name
:channel_name,
:headers
]

@type t :: %__MODULE__{
Expand Down
3 changes: 2 additions & 1 deletion lib/realtime_web/channels/user_socket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ defmodule RealtimeWeb.UserSocket do
postgres_cdc_module: postgres_cdc_module,
tenant: external_id,
log_level: log_level,
tenant_token: token
tenant_token: token,
headers: headers
}

assigns = Map.from_struct(assigns)
Expand Down
260 changes: 153 additions & 107 deletions test/realtime_web/channels/realtime_channel_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ defmodule RealtimeWeb.RealtimeChannelTest do
import Mock

alias Phoenix.Socket
alias RealtimeWeb.{ChannelsAuthorization, Joken.CurrentTime, UserSocket}

@tenant "dev_tenant"
alias Realtime.Tenants

alias RealtimeWeb.ChannelsAuthorization
alias RealtimeWeb.Joken.CurrentTime
alias RealtimeWeb.UserSocket

@default_limits %{
max_concurrent_users: 200,
Expand All @@ -17,120 +20,100 @@ defmodule RealtimeWeb.RealtimeChannelTest do
max_bytes_per_second: 100_000
}

@default_conn_opts [
connect_info: %{
uri: %{host: "#{@tenant}.localhost:4000/socket/websocket", query: ""},
x_headers: [{"x-api-key", "token123"}]
}
]

setup do
setup context do
start_supervised!(CurrentTime.Mock)
tenant = tenant_fixture()
settings = Realtime.PostgresCdc.filter_settings("postgres_cdc_rls", tenant.extensions)
settings = Map.put(settings, "id", tenant.external_id)
settings = Map.put(settings, "db_socket_opts", [:inet])

start_supervised!({Tenants.Migrations, settings})
{:ok, conn} = Tenants.Connect.lookup_or_start_connection(tenant.external_id)
truncate_table(conn, "realtime.channels")

case context do
%{rls: policy} ->
create_rls_policy(conn, policy)

on_exit(fn ->
Postgrex.query!(conn, "drop policy #{policy} on realtime.channels", [])
end)

_ ->
:ok
end

%{tenant: tenant_fixture(), conn: conn}
end

setup_with_mocks [
{
ChannelsAuthorization,
[],
[
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time() + 1_000, "role" => "postgres"}}
end
]
}
] do
:ok
end

describe "maximum number of connected clients per tenant" do
test "not reached" do
with_mocks([
{ChannelsAuthorization, [],
[
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time() + 1_000, "role" => "postgres"}}
end
]}
]) do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, @default_conn_opts)

socket = Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: 1}})
assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{})
end
test "not reached", %{tenant: tenant} do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))

socket = Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: 1}})
assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{})
end

test "reached" do
with_mocks([
{ChannelsAuthorization, [],
[
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time() + 1_000, "role" => "postgres"}}
end
]}
]) do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, @default_conn_opts)

socket_at_capacity =
Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: 0}})

socket_over_capacity =
Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: -1}})

assert {:error, %{reason: "{:error, :too_many_connections}"}} =
subscribe_and_join(socket_at_capacity, "realtime:test", %{})

assert {:error, %{reason: "{:error, :too_many_connections}"}} =
subscribe_and_join(socket_over_capacity, "realtime:test", %{})
end
test "reached", %{tenant: tenant} do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))

socket_at_capacity =
Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: 0}})

socket_over_capacity =
Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: -1}})

assert {:error, %{reason: "{:error, :too_many_connections}"}} =
subscribe_and_join(socket_at_capacity, "realtime:test", %{})

assert {:error, %{reason: "{:error, :too_many_connections}"}} =
subscribe_and_join(socket_over_capacity, "realtime:test", %{})
end
end

describe "token expiration" do
test "valid" do
with_mocks([
{ChannelsAuthorization, [],
[
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time() + 1, "role" => "postgres"}}
end
]}
]) do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, @default_conn_opts)

assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{})
end
test "valid", %{tenant: tenant} do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))
assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{})
end

test "invalid" do
with_mocks([
{ChannelsAuthorization, [],
[
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time(), "role" => "postgres"}}
end
]}
]) do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, @default_conn_opts)

assert {:error, %{reason: "{:error, 0}"}} =
subscribe_and_join(socket, "realtime:test", %{})
end

with_mocks([
{ChannelsAuthorization, [],
[
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time() - 1, "role" => "postgres"}}
end
]}
]) do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, @default_conn_opts)

assert {:error, %{reason: "{:error, -1}"}} =
subscribe_and_join(socket, "realtime:test", %{})
end
test_with_mock "token about to expire", %{tenant: tenant}, ChannelsAuthorization, [],
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time(), "role" => "postgres"}}
end do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))

assert {:error, %{reason: "{:error, 0}"}} = subscribe_and_join(socket, "realtime:test", %{})
end
end

describe "checks tenant db connectivity" do
setup_with_mocks([
{ChannelsAuthorization, [],
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time() + 1_000, "role" => "postgres"}}
end}
]) do
:ok
test_with_mock "token that has expired", %{tenant: tenant}, ChannelsAuthorization, [],
authorize_conn: fn _, _ ->
{:ok, %{"exp" => Joken.current_time() - 1, "role" => "postgres"}}
end do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))

assert {:error, %{reason: "{:error, -1}"}} =
subscribe_and_join(socket, "realtime:test", %{})
end
end

test "successful connection proceeds with join" do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, @default_conn_opts)
describe "checks tenant db connectivity" do
test "successful connection proceeds with join", %{tenant: tenant} do
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))
assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{})
end

Expand All @@ -155,17 +138,80 @@ defmodule RealtimeWeb.RealtimeChannelTest do

tenant = tenant_fixture(%{"extensions" => extensions})

conn_opts = [
connect_info: %{
uri: %{host: "#{tenant.external_id}.localhost:4000/socket/websocket", query: ""},
x_headers: [{"x-api-key", "token123"}]
}
]

{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts)
{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))

assert {:error, %{reason: "{:error, :tenant_database_unavailable}"}} =
subscribe_and_join(socket, "realtime:test", %{})
end
end

describe "check authorization on connect" do
@tag role: "authenticated", rls: :select_authenticated_role
test_with_mock "authenticated user has read permissions",
%{tenant: tenant, role: role},
ChannelsAuthorization,
[],
authorize_conn: fn _, _ ->
{:ok,
%{
"exp" => Joken.current_time() + 1_000,
"role" => role,
"sub" => random_string()
}}
end do
channel_name = random_string()
channel_fixture(tenant, %{"name" => channel_name})
params = %{"config" => %{"channel" => channel_name, "public" => true}}

{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))

assert {:ok, _, %Socket{} = socket} = subscribe_and_join(socket, "realtime:test", params)
assert %{read: true} = socket.assigns.permissions
end

@tag role: "anon", rls: :select_authenticated_role
test_with_mock "anon user has no read permissions",
%{tenant: tenant, role: role},
ChannelsAuthorization,
[],
authorize_conn: fn _, _ ->
{:ok,
%{
"exp" => Joken.current_time() + 1_000,
"role" => role,
"sub" => random_string()
}}
end do
channel_name = random_string()
channel_fixture(tenant, %{"name" => channel_name})
params = %{"config" => %{"channel" => channel_name, "public" => true}}

{:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant))

assert {:ok, _, %Socket{} = socket} = subscribe_and_join(socket, "realtime:test", params)
assert %{read: false} = socket.assigns.permissions
end
end

defp conn_opts(tenant) do
[
connect_info: %{
uri: %{host: "#{tenant.external_id}.localhost:4000/socket/websocket", query: ""},
x_headers: [{"x-api-key", "token123"}]
}
]
end

defp create_rls_policy(conn, :select_authenticated_role) do
Postgrex.query!(
conn,
"""
create policy select_authenticated_role
on realtime.channels for select
to authenticated
using ( realtime.channel_name() = name );
""",
[]
)
end
end
2 changes: 2 additions & 0 deletions test/support/channel_case.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ defmodule RealtimeWeb.ChannelCase do

# The default endpoint for testing
@endpoint RealtimeWeb.Endpoint
def truncate_table(db_conn, table),
do: Postgrex.query!(db_conn, "TRUNCATE TABLE #{table}", [])
end
end

Expand Down

0 comments on commit 0147ace

Please sign in to comment.