diff --git a/cmd/root.go b/cmd/root.go index ac01a021..6383ff21 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -384,7 +384,11 @@ func runSignalWrapper(cmd *Command) error { case p = <-startCh: } cmd.Println("The proxy has started successfully and is ready for new connections!") - defer p.Close() + defer func() { + if cErr := p.Close(); cErr != nil { + cmd.PrintErrf("error during shutdown: %v\n", cErr) + } + }() go func() { shutdownCh <- p.Serve(ctx) diff --git a/cmd/root_test.go b/cmd/root_test.go index 19d30c59..f869a7f3 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -362,7 +362,7 @@ func (*spyDialer) Close() error { } func TestCommandWithCustomDialer(t *testing.T) { - want := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance" + want := "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance" s := &spyDialer{} c := NewCommand(WithDialer(s)) // Keep the test output quiet diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 1952b9b8..67661756 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -176,76 +176,21 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client, } } - pc := newPortConfig(conf.Port) var mnts []*socketMount + pc := newPortConfig(conf.Port) for _, inst := range conf.Instances { - var ( - // network is one of "tcp" or "unix" - network string - // address is either a TCP host port, or a Unix socket - address string - ) - // IF - // a global Unix socket directory is NOT set AND - // an instance-level Unix socket is NOT set - // (e.g., I didn't set a Unix socket globally or for this instance) - // OR - // an instance-level TCP address or port IS set - // (e.g., I'm overriding any global settings to use TCP for this - // instance) - // use a TCP listener. - // Otherwise, use a Unix socket. - if (conf.UnixSocket == "" && inst.UnixSocket == "") || - (inst.Addr != "" || inst.Port != 0) { - network = "tcp" - - a := conf.Addr - if inst.Addr != "" { - a = inst.Addr - } - - var np int - switch { - case inst.Port != 0: - np = inst.Port - case conf.Port != 0: - np = pc.nextPort() - default: - np = pc.nextPort() - } - - address = net.JoinHostPort(a, fmt.Sprint(np)) - } else { - network = "unix" - - dir := conf.UnixSocket - if dir == "" { - dir = inst.UnixSocket - } - ud, err := UnixSocketDir(dir, inst.Name) - if err != nil { - return nil, err - } - // Create the parent directory that will hold the socket. - if _, err := os.Stat(ud); err != nil { - if err = os.Mkdir(ud, 0777); err != nil { - return nil, err - } - } - // use the Postgres-specific socket name - address = filepath.Join(ud, ".s.PGSQL.5432") - } - - m := &socketMount{inst: inst.Name} - addr, err := m.listen(ctx, network, address) + m, err := newSocketMount(ctx, conf, pc, inst) if err != nil { for _, m := range mnts { - m.close() + mErr := m.Close() + if mErr != nil { + cmd.PrintErrf("failed to close mount: %v", mErr) + } } return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err) } - cmd.Printf("[%s] Listening on %s\n", inst.Name, addr.String()) + cmd.Printf("[%s] Listening on %s\n", inst.Name, m.Addr()) mnts = append(mnts, m) } @@ -277,22 +222,45 @@ func (c *Client) Serve(ctx context.Context) error { return <-exitCh } -// Close triggers the proxyClient to shutdown. -func (c *Client) Close() { - defer c.dialer.Close() +// MultiErr is a group of errors wrapped into one. +type MultiErr []error + +// Error returns a single string representing one or more errors. +func (m MultiErr) Error() string { + l := len(m) + if l == 1 { + return m[0].Error() + } + var errs []string + for _, e := range m { + errs = append(errs, e.Error()) + } + return strings.Join(errs, ", ") +} + +func (c *Client) Close() error { + var mErr MultiErr for _, m := range c.mnts { - m.close() + err := m.Close() + if err != nil { + mErr = append(mErr, err) + } + } + cErr := c.dialer.Close() + if cErr != nil { + mErr = append(mErr, cErr) + } + if len(mErr) > 0 { + return mErr } + return nil } // serveSocketMount persistently listens to the socketMounts listener and proxies connections to a // given AlloyDB instance. func (c *Client) serveSocketMount(ctx context.Context, s *socketMount) error { - if s.listener == nil { - return fmt.Errorf("[%s] mount doesn't have a listener set", s.inst) - } for { - cConn, err := s.listener.Accept() + cConn, err := s.Accept() if err != nil { if nerr, ok := err.(net.Error); ok && nerr.Temporary() { c.cmd.PrintErrf("[%s] Error accepting connection: %v\n", s.inst, err) @@ -327,22 +295,82 @@ type socketMount struct { listener net.Listener } -// listen causes a socketMount to create a Listener at the specified network address. -func (s *socketMount) listen(ctx context.Context, network string, address string) (net.Addr, error) { +func newSocketMount(ctx context.Context, conf *Config, pc *portConfig, inst InstanceConnConfig) (*socketMount, error) { + var ( + // network is one of "tcp" or "unix" + network string + // address is either a TCP host port, or a Unix socket + address string + ) + // IF + // a global Unix socket directory is NOT set AND + // an instance-level Unix socket is NOT set + // (e.g., I didn't set a Unix socket globally or for this instance) + // OR + // an instance-level TCP address or port IS set + // (e.g., I'm overriding any global settings to use TCP for this + // instance) + // use a TCP listener. + // Otherwise, use a Unix socket. + if (conf.UnixSocket == "" && inst.UnixSocket == "") || + (inst.Addr != "" || inst.Port != 0) { + network = "tcp" + + a := conf.Addr + if inst.Addr != "" { + a = inst.Addr + } + + var np int + switch { + case inst.Port != 0: + np = inst.Port + default: + np = pc.nextPort() + } + + address = net.JoinHostPort(a, fmt.Sprint(np)) + } else { + network = "unix" + + dir := conf.UnixSocket + if dir == "" { + dir = inst.UnixSocket + } + ud, err := UnixSocketDir(dir, inst.Name) + if err != nil { + return nil, err + } + // Create the parent directory that will hold the socket. + if _, err := os.Stat(ud); err != nil { + if err = os.Mkdir(ud, 0777); err != nil { + return nil, err + } + } + // use the Postgres-specific socket name + address = filepath.Join(ud, ".s.PGSQL.5432") + } + lc := net.ListenConfig{KeepAlive: 30 * time.Second} - l, err := lc.Listen(ctx, network, address) + ln, err := lc.Listen(ctx, network, address) if err != nil { return nil, err } - s.listener = l - return s.listener.Addr(), nil + m := &socketMount{inst: inst.Name, listener: ln} + return m, nil +} + +func (s *socketMount) Addr() net.Addr { + return s.listener.Addr() +} + +func (s *socketMount) Accept() (net.Conn, error) { + return s.listener.Accept() } // close stops the mount from listening for any more connections -func (s *socketMount) close() error { - err := s.listener.Close() - s.listener = nil - return err +func (s *socketMount) Close() error { + return s.listener.Close() } // proxyConn sets up a bidirectional copy between two open connections diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 60e046da..f924fe55 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -16,11 +16,13 @@ package proxy_test import ( "context" + "errors" "io/ioutil" "net" "os" "path/filepath" "testing" + "time" "cloud.google.com/go/alloydbconn" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" @@ -37,13 +39,22 @@ type testCase struct { } func (fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) { - return nil, nil + conn, _ := net.Pipe() + return conn, nil } func (fakeDialer) Close() error { return nil } +type errorDialer struct { + fakeDialer +} + +func (errorDialer) Close() error { + return errors.New("errorDialer returns error on Close") +} + func createTempDir(t *testing.T) (string, func()) { testDir, err := ioutil.TempDir("", "*") if err != nil { @@ -216,6 +227,81 @@ func TestClientInitialization(t *testing.T) { } } +func TestClientClosesCleanly(t *testing.T) { + in := &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:reg:inst"}, + }, + Dialer: fakeDialer{}, + } + c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in) + if err != nil { + t.Fatalf("proxy.NewClient error want = nil, got = %v", err) + } + go c.Serve(context.Background()) + time.Sleep(time.Second) // allow the socket to start listening + + conn, dErr := net.Dial("tcp", "127.0.0.1:5000") + if dErr != nil { + t.Fatalf("net.Dial error = %v", dErr) + } + _ = conn.Close() + + if err := c.Close(); err != nil { + t.Fatalf("c.Close() error = %v", err) + } +} + +func TestClosesWithError(t *testing.T) { + in := &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:reg:inst"}, + }, + Dialer: errorDialer{}, + } + c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in) + if err != nil { + t.Fatalf("proxy.NewClient error want = nil, got = %v", err) + } + go c.Serve(context.Background()) + time.Sleep(time.Second) // allow the socket to start listening + + if err = c.Close(); err == nil { + t.Fatal("c.Close() should error, got nil") + } +} + +func TestMultiErrorFormatting(t *testing.T) { + tcs := []struct { + desc string + in proxy.MultiErr + want string + }{ + { + desc: "with one error", + in: proxy.MultiErr{errors.New("woops")}, + want: "woops", + }, + { + desc: "with many errors", + in: proxy.MultiErr{errors.New("woops"), errors.New("another error")}, + want: "woops, another error", + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + if got := tc.in.Error(); got != tc.want { + t.Errorf("want = %v, got = %v", tc.want, got) + } + }) + } +} + func TestClientInitializationWorksRepeatedly(t *testing.T) { // The client creates a Unix socket on initial startup and does not remove // it on shutdown. This test ensures the existing socket does not cause