diff --git a/server/server.go b/server/server.go index 46acff9f..abcf5389 100644 --- a/server/server.go +++ b/server/server.go @@ -73,8 +73,9 @@ type OpAMPServer interface { // accept connections. Start(settings StartSettings) error - // Stop accepting new connections and close all current connections. This should - // block until all connections are closed. + // Stop accepting new connections and close all current connections. + // This operation should block until both the server socket and all + // connections have been closed. Stop(ctx context.Context) error // Addr returns the network address Server is listening on. Nil if not started. diff --git a/server/serverimpl.go b/server/serverimpl.go index dcdd044c..5b7903c1 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -19,17 +19,17 @@ import ( serverTypes "github.com/open-telemetry/opamp-go/server/types" ) -var ( - errAlreadyStarted = errors.New("already started") +var errAlreadyStarted = errors.New("already started") + +const ( + defaultOpAMPPath = "/v1/opamp" + headerContentType = "Content-Type" + headerContentEncoding = "Content-Encoding" + headerAcceptEncoding = "Accept-Encoding" + contentEncodingGzip = "gzip" + contentTypeProtobuf = "application/x-protobuf" ) -const defaultOpAMPPath = "/v1/opamp" -const headerContentType = "Content-Type" -const headerContentEncoding = "Content-Encoding" -const headerAcceptEncoding = "Accept-Encoding" -const contentEncodingGzip = "gzip" -const contentTypeProtobuf = "application/x-protobuf" - type server struct { logger types.Logger settings Settings @@ -39,7 +39,8 @@ type server struct { // The listening HTTP Server after successful Start() call. Nil if Start() // is not called or was not successful. - httpServer *http.Server + httpServer *http.Server + httpServerServeWg *sync.WaitGroup // The network address Server is listening on. Nil if not started. addr net.Addr @@ -108,6 +109,9 @@ func (s *server) Start(settings StartSettings) error { ConnContext: contextWithConn, } s.httpServer = hs + httpServerServeWg := sync.WaitGroup{} + httpServerServeWg.Add(1) + s.httpServerServeWg = &httpServerServeWg listenAddr := s.httpServer.Addr @@ -118,7 +122,10 @@ func (s *server) Start(settings StartSettings) error { } err = s.startHttpServer( listenAddr, - func(l net.Listener) error { return hs.ServeTLS(l, "", "") }, + func(l net.Listener) error { + defer httpServerServeWg.Done() + return hs.ServeTLS(l, "", "") + }, ) } else { if listenAddr == "" { @@ -126,7 +133,10 @@ func (s *server) Start(settings StartSettings) error { } err = s.startHttpServer( listenAddr, - func(l net.Listener) error { return hs.Serve(l) }, + func(l net.Listener) error { + defer httpServerServeWg.Done() + return hs.Serve(l) + }, ) } return err @@ -159,7 +169,16 @@ func (s *server) Stop(ctx context.Context) error { defer func() { s.httpServer = nil }() // This stops accepting new connections. TODO: close existing // connections and wait them to be terminated. - return s.httpServer.Shutdown(ctx) + err := s.httpServer.Shutdown(ctx) + if err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + s.httpServerServeWg.Wait() + } } return nil } @@ -366,7 +385,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter w.Header().Set(headerContentEncoding, contentEncodingGzip) } _, err = w.Write(bodyBytes) - if err != nil { s.logger.Debugf(req.Context(), "Cannot send HTTP response: %v", err) } diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index ca7a5c33..90de3af7 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -57,6 +57,36 @@ func TestServerStartStop(t *testing.T) { assert.NoError(t, err) } +func TestServerStartStopWithCancel(t *testing.T) { + srv := startServer(t, &StartSettings{}) + + err := srv.Start(StartSettings{}) + assert.ErrorIs(t, err, errAlreadyStarted) + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err = srv.Stop(canceledCtx) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestServerStartStopIdempotency(t *testing.T) { + endpoint := testhelpers.GetAvailableLocalAddress() + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("Attempt #%d: ", i), func(t *testing.T) { + srv := startServer(t, &StartSettings{ + ListenEndpoint: endpoint, + }) + + err := srv.Start(StartSettings{}) + assert.ErrorIs(t, err, errAlreadyStarted) + + err = srv.Stop(context.Background()) + assert.NoError(t, err) + }) + } +} + func TestServerStartStopWithMiddleware(t *testing.T) { var addedMiddleware atomic.Bool assert.False(t, addedMiddleware.Load()) @@ -620,7 +650,6 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) { } func TestServerHonoursClientRequestContentEncoding(t *testing.T) { - hc := http.Client{} var rcvMsg atomic.Value var onConnectedCalled, onCloseCalled int32 @@ -698,7 +727,6 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) { } func TestServerHonoursAcceptEncoding(t *testing.T) { - hc := http.Client{} var rcvMsg atomic.Value var onConnectedCalled, onCloseCalled int32 @@ -985,7 +1013,6 @@ func BenchmarkSendToClient(b *testing.B) { } srv := New(&sharedinternal.NopLogger{}) err := srv.Start(*settings) - if err != nil { b.Error(err) } @@ -1017,7 +1044,6 @@ func BenchmarkSendToClient(b *testing.B) { for _, conn := range serverConnections { err := conn.Send(context.Background(), &protobufs.ServerToAgent{}) - if err != nil { b.Error(err) } @@ -1026,5 +1052,4 @@ func BenchmarkSendToClient(b *testing.B) { for _, conn := range clientConnections { conn.Close() } - }