From e0b95b9a0a6c841190950f36eee61b58abb6e66c Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Thu, 28 Jul 2022 11:12:03 -0600 Subject: [PATCH] feat: add support for health-check flag (#85) This is an adaptation of https://github.com/GoogleCloudPlatform/cloudsql-proxy/pull/1271. --- README.md | 2 +- cmd/root.go | 98 +++++---- cmd/root_test.go | 4 - internal/healthcheck/healthcheck.go | 110 +++++++++++ internal/healthcheck/healthcheck_test.go | 241 +++++++++++++++++++++++ internal/proxy/proxy.go | 49 ++++- internal/proxy/proxy_test.go | 143 +++++++++++++- 7 files changed, 596 insertions(+), 51 deletions(-) create mode 100644 internal/healthcheck/healthcheck.go create mode 100644 internal/healthcheck/healthcheck_test.go diff --git a/README.md b/README.md index a6792a21..bdd99720 100644 --- a/README.md +++ b/README.md @@ -271,7 +271,7 @@ and from a AlloyDB instance. The `ALL_PROXY` environment variable supports `socks5h` protocol. The `HTTPS_PROXY` (or `HTTP_PROXY`) specifies the proxy for all HTTP(S) traffic -to the SQL Admin API. Specifying `HTTPS_PROXY` or `HTTP_PROXY` is only necessary +to the AlloyDB Admin API. Specifying `HTTPS_PROXY` or `HTTP_PROXY` is only necessary when you want to proxy this traffic. Otherwise, it is optional. See [`http.ProxyFromEnvironment`](https://pkg.go.dev/net/http@go1.17.3#ProxyFromEnvironment) for possible values. diff --git a/cmd/root.go b/cmd/root.go index cb5ef182..6453c1f6 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,6 +32,7 @@ import ( "contrib.go.opencensus.io/exporter/prometheus" "contrib.go.opencensus.io/exporter/stackdriver" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/healthcheck" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/log" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" "github.com/spf13/cobra" @@ -88,6 +89,7 @@ type Command struct { telemetryProject string telemetryPrefix string prometheusNamespace string + healthCheck bool httpPort string } @@ -186,6 +188,10 @@ the maximum time has passed. Defaults to 0s.`) "Enable Prometheus for metric collection using the provided namespace") cmd.PersistentFlags().StringVar(&c.httpPort, "http-port", "9090", "Port for the Prometheus server to use") + cmd.PersistentFlags().BoolVar(&c.healthCheck, "health-check", false, + `Enables HTTP endpoints /startup, /liveness, and /readiness +that report on the proxy's health. Endpoints are available on localhost +only. Uses the port specified by the http-port flag.`) // Global and per instance flags cmd.PersistentFlags().StringVarP(&c.conf.Addr, "address", "a", "127.0.0.1", @@ -241,18 +247,18 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { cmd.logger.Infof("Using API Endpoint %v", conf.APIEndpointURL) } - if userHasSet("http-port") && !userHasSet("prometheus-namespace") { - return newBadCommandError("cannot specify --http-port without --prometheus-namespace") + if userHasSet("http-port") && !userHasSet("prometheus-namespace") && !userHasSet("health-check") { + cmd.logger.Infof("Ignoring --http-port because --prometheus-namespace or --health-check was not set") } if !userHasSet("telemetry-project") && userHasSet("telemetry-prefix") { - cmd.logger.Infof("Ignoring telementry-prefix as telemetry-project was not set") + cmd.logger.Infof("Ignoring --telementry-prefix as --telemetry-project was not set") } if !userHasSet("telemetry-project") && userHasSet("disable-metrics") { - cmd.logger.Infof("Ignoring disable-metrics as telemetry-project was not set") + cmd.logger.Infof("Ignoring --disable-metrics as --telemetry-project was not set") } if !userHasSet("telemetry-project") && userHasSet("disable-traces") { - cmd.logger.Infof("Ignoring disable-traces as telemetry-project was not set") + cmd.logger.Infof("Ignoring --disable-traces as --telemetry-project was not set") } var ics []proxy.InstanceConnConfig @@ -328,9 +334,8 @@ func runSignalWrapper(cmd *Command) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - // Configure Cloud Trace and/or Cloud Monitoring based on command - // invocation. If a project has not been enabled, no traces or metrics are - // enabled. + // Configure collectors before the proxy has started to ensure we are + // collecting metrics before *ANY* AlloyDB Admin API calls are made. enableMetrics := !cmd.disableMetrics enableTraces := !cmd.disableTraces if cmd.telemetryProject != "" && (enableMetrics || enableTraces) { @@ -358,40 +363,22 @@ func runSignalWrapper(cmd *Command) error { }() } - shutdownCh := make(chan error) - + var ( + needsHTTPServer bool + mux = http.NewServeMux() + ) if cmd.prometheusNamespace != "" { + needsHTTPServer = true e, err := prometheus.NewExporter(prometheus.Options{ Namespace: cmd.prometheusNamespace, }) if err != nil { return err } - mux := http.NewServeMux() mux.Handle("/metrics", e) - addr := fmt.Sprintf("localhost:%s", cmd.httpPort) - server := &http.Server{Addr: addr, Handler: mux} - go func() { - select { - case <-ctx.Done(): - // Give the HTTP server a second to shutdown cleanly. - ctx2, _ := context.WithTimeout(context.Background(), time.Second) - if err := server.Shutdown(ctx2); err != nil { - cmd.logger.Errorf("failed to shutdown Prometheus HTTP server: %v\n", err) - } - } - }() - go func() { - err := server.ListenAndServe() - if err == http.ErrServerClosed { - return - } - if err != nil { - shutdownCh <- fmt.Errorf("failed to start prometheus HTTP server: %v", err) - } - }() } + shutdownCh := make(chan error) // watch for sigterm / sigint signals signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) @@ -429,18 +416,55 @@ func runSignalWrapper(cmd *Command) error { cmd.logger.Errorf("The proxy has encountered a terminal error: %v", err) return err case p = <-startCh: + cmd.logger.Infof("The proxy has started successfully and is ready for new connections!") } - cmd.logger.Infof("The proxy has started successfully and is ready for new connections!") - defer p.Close() defer func() { if cErr := p.Close(); cErr != nil { cmd.logger.Errorf("error during shutdown: %v", cErr) } }() - go func() { - shutdownCh <- p.Serve(ctx) - }() + notify := func() {} + if cmd.healthCheck { + needsHTTPServer = true + hc := healthcheck.NewCheck(p, cmd.logger) + mux.HandleFunc("/startup", hc.HandleStartup) + mux.HandleFunc("/readiness", hc.HandleReadiness) + mux.HandleFunc("/liveness", hc.HandleLiveness) + notify = hc.NotifyStarted + } + + // Start the HTTP server if anything requiring HTTP is specified. + if needsHTTPServer { + server := &http.Server{ + Addr: fmt.Sprintf("localhost:%s", cmd.httpPort), + Handler: mux, + } + // Start the HTTP server. + go func() { + err := server.ListenAndServe() + if err == http.ErrServerClosed { + return + } + if err != nil { + shutdownCh <- fmt.Errorf("failed to start HTTP server: %v", err) + } + }() + // Handle shutdown of the HTTP server gracefully. + go func() { + select { + case <-ctx.Done(): + // Give the HTTP server a second to shutdown cleanly. + ctx2, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := server.Shutdown(ctx2); err != nil { + cmd.logger.Errorf("failed to shutdown Prometheus HTTP server: %v\n", err) + } + } + }() + } + + go func() { shutdownCh <- p.Serve(ctx, notify) }() err := <-shutdownCh switch { diff --git a/cmd/root_test.go b/cmd/root_test.go index 311696d9..55497456 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -350,10 +350,6 @@ func TestNewCommandWithErrors(t *testing.T) { desc: "using the unix socket and port query params", args: []string{"projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path&port=5000"}, }, - { - desc: "enabling a Prometheus port without a namespace", - args: []string{"--http-port", "1111", "proj:region:inst"}, - }, { desc: "using an invalid url for host flag", args: []string{"--host", "https://invalid:url[/]", "proj:region:inst"}, diff --git a/internal/healthcheck/healthcheck.go b/internal/healthcheck/healthcheck.go new file mode 100644 index 00000000..ecb15d7f --- /dev/null +++ b/internal/healthcheck/healthcheck.go @@ -0,0 +1,110 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package healthcheck tests and communicates the health of the AlloyDB Auth +// proxy. +package healthcheck + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" +) + +// Check provides HTTP handlers for use as healthchecks typically in a +// Kubernetes context. +type Check struct { + once *sync.Once + started chan struct{} + proxy *proxy.Client + logger alloydb.Logger +} + +// NewCheck is the initializer for Check. +func NewCheck(p *proxy.Client, l alloydb.Logger) *Check { + return &Check{ + once: &sync.Once{}, + started: make(chan struct{}), + proxy: p, + logger: l, + } +} + +// NotifyStarted notifies the check that the proxy has started up successfully. +func (c *Check) NotifyStarted() { + c.once.Do(func() { close(c.started) }) +} + +// HandleStartup reports whether the Check has been notified of startup. +func (c *Check) HandleStartup(w http.ResponseWriter, _ *http.Request) { + select { + case <-c.started: + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + default: + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("error")) + } +} + +var errNotStarted = errors.New("proxy is not started") + +// HandleReadiness ensures the Check has been notified of successful startup, +// that the proxy has not reached maximum connections, and that all connections +// are healthy. +func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + select { + case <-c.started: + default: + c.logger.Errorf("[Health Check] Readiness failed: %v", errNotStarted) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(errNotStarted.Error())) + return + } + + if open, max := c.proxy.ConnCount(); max > 0 && open == max { + err := fmt.Errorf("max connections reached (open = %v, max = %v)", open, max) + c.logger.Errorf("[Health Check] Readiness failed: %v", err) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + + err := c.proxy.CheckConnections(ctx) + if err != nil { + c.logger.Errorf("[Health Check] Readiness failed: %v", err) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) +} + +// HandleLiveness indicates the process is up and responding to HTTP requests. +// If this check fails (because it's not reachable), the process is in a bad +// state and should be restarted. +func (c *Check) HandleLiveness(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) +} diff --git a/internal/healthcheck/healthcheck_test.go b/internal/healthcheck/healthcheck_test.go new file mode 100644 index 00000000..ed8f0286 --- /dev/null +++ b/internal/healthcheck/healthcheck_test.go @@ -0,0 +1,241 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package healthcheck_test + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "cloud.google.com/go/alloydbconn" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/healthcheck" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/log" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" +) + +var ( + logger = log.NewStdLogger(os.Stdout, os.Stdout) + proxyHost = "127.0.0.1" + proxyPort = 9000 +) + +func proxyAddr() string { + return fmt.Sprintf("%s:%d", proxyHost, proxyPort) +} + +func dialTCP(t *testing.T, addr string) net.Conn { + for i := 0; i < 10; i++ { + conn, err := net.Dial("tcp", addr) + if err == nil { + return conn + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("failed to dial %v", addr) + return nil +} + +type fakeDialer struct{} + +func (*fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) { + conn, _ := net.Pipe() + return conn, nil +} + +func (*fakeDialer) EngineVersion(ctx context.Context, inst string) (string, error) { + return "POSTGRES_14", nil +} + +func (*fakeDialer) Close() error { + return nil +} + +type errorDialer struct { + fakeDialer +} + +func (*errorDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) { + return nil, errors.New("errorDialer always errors") +} + +func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer) *proxy.Client { + c := &proxy.Config{ + Addr: proxyHost, + Port: proxyPort, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + MaxConnections: maxConns, + } + p, err := proxy.NewClient(context.Background(), dialer, logger, c) + if err != nil { + t.Fatalf("proxy.NewClient: %v", err) + } + return p +} + +func newTestProxyWithMaxConns(t *testing.T, maxConns uint64) *proxy.Client { + return newProxyWithParams(t, maxConns, &fakeDialer{}) +} + +func newTestProxyWithDialer(t *testing.T, d alloydb.Dialer) *proxy.Client { + return newProxyWithParams(t, 0, d) +} + +func newTestProxy(t *testing.T) *proxy.Client { + return newProxyWithParams(t, 0, &fakeDialer{}) +} + +func TestHandleStartupWhenNotNotified(t *testing.T) { + p := newTestProxy(t) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + + rec := httptest.NewRecorder() + check.HandleStartup(rec, &http.Request{}) + + // Startup is not complete because the Check has not been notified of the + // proxy's startup. + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } +} + +func TestHandleStartupWhenNotified(t *testing.T) { + p := newTestProxy(t) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + + check.NotifyStarted() + + rec := httptest.NewRecorder() + check.HandleStartup(rec, &http.Request{}) + + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } +} + +func TestHandleReadinessWhenNotNotified(t *testing.T) { + p := newTestProxy(t) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{}) + + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } +} + +func TestHandleReadinessForMaxConns(t *testing.T) { + p := newTestProxyWithMaxConns(t, 1) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + started := make(chan struct{}) + check := healthcheck.NewCheck(p, logger) + go p.Serve(context.Background(), func() { + check.NotifyStarted() + close(started) + }) + select { + case <-started: + // proxy has started + case <-time.After(10 * time.Second): + t.Fatal("proxy has not started after 10 seconds") + } + + conn := dialTCP(t, proxyAddr()) + defer conn.Close() + + // The proxy calls the dialer in a separate goroutine. So wait for that + // goroutine to run before asserting on the readiness response. + waitForConnect := func(t *testing.T, wantCode int) *http.Response { + for i := 0; i < 10; i++ { + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{}) + resp := rec.Result() + if resp.StatusCode == wantCode { + return resp + } + time.Sleep(time.Second) + } + return nil + } + resp := waitForConnect(t, http.StatusServiceUnavailable) + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + if !strings.Contains(string(body), "max connections") { + t.Fatalf("want max connections error, got = %v", string(body)) + } +} + +func TestHandleReadinessWithConnectionProblems(t *testing.T) { + p := newTestProxyWithDialer(t, &errorDialer{}) // error dialer will error on dial + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + check.NotifyStarted() + + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{}) + + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + if want := "errorDialer"; !strings.Contains(string(body), want) { + t.Fatalf("want substring with = %q, got = %v", want, string(body)) + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 79bb74b5..8bc17c94 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -196,7 +196,6 @@ func NewClient(ctx context.Context, d alloydb.Dialer, l alloydb.Logger, conf *Co // Check if the caller has configured a dialer. // Otherwise, initialize a new one. if d == nil { - var err error dialerOpts, err := conf.DialerOptions(l) if err != nil { return nil, fmt.Errorf("error initializing dialer: %v", err) @@ -235,9 +234,54 @@ func NewClient(ctx context.Context, d alloydb.Dialer, l alloydb.Logger, conf *Co return c, nil } +// CheckConnections dials each registered instance and reports any errors that +// may have occurred. +func (c *Client) CheckConnections(ctx context.Context) error { + var ( + wg sync.WaitGroup + errCh = make(chan error, len(c.mnts)) + ) + for _, m := range c.mnts { + wg.Add(1) + go func(inst string) { + defer wg.Done() + conn, err := c.dialer.Dial(ctx, inst) + if err != nil { + errCh <- err + return + } + cErr := conn.Close() + if err != nil { + errCh <- fmt.Errorf("%v: %v", inst, cErr) + } + }(m.inst) + } + wg.Wait() + + var mErr MultiErr + for i := 0; i < len(c.mnts); i++ { + select { + case err := <-errCh: + mErr = append(mErr, err) + default: + continue + } + } + if len(mErr) > 0 { + return mErr + } + return nil +} + +// ConnCount returns the number of open connections and the maximum allowed +// connections. Returns 0 when the maximum allowed connections have not been set. +func (c *Client) ConnCount() (uint64, uint64) { + return atomic.LoadUint64(&c.connCount), c.maxConns +} + // Serve starts proxying connections for all configured instances using the // associated socket. -func (c *Client) Serve(ctx context.Context) error { +func (c *Client) Serve(ctx context.Context, notify func()) error { ctx, cancel := context.WithCancel(ctx) defer cancel() exitCh := make(chan error) @@ -258,6 +302,7 @@ func (c *Client) Serve(ctx context.Context) error { } }(m) } + notify() return <-exitCh } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2f11b377..f53d1acf 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -67,6 +67,10 @@ type errorDialer struct { fakeDialer } +func (*errorDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) { + return nil, errors.New("errorDialer returns error on Dial") +} + func (*errorDialer) Close() error { return errors.New("errorDialer returns error on Close") } @@ -217,7 +221,11 @@ func TestClientInitialization(t *testing.T) { if err != nil { t.Fatalf("want error = nil, got = %v", err) } - defer c.Close() + defer func() { + if err := c.Close(); err != nil { + t.Logf("failed to close client: %v", err) + } + }() for _, addr := range tc.wantTCPAddrs { conn := tryTCPDial(t, addr) err = conn.Close() @@ -256,7 +264,7 @@ func TestClientLimitsMaxConnections(t *testing.T) { t.Fatalf("proxy.NewClient error: %v", err) } defer c.Close() - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn1, err1 := net.Dial("tcp", "127.0.0.1:5000") if err1 != nil { @@ -323,7 +331,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn := tryTCPDial(t, "127.0.0.1:5000") _ = conn.Close() @@ -338,7 +346,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) var open []net.Conn for i := 0; i < 5; i++ { @@ -369,7 +377,7 @@ func TestClientClosesCleanly(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn := tryTCPDial(t, "127.0.0.1:5000") _ = conn.Close() @@ -392,7 +400,7 @@ func TestClosesWithError(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn := tryTCPDial(t, "127.0.0.1:5000") defer conn.Close() @@ -496,7 +504,7 @@ func TestClientInitializationWithCustomHost(t *testing.T) { } defer c.Close() - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn := tryTCPDial(t, "localhost:7000") defer conn.Close() @@ -520,3 +528,124 @@ func TestClientInitializationWithCustomHost(t *testing.T) { spyWasCalled(t) } + +func TestClientNotifiesCallerOnServe(t *testing.T) { + ctx := context.Background() + in := &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + } + logger := log.NewStdLogger(os.Stdout, os.Stdout) + c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, in) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + done := make(chan struct{}) + notify := func() { close(done) } + + go c.Serve(ctx, notify) + + verifyNotification := func(t *testing.T, ch <-chan struct{}) { + for i := 0; i < 10; i++ { + select { + case <-ch: + return + default: + time.Sleep(100 * time.Millisecond) + } + } + t.Fatal("channel should have been closed but was not") + } + verifyNotification(t, done) +} + +func TestClientConnCount(t *testing.T) { + logger := log.NewStdLogger(os.Stdout, os.Stdout) + in := &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + MaxConnections: 10, + } + + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + defer c.Close() + go c.Serve(context.Background(), func() {}) + + gotOpen, gotMax := c.ConnCount() + if gotOpen != 0 { + t.Fatalf("want 0 open connections, got = %v", gotOpen) + } + if gotMax != 10 { + t.Fatalf("want 10 max connections, got = %v", gotMax) + } + + conn := tryTCPDial(t, "127.0.0.1:5000") + defer conn.Close() + + verifyOpen := func(t *testing.T, want uint64) { + var got uint64 + for i := 0; i < 10; i++ { + got, _ = c.ConnCount() + if got == want { + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("open connections, want = %v, got = %v", want, got) + } + verifyOpen(t, 1) +} + +func TestCheckConnections(t *testing.T) { + logger := log.NewStdLogger(os.Stdout, os.Stdout) + in := &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + } + d := &fakeDialer{} + c, err := proxy.NewClient(context.Background(), d, logger, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + defer c.Close() + go c.Serve(context.Background(), func() {}) + + if err = c.CheckConnections(context.Background()); err != nil { + t.Fatalf("CheckConnections failed: %v", err) + } + + if want, got := 1, d.dialAttempts(); want != got { + t.Fatalf("dial attempts: want = %v, got = %v", want, got) + } + + in = &proxy.Config{ + Addr: "127.0.0.1", + Port: 6000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg1"}, + {Name: "proj:region:pg2"}, + }, + } + ed := &errorDialer{} + c, err = proxy.NewClient(context.Background(), ed, logger, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + defer c.Close() + go c.Serve(context.Background(), func() {}) + + err = c.CheckConnections(context.Background()) + if err == nil { + t.Fatal("CheckConnections should have failed, but did not") + } +}