From de95e9d10ff75451ca892ca568e1791b8f879f38 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Wed, 2 Jun 2021 14:05:52 -0400 Subject: [PATCH] Fix closing of decorated watcher channel on timeout --- .../generic/registry/decorated_watcher.go | 22 ++-- .../registry/decorated_watcher_test.go | 73 ++++++++++-- .../pkg/registry/generic/registry/store.go | 4 +- test/integration/apimachinery/BUILD | 1 + .../apimachinery/watch_timeout_test.go | 107 ++++++++++++++++++ 5 files changed, 189 insertions(+), 18 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher.go b/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher.go index 005a376d404b0..fe677a43ece1c 100644 --- a/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher.go +++ b/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher.go @@ -31,8 +31,8 @@ type decoratedWatcher struct { resultCh chan watch.Event } -func newDecoratedWatcher(w watch.Interface, decorator ObjectFunc) *decoratedWatcher { - ctx, cancel := context.WithCancel(context.Background()) +func newDecoratedWatcher(ctx context.Context, w watch.Interface, decorator ObjectFunc) *decoratedWatcher { + ctx, cancel := context.WithCancel(ctx) d := &decoratedWatcher{ w: w, decorator: decorator, @@ -43,14 +43,18 @@ func newDecoratedWatcher(w watch.Interface, decorator ObjectFunc) *decoratedWatc return d } +// run decorates watch events from the underlying watcher until its result channel +// is closed or the passed in context is done. +// When run() returns, decoratedWatcher#resultCh is closed. func (d *decoratedWatcher) run(ctx context.Context) { var recv, send watch.Event var ok bool + defer close(d.resultCh) for { select { case recv, ok = <-d.w.ResultChan(): - // The underlying channel may be closed after timeout. if !ok { + // The underlying channel was closed, cancel our context d.cancel() return } @@ -67,20 +71,24 @@ func (d *decoratedWatcher) run(ctx context.Context) { } select { case d.resultCh <- send: - if send.Type == watch.Error { - d.cancel() - } + // propagated event successfully case <-ctx.Done(): + // context timed out or was cancelled, stop the underlying watcher + d.w.Stop() + return } case <-ctx.Done(): + // context timed out or was cancelled, stop the underlying watcher d.w.Stop() - close(d.resultCh) return } } } func (d *decoratedWatcher) Stop() { + // stop the underlying watcher + d.w.Stop() + // cancel our context d.cancel() } diff --git a/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher_test.go b/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher_test.go index 0afbd773f0781..198a0ca24685b 100644 --- a/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher_test.go +++ b/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/decorated_watcher_test.go @@ -17,6 +17,7 @@ limitations under the License. package registry import ( + "context" "fmt" "testing" "time" @@ -31,26 +32,80 @@ import ( func TestDecoratedWatcher(t *testing.T) { w := watch.NewFake() decorator := func(obj runtime.Object) error { - pod := obj.(*example.Pod) - pod.Annotations = map[string]string{"decorated": "true"} + if pod, ok := obj.(*example.Pod); ok { + pod.Annotations = map[string]string{"decorated": "true"} + } return nil } - dw := newDecoratedWatcher(w, decorator) + ctx, cancel := context.WithCancel(context.Background()) + dw := newDecoratedWatcher(ctx, w, decorator) defer dw.Stop() - go w.Add(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}) + go func() { + w.Error(&metav1.Status{Status: "Failure"}) + w.Add(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}) + w.Error(&metav1.Status{Status: "Failure"}) + w.Modify(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}) + w.Error(&metav1.Status{Status: "Failure"}) + w.Delete(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}) + }() + + expectErrorEvent(t, dw) // expect error is plumbed and doesn't force close the watcher + expectPodEvent(t, dw, watch.Added) + expectErrorEvent(t, dw) // expect error is plumbed and doesn't force close the watcher + expectPodEvent(t, dw, watch.Modified) + expectErrorEvent(t, dw) // expect error is plumbed and doesn't force close the watcher + expectPodEvent(t, dw, watch.Deleted) + + // cancel the passed-in context to simulate request timeout + cancel() + + // expect the decorated channel to be closed + select { + case e, ok := <-dw.ResultChan(): + if ok { + t.Errorf("expected result chan closed, got %#v", e) + } + case <-time.After(wait.ForeverTestTimeout): + t.Errorf("timeout after %v", wait.ForeverTestTimeout) + } + + // expect the underlying watcher to have been stopped as a result of the context cancellation + if !w.IsStopped() { + t.Errorf("expected underlying watcher to be stopped") + } +} + +func expectPodEvent(t *testing.T, dw *decoratedWatcher, watchType watch.EventType) { select { case e := <-dw.ResultChan(): pod, ok := e.Object.(*example.Pod) if !ok { - t.Errorf("Should received object of type *api.Pod, get type (%T)", e.Object) - return + t.Fatalf("Should received object of type *api.Pod, get type (%T)", e.Object) } if pod.Annotations["decorated"] != "true" { - t.Errorf("pod.Annotations[\"decorated\"], want=%s, get=%s", "true", pod.Labels["decorated"]) + t.Fatalf("pod.Annotations[\"decorated\"], want=%s, get=%s", "true", pod.Labels["decorated"]) + } + if e.Type != watchType { + t.Fatalf("expected type %s, got %s", watchType, e.Type) } case <-time.After(wait.ForeverTestTimeout): - t.Errorf("timeout after %v", wait.ForeverTestTimeout) + t.Fatalf("timeout after %v", wait.ForeverTestTimeout) + } +} + +func expectErrorEvent(t *testing.T, dw *decoratedWatcher) { + select { + case e := <-dw.ResultChan(): + _, ok := e.Object.(*metav1.Status) + if !ok { + t.Fatalf("Should received object of type *metav1.Status, get type (%T)", e.Object) + } + if e.Type != watch.Error { + t.Fatalf("expected type %s, got %s", watch.Error, e.Type) + } + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("timeout after %v", wait.ForeverTestTimeout) } } @@ -60,7 +115,7 @@ func TestDecoratedWatcherError(t *testing.T) { decorator := func(obj runtime.Object) error { return expErr } - dw := newDecoratedWatcher(w, decorator) + dw := newDecoratedWatcher(context.Background(), w, decorator) defer dw.Stop() go w.Add(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}) diff --git a/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/store.go b/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/store.go index 7c2f4c390eba7..469a26858faa4 100644 --- a/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/store.go +++ b/staging/src/k8s.io/apiserver/pkg/registry/generic/registry/store.go @@ -1136,7 +1136,7 @@ func (e *Store) WatchPredicate(ctx context.Context, p storage.SelectionPredicate return nil, err } if e.Decorator != nil { - return newDecoratedWatcher(w, e.Decorator), nil + return newDecoratedWatcher(ctx, w, e.Decorator), nil } return w, nil } @@ -1149,7 +1149,7 @@ func (e *Store) WatchPredicate(ctx context.Context, p storage.SelectionPredicate return nil, err } if e.Decorator != nil { - return newDecoratedWatcher(w, e.Decorator), nil + return newDecoratedWatcher(ctx, w, e.Decorator), nil } return w, nil } diff --git a/test/integration/apimachinery/BUILD b/test/integration/apimachinery/BUILD index d6dd4064e67da..49bd199ffc343 100644 --- a/test/integration/apimachinery/BUILD +++ b/test/integration/apimachinery/BUILD @@ -25,6 +25,7 @@ go_test( "//staging/src/k8s.io/client-go/tools/watch:go_default_library", "//staging/src/k8s.io/kubectl/pkg/proxy:go_default_library", "//test/integration/framework:go_default_library", + "//vendor/golang.org/x/net/websocket:go_default_library", ], ) diff --git a/test/integration/apimachinery/watch_timeout_test.go b/test/integration/apimachinery/watch_timeout_test.go index cabd37cc6dc38..98ac6fcdfbcfe 100644 --- a/test/integration/apimachinery/watch_timeout_test.go +++ b/test/integration/apimachinery/watch_timeout_test.go @@ -17,13 +17,19 @@ limitations under the License. package apimachinery import ( + "bytes" "context" + "io" + "log" "net/http/httptest" "net/http/httputil" "net/url" + "strings" "testing" "time" + "golang.org/x/net/websocket" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -35,6 +41,107 @@ import ( "k8s.io/kubernetes/test/integration/framework" ) +func TestWebsocketWatchClientTimeout(t *testing.T) { + // server setup + masterConfig := framework.NewIntegrationTestMasterConfig() + instance, s, closeFn := framework.RunAMaster(masterConfig) + defer closeFn() + + // object setup + service := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{{Name: "http", Port: 80}}, + }, + } + configmap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + } + clientset, err := kubernetes.NewForConfig(instance.GenericAPIServer.LoopbackClientConfig) + if err != nil { + t.Fatal(err) + } + if _, err := clientset.CoreV1().Services("default").Create(context.TODO(), service, metav1.CreateOptions{}); err != nil { + t.Fatal(err) + } + if _, err := clientset.CoreV1().ConfigMaps("default").Create(context.TODO(), configmap, metav1.CreateOptions{}); err != nil { + t.Fatal(err) + } + + testcases := []struct { + name string + path string + timeout time.Duration + expectResult string + }{ + { + name: "configmaps", + path: "/api/v1/configmaps?watch=true&timeoutSeconds=5", + timeout: 10 * time.Second, + expectResult: `"name":"test"`, + }, + { + name: "services", + path: "/api/v1/services?watch=true&timeoutSeconds=5", + timeout: 10 * time.Second, + expectResult: `"name":"test"`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + + u, _ := url.Parse(s.URL) + apiURL := "ws://" + u.Host + tc.path + wsc, err := websocket.NewConfig(apiURL, apiURL) + if err != nil { + log.Fatal(err) + } + + wsConn, err := websocket.DialConfig(wsc) + if err != nil { + t.Fatal(err) + } + defer wsConn.Close() + + resultCh := make(chan string) + go func() { + defer close(resultCh) + buf := &bytes.Buffer{} + for { + var msg []byte + if err := websocket.Message.Receive(wsConn, &msg); err != nil { + if err == io.EOF { + resultCh <- buf.String() + return + } + if !t.Failed() { + // if we didn't already fail, treat this as an error + t.Errorf("Failed to read completely from websocket %v", err) + } + return + } + if len(msg) == 0 { + t.Logf("zero-length message") + continue + } + t.Logf("Read %v %v", len(msg), string(msg)) + buf.Write(msg) + } + }() + + select { + case resultString := <-resultCh: + if !strings.Contains(resultString, tc.expectResult) { + t.Fatalf("Unexpected result:\n%s", resultString) + } + case <-time.After(tc.timeout): + t.Fatalf("hit timeout before connection closed") + } + }) + } +} + func TestWatchClientTimeout(t *testing.T) { masterConfig := framework.NewIntegrationTestMasterConfig() _, s, closeFn := framework.RunAMaster(masterConfig)