Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dotnet] Refactor WebSocket communication for BiDi #12614

Merged
merged 13 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 57 additions & 141 deletions dotnet/src/webdriver/DevTools/DevToolsSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
using System;
using System.Collections.Concurrent;
using System.Globalization;
using System.IO;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json;
Expand Down Expand Up @@ -50,15 +47,14 @@ public class DevToolsSession : IDevToolsSession
private bool isDisposed = false;
private string attachedTargetId;

private ClientWebSocket sessionSocket;
private WebSocketConnection connection;
private ConcurrentDictionary<long, DevToolsCommandData> pendingCommands = new ConcurrentDictionary<long, DevToolsCommandData>();
private readonly BlockingCollection<string> messageQueue = new BlockingCollection<string>();
private readonly Task messageQueueMonitorTask;
private long currentCommandId = 0;

private DevToolsDomains domains;

private CancellationTokenSource receiveCancellationToken;
private Task receiveTask;

/// <summary>
/// Initializes a new instance of the DevToolsSession class, using the specified WebSocket endpoint.
/// </summary>
Expand All @@ -76,6 +72,8 @@ public DevToolsSession(string endpointAddress)
{
this.websocketAddress = endpointAddress;
}
this.messageQueueMonitorTask = Task.Run(() => this.MonitorMessageQueue());
this.messageQueueMonitorTask.ConfigureAwait(false);
}

/// <summary>
Expand Down Expand Up @@ -213,15 +211,13 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains

var message = new DevToolsCommandData(Interlocked.Increment(ref this.currentCommandId), this.ActiveSessionId, commandName, commandParameters);

if (this.sessionSocket != null && this.sessionSocket.State == WebSocketState.Open)
if (this.connection != null && this.connection.IsActive)
{
LogTrace("Sending {0} {1}: {2}", message.CommandId, message.CommandName, commandParameters.ToString());

var contents = JsonConvert.SerializeObject(message);
var contentBuffer = Encoding.UTF8.GetBytes(contents);

string contents = JsonConvert.SerializeObject(message);
this.pendingCommands.TryAdd(message.CommandId, message);
await this.sessionSocket.SendAsync(new ArraySegment<byte>(contentBuffer), WebSocketMessageType.Text, true, cancellationToken);
await this.connection.SendData(contents);

var responseWasReceived = await Task.Run(() => message.SyncEvent.Wait(millisecondsTimeout.Value, cancellationToken));

Expand All @@ -230,8 +226,7 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains
throw new InvalidOperationException($"A command response was not received: {commandName}");
}

DevToolsCommandData modified;
if (this.pendingCommands.TryRemove(message.CommandId, out modified))
if (this.pendingCommands.TryRemove(message.CommandId, out DevToolsCommandData modified))
{
if (modified.IsError)
{
Expand All @@ -256,10 +251,7 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains
}
else
{
if (this.sessionSocket != null)
{
LogTrace("WebSocket is not connected (current state is {0}); not sending {1}", this.sessionSocket.State, message.CommandName);
}
LogTrace("WebSocket is not connected; not sending {0}", message.CommandName);
}

return null;
Expand Down Expand Up @@ -330,11 +322,7 @@ protected void Dispose(bool disposing)
{
this.Domains.Target.TargetDetached -= this.OnTargetDetached;
this.pendingCommands.Clear();
this.TerminateSocketConnection();

// Note: Canceling the receive task will dispose of
// the underlying ClientWebSocket instance.
this.CancelReceiveTask();
this.TerminateSocketConnection().GetAwaiter().GetResult();
}

this.isDisposed = true;
Expand Down Expand Up @@ -377,28 +365,6 @@ private async Task<int> InitializeProtocol(int requestedProtocolVersion)
return protocolVersion;
}

private async Task InitializeSocketConnection()
{
LogTrace("Creating WebSocket");
this.sessionSocket = new ClientWebSocket();
this.sessionSocket.Options.KeepAliveInterval = TimeSpan.Zero;

try
{
var timeoutTokenSource = new CancellationTokenSource(this.openConnectionWaitTimeSpan);
await this.sessionSocket.ConnectAsync(new Uri(this.websocketAddress), timeoutTokenSource.Token);
while (this.sessionSocket.State != WebSocketState.Open && !timeoutTokenSource.Token.IsCancellationRequested) ;
}
catch (OperationCanceledException e)
{
throw new WebDriverException(string.Format(CultureInfo.InvariantCulture, "Could not establish WebSocket connection within {0} seconds.", this.openConnectionWaitTimeSpan.TotalSeconds), e);
}

LogTrace("WebSocket created; starting message listener");
this.receiveCancellationToken = new CancellationTokenSource();
this.receiveTask = Task.Run(() => ReceiveMessage().ConfigureAwait(false));
}

private async Task InitializeSession()
{
LogTrace("Creating session");
Expand Down Expand Up @@ -445,116 +411,56 @@ private void OnTargetDetached(object sender, TargetDetachedEventArgs e)
}
}

private void TerminateSocketConnection()
private async Task InitializeSocketConnection()
{
if (this.sessionSocket != null && this.sessionSocket.State == WebSocketState.Open)
{
var closeConnectionTokenSource = new CancellationTokenSource(this.closeConnectionWaitTimeSpan);
try
{
// Since Chromium-based DevTools does not respond to the close
// request with a correctly echoed WebSocket close packet, but
// rather just terminates the socket connection, so we have to
// catch the exception thrown when the socket is terminated
// unexpectedly. Also, because we are using async, waiting for
// the task to complete might throw a TaskCanceledException,
// which we should also catch. Additiionally, there are times
// when mulitple failure modes can be seen, which will throw an
// AggregateException, consolidating several exceptions into one,
// and this too must be caught. Finally, the call to CloseAsync
// will hang even though the connection is already severed.
// Wait for the task to complete for a short time (since we're
// restricted to localhost, the default of 2 seconds should be
// plenty; if not, change the initialization of the timout),
// and if the task is still running, then we assume the connection
// is properly closed.
LogTrace("Sending socket close request");
Task closeTask = Task.Run(async () => await this.sessionSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, closeConnectionTokenSource.Token));
closeTask.Wait();
}
catch (WebSocketException)
{
}
catch (TaskCanceledException)
{
}
catch (AggregateException)
{
}
}
LogTrace("Creating WebSocket");
this.connection = new WebSocketConnection(this.openConnectionWaitTimeSpan, this.closeConnectionWaitTimeSpan);
connection.DataReceived += OnConnectionDataReceived;
await connection.Start(this.websocketAddress);
LogTrace("WebSocket created");
}

private void CancelReceiveTask()
private async Task TerminateSocketConnection()
{
if (this.receiveTask != null)
LogTrace("Closing WebSocket");
if (this.connection != null && this.connection.IsActive)
{
// Wait for the recieve task to be completely exited (for
// whatever reason) before attempting to dispose it. Also
// note that canceling the receive task will dispose of the
// underlying WebSocket.
this.receiveCancellationToken.Cancel();
this.receiveTask.Wait();
this.receiveTask.Dispose();
this.receiveTask = null;
await this.connection.Stop();
await this.ShutdownMessageQueue();
}
LogTrace("WebSocket closed");
}

private async Task ReceiveMessage()
private async Task ShutdownMessageQueue()
{
var cancellationToken = this.receiveCancellationToken.Token;
try
{
var buffer = WebSocket.CreateClientBuffer(1024, 1024);
while (this.sessionSocket.State != WebSocketState.Closed && !cancellationToken.IsCancellationRequested)
{
WebSocketReceiveResult result = await this.sessionSocket.ReceiveAsync(buffer, cancellationToken);
if (!cancellationToken.IsCancellationRequested)
{
if (result.MessageType == WebSocketMessageType.Close && this.sessionSocket.State == WebSocketState.CloseReceived)
{
LogTrace("Got WebSocket close message from browser");
await this.sessionSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken);
}
}

if (this.sessionSocket.State == WebSocketState.Open && result.MessageType != WebSocketMessageType.Close)
{
using (var stream = new MemoryStream())
{
stream.Write(buffer.Array, 0, result.Count);
while (!result.EndOfMessage)
{
result = await this.sessionSocket.ReceiveAsync(buffer, cancellationToken);
stream.Write(buffer.Array, 0, result.Count);
}

stream.Seek(0, SeekOrigin.Begin);
using (var reader = new StreamReader(stream, Encoding.UTF8))
{
string message = reader.ReadToEnd();

// fire and forget
// TODO: we need implement some kind of queue
Task.Run(() => ProcessIncomingMessage(message));
}
}
}
}
}
catch (OperationCanceledException)
// THe WebSockect connection is always closed before this method
// is called, so there will eventually be no more data written
// into the message queue, meaning this loop should be guaranteed
// to complete.
while (this.connection.IsActive)
{
await Task.Delay(TimeSpan.FromMilliseconds(10));
jimevans marked this conversation as resolved.
Show resolved Hide resolved
}
catch (WebSocketException)
{
}
finally

this.messageQueue.CompleteAdding();
await this.messageQueueMonitorTask;
}

private void MonitorMessageQueue()
{
// GetConsumingEnumerable blocks until if BlockingCollection.IsCompleted
// is false (i.e., is still able to be written to), and there are no items
// in the collection. Once any items are added to the collection, the method
// unblocks and we can process any items in the collection at that moment.
// Once IsCompleted is true, the method unblocks with no items in returned
// in the IEnumerable, meaning the foreach loop will terminate gracefully.
foreach (string message in this.messageQueue.GetConsumingEnumerable())
{
this.sessionSocket.Dispose();
this.sessionSocket = null;
this.ProcessMessage(message);
}
}

private void ProcessIncomingMessage(string message)
private void ProcessMessage(string message)
{
var messageObject = JObject.Parse(message);

Expand Down Expand Up @@ -594,7 +500,12 @@ private void ProcessIncomingMessage(string message)

LogTrace("Recieved Event {0}: {1}", method, eventData.ToString());

OnDevToolsEventReceived(new DevToolsEventReceivedEventArgs(methodParts[0], methodParts[1], eventData));
// Dispatch the event on a new thread so that any event handlers
// responding to the event will not block this thread from processing
// DevTools commands that may be sent in the body of the attached
// event handler. If thread pool starvation seems to become a problem,
// we can switch to a channel-based queue.
Task.Run(() => OnDevToolsEventReceived(new DevToolsEventReceivedEventArgs(methodParts[0], methodParts[1], eventData)));

return;
}
Expand All @@ -610,6 +521,11 @@ private void OnDevToolsEventReceived(DevToolsEventReceivedEventArgs e)
}
}

private void OnConnectionDataReceived(object sender, WebSocketConnectionDataReceivedEventArgs e)
{
this.messageQueue.Add(e.Data);
}

private void LogTrace(string message, params object[] args)
{
if (LogMessage != null)
Expand Down
Loading