Skip to content

Commit

Permalink
Refactor to strongly typed async replies
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew-Davey committed Jan 30, 2024
1 parent 41dffd1 commit 26d5a10
Show file tree
Hide file tree
Showing 38 changed files with 578 additions and 677 deletions.
32 changes: 25 additions & 7 deletions Lapine.Core/Agents/Agent.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
namespace Lapine.Agents;

using System.Runtime.CompilerServices;
using System.Threading.Channels;

interface IAgent<in TProtocol> {
ValueTask PostAsync(TProtocol message, CancellationToken cancellationToken = default);
ValueTask<Object> PostAndReplyAsync(Func<AsyncReplyChannel, TProtocol> messageFactory);
ValueTask PostAndReplyAsync(Func<AsyncReplyChannel, TProtocol> messageFactory);
ValueTask<TReply> PostAndReplyAsync<TReply>(Func<AsyncReplyChannel<TReply>, TProtocol> messageFactory);
ValueTask StopAsync();
}

class AsyncReplyChannel(Action<Object> reply) {
public void Reply(Object response) => reply(response);
class AsyncReplyChannel(TaskCompletionSource promise) {
public void Complete() => promise.SetResult();
public void Fault(Exception fault) => promise.SetException(fault);
}

class AsyncReplyChannel<TReply>(TaskCompletionSource<TReply> promise) {
public void Reply(TReply response) => promise.SetResult(response);
public void Fault(Exception fault) => promise.SetException(fault);
}

class Agent<TProtocol> : IAgent<TProtocol> {
Expand Down Expand Up @@ -39,11 +45,23 @@ static public IAgent<TProtocol> StartNew(Behaviour<TProtocol> initialBehaviour)
public async ValueTask PostAsync(TProtocol message, CancellationToken cancellationToken = default) =>
await _mailbox.Writer.WriteAsync(message, cancellationToken);

public async ValueTask<Object> PostAndReplyAsync(Func<AsyncReplyChannel, TProtocol> messageFactory) {
public async ValueTask PostAndReplyAsync(Func<AsyncReplyChannel, TProtocol> messageFactory) {
ArgumentNullException.ThrowIfNull(messageFactory);

var promise = new TaskCompletionSource();
var replyChannel = new AsyncReplyChannel(promise);
var message = messageFactory(replyChannel);

await _mailbox.Writer.WriteAsync(message);

await promise.Task;
}

public async ValueTask<TReply> PostAndReplyAsync<TReply>(Func<AsyncReplyChannel<TReply>, TProtocol> messageFactory) {
ArgumentNullException.ThrowIfNull(messageFactory);

var promise = AsyncValueTaskMethodBuilder<Object>.Create();
var replyChannel = new AsyncReplyChannel(reply => promise.SetResult(reply));
var promise = new TaskCompletionSource<TReply>();
var replyChannel = new AsyncReplyChannel<TReply>(promise);
var message = messageFactory(replyChannel);

await _mailbox.Writer.WriteAsync(message);
Expand Down
179 changes: 88 additions & 91 deletions Lapine.Core/Agents/AmqpClientAgent.Behaviours.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ namespace Lapine.Agents;
using Lapine.Protocol;

static partial class AmqpClientAgent {
record State(
ConnectionConfiguration ConnectionConfiguration,
ISocketAgent SocketAgent,
IHeartbeatAgent HeartbeatAgent,
IObservable<RawFrame> ReceivedFrames,
IObservable<ConnectionEvent> ConnectionEvents,
IDispatcherAgent Dispatcher,
IImmutableList<UInt16> AvailableChannelIds
);

static Behaviour<Protocol> Disconnected() =>
async context => {
switch (context.Message) {
Expand All @@ -17,7 +27,7 @@ static Behaviour<Protocol> Disconnected() =>
var remainingEndpoints = new Queue<IPEndPoint>(connectionConfiguration.GetConnectionSequence());

if (remainingEndpoints.Count == 0) {
replyChannel.Reply(new Exception("No endpoints specified in connection configuration"));
replyChannel.Fault(new Exception("No endpoints specified in connection configuration"));
return context;
}

Expand All @@ -27,123 +37,110 @@ static Behaviour<Protocol> Disconnected() =>
var endpoint = remainingEndpoints.Dequeue();
var socketAgent = SocketAgent.Create();

switch (await socketAgent.ConnectAsync(endpoint, cts.Token)) {
case ConnectionFailed(var fault) when remainingEndpoints.Any(): {
accumulatedFailures.Add(fault);
continue;
}
case ConnectionFailed(var fault): {
accumulatedFailures.Add(fault);
replyChannel.Reply(new AggregateException("Could not connect to any of the configured endpoints", accumulatedFailures));
return context;
try {
var (connectionEvents, receivedFrames) = await socketAgent.ConnectAsync(endpoint, cts.Token);

var dispatcher = DispatcherAgent.Create();
await dispatcher.DispatchTo(socketAgent, 0);

var handshakeAgent = HandshakeAgent.Create(
receivedFrames : receivedFrames,
connectionEvents : connectionEvents,
dispatcher : dispatcher,
cancellationToken: cts.Token
);

var connectionAgreement = await handshakeAgent.StartHandshake(connectionConfiguration);
await socketAgent.Tune(connectionAgreement.MaxFrameSize);

var heartbeatAgent = HeartbeatAgent.Create();

if (connectionConfiguration.ConnectionIntegrityStrategy.HeartbeatFrequency.HasValue) {
var heartbeatEvents = await heartbeatAgent.Start(receivedFrames, dispatcher, connectionConfiguration.ConnectionIntegrityStrategy.HeartbeatFrequency.Value);
heartbeatEvents.Subscribe(onNext: message => context.Self.PostAsync(new HeartbeatEventEventReceived(message)));
}
case Connected(var connectionEvents, var receivedFrames): {
var dispatcher = DispatcherAgent.Create();
await dispatcher.DispatchTo(socketAgent, 0);

var handshakeAgent = HandshakeAgent.Create(
receivedFrames : receivedFrames,
connectionEvents : connectionEvents,
dispatcher : dispatcher,
cancellationToken: cts.Token
);

switch (await handshakeAgent.StartHandshake(connectionConfiguration)) {
case ConnectionAgreed(var connectionAgreement): {
await socketAgent.Tune(connectionAgreement.MaxFrameSize);

var heartbeatAgent = HeartbeatAgent.Create();

if (connectionConfiguration.ConnectionIntegrityStrategy.HeartbeatFrequency.HasValue) {
var heartbeatEvents = await heartbeatAgent.Start(receivedFrames, dispatcher, connectionConfiguration.ConnectionIntegrityStrategy.HeartbeatFrequency.Value);
heartbeatEvents.Subscribe(onNext: message => context.Self.PostAsync(new HeartbeatEventEventReceived(message)));
}

// If tcp keepalives are enabled, configure the socket...
if (connectionConfiguration.ConnectionIntegrityStrategy.KeepAliveSettings.HasValue) {
var (probeTime, retryInterval, retryCount) = connectionConfiguration.ConnectionIntegrityStrategy.KeepAliveSettings.Value;

await socketAgent.EnableTcpKeepAlives(probeTime, retryInterval, retryCount);
}

replyChannel.Reply(true);

return context with {
Behaviour = Connected(
connectionConfiguration: connectionConfiguration,
socketAgent : socketAgent,
heartbeatAgent : heartbeatAgent,
receivedFrames : receivedFrames,
connectionEvents : connectionEvents,
dispatcher : dispatcher,
availableChannelIds : Enumerable.Range(1, connectionAgreement.MaxChannelCount)
.Select(channelId => (UInt16) channelId)
.ToImmutableList()
)
};
}
case HandshakeFailed(var fault): {
replyChannel.Reply(fault);
return context;
}
}

break;

// If tcp keepalives are enabled, configure the socket...
if (connectionConfiguration.ConnectionIntegrityStrategy.KeepAliveSettings.HasValue) {
var (probeTime, retryInterval, retryCount) = connectionConfiguration.ConnectionIntegrityStrategy.KeepAliveSettings.Value;

await socketAgent.EnableTcpKeepAlives(probeTime, retryInterval, retryCount);
}

replyChannel.Complete();

var state = new State(
ConnectionConfiguration: connectionConfiguration,
SocketAgent : socketAgent,
HeartbeatAgent : heartbeatAgent,
ReceivedFrames : receivedFrames,
ConnectionEvents : connectionEvents,
Dispatcher : dispatcher,
AvailableChannelIds : Enumerable.Range(1, connectionAgreement.MaxChannelCount)
.Select(channelId => (UInt16)channelId)
.ToImmutableList()
);

return context with {
Behaviour = Connected(state)
};
}
catch (Exception fault) {
accumulatedFailures.Add(fault);
}
}
replyChannel.Fault(new AggregateException("Could not connect to any of the configured endpoints", accumulatedFailures));
return context;
}
case Disconnect(var replyChannel): {
replyChannel.Reply(true);
replyChannel.Complete();
return context;
}
default: throw new Exception($"Unexpected message '{context.Message.GetType().FullName}' in '{nameof(Disconnected)}' behaviour.");
}
};

static Behaviour<Protocol> Connected(ConnectionConfiguration connectionConfiguration, ISocketAgent socketAgent, IHeartbeatAgent heartbeatAgent, IObservable<RawFrame> receivedFrames, IObservable<Object> connectionEvents, IDispatcherAgent dispatcher, IImmutableList<UInt16> availableChannelIds) =>
static Behaviour<Protocol> Connected(State state) =>
async context => {
switch (context.Message) {
case HeartbeatEventEventReceived(RemoteFlatline): {
await heartbeatAgent.Stop();
await dispatcher.Stop();
await socketAgent.Disconnect();
await socketAgent.StopAsync();
await state.HeartbeatAgent.Stop();
await state.Dispatcher.Stop();
await state.SocketAgent.Disconnect();
await state.SocketAgent.StopAsync();

return context with { Behaviour = Disconnected() };
}
case Disconnect(var replyChannel): {
await heartbeatAgent.Stop();
await dispatcher.Stop();
await socketAgent.Disconnect();
await socketAgent.StopAsync();
await state.HeartbeatAgent.Stop();
await state.Dispatcher.Stop();
await state.SocketAgent.Disconnect();
await state.SocketAgent.StopAsync();

replyChannel.Reply(true);
replyChannel.Complete();

return context;
}
case OpenChannel(var replyChannel, var cancellationToken): {
var channelId = availableChannelIds[0];
var channelAgent = ChannelAgent.Create(connectionConfiguration.MaximumFrameSize);
var channelId = state.AvailableChannelIds[0];
var channelAgent = ChannelAgent.Create(state.ConnectionConfiguration.MaximumFrameSize);

using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(connectionConfiguration.CommandTimeout);

switch (await channelAgent.Open(channelId,receivedFrames.Where(frame => frame.Channel == channelId), connectionEvents, socketAgent, cts.Token)) {
case true: {
replyChannel.Reply(channelAgent);
return context with {
Behaviour = Connected(connectionConfiguration, socketAgent, heartbeatAgent, receivedFrames, connectionEvents, dispatcher, availableChannelIds.Remove(channelId))
};
}
case Exception fault: {
replyChannel.Reply(fault);
break;
}
}

return context;
cts.CancelAfter(state.ConnectionConfiguration.CommandTimeout);

return await channelAgent.Open(channelId, state.ReceivedFrames.Where(frame => frame.Channel == channelId), state.ConnectionEvents, state.SocketAgent, cts.Token)
.ContinueWith(
onCompleted: () => {
replyChannel.Reply(channelAgent);
return context with {
Behaviour = Connected(state with { AvailableChannelIds = state.AvailableChannelIds.Remove(channelId) })
};
},
onFaulted: fault => {
replyChannel.Fault(fault);
return context;
}
);
}
default: throw new Exception($"Unexpected message '{context.Message.GetType().FullName}' in '{nameof(Connected)}' behaviour.");
}
Expand Down
2 changes: 1 addition & 1 deletion Lapine.Core/Agents/AmqpClientAgent.Protocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ record EstablishConnection(
) : Protocol;

record OpenChannel(
AsyncReplyChannel ReplyChannel,
AsyncReplyChannel<IChannelAgent> ReplyChannel,
CancellationToken CancellationToken = default
) : Protocol;

Expand Down
8 changes: 4 additions & 4 deletions Lapine.Core/Agents/AmqpClientAgent.Wrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ namespace Lapine.Agents;

static partial class AmqpClientAgent {
class Wrapper(IAgent<Protocol> agent) : IAmqpClientAgent {
async Task<Object> IAmqpClientAgent.EstablishConnection(ConnectionConfiguration configuration, CancellationToken cancellationToken) =>
async Task IAmqpClientAgent.EstablishConnection(ConnectionConfiguration configuration, CancellationToken cancellationToken) =>
await agent.PostAndReplyAsync(replyChannel => new EstablishConnection(configuration, replyChannel, cancellationToken));

async Task<Object> IAmqpClientAgent.OpenChannel(CancellationToken cancellationToken) =>
await agent.PostAndReplyAsync(replyChannel => new OpenChannel(replyChannel, cancellationToken));
async Task<IChannelAgent> IAmqpClientAgent.OpenChannel(CancellationToken cancellationToken) =>
await agent.PostAndReplyAsync<IChannelAgent>(replyChannel => new OpenChannel(replyChannel, cancellationToken));

async Task<Object> IAmqpClientAgent.Disconnect() =>
async Task IAmqpClientAgent.Disconnect() =>
await agent.PostAndReplyAsync(replyChannel => new Disconnect(replyChannel));

async Task IAmqpClientAgent.Stop() =>
Expand Down
6 changes: 3 additions & 3 deletions Lapine.Core/Agents/AmqpClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ namespace Lapine.Agents;
using Lapine.Client;

interface IAmqpClientAgent {
Task<Object> EstablishConnection(ConnectionConfiguration configuration, CancellationToken cancellationToken = default);
Task<Object> OpenChannel(CancellationToken cancellationToken = default);
Task<Object> Disconnect();
Task EstablishConnection(ConnectionConfiguration configuration, CancellationToken cancellationToken = default);
Task<IChannelAgent> OpenChannel(CancellationToken cancellationToken = default);
Task Disconnect();
Task Stop();
}

Expand Down
Loading

0 comments on commit 26d5a10

Please sign in to comment.