diff --git a/cmd/server/command.go b/cmd/server/command.go index ac90645..1eb5bbc 100644 --- a/cmd/server/command.go +++ b/cmd/server/command.go @@ -27,6 +27,7 @@ func RegisterCommand(rootCmd *cobra.Command) error { serverCfg.RegisterTelemetryConfig(cmd) serverCfg.RegisterClusterConfig(cmd) serverCfg.RegisterMetricProviderConfig(cmd) + serverCfg.RegisterDebugConfig(cmd) logCfg.RegisterConfig(cmd) rootCmd.AddCommand(cmd) @@ -53,6 +54,7 @@ func runServer(_ *cobra.Command, _ []string) { } cfg := &server.Config{ + Debug: serverCfg.GetDebugEnabled(), Cluster: &clusterConfig, MetricProvider: metricProviderConfig, Server: &serverConfig, diff --git a/pkg/config/server/debug.go b/pkg/config/server/debug.go new file mode 100644 index 0000000..ab64f55 --- /dev/null +++ b/pkg/config/server/debug.go @@ -0,0 +1,30 @@ +package server + +import ( + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +const configKeyDebugEnable = "debug-enabled" + +// GetDebugEnabled is used to identify whether the operator has enabled the server debug API +// endpoints. +func GetDebugEnabled() bool { return viper.GetBool(configKeyDebugEnable) } + +// RegisterDebugConfig registers the CLI flags used to alter the server debug API endpoints. +func RegisterDebugConfig(cmd *cobra.Command) { + flags := cmd.PersistentFlags() + + { + const ( + key = configKeyDebugEnable + longOpt = "debug-enabled" + defaultValue = false + description = "Specifies if the debugging HTTP endpoints should be enabled" + ) + + flags.Bool(longOpt, defaultValue, description) + _ = viper.BindPFlag(key, flags.Lookup(longOpt)) + viper.SetDefault(key, defaultValue) + } +} diff --git a/pkg/config/server/debug_test.go b/pkg/config/server/debug_test.go new file mode 100644 index 0000000..c6574b5 --- /dev/null +++ b/pkg/config/server/debug_test.go @@ -0,0 +1,15 @@ +package server + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func Test_DebugConfig(t *testing.T) { + fakeCMD := &cobra.Command{} + RegisterDebugConfig(fakeCMD) + cfg := GetClusterConfig() + assert.False(t, GetDebugEnabled(), cfg.Addr) +} diff --git a/pkg/server/config.go b/pkg/server/config.go index 8821d78..e77e9ef 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -5,6 +5,7 @@ import ( ) type Config struct { + Debug bool Cluster *serverCfg.ClusterConfig MetricProvider *serverCfg.MetricProviderConfig Server *serverCfg.Config @@ -59,3 +60,17 @@ const ( routeSystemInfoName = "GetSystemInfo" routeSystemInfoPattern = "/v1/system/info" ) + +// Debug server routes. +const ( + routeGetDebugPPROFName = "GetDebugPPROF" + routeGetDebugPPROFPattern = "/debug/pprof/" + routeGetDebugPPROFCMDLineName = "GetDebugPPROFCMDLine" + routeGetDebugPPROFCMDLinePattern = "/debug/pprof/cmdline" + routeGetDebugPPROFProfileName = "GetDebugPPROFProfile" + routeGetDebugPPROFProfilePattern = "/debug/pprof/profile" + routeGetDebugPPROFSymbolName = "GetDebugPPROFSymbol" + routeGetDebugPPROFSymbolPattern = "/debug/pprof/symbol" + routeGetDebugPPROFTraceName = "GetDebugPPROFTrace" + routeGetDebugPPROFTracePattern = "/debug/pprof/trace" +) diff --git a/pkg/server/routes.go b/pkg/server/routes.go index f8c5fca..1d5e0e0 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -2,6 +2,7 @@ package server import ( "net/http" + "net/http/pprof" policyV1 "github.com/jrasell/sherpa/pkg/policy/v1" scaleV1 "github.com/jrasell/sherpa/pkg/scale/v1" @@ -33,6 +34,12 @@ func (h *HTTPServer) setupRoutes() *router.RouteTable { policyRoutes := h.setupPolicyRoutes() r = append(r, policyRoutes) + // Setup the server debug routes if enabled. + if h.cfg.Debug { + debugRoutes := h.setupDebugRoutes() + r = append(r, debugRoutes) + } + // Setup the UI routes if it is enabled. if h.cfg.Server.UI { uiRoutes := h.setupUIRoutes() @@ -213,3 +220,40 @@ func (h *HTTPServer) setupAPIPolicyRoutes() []router.Route { }, } } + +func (h *HTTPServer) setupDebugRoutes() []router.Route { + h.logger.Debug().Msg("setting up server Debug routes") + + return router.Routes{ + router.Route{ + Name: routeGetDebugPPROFName, + Method: http.MethodGet, + Pattern: routeGetDebugPPROFPattern, + HandlerFunc: pprof.Index, + }, + router.Route{ + Name: routeGetDebugPPROFCMDLineName, + Method: http.MethodGet, + Pattern: routeGetDebugPPROFCMDLinePattern, + HandlerFunc: pprof.Cmdline, + }, + router.Route{ + Name: routeGetDebugPPROFProfileName, + Method: http.MethodGet, + Pattern: routeGetDebugPPROFProfilePattern, + HandlerFunc: pprof.Profile, + }, + router.Route{ + Name: routeGetDebugPPROFSymbolName, + Method: http.MethodGet, + Pattern: routeGetDebugPPROFSymbolPattern, + HandlerFunc: pprof.Symbol, + }, + router.Route{ + Name: routeGetDebugPPROFTraceName, + Method: http.MethodGet, + Pattern: routeGetDebugPPROFTracePattern, + HandlerFunc: pprof.Trace, + }, + } +}