Skip to content

Commit

Permalink
Fix closing of decorated watcher channel on timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
liggitt committed Jun 4, 2021
1 parent bd3ce3a commit de95e9d
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ type decoratedWatcher struct {
resultCh chan watch.Event
}

func newDecoratedWatcher(w watch.Interface, decorator ObjectFunc) *decoratedWatcher {
ctx, cancel := context.WithCancel(context.Background())
func newDecoratedWatcher(ctx context.Context, w watch.Interface, decorator ObjectFunc) *decoratedWatcher {
ctx, cancel := context.WithCancel(ctx)
d := &decoratedWatcher{
w: w,
decorator: decorator,
Expand All @@ -43,14 +43,18 @@ func newDecoratedWatcher(w watch.Interface, decorator ObjectFunc) *decoratedWatc
return d
}

// run decorates watch events from the underlying watcher until its result channel
// is closed or the passed in context is done.
// When run() returns, decoratedWatcher#resultCh is closed.
func (d *decoratedWatcher) run(ctx context.Context) {
var recv, send watch.Event
var ok bool
defer close(d.resultCh)
for {
select {
case recv, ok = <-d.w.ResultChan():
// The underlying channel may be closed after timeout.
if !ok {
// The underlying channel was closed, cancel our context
d.cancel()
return
}
Expand All @@ -67,20 +71,24 @@ func (d *decoratedWatcher) run(ctx context.Context) {
}
select {
case d.resultCh <- send:
if send.Type == watch.Error {
d.cancel()
}
// propagated event successfully
case <-ctx.Done():
// context timed out or was cancelled, stop the underlying watcher
d.w.Stop()
return
}
case <-ctx.Done():
// context timed out or was cancelled, stop the underlying watcher
d.w.Stop()
close(d.resultCh)
return
}
}
}

func (d *decoratedWatcher) Stop() {
// stop the underlying watcher
d.w.Stop()
// cancel our context
d.cancel()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package registry

import (
"context"
"fmt"
"testing"
"time"
Expand All @@ -31,26 +32,80 @@ import (
func TestDecoratedWatcher(t *testing.T) {
w := watch.NewFake()
decorator := func(obj runtime.Object) error {
pod := obj.(*example.Pod)
pod.Annotations = map[string]string{"decorated": "true"}
if pod, ok := obj.(*example.Pod); ok {
pod.Annotations = map[string]string{"decorated": "true"}
}
return nil
}
dw := newDecoratedWatcher(w, decorator)
ctx, cancel := context.WithCancel(context.Background())
dw := newDecoratedWatcher(ctx, w, decorator)
defer dw.Stop()

go w.Add(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}})
go func() {
w.Error(&metav1.Status{Status: "Failure"})
w.Add(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}})
w.Error(&metav1.Status{Status: "Failure"})
w.Modify(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}})
w.Error(&metav1.Status{Status: "Failure"})
w.Delete(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}})
}()

expectErrorEvent(t, dw) // expect error is plumbed and doesn't force close the watcher
expectPodEvent(t, dw, watch.Added)
expectErrorEvent(t, dw) // expect error is plumbed and doesn't force close the watcher
expectPodEvent(t, dw, watch.Modified)
expectErrorEvent(t, dw) // expect error is plumbed and doesn't force close the watcher
expectPodEvent(t, dw, watch.Deleted)

// cancel the passed-in context to simulate request timeout
cancel()

// expect the decorated channel to be closed
select {
case e, ok := <-dw.ResultChan():
if ok {
t.Errorf("expected result chan closed, got %#v", e)
}
case <-time.After(wait.ForeverTestTimeout):
t.Errorf("timeout after %v", wait.ForeverTestTimeout)
}

// expect the underlying watcher to have been stopped as a result of the context cancellation
if !w.IsStopped() {
t.Errorf("expected underlying watcher to be stopped")
}
}

func expectPodEvent(t *testing.T, dw *decoratedWatcher, watchType watch.EventType) {
select {
case e := <-dw.ResultChan():
pod, ok := e.Object.(*example.Pod)
if !ok {
t.Errorf("Should received object of type *api.Pod, get type (%T)", e.Object)
return
t.Fatalf("Should received object of type *api.Pod, get type (%T)", e.Object)
}
if pod.Annotations["decorated"] != "true" {
t.Errorf("pod.Annotations[\"decorated\"], want=%s, get=%s", "true", pod.Labels["decorated"])
t.Fatalf("pod.Annotations[\"decorated\"], want=%s, get=%s", "true", pod.Labels["decorated"])
}
if e.Type != watchType {
t.Fatalf("expected type %s, got %s", watchType, e.Type)
}
case <-time.After(wait.ForeverTestTimeout):
t.Errorf("timeout after %v", wait.ForeverTestTimeout)
t.Fatalf("timeout after %v", wait.ForeverTestTimeout)
}
}

func expectErrorEvent(t *testing.T, dw *decoratedWatcher) {
select {
case e := <-dw.ResultChan():
_, ok := e.Object.(*metav1.Status)
if !ok {
t.Fatalf("Should received object of type *metav1.Status, get type (%T)", e.Object)
}
if e.Type != watch.Error {
t.Fatalf("expected type %s, got %s", watch.Error, e.Type)
}
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("timeout after %v", wait.ForeverTestTimeout)
}
}

Expand All @@ -60,7 +115,7 @@ func TestDecoratedWatcherError(t *testing.T) {
decorator := func(obj runtime.Object) error {
return expErr
}
dw := newDecoratedWatcher(w, decorator)
dw := newDecoratedWatcher(context.Background(), w, decorator)
defer dw.Stop()

go w.Add(&example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ func (e *Store) WatchPredicate(ctx context.Context, p storage.SelectionPredicate
return nil, err
}
if e.Decorator != nil {
return newDecoratedWatcher(w, e.Decorator), nil
return newDecoratedWatcher(ctx, w, e.Decorator), nil
}
return w, nil
}
Expand All @@ -1149,7 +1149,7 @@ func (e *Store) WatchPredicate(ctx context.Context, p storage.SelectionPredicate
return nil, err
}
if e.Decorator != nil {
return newDecoratedWatcher(w, e.Decorator), nil
return newDecoratedWatcher(ctx, w, e.Decorator), nil
}
return w, nil
}
Expand Down
1 change: 1 addition & 0 deletions test/integration/apimachinery/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ go_test(
"//staging/src/k8s.io/client-go/tools/watch:go_default_library",
"//staging/src/k8s.io/kubectl/pkg/proxy:go_default_library",
"//test/integration/framework:go_default_library",
"//vendor/golang.org/x/net/websocket:go_default_library",
],
)

Expand Down
107 changes: 107 additions & 0 deletions test/integration/apimachinery/watch_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@ limitations under the License.
package apimachinery

import (
"bytes"
"context"
"io"
"log"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"
"time"

"golang.org/x/net/websocket"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
Expand All @@ -35,6 +41,107 @@ import (
"k8s.io/kubernetes/test/integration/framework"
)

func TestWebsocketWatchClientTimeout(t *testing.T) {
// server setup
masterConfig := framework.NewIntegrationTestMasterConfig()
instance, s, closeFn := framework.RunAMaster(masterConfig)
defer closeFn()

// object setup
service := &corev1.Service{
ObjectMeta: metav1.ObjectMeta{Name: "test"},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{{Name: "http", Port: 80}},
},
}
configmap := &corev1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{Name: "test"},
}
clientset, err := kubernetes.NewForConfig(instance.GenericAPIServer.LoopbackClientConfig)
if err != nil {
t.Fatal(err)
}
if _, err := clientset.CoreV1().Services("default").Create(context.TODO(), service, metav1.CreateOptions{}); err != nil {
t.Fatal(err)
}
if _, err := clientset.CoreV1().ConfigMaps("default").Create(context.TODO(), configmap, metav1.CreateOptions{}); err != nil {
t.Fatal(err)
}

testcases := []struct {
name string
path string
timeout time.Duration
expectResult string
}{
{
name: "configmaps",
path: "/api/v1/configmaps?watch=true&timeoutSeconds=5",
timeout: 10 * time.Second,
expectResult: `"name":"test"`,
},
{
name: "services",
path: "/api/v1/services?watch=true&timeoutSeconds=5",
timeout: 10 * time.Second,
expectResult: `"name":"test"`,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {

u, _ := url.Parse(s.URL)
apiURL := "ws://" + u.Host + tc.path
wsc, err := websocket.NewConfig(apiURL, apiURL)
if err != nil {
log.Fatal(err)
}

wsConn, err := websocket.DialConfig(wsc)
if err != nil {
t.Fatal(err)
}
defer wsConn.Close()

resultCh := make(chan string)
go func() {
defer close(resultCh)
buf := &bytes.Buffer{}
for {
var msg []byte
if err := websocket.Message.Receive(wsConn, &msg); err != nil {
if err == io.EOF {
resultCh <- buf.String()
return
}
if !t.Failed() {
// if we didn't already fail, treat this as an error
t.Errorf("Failed to read completely from websocket %v", err)
}
return
}
if len(msg) == 0 {
t.Logf("zero-length message")
continue
}
t.Logf("Read %v %v", len(msg), string(msg))
buf.Write(msg)
}
}()

select {
case resultString := <-resultCh:
if !strings.Contains(resultString, tc.expectResult) {
t.Fatalf("Unexpected result:\n%s", resultString)
}
case <-time.After(tc.timeout):
t.Fatalf("hit timeout before connection closed")
}
})
}
}

func TestWatchClientTimeout(t *testing.T) {
masterConfig := framework.NewIntegrationTestMasterConfig()
_, s, closeFn := framework.RunAMaster(masterConfig)
Expand Down

0 comments on commit de95e9d

Please sign in to comment.