Skip to content

Commit

Permalink
Refactor to strongly typed actor protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew-Davey committed Jan 27, 2024
1 parent 89c673d commit 2f4ebbc
Show file tree
Hide file tree
Showing 23 changed files with 844 additions and 657 deletions.
Empty file modified Lapine.Core.IntegrationTests/BrokerProxy.cs
100755 → 100644
Empty file.
Empty file modified Lapine.Core.IntegrationTests/Client/ConnectionTests.cs
100755 → 100644
Empty file.
38 changes: 18 additions & 20 deletions Lapine.Core/Agents/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,24 @@ namespace Lapine.Agents;
using System.Runtime.CompilerServices;
using System.Threading.Channels;

interface IAgent {
ValueTask PostAsync(Object message, CancellationToken cancellationToken = default);
ValueTask<Object> PostAndReplyAsync(Object message);
interface IAgent<in TProtocol> {
ValueTask PostAsync(TProtocol message, CancellationToken cancellationToken = default);
ValueTask<Object> PostAndReplyAsync(Func<AsyncReplyChannel, TProtocol> messageFactory);
ValueTask StopAsync();
}

class AsyncReplyChannel {
readonly Action<Object> _reply;

public AsyncReplyChannel(Action<Object> reply) =>
_reply = reply ?? throw new ArgumentNullException(nameof(reply));

public void Reply(Object response) => _reply(response);
class AsyncReplyChannel(Action<Object> reply) {
public void Reply(Object response) => reply(response);
}

class Agent : IAgent {
readonly Channel<Object> _mailbox;
class Agent<TProtocol> : IAgent<TProtocol> {
readonly Channel<TProtocol> _mailbox;
readonly Task _messageLoop;

Agent(Channel<Object> mailbox, Behaviour initialBehaviour) {
Agent(Channel<TProtocol> mailbox, Behaviour<TProtocol> initialBehaviour) {
_mailbox = mailbox;
_messageLoop = Task.Factory.StartNew(async () => {
var context = new MessageContext(this, initialBehaviour, null!);
var context = new MessageContext<TProtocol>(this, initialBehaviour, default!);

while (await _mailbox.Reader.WaitToReadAsync()) {
var message = await _mailbox.Reader.ReadAsync();
Expand All @@ -34,21 +29,24 @@ class Agent : IAgent {
});
}

static public IAgent StartNew(Behaviour initialBehaviour) {
var mailbox = Channel.CreateUnbounded<Object>(new UnboundedChannelOptions {
static public IAgent<TProtocol> StartNew(Behaviour<TProtocol> initialBehaviour) {
var mailbox = Channel.CreateUnbounded<TProtocol>(new UnboundedChannelOptions {
SingleReader = true
});
return new Agent(mailbox, initialBehaviour);
return new Agent<TProtocol>(mailbox, initialBehaviour);
}

public async ValueTask PostAsync(Object message, CancellationToken cancellationToken = default) =>
public async ValueTask PostAsync(TProtocol message, CancellationToken cancellationToken = default) =>
await _mailbox.Writer.WriteAsync(message, cancellationToken);

public async ValueTask<Object> PostAndReplyAsync(Object message) {
public async ValueTask<Object> PostAndReplyAsync(Func<AsyncReplyChannel, TProtocol> messageFactory) {
ArgumentNullException.ThrowIfNull(messageFactory);

var promise = AsyncValueTaskMethodBuilder<Object>.Create();
var replyChannel = new AsyncReplyChannel(reply => promise.SetResult(reply));
var message = messageFactory(replyChannel);

await _mailbox.Writer.WriteAsync((message, replyChannel));
await _mailbox.Writer.WriteAsync(message);

return await promise.Task;
}
Expand Down
113 changes: 57 additions & 56 deletions Lapine.Core/Agents/AmqpClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,32 @@ namespace Lapine.Agents;
using Lapine.Client;
using Lapine.Protocol;

using static Lapine.Agents.ChannelAgent.Protocol;
using static Lapine.Agents.DispatcherAgent.Protocol;
using static Lapine.Agents.HandshakeAgent.Protocol;
using static Lapine.Agents.HeartbeatAgent.Protocol;
using static Lapine.Agents.AmqpClientAgent.Protocol;
using static Lapine.Agents.SocketAgent.Protocol;

static class AmqpClientAgent {
static public class Protocol {
public record EstablishConnection(ConnectionConfiguration Configuration, CancellationToken CancellationToken = default);
public record OpenChannel(CancellationToken CancellationToken = default);
public record Disconnect;
}

static public IAgent Create() =>
Agent.StartNew(Disconnected());

static Behaviour Disconnected() =>
interface IAmqpClientAgent {
Task<Object> EstablishConnection(ConnectionConfiguration configuration, CancellationToken cancellationToken = default);
Task<Object> OpenChannel(CancellationToken cancellationToken = default);
Task<Object> Disconnect();
Task Stop();
}

class AmqpClientAgent : IAmqpClientAgent {
readonly IAgent<Protocol> _agent;

AmqpClientAgent(IAgent<Protocol> agent) =>
_agent = agent ?? throw new ArgumentNullException(nameof(agent));

abstract record Protocol;
record EstablishConnection(ConnectionConfiguration Configuration, AsyncReplyChannel ReplyChannel, CancellationToken CancellationToken = default) : Protocol;
record OpenChannel(AsyncReplyChannel ReplyChannel, CancellationToken CancellationToken = default) : Protocol;
record Disconnect(AsyncReplyChannel ReplyChannel) : Protocol;
record HeartbeatEventEventReceived(Object Message) : Protocol;

static public IAmqpClientAgent Create() =>
new AmqpClientAgent(Agent<Protocol>.StartNew(Disconnected()));

static Behaviour<Protocol> Disconnected() =>
async context => {
switch (context.Message) {
case (EstablishConnection(var connectionConfiguration, var cancellationToken), AsyncReplyChannel replyChannel): {
case EstablishConnection(var connectionConfiguration, var replyChannel, var cancellationToken): {
var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(connectionConfiguration.ConnectionTimeout);

Expand All @@ -43,7 +48,7 @@ static Behaviour Disconnected() =>
var endpoint = remainingEndpoints.Dequeue();
var socketAgent = SocketAgent.Create();

switch (await socketAgent.PostAndReplyAsync(new Connect(endpoint, cts.Token))) {
switch (await socketAgent.ConnectAsync(endpoint, cts.Token)) {
case ConnectionFailed(var fault) when remainingEndpoints.Any(): {
accumulatedFailures.Add(fault);
continue;
Expand All @@ -55,7 +60,7 @@ static Behaviour Disconnected() =>
}
case Connected(var connectionEvents, var receivedFrames): {
var dispatcher = DispatcherAgent.Create();
await dispatcher.PostAsync(new DispatchTo(socketAgent, 0));
await dispatcher.DispatchTo(socketAgent, 0);

var handshakeAgent = HandshakeAgent.Create(
receivedFrames : receivedFrames,
Expand All @@ -64,30 +69,22 @@ static Behaviour Disconnected() =>
cancellationToken: cts.Token
);

switch (await handshakeAgent.PostAndReplyAsync(new StartHandshake(connectionConfiguration))) {
case ConnectionAgreement connectionAgreement: {
await socketAgent.PostAsync(new Tune(connectionAgreement.MaxFrameSize));
switch (await handshakeAgent.StartHandshake(connectionConfiguration)) {
case HandshakeAgent.ConnectionAgreed(var connectionAgreement): {
await socketAgent.Tune(connectionAgreement.MaxFrameSize);

var heartbeatAgent = HeartbeatAgent.Create();

if (connectionConfiguration.ConnectionIntegrityStrategy.HeartbeatFrequency.HasValue) {
var heartbeatEvents = (IObservable<Object>) await heartbeatAgent.PostAndReplyAsync(new StartHeartbeat(
ReceivedFrames: receivedFrames,
Dispatcher : dispatcher,
Frequency : connectionConfiguration.ConnectionIntegrityStrategy.HeartbeatFrequency.Value
));
heartbeatEvents.Subscribe(onNext: message => context.Self.PostAndReplyAsync(message));
var heartbeatEvents = await heartbeatAgent.Start(receivedFrames, dispatcher, connectionConfiguration.ConnectionIntegrityStrategy.HeartbeatFrequency.Value);
heartbeatEvents.Subscribe(onNext: message => context.Self.PostAsync(new HeartbeatEventEventReceived(message)));
}

// If tcp keepalives are enabled, configure the socket...
if (connectionConfiguration.ConnectionIntegrityStrategy.KeepAliveSettings.HasValue) {
var (probeTime, retryInterval, retryCount) = connectionConfiguration.ConnectionIntegrityStrategy.KeepAliveSettings.Value;

await socketAgent.PostAsync(new EnableTcpKeepAlives(
ProbeTime : probeTime,
RetryInterval: retryInterval,
RetryCount : retryCount
), cts.Token);
await socketAgent.EnableTcpKeepAlives(probeTime, retryInterval, retryCount);
}

replyChannel.Reply(true);
Expand All @@ -106,7 +103,7 @@ await socketAgent.PostAsync(new EnableTcpKeepAlives(
)
};
}
case Exception fault: {
case HandshakeAgent.HandshakeFailed(var fault): {
replyChannel.Reply(fault);
return context;
}
Expand All @@ -118,51 +115,43 @@ await socketAgent.PostAsync(new EnableTcpKeepAlives(
}
return context;
}
case (Protocol.Disconnect, AsyncReplyChannel replyChannel): {
case Disconnect(var replyChannel): {
replyChannel.Reply(true);
return context;
}
default: throw new Exception($"Unexpected message '{context.Message.GetType().FullName}' in '{nameof(Disconnected)}' behaviour.");
}
};

static Behaviour Connected(ConnectionConfiguration connectionConfiguration, IAgent socketAgent, IAgent heartbeatAgent, IObservable<RawFrame> receivedFrames, IObservable<Object> connectionEvents, IAgent dispatcher, IImmutableList<UInt16> availableChannelIds) =>
static Behaviour<Protocol> Connected(ConnectionConfiguration connectionConfiguration, ISocketAgent socketAgent, IHeartbeatAgent heartbeatAgent, IObservable<RawFrame> receivedFrames, IObservable<Object> connectionEvents, IDispatcherAgent dispatcher, IImmutableList<UInt16> availableChannelIds) =>
async context => {
switch (context.Message) {
case RemoteFlatline: {
await heartbeatAgent.StopAsync();
await dispatcher.StopAsync();
await socketAgent.PostAsync(new SocketAgent.Protocol.Disconnect());
case HeartbeatEventEventReceived(RemoteFlatline): {
await heartbeatAgent.Stop();
await dispatcher.Stop();
await socketAgent.Disconnect();
await socketAgent.StopAsync();

return context with { Behaviour = Disconnected() };
}
case (AmqpClientAgent.Protocol.Disconnect, AsyncReplyChannel replyChannel): {
await heartbeatAgent.StopAsync();
await dispatcher.StopAsync();
await socketAgent.PostAsync(new SocketAgent.Protocol.Disconnect());
case Disconnect(var replyChannel): {
await heartbeatAgent.Stop();
await dispatcher.Stop();
await socketAgent.Disconnect();
await socketAgent.StopAsync();

replyChannel.Reply(true);

return context;
}
case (OpenChannel(var cancellationToken), AsyncReplyChannel replyChannel): {
case OpenChannel(var replyChannel, var cancellationToken): {
var channelId = availableChannelIds[0];
var channelAgent = ChannelAgent.Create(connectionConfiguration.MaximumFrameSize);

using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(connectionConfiguration.CommandTimeout);

var command = new Open(
ChannelId : channelId,
ReceivedFrames : receivedFrames.Where(frame => frame.Channel == channelId),
ConnectionEvents : connectionEvents,
SocketAgent : socketAgent,
CancellationToken: cts.Token
);

switch (await channelAgent.PostAndReplyAsync(command)) {
switch (await channelAgent.Open(channelId,receivedFrames.Where(frame => frame.Channel == channelId), connectionEvents, socketAgent, cts.Token)) {
case true: {
replyChannel.Reply(channelAgent);
return context with {
Expand All @@ -180,4 +169,16 @@ static Behaviour Connected(ConnectionConfiguration connectionConfiguration, IAge
default: throw new Exception($"Unexpected message '{context.Message.GetType().FullName}' in '{nameof(Connected)}' behaviour.");
}
};

async Task<Object> IAmqpClientAgent.EstablishConnection(ConnectionConfiguration configuration, CancellationToken cancellationToken) =>
await _agent.PostAndReplyAsync(replyChannel => new EstablishConnection(configuration, replyChannel, cancellationToken));

async Task<Object> IAmqpClientAgent.OpenChannel(CancellationToken cancellationToken) =>
await _agent.PostAndReplyAsync(replyChannel => new OpenChannel(replyChannel, cancellationToken));

async Task<Object> IAmqpClientAgent.Disconnect() =>
await _agent.PostAndReplyAsync(replyChannel => new Disconnect(replyChannel));

async Task IAmqpClientAgent.Stop() =>
await _agent.StopAsync();
}
2 changes: 1 addition & 1 deletion Lapine.Core/Agents/Behaviour.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
namespace Lapine.Agents;

delegate ValueTask<MessageContext> Behaviour(MessageContext context);
delegate ValueTask<MessageContext<TProtocol>> Behaviour<TProtocol>(MessageContext<TProtocol> context);
Loading

0 comments on commit 2f4ebbc

Please sign in to comment.