diff --git a/agent/app/agent_windows.go b/agent/app/agent_windows.go index 78275dbbd08..1f789e8610d 100644 --- a/agent/app/agent_windows.go +++ b/agent/app/agent_windows.go @@ -135,20 +135,17 @@ func (h *handler) handleWindowsRequests(ctx context.Context, requests <-chan svc // runAgent runs the ECS agent inside a goroutine and waits to be told to exit. func (h *handler) runAgent(ctx context.Context) { agentCtx, cancel := context.WithCancel(ctx) - wg := sync.WaitGroup{} - running := true + indicator := newTermHandlerIndicator() + terminationHandler := func(saver statemanager.Saver, taskEngine engine.TaskEngine) { - // We're using a waitgroup, a context, and a simple flag here. The waitgroup gets added to as soon as this - // handler is invoked (agent.start() ultimately invokes it in a goroutine) so that at the end of the outer - // runAgent() function we know to wait for the handler to complete. We then block on the context being - // canceled; this is our signal that the handler should actually run (happens either when the parent context is - // canceled because Windows told us to exit, or because the agent goroutine below exited unexpectedly). The - // flag gets evaluated so that we know whether to actually save state; if the agent isn't properly running, we - // may not actually have any data to save. - wg.Add(1) - defer wg.Done() + // We're using a custom indicator to record that the handler is scheduled to be executed (has been invoked) and + // to determine whether it should run (we skip when the agent engine has already exited). After recording to + // the indicator that the handler has been invoked, we wait on the context. When we wake up, we determine + // whether to execute or not based on whether the agent is still running. + defer indicator.done() + indicator.setInvoked() <-agentCtx.Done() - if !running { + if !indicator.isAgentRunning() { return } @@ -163,7 +160,7 @@ func (h *handler) runAgent(ctx context.Context) { go func() { h.ecsAgent.start() // should block forever, unless there is an error // TODO: distinguish between recoverable and unrecoverable errors - running = false + indicator.agentStopped() cancel() }() @@ -172,7 +169,7 @@ func (h *handler) runAgent(ctx context.Context) { // wait for the termination handler to run. Once the termination handler runs, we can safely exit. If the agent // exits by itself, the termination handler doesn't need to do anything and skips. If the agent exits before the // termination handler is invoked, we can exit immediately. - wg.Wait() + indicator.wait() } // sleepCtx provides a cancelable sleep @@ -186,3 +183,49 @@ func sleepCtx(ctx context.Context, duration time.Duration) { case <-done: } } + +type termHandlerIndicator struct { + mu sync.Mutex + agentRunning bool + handlerInvoked bool + handlerDone chan struct{} +} + +func newTermHandlerIndicator() *termHandlerIndicator { + return &termHandlerIndicator{ + agentRunning: true, + handlerInvoked: false, + handlerDone: make(chan struct{}), + } +} + +func (t *termHandlerIndicator) isAgentRunning() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.agentRunning +} + +func (t *termHandlerIndicator) agentStopped() { + t.mu.Lock() + defer t.mu.Unlock() + t.agentRunning = false +} + +func (t *termHandlerIndicator) done() { + close(t.handlerDone) +} + +func (t *termHandlerIndicator) setInvoked() { + t.mu.Lock() + defer t.mu.Unlock() + t.handlerInvoked = true +} + +func (t *termHandlerIndicator) wait() { + t.mu.Lock() + invoked := t.handlerInvoked + t.mu.Unlock() + if invoked { + <-t.handlerDone + } +}