From 07dcd9e716d00fee9d6bc3ed31348f7b20257a0a Mon Sep 17 00:00:00 2001 From: Tiffany Jernigan Date: Fri, 15 Apr 2016 17:21:05 -0700 Subject: [PATCH] Fixes 821: Watch not closed by server when snapd exits --- mgmt/rest/client/client_func_test.go | 3 +- mgmt/rest/client/client_tribe_func_test.go | 3 +- mgmt/rest/rest_func_test.go | 3 +- mgmt/rest/server.go | 54 +++++++++++++------ mgmt/rest/task.go | 14 +++++ mgmt/rest/tribe_test.go | 3 +- snapd.go | 62 +++++++++++----------- 7 files changed, 92 insertions(+), 50 deletions(-) diff --git a/mgmt/rest/client/client_func_test.go b/mgmt/rest/client/client_func_test.go index 03d4e697a..45592d3ef 100644 --- a/mgmt/rest/client/client_func_test.go +++ b/mgmt/rest/client/client_func_test.go @@ -91,7 +91,8 @@ func startAPI() string { } log.Fatal(err) }(r.Err()) - r.Start("127.0.0.1:0") + r.SetAddress("127.0.0.1:0") + r.Start() time.Sleep(100 * time.Millisecond) return fmt.Sprintf("http://localhost:%d", r.Port()) } diff --git a/mgmt/rest/client/client_tribe_func_test.go b/mgmt/rest/client/client_tribe_func_test.go index ccaeeab34..ad02bc714 100644 --- a/mgmt/rest/client/client_tribe_func_test.go +++ b/mgmt/rest/client/client_tribe_func_test.go @@ -201,7 +201,8 @@ func startTribes(count int) []int { r.BindMetricManager(c) r.BindTaskManager(s) r.BindTribeManager(t) - r.Start(":" + strconv.Itoa(mgtPort)) + r.SetAddress(":" + strconv.Itoa(mgtPort)) + r.Start() wg.Add(1) timer := time.After(10 * time.Second) go func(port int) { diff --git a/mgmt/rest/rest_func_test.go b/mgmt/rest/rest_func_test.go index b222591ab..6c6e0579d 100644 --- a/mgmt/rest/rest_func_test.go +++ b/mgmt/rest/rest_func_test.go @@ -470,7 +470,8 @@ func startAPI(cfg *mockConfig) *restAPIInstance { } log.Fatal(err) }(r.Err()) - r.Start("127.0.0.1:0") + r.SetAddress("127.0.0.1:0") + r.Start() time.Sleep(time.Millisecond * 100) return &restAPIInstance{ port: r.Port(), diff --git a/mgmt/rest/server.go b/mgmt/rest/server.go index 65a92ee54..d94e1cdec 100644 --- a/mgmt/rest/server.go +++ b/mgmt/rest/server.go @@ -28,6 +28,7 @@ import ( "net" "net/http" "strings" + "sync" "time" log "github.com/Sirupsen/logrus" @@ -120,27 +121,31 @@ type managesConfig interface { } type Server struct { - mm managesMetrics - mt managesTasks - tr managesTribe - mc managesConfig - n *negroni.Negroni - r *httprouter.Router - tls *tls - auth bool - authpwd string - addr net.Addr - err chan error + mm managesMetrics + mt managesTasks + tr managesTribe + mc managesConfig + n *negroni.Negroni + r *httprouter.Router + tls *tls + auth bool + authpwd string + addrString string + addr net.Addr + wg sync.WaitGroup + killChan chan struct{} + err chan error } -// func New(https bool, cpath, kpath string) (*Server, error) { +// New creates a REST API server with a given config func New(cfg *Config) (*Server, error) { // pull a few parameters from the configuration passed in by snapd https := cfg.HTTPS cpath := cfg.RestCertificate kpath := cfg.RestKey s := &Server{ - err: make(chan error), + err: make(chan error), + killChan: make(chan struct{}), } if https { var err error @@ -163,7 +168,7 @@ func New(cfg *Config) (*Server, error) { return s, nil } -// get the default snapd configuration +// GetDefaultConfig gets the default snapd configuration func GetDefaultConfig() *Config { return &Config{ Enable: defaultEnable, @@ -204,9 +209,26 @@ func (s *Server) authMiddleware(rw http.ResponseWriter, r *http.Request, next ht } } -func (s *Server) Start(addrString string) { +func (s *Server) Name() string { + return "REST" +} + +func (s *Server) SetAddress(addrString string) { + s.addrString = addrString +} + +func (s *Server) Start() error { s.addRoutes() - s.run(addrString) + s.run(s.addrString) + restLogger.WithFields(log.Fields{ + "_block": "start", + }).Info("REST started") + return nil +} + +func (s *Server) Stop() { + close(s.killChan) + s.wg.Wait() } func (s *Server) Err() <-chan error { diff --git a/mgmt/rest/task.go b/mgmt/rest/task.go index 707249562..671e2a372 100644 --- a/mgmt/rest/task.go +++ b/mgmt/rest/task.go @@ -142,6 +142,8 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request, p httprouter.Pa } func (s *Server) watchTask(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + s.wg.Add(1) + defer s.wg.Done() logger := log.WithFields(log.Fields{ "_module": "api", "_block": "watch-task", @@ -229,6 +231,18 @@ func (s *Server) watchTask(w http.ResponseWriter, r *http.Request, p httprouter. tc.Close() // exit since this client is no longer listening respond(200, &rbody.ScheduledTaskWatchingEnded{}, w) + return + case <-s.killChan: + logger.WithFields(log.Fields{ + "task-id": id, + }).Debug("snapd exiting; disconnecting client") + // Flush since we are sending nothing new + flusher.Flush() + // Close out watcher removing it from the scheduler + tc.Close() + // exit since this client is no longer listening + respond(200, &rbody.ScheduledTaskWatchingEnded{}, w) + return } } } diff --git a/mgmt/rest/tribe_test.go b/mgmt/rest/tribe_test.go index ba0095fdc..44bbd2b79 100644 --- a/mgmt/rest/tribe_test.go +++ b/mgmt/rest/tribe_test.go @@ -701,7 +701,8 @@ func startTribes(count int, seed string) ([]int, int) { r.BindMetricManager(c) r.BindTaskManager(s) r.BindTribeManager(t) - r.Start(":" + strconv.Itoa(mgtPort)) + r.SetAddress(":" + strconv.Itoa(mgtPort)) + r.Start() wg.Add(1) timer := time.After(10 * time.Second) go func(port int) { diff --git a/snapd.go b/snapd.go index 016961f77..2d957139e 100644 --- a/snapd.go +++ b/snapd.go @@ -289,6 +289,38 @@ func action(ctx *cli.Context) { tr = t } + //Setup RESTful API if it was enabled in the configuration + if cfg.RestAPI.Enable { + r, err := rest.New(cfg.RestAPI) + if err != nil { + log.Fatal(err) + } + r.BindMetricManager(c) + r.BindConfigManager(c.Config) + r.BindTaskManager(s) + + //Rest Authentication + if cfg.RestAPI.RestAuth { + log.Info("REST API authentication is enabled") + r.SetAPIAuth(cfg.RestAPI.RestAuth) + log.Info("REST API authentication password is set") + r.SetAPIAuthPwd(cfg.RestAPI.RestAuthPassword) + if !cfg.RestAPI.HTTPS { + log.Warning("Using REST API authentication without HTTPS enabled.") + } + } + + if tr != nil { + r.BindTribeManager(tr) + } + go monitorErrors(r.Err()) + r.SetAddress(fmt.Sprintf(":%d", cfg.RestAPI.Port)) + coreModules = append(coreModules, r) + log.Info("REST API is enabled") + } else { + log.Info("REST API is disabled") + } + // Set interrupt handling so we can die gracefully. startInterruptHandling(coreModules...) @@ -469,36 +501,6 @@ func action(ctx *cli.Context) { log.Info("auto discover path is disabled") } - //Setup RESTful API if it was enbled in th configuration - if cfg.RestAPI.Enable { - r, err := rest.New(cfg.RestAPI) - if err != nil { - log.Fatal(err) - } - r.BindMetricManager(c) - r.BindConfigManager(c.Config) - r.BindTaskManager(s) - //Rest Authentication - if cfg.RestAPI.RestAuth { - log.Info("REST API authentication is enabled") - r.SetAPIAuth(cfg.RestAPI.RestAuth) - log.Info("REST API authentication password is set") - r.SetAPIAuthPwd(cfg.RestAPI.RestAuthPassword) - if !cfg.RestAPI.HTTPS { - log.Warning("Using REST API authentication without HTTPS enabled.") - } - } - - if tr != nil { - r.BindTribeManager(tr) - } - go monitorErrors(r.Err()) - r.Start(fmt.Sprintf(":%d", cfg.RestAPI.Port)) - log.Info("REST API is enabled") - } else { - log.Info("REST API is disabled") - } - log.WithFields( log.Fields{ "block": "main",