Skip to content

Commit

Permalink
[client] rewrite socket handler logic to be OTP like
Browse files Browse the repository at this point in the history
  • Loading branch information
RoadRunnr committed Apr 16, 2024
1 parent 6d7f3ea commit 600d97d
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 220 deletions.
106 changes: 59 additions & 47 deletions src/eradius_client.erl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
%% parameter. Changing it currently requires a restart. It can be given as a string or ip address tuple,
%% or the atom ``undefined'' (the default), which uses whatever address the OS selects.
-module(eradius_client).

-export([start_link/0, send_request/2, send_request/3, send_remote_request/3, send_remote_request/4]).

%% internal
-export([reconfigure/0, send_remote_request_loop/8, find_suitable_peer/1,
restore_upstream_server/1, store_radius_server_from_pool/3,
Expand All @@ -22,6 +24,10 @@

-import(eradius_lib, [printable_peer/2]).

-ifdef(TEST).
-export([get_state/0]).
-endif.

-include_lib("stdlib/include/ms_transform.hrl").
-include_lib("kernel/include/logger.hrl").
-include("eradius_dict.hrl").
Expand Down Expand Up @@ -230,37 +236,48 @@ handle_failed_request(Request, {ServerIP, Port} = _FailedServer, UpstreamServers
end.

%% @private
%% send_remote_request_loop/8
send_remote_request_loop(ReplyPid, Socket, ReqId, Peer, EncRequest, Retries, Timeout, MetricsInfo) ->
ReplyPid ! {self(), send_request_loop(Socket, ReqId, Peer, EncRequest, Retries, Timeout, MetricsInfo)}.

send_request_loop(Socket, ReqId, Peer, Request = #radius_request{}, Retries, Timeout, undefined) ->
%% send_remote_request_loop/7
send_request_loop(Socket, ReqId, Peer, Request = #radius_request{},
Retries, Timeout, undefined) ->
send_request_loop(Socket, ReqId, Peer, Request, Retries, Timeout, eradius_lib:make_addr_info(Peer));
send_request_loop(Socket, ReqId, Peer, Request, Retries, Timeout, MetricsInfo) ->
send_request_loop(Socket, ReqId, Peer, Request,
Retries, Timeout, MetricsInfo) ->
{Authenticator, EncRequest} = eradius_lib:encode_request(Request),
SMon = erlang:monitor(process, Socket),
send_request_loop(Socket, SMon, Peer, ReqId, Authenticator, EncRequest, Timeout, Retries, MetricsInfo, Request#radius_request.secret, Request).
send_request_loop(Socket, Peer, ReqId, Authenticator, EncRequest,
Timeout, Retries, MetricsInfo, Request#radius_request.secret, Request).

send_request_loop(_Socket, SMon, _Peer, _ReqId, _Authenticator, _EncRequest, Timeout, 0, MetricsInfo, _Secret, Request) ->
%% send_remote_request_loop/10
send_request_loop(_Socket, _Peer, _ReqId, _Authenticator, _EncRequest,
Timeout, 0, MetricsInfo, _Secret, Request) ->
TS = erlang:convert_time_unit(Timeout, millisecond, native),
update_client_request(timeout, MetricsInfo, TS, Request),
erlang:demonitor(SMon, [flush]),
{error, timeout};
send_request_loop(Socket, SMon, Peer = {_ServerName, {IP, Port}}, ReqId, Authenticator, EncRequest, Timeout, RetryN, MetricsInfo, Secret, Request) ->
Socket ! {self(), send_request, {IP, Port}, ReqId, EncRequest},
update_client_request(pending, MetricsInfo, 1, Request),
receive
{Socket, response, ReqId, Response} ->
update_client_request(pending, MetricsInfo, -1, Request),
send_request_loop(Socket, Peer = {_ServerName, {IP, Port}}, ReqId, Authenticator, EncRequest,
Timeout, RetryN, MetricsInfo, Secret, Request) ->
Result =
try
update_client_request(pending, MetricsInfo, 1, Request),
eradius_client_socket:send_request(Socket, {IP, Port}, ReqId, EncRequest, Timeout)
after
update_client_request(pending, MetricsInfo, -1, Request)
end,

case Result of
{response, ReqId, Response} ->
{ok, Response, Secret, Authenticator};
{'DOWN', SMon, process, Socket, _} ->
{error, close} ->
{error, socket_down};
{Socket, error, Error} ->
{error, Error}
after
Timeout ->
{error, timeout} ->
TS = erlang:convert_time_unit(Timeout, millisecond, native),
update_client_request(retransmission, MetricsInfo, TS, Request),
send_request_loop(Socket, SMon, Peer, ReqId, Authenticator, EncRequest, Timeout, RetryN - 1, MetricsInfo, Secret, Request)
send_request_loop(Socket, Peer, ReqId, Authenticator, EncRequest,
Timeout, RetryN - 1, MetricsInfo, Secret, Request);
{error, _} = Error ->
Error
end.

%% @private
Expand Down Expand Up @@ -329,7 +346,7 @@ reconfigure() ->

%% @private
init([]) ->
{ok, Sup} = eradius_client_sup:start(),
{ok, Sup} = eradius_client_sup:start_link(),
case configure(#state{socket_ip = null, sup = Sup}) of
{error, Error} -> {stop, Error};
Else -> Else
Expand All @@ -355,10 +372,6 @@ handle_call(reconfigure, _From, State) ->
{ok, NState} -> {reply, ok, NState}
end;

%% @private
handle_call(debug, _From, State) ->
{reply, {ok, State}, State};

%% @private
handle_call(_OtherCall, _From, State) ->
{noreply, State}.
Expand Down Expand Up @@ -402,6 +415,16 @@ configure(State) ->
{error, {bad_client_ip, ClientIP}}
end.

-ifdef(TEST).

get_state() ->
State = sys:get_state(?SERVER),
Keys = record_info(fields, state),
Values = tl(tuple_to_list(State)),
maps:from_list(lists:zip(Keys, Values)).

-endif.

%% private
prepare_pools() ->
ets:new(?MODULE, [ordered_set, public, named_table, {keypos, 1}, {write_concurrency,true}]),
Expand Down Expand Up @@ -456,14 +479,14 @@ configure_address(State = #state{socket_ip = OAdd, sockets = Sockts}, NPorts, NA
{ok, State#state{socket_ip = NAdd, no_ports = NPorts}};
NAdd ->
configure_ports(State, NPorts);
_ ->
_ ->
?LOG(info, "Reopening RADIUS client sockets (client_ip changed to ~s)", [inet:ntoa(NAdd)]),
array:map( fun(_PortIdx, Pid) ->
case Pid of
undefined -> done;
_ -> Pid ! close
end
end, Sockts),
array:map(
fun(_PortIdx, undefined) ->
ok;
(_PortIdx, Socket) ->
eradius_client_socket:close(Socket)
end, Sockts),
{ok, State#state{sockets = array:new(), socket_ip = NAdd, no_ports = NPorts}}
end.

Expand All @@ -490,11 +513,8 @@ close_sockets(NPorts, Sockets) ->
List = array:to_list(Sockets),
{_, Rest} = lists:split(NPorts, List),
lists:map(
fun(Pid) ->
case Pid of
undefined -> done;
_ -> Pid ! close
end
fun(undefined) -> ok;
(Socket) -> eradius_client_socket:close(Socket)
end, Rest),
array:resize(NPorts, Sockets)
end.
Expand All @@ -516,18 +536,10 @@ next_port_and_req_id(Peer, NumberOfPorts, Counters) ->
find_socket_process(PortIdx, Sockets, SocketIP, Sup) ->
case array:get(PortIdx, Sockets) of
undefined ->
Res = supervisor:start_child(Sup, {PortIdx,
{eradius_client_socket, start, [SocketIP, self(), PortIdx]},
transient, brutal_kill, worker, [eradius_client_socket]}),
Pid = case Res of
{ok, P} -> P;
{error, already_present} ->
{ok, P} = supervisor:restart_child(Sup, PortIdx),
P
end,
{Pid, array:set(PortIdx, Pid, Sockets)};
Pid when is_pid(Pid) ->
{Pid, Sockets}
{ok, Socket} = eradius_client_socket:new(Sup, SocketIP),
{Socket, array:set(PortIdx, Socket, Sockets)};
Socket ->
{Socket, Sockets}
end.

update_socket_process(PortIdx, Sockets, Pid) ->
Expand Down
142 changes: 100 additions & 42 deletions src/eradius_client_socket.erl
Original file line number Diff line number Diff line change
@@ -1,76 +1,125 @@
%% Copyright (c) 2002-2007, Martin Björklund and Torbjörn Törnkvist
%% Copyright (c) 2011, Travelping GmbH <info@travelping.com>
%%
%% SPDX-License-Identifier: MIT
%%
-module(eradius_client_socket).

-behaviour(gen_server).

-export([start/3]).
%% API
-export([new/2, start_link/1, send_request/5, close/1]).

%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).

-record(state, {client, socket, pending, mode, counter}).
-record(state, {socket, active_n, pending, mode, counter}).

%%%=========================================================================
%%% API
%%%=========================================================================

new(Sup, SocketIP) ->
eradius_client_sup:new(Sup, SocketIP).

start_link(SocketIP) ->
gen_server:start_link(?MODULE, [SocketIP], []).

start(SocketIP, Client, PortIdx) ->
gen_server:start_link(?MODULE, [SocketIP, Client, PortIdx], []).
send_request(Socket, Peer, ReqId, Request, Timeout) ->
try
gen_server:call(Socket, {send_request, Peer, ReqId, Request}, Timeout)
catch
exit:{timeout, _} ->
{error, timeout};
exit:{noproc, _} ->
{error, closed};
{nodedown, _} ->
{error, closed}
end.

init([SocketIP, Client, PortIdx]) ->
close(Socket) ->
gen_server:cast(Socket, close).

%%%===================================================================
%%% gen_server callbacks
%%%===================================================================

init([SocketIP]) ->
case SocketIP of
undefined ->
ExtraOptions = [];
SocketIP when is_tuple(SocketIP) ->
ExtraOptions = [{ip, SocketIP}]
end,
ActiveN = application:get_env(eradius, active_n, 100),
RecBuf = application:get_env(eradius, recbuf, 8192),
SndBuf = application:get_env(eradius, sndbuf, 131072),
{ok, Socket} = gen_udp:open(0, [{active, once}, binary, {recbuf, RecBuf}, {sndbuf, SndBuf} | ExtraOptions]),
{ok, #state{client = Client, socket = Socket, pending = maps:new(), mode = active, counter = 0}}.
Opts = [{active, ActiveN}, binary, {recbuf, RecBuf}, {sndbuf, SndBuf} | ExtraOptions],
{ok, Socket} = gen_udp:open(0, Opts),

State = #state{
socket = Socket,
active_n = ActiveN,
pending = #{},
mode = active
},
{ok, State}.

handle_call({send_request, {IP, Port}, ReqId, Request}, From,
#state{socket = Socket, pending = Pending} = State) ->
case gen_udp:send(Socket, IP, Port, Request) of
ok ->
ReqKey = {IP, Port, ReqId},
NPending = Pending#{ReqKey => From},
{noreply, State#state{pending = NPending}};
{error, Reason} ->
{reply, {error, Reason}, State}
end;

handle_call(_Request, _From, State) ->
{noreply, State}.

handle_cast(close, #state{pending = Pending} = State)
when map_size(Pending) =:= 0 ->
{stop, normal, State};
handle_cast(close, State) ->
{noreply, State#state{mode = inactive}};

handle_cast(_Msg, State) ->
{noreply, State}.

handle_info({SenderPid, send_request, {IP, Port}, ReqId, EncRequest},
State = #state{socket = Socket, pending = Pending, counter = Counter}) ->
case gen_udp:send(Socket, IP, Port, EncRequest) of
ok ->
ReqKey = {IP, Port, ReqId},
NPending = maps:put(ReqKey, SenderPid, Pending),
{noreply, State#state{pending = NPending, counter = Counter+1}};
{error, Reason} ->
SenderPid ! {error, Reason},
{noreply, State}
end;
handle_info({udp_passive, _Socket}, #state{socket = Socket, active_n = ActiveN} = State) ->
inet:setopts(Socket, [{active, ActiveN}]),
{noreply, State};

handle_info({udp, Socket, FromIP, FromPort, EncRequest},
State = #state{socket = Socket, pending = Pending, mode = Mode, counter = Counter}) ->
case eradius_lib:decode_request_id(EncRequest) of
{ReqId, EncRequest} ->
case maps:find({FromIP, FromPort, ReqId}, Pending) of
error ->
%% discard reply because we didn't expect it
inet:setopts(Socket, [{active, once}]),
{noreply, State};
{ok, WaitingSender} ->
WaitingSender ! {self(), response, ReqId, EncRequest},
inet:setopts(Socket, [{active, once}]),
handle_info({udp, Socket, FromIP, FromPort, Request},
State = #state{socket = Socket, pending = Pending, mode = Mode}) ->
case eradius_lib:decode_request_id(Request) of
{ReqId, Request} ->
case Pending of
#{{FromIP, FromPort, ReqId} := From} ->
gen_server:reply(From, {response, ReqId, Request}),

flow_control(State),
NPending = maps:remove({FromIP, FromPort, ReqId}, Pending),
NState = State#state{pending = NPending, counter = Counter-1},
case {Mode, Counter-1} of
{inactive, 0} -> {stop, normal, NState};
_ -> {noreply, NState}
end
NState = State#state{pending = NPending},
case Mode of
inactive when map_size(NPending) =:= 0 ->
{stop, normal, NState};
_ ->
{noreply, NState}
end;
_ ->
%% discard reply because we didn't expect it
flow_control(State),
{noreply, State}
end;
{bad_pdu, _} ->
%% discard reply because it was malformed
inet:setopts(Socket, [{active, once}]),
flow_control(State),
{noreply, State}
end;

handle_info(close, State = #state{counter = Counter}) ->
case Counter of
0 -> {stop, normal, State};
_ -> {noreply, State#state{mode = inactive}}
end;

handle_info(_Info, State) ->
{noreply, State}.

Expand All @@ -79,3 +128,12 @@ terminate(_Reason, _State) ->

code_change(_OldVsn, State, _Extra) ->
{ok, State}.

%%%=========================================================================
%%% internal functions
%%%=========================================================================

flow_control(#state{socket = Socket, active_n = once}) ->
inet:setopts(Socket, [{active, once}]);
flow_control(_) ->
ok.
Loading

0 comments on commit 600d97d

Please sign in to comment.