diff --git a/clientv3/integration/dial_test.go b/clientv3/integration/dial_test.go index 17af8e5cbe1e..368c7f306718 100644 --- a/clientv3/integration/dial_test.go +++ b/clientv3/integration/dial_test.go @@ -15,7 +15,11 @@ package integration import ( + "io" + "io/ioutil" "math/rand" + "os" + "path/filepath" "testing" "time" @@ -66,6 +70,89 @@ func TestDialTLSExpired(t *testing.T) { } } +// TestDialTLSExpiredReload ensures server reloads expired certs, +// rejecting client requests, and vice versa. +func TestDialTLSExpiredReload(t *testing.T) { + defer testutil.AfterTest(t) + + ts, err := copyTLSFiles(testTLSInfo) + if err != nil { + t.Fatal(err) + } + certsDir := filepath.Dir(ts.KeyFile) + defer os.RemoveAll(certsDir) + + tse, err := copyTLSFiles(testTLSInfoExpired) + if err != nil { + t.Fatal(err) + } + dir2 := filepath.Dir(tse.KeyFile) + defer os.RemoveAll(dir2) + + var tmpDir string + tmpDir, err = ioutil.TempDir(os.TempDir(), "fixtures") + if err != nil { + t.Fatal(err) + } + os.RemoveAll(tmpDir) + defer os.RemoveAll(tmpDir) + + // start with valid certs + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) + defer clus.Terminate(t) + + // replace certs directory with expired ones + if err = os.Rename(certsDir, tmpDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(dir2, certsDir); err != nil { + t.Fatal(err) + } + + // 'tmpDir' now has valid certs + // 'certsDir' now has expired certs; 'dir2' does not exist + + // now server expects 'tls: bad certificate' + // on incoming client requests + tls, err := ts.ClientConfig() + if err != nil { + t.Fatal(err) + } + _, err = clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: 3 * time.Second, + TLS: tls, + }) + if err != grpc.ErrClientConnTimeout { + t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err) + } + + // swap expired certs back with valid ones + if err = os.Rename(tmpDir, dir2); err != nil { + t.Fatal(err) + } + if err = os.Rename(certsDir, tmpDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(dir2, certsDir); err != nil { + t.Fatal(err) + } + tls, err = ts.ClientConfig() + if err != nil { + t.Fatal(err) + } + var cl *clientv3.Client + cl, err = clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: 3 * time.Second, + TLS: tls, + }) + defer cl.Close() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + // TestDialSetEndpoints ensures SetEndpoints can replace unavailable endpoints with available ones. func TestDialSetEndpointsBeforeFail(t *testing.T) { testDialSetEndpoints(t, true) @@ -173,3 +260,52 @@ func TestDialForeignEndpoint(t *testing.T) { t.Fatal(err) } } + +// copyTLSFiles clones certs files to temp directory. +func copyTLSFiles(ti transport.TLSInfo) (transport.TLSInfo, error) { + tmpdir, err := ioutil.TempDir(os.TempDir(), "fixtures") + if err != nil { + return transport.TLSInfo{}, err + } + ci := transport.TLSInfo{ + KeyFile: filepath.Join(tmpdir, "server-key.pem"), + CertFile: filepath.Join(tmpdir, "server.pem"), + TrustedCAFile: filepath.Join(tmpdir, "etcd-root-ca.pem"), + ClientCertAuth: ti.ClientCertAuth, + } + if err = copyFile(ti.KeyFile, ci.KeyFile); err != nil { + return transport.TLSInfo{}, err + } + if err = copyFile(ti.CertFile, ci.CertFile); err != nil { + return transport.TLSInfo{}, err + } + if err = copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil { + return transport.TLSInfo{}, err + } + return ci, nil +} + +func copyFile(src, dst string) error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + + w, err := os.Create(dst) + if err != nil { + return err + } + defer w.Close() + + if _, err = io.Copy(w, f); err != nil { + return err + } + if err = w.Sync(); err != nil { + return err + } + if _, err = w.Seek(0, io.SeekStart); err != nil { + return err + } + return nil +}