Skip to content

Commit

Permalink
cleanup cancellation handling in SocketAsyncContext (#53479)
Browse files Browse the repository at this point in the history
* cleanup cancellation handling in SocketAsyncContext

* fix MacOS failback

* Apply suggestions from code review

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* Update src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* address feedback

Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
3 people authored Jun 9, 2021
1 parent 14343bd commit ac87f00
Showing 1 changed file with 90 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ private abstract class AsyncOperation : IThreadPoolWorkItem
private enum State
{
Waiting = 0,
Running = 1,
Complete = 2,
Cancelled = 3
Running,
RunningWithPendingCancellation,
Complete,
Canceled
}

private int _state; // Actually AsyncOperation.State.
Expand Down Expand Up @@ -149,92 +150,103 @@ public void Reset()
#endif
}

public bool TryComplete(SocketAsyncContext context)
public OperationResult TryComplete(SocketAsyncContext context)
{
TraceWithContext(context, "Enter");

bool result = DoTryComplete(context);

TraceWithContext(context, $"Exit, result={result}");
// Set state to Running, unless we've been canceled
int oldState = Interlocked.CompareExchange(ref _state, (int)State.Running, (int)State.Waiting);
if (oldState == (int)State.Canceled)
{
TraceWithContext(context, "Exit, Previously canceled");
return OperationResult.Cancelled;
}

return result;
}
Debug.Assert(oldState == (int)State.Waiting, $"Unexpected operation state: {(State)oldState}");

public bool TrySetRunning()
{
State oldState = (State)Interlocked.CompareExchange(ref _state, (int)State.Running, (int)State.Waiting);
if (oldState == State.Cancelled)
// Try to perform the IO
if (DoTryComplete(context))
{
// This operation has already been cancelled, and had its completion processed.
// Simply return false to indicate no further processing is needed.
return false;
Debug.Assert((State)Volatile.Read(ref _state) is State.Running or State.RunningWithPendingCancellation, "Unexpected operation state");

Volatile.Write(ref _state, (int)State.Complete);

TraceWithContext(context, "Exit, Completed");
return OperationResult.Completed;
}

Debug.Assert(oldState == (int)State.Waiting);
return true;
}
// Set state back to Waiting, unless we were canceled, in which case we have to process cancellation now
int newState;
while (true)
{
int state = Volatile.Read(ref _state);
Debug.Assert(state is (int)State.Running or (int)State.RunningWithPendingCancellation, $"Unexpected operation state: {(State)state}");

public void SetComplete()
{
Debug.Assert(Volatile.Read(ref _state) == (int)State.Running);
newState = (state == (int)State.Running ? (int)State.Waiting : (int)State.Canceled);
if (state == Interlocked.CompareExchange(ref _state, newState, state))
{
break;
}

Volatile.Write(ref _state, (int)State.Complete);
}
// Race to update the state. Loop and try again.
}

public void SetWaiting()
{
Debug.Assert(Volatile.Read(ref _state) == (int)State.Running);
if (newState == (int)State.Canceled)
{
ProcessCancellation();
TraceWithContext(context, "Exit, Newly cancelled");
return OperationResult.Cancelled;
}

Volatile.Write(ref _state, (int)State.Waiting);
TraceWithContext(context, "Exit, Pending");
return OperationResult.Pending;
}

public bool TryCancel()
{
Trace("Enter");

// We're already canceling, so we don't need to still be hooked up to listen to cancellation.
// The cancellation request could also be caused by something other than the token, so it's
// important we clean it up, regardless.
// Note we could be cancelling because of socket close. Regardless, we don't need the registration anymore.
CancellationRegistration.Dispose();

// Try to transition from Waiting to Cancelled
SpinWait spinWait = default;
bool keepWaiting = true;
while (keepWaiting)
int newState;
while (true)
{
int state = Interlocked.CompareExchange(ref _state, (int)State.Cancelled, (int)State.Waiting);
switch ((State)state)
int state = Volatile.Read(ref _state);
if (state is (int)State.Complete or (int)State.Canceled or (int)State.RunningWithPendingCancellation)
{
case State.Running:
// A completion attempt is in progress. Keep busy-waiting.
Trace("Busy wait");
spinWait.SpinOnce();
break;
return false;
}

case State.Complete:
// A completion attempt succeeded. Consider this operation as having completed within the timeout.
Trace("Exit, previously completed");
return false;
newState = (state == (int)State.Waiting ? (int)State.Canceled : (int)State.RunningWithPendingCancellation);
if (state == Interlocked.CompareExchange(ref _state, newState, state))
{
break;
}

case State.Waiting:
// This operation was successfully cancelled.
// Break out of the loop to handle the cancellation
keepWaiting = false;
break;
// Race to update the state. Loop and try again.
}

case State.Cancelled:
// Someone else cancelled the operation.
// The previous canceller will have fired the completion, etc.
Trace("Exit, previously cancelled");
return false;
}
if (newState == (int)State.RunningWithPendingCancellation)
{
// TryComplete will either succeed, or it will see the pending cancellation and deal with it.
return false;
}

Trace("Cancelled, processing completion");
ProcessCancellation();

// The operation successfully cancelled.
// It's our responsibility to set the error code and queue the completion.
DoAbort();
// Note, we leave the operation in the OperationQueue.
// When we get around to processing it, we'll see it's cancelled and skip it.
return true;
}

public void ProcessCancellation()
{
Trace("Enter");

Debug.Assert(_state == (int)State.Canceled);

ErrorCode = SocketError.OperationAborted;

ManualResetEventSlim? e = Event;
if (e != null)
Expand All @@ -252,12 +264,6 @@ public bool TryCancel()
// to do further processing on the item that's still in the list.
ThreadPool.UnsafeQueueUserWorkItem(o => ((AsyncOperation)o!).InvokeCallback(allowPooling: false), this);
}

Trace("Exit");

// Note, we leave the operation in the OperationQueue.
// When we get around to processing it, we'll see it's cancelled and skip it.
return true;
}

public void Dispatch()
Expand Down Expand Up @@ -306,12 +312,9 @@ void IThreadPoolWorkItem.Execute()
// Called when op is not in the queue yet, so can't be otherwise executing
public void DoAbort()
{
Abort();
ErrorCode = SocketError.OperationAborted;
}

protected abstract void Abort();

protected abstract bool DoTryComplete(SocketAsyncContext context);

public abstract void InvokeCallback(bool allowPooling);
Expand Down Expand Up @@ -354,8 +357,6 @@ private abstract class SendOperation : WriteOperation

public SendOperation(SocketAsyncContext context) : base(context) { }

protected sealed override void Abort() { }

public Action<int, byte[]?, int, SocketFlags, SocketError>? Callback { get; set; }

public override void InvokeCallback(bool allowPooling) =>
Expand Down Expand Up @@ -442,8 +443,6 @@ private abstract class ReceiveOperation : ReadOperation

public ReceiveOperation(SocketAsyncContext context) : base(context) { }

protected sealed override void Abort() { }

public Action<int, byte[]?, int, SocketFlags, SocketError>? Callback { get; set; }

public override void InvokeCallback(bool allowPooling) =>
Expand Down Expand Up @@ -554,8 +553,6 @@ private sealed class ReceiveMessageFromOperation : ReadOperation

public ReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { }

protected sealed override void Abort() { }

public Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; }

protected override bool DoTryComplete(SocketAsyncContext context) =>
Expand All @@ -579,8 +576,6 @@ private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation

public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { }

protected sealed override void Abort() { }

public Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; }

protected override bool DoTryComplete(SocketAsyncContext context) =>
Expand All @@ -598,9 +593,6 @@ public AcceptOperation(SocketAsyncContext context) : base(context) { }

public Action<IntPtr, byte[], int, SocketError>? Callback { get; set; }

protected override void Abort() =>
AcceptedFileDescriptor = (IntPtr)(-1);

protected override bool DoTryComplete(SocketAsyncContext context)
{
bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress!, ref SocketAddressLen, out AcceptedFileDescriptor, out ErrorCode);
Expand Down Expand Up @@ -631,8 +623,6 @@ public ConnectOperation(SocketAsyncContext context) : base(context) { }

public Action<SocketError>? Callback { get; set; }

protected override void Abort() { }

protected override bool DoTryComplete(SocketAsyncContext context)
{
bool result = SocketPal.TryCompleteConnect(context._socket, SocketAddressLen, out ErrorCode);
Expand All @@ -653,8 +643,6 @@ private sealed class SendFileOperation : WriteOperation

public SendFileOperation(SocketAsyncContext context) : base(context) { }

protected override void Abort() { }

public Action<long, SocketError>? Callback { get; set; }

public override void InvokeCallback(bool allowPooling) =>
Expand Down Expand Up @@ -694,6 +682,13 @@ public void Dispose()
}
}

public enum OperationResult
{
Pending = 0,
Completed = 1,
Cancelled = 2
}

private struct OperationQueue<TOperation>
where TOperation : AsyncOperation
{
Expand Down Expand Up @@ -864,7 +859,7 @@ public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation
}

// Retry the operation.
if (operation.TryComplete(context))
if (operation.TryComplete(context) != OperationResult.Pending)
{
Trace(context, $"Leave, retry succeeded");
return false;
Expand All @@ -880,7 +875,7 @@ static void HandleFailedRegistration(SocketAsyncContext context, TOperation oper
{
// Because the other end close, we expect the operation to complete when we retry it.
// If it doesn't, we fall through and throw an Exception.
if (operation.TryComplete(context))
if (operation.TryComplete(context) != OperationResult.Pending)
{
return;
}
Expand Down Expand Up @@ -979,13 +974,6 @@ internal void ProcessAsyncOperation(TOperation op)
}
}

public enum OperationResult
{
Pending = 0,
Completed = 1,
Cancelled = 2
}

public OperationResult ProcessQueuedOperation(TOperation op)
{
SocketAsyncContext context = op.AssociatedContext;
Expand All @@ -1010,27 +998,15 @@ public OperationResult ProcessQueuedOperation(TOperation op)
}
}

bool wasCompleted = false;
OperationResult result;
while (true)
{
// Try to change the op state to Running.
// If this fails, it means the operation was previously cancelled,
// and we should just remove it from the queue without further processing.
if (!op.TrySetRunning())
{
break;
}

// Try to perform the IO
if (op.TryComplete(context))
result = op.TryComplete(context);
if (result != OperationResult.Pending)
{
op.SetComplete();
wasCompleted = true;
break;
}

op.SetWaiting();

// Check for retry and reset queue state.

using (Lock())
Expand Down Expand Up @@ -1097,7 +1073,8 @@ public OperationResult ProcessQueuedOperation(TOperation op)

nextOp?.Dispatch();

return (wasCompleted ? OperationResult.Completed : OperationResult.Cancelled);
Debug.Assert(result != OperationResult.Pending);
return result;
}

public void CancelAndContinueProcessing(TOperation op)
Expand Down Expand Up @@ -1360,9 +1337,9 @@ private void PerformSyncOperation<TOperation>(ref OperationQueue<TOperation> que
e.Reset();

// We've been signalled to try to process the operation.
OperationQueue<TOperation>.OperationResult result = queue.ProcessQueuedOperation(operation);
if (result == OperationQueue<TOperation>.OperationResult.Completed ||
result == OperationQueue<TOperation>.OperationResult.Cancelled)
OperationResult result = queue.ProcessQueuedOperation(operation);
if (result == OperationResult.Completed ||
result == OperationResult.Cancelled)
{
break;
}
Expand Down

0 comments on commit ac87f00

Please sign in to comment.