From ed10ca13f4d2f815299daecba4fd79e7d8ab26dd Mon Sep 17 00:00:00 2001 From: Kafuu Chino Date: Tue, 2 Aug 2022 18:55:41 +0800 Subject: [PATCH] *: avoid closing a watch with ID 0 incorrectly Signed-off-by: Kafuu Chino add test 1 1 1 --- clientv3/watch.go | 21 ++++++++----- etcdserver/api/v3rpc/watch.go | 16 ++++++++-- integration/v3_auth_test.go | 55 +++++++++++++++++++++++++++++++++++ integration/v3_watch_test.go | 5 ++-- mvcc/watcher.go | 7 ++--- proxy/grpcproxy/watch.go | 4 +-- 6 files changed, 89 insertions(+), 19 deletions(-) diff --git a/clientv3/watch.go b/clientv3/watch.go index c064858c6b3..c011177baf6 100644 --- a/clientv3/watch.go +++ b/clientv3/watch.go @@ -38,6 +38,13 @@ const ( EventTypePut = mvccpb.PUT closeSendErrTimeout = 250 * time.Millisecond + + // AutoWatchID is the watcher ID passed in WatchStream.Watch when no + // user-provided ID is available. If pass, an ID will automatically be assigned. + AutoWatchID = 0 + + // InvalidWatchID represents an invalid watch ID and prevents duplication with an existing watch. + InvalidWatchID = -1 ) type Event mvccpb.Event @@ -444,7 +451,7 @@ func (w *watcher) closeStream(wgs *watchGrpcStream) { func (w *watchGrpcStream) addSubstream(resp *pb.WatchResponse, ws *watcherStream) { // check watch ID for backward compatibility (<= v3.3) - if resp.WatchId == -1 || (resp.Canceled && resp.CancelReason != "") { + if resp.WatchId == InvalidWatchID || (resp.Canceled && resp.CancelReason != "") { w.closeErr = v3rpc.Error(errors.New(resp.CancelReason)) // failed; no channel close(ws.recvc) @@ -475,7 +482,7 @@ func (w *watchGrpcStream) closeSubstream(ws *watcherStream) { } else if ws.outc != nil { close(ws.outc) } - if ws.id != -1 { + if ws.id != InvalidWatchID { delete(w.substreams, ws.id) return } @@ -537,7 +544,7 @@ func (w *watchGrpcStream) run() { // TODO: pass custom watch ID? ws := &watcherStream{ initReq: *wreq, - id: -1, + id: InvalidWatchID, outc: outc, // unbuffered so resumes won't cause repeat events recvc: make(chan *WatchResponse), @@ -687,7 +694,7 @@ func (w *watchGrpcStream) run() { return case ws := <-w.closingc: - if ws.id != -1 { + if ws.id != InvalidWatchID { // client is closing an established watch; close it on the server proactively instead of waiting // to close when the next message arrives cancelSet[ws.id] = struct{}{} @@ -749,9 +756,9 @@ func (w *watchGrpcStream) dispatchEvent(pbresp *pb.WatchResponse) bool { cancelReason: pbresp.CancelReason, } - // watch IDs are zero indexed, so request notify watch responses are assigned a watch ID of -1 to + // watch IDs are zero indexed, so request notify watch responses are assigned a watch ID of InvalidWatchID to // indicate they should be broadcast. - if wr.IsProgressNotify() && pbresp.WatchId == -1 { + if wr.IsProgressNotify() && pbresp.WatchId == InvalidWatchID { return w.broadcastResponse(wr) } @@ -906,7 +913,7 @@ func (w *watchGrpcStream) newWatchClient() (pb.Watch_WatchClient, error) { w.resumec = make(chan struct{}) w.joinSubstreams() for _, ws := range w.substreams { - ws.id = -1 + ws.id = InvalidWatchID w.resuming = append(w.resuming, ws) } // strip out nils, if any diff --git a/etcdserver/api/v3rpc/watch.go b/etcdserver/api/v3rpc/watch.go index dbbe754141f..c33654dfaa9 100644 --- a/etcdserver/api/v3rpc/watch.go +++ b/etcdserver/api/v3rpc/watch.go @@ -16,12 +16,14 @@ package v3rpc import ( "context" + "fmt" "io" "math/rand" "sync" "time" "go.etcd.io/etcd/auth" + "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/etcdserver" "go.etcd.io/etcd/etcdserver/api/v3rpc/rpctypes" pb "go.etcd.io/etcd/etcdserver/etcdserverpb" @@ -294,7 +296,7 @@ func (sws *serverWatchStream) recvLoop() error { wr := &pb.WatchResponse{ Header: sws.newResponseHeader(sws.watchStream.Rev()), - WatchId: creq.WatchId, + WatchId: clientv3.InvalidWatchID, Canceled: true, Created: true, CancelReason: cancelReason, @@ -328,7 +330,10 @@ func (sws *serverWatchStream) recvLoop() error { sws.fragment[id] = true } sws.mu.Unlock() + } else { + id = clientv3.InvalidWatchID } + wr := &pb.WatchResponse{ Header: sws.newResponseHeader(wsrev), WatchId: int64(id), @@ -365,7 +370,7 @@ func (sws *serverWatchStream) recvLoop() error { if uv.ProgressRequest != nil { sws.ctrlStream <- &pb.WatchResponse{ Header: sws.newResponseHeader(sws.watchStream.Rev()), - WatchId: -1, // response is not associated with any WatchId and will be broadcast to all watch channels + WatchId: clientv3.InvalidWatchID, // response is not associated with any WatchId and will be broadcast to all watch channels } } default: @@ -504,7 +509,12 @@ func (sws *serverWatchStream) sendLoop() { // track id creation wid := mvcc.WatchID(c.WatchId) - if c.Canceled { + + if !(!(c.Canceled && c.Created) || wid == clientv3.InvalidWatchID) { + panic(fmt.Sprintf("unexpected watchId: %d, wanted: %d, since both 'Canceled' and 'Created' are true", wid, clientv3.InvalidWatchID)) + } + + if c.Canceled && wid != clientv3.InvalidWatchID { delete(ids, wid) continue } diff --git a/integration/v3_auth_test.go b/integration/v3_auth_test.go index bccc77ed55a..bcc22ddb30f 100644 --- a/integration/v3_auth_test.go +++ b/integration/v3_auth_test.go @@ -504,3 +504,58 @@ func TestV3AuthWatchAndTokenExpire(t *testing.T) { watchResponse = <-wChan testutil.AssertNil(t, watchResponse.Err()) } + +func TestV3AuthWatchErrorAndWatchId0(t *testing.T) { + defer testutil.AfterTest(t) + clus := NewClusterV3(t, &ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) + defer cancel() + + users := []user{ + { + name: "user1", + password: "user1-123", + role: "role1", + key: "k1", + end: "k2", + }, + } + + authSetupUsers(t, toGRPC(clus.Client(0)).Auth, users) + + authSetupRoot(t, toGRPC(clus.Client(0)).Auth) + + c, cerr := clientv3.New(clientv3.Config{Endpoints: clus.Client(0).Endpoints(), Username: "user1", Password: "user1-123"}) + if cerr != nil { + t.Fatal(cerr) + } + defer c.Close() + + watchStartCh, watchEndCh := make(chan interface{}), make(chan interface{}) + + go func() { + wChan := c.Watch(ctx, "k1", clientv3.WithRev(1)) + watchStartCh <- struct{}{} + watchResponse := <-wChan + t.Logf("watch response from k1: %v", watchResponse) + testutil.AssertTrue(t, len(watchResponse.Events) != 0) + watchEndCh <- struct{}{} + }() + + // Chan for making sure that the above goroutine invokes Watch() + // So the above Watch() can get watch ID = 0 + <-watchStartCh + + wChan := c.Watch(ctx, "non-allowed-key", clientv3.WithRev(1)) + watchResponse := <-wChan + testutil.AssertNotNil(t, watchResponse.Err()) // permission denied + + _, err := c.Put(ctx, "k1", "val") + if err != nil { + t.Fatalf("Unexpected error from Put: %v", err) + } + + <-watchEndCh +} diff --git a/integration/v3_watch_test.go b/integration/v3_watch_test.go index ec2139a9a6c..4ad7b2fcb82 100644 --- a/integration/v3_watch_test.go +++ b/integration/v3_watch_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/etcdserver/api/v3rpc" pb "go.etcd.io/etcd/etcdserver/etcdserverpb" "go.etcd.io/etcd/mvcc/mvccpb" @@ -396,8 +397,8 @@ func TestV3WatchWrongRange(t *testing.T) { if cresp.Canceled != tt.canceled { t.Fatalf("#%d: canceled %v, want %v", i, tt.canceled, cresp.Canceled) } - if tt.canceled && cresp.WatchId != -1 { - t.Fatalf("#%d: canceled watch ID %d, want -1", i, cresp.WatchId) + if tt.canceled && cresp.WatchId != clientv3.InvalidWatchID { + t.Fatalf("#%d: canceled watch ID %d, want %d", i, cresp.WatchId, clientv3.InvalidWatchID) } } } diff --git a/mvcc/watcher.go b/mvcc/watcher.go index 2846d62a5d4..515c278b0e1 100644 --- a/mvcc/watcher.go +++ b/mvcc/watcher.go @@ -19,13 +19,10 @@ import ( "errors" "sync" + "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/mvcc/mvccpb" ) -// AutoWatchID is the watcher ID passed in WatchStream.Watch when no -// user-provided ID is available. If pass, an ID will automatically be assigned. -const AutoWatchID WatchID = 0 - var ( ErrWatcherNotExist = errors.New("mvcc: watcher does not exist") ErrEmptyWatcherRange = errors.New("mvcc: watcher range is empty") @@ -118,7 +115,7 @@ func (ws *watchStream) Watch(id WatchID, key, end []byte, startRev int64, fcs .. return -1, ErrEmptyWatcherRange } - if id == AutoWatchID { + if id == clientv3.AutoWatchID { for ws.watchers[ws.nextID] != nil { ws.nextID++ } diff --git a/proxy/grpcproxy/watch.go b/proxy/grpcproxy/watch.go index 4493319109e..c62f07416a4 100644 --- a/proxy/grpcproxy/watch.go +++ b/proxy/grpcproxy/watch.go @@ -233,7 +233,7 @@ func (wps *watchProxyStream) recvLoop() error { if err := wps.checkPermissionForWatch(cr.Key, cr.RangeEnd); err != nil { wps.watchCh <- &pb.WatchResponse{ Header: &pb.ResponseHeader{}, - WatchId: -1, + WatchId: clientv3.InvalidWatchID, Created: true, Canceled: true, CancelReason: err.Error(), @@ -252,7 +252,7 @@ func (wps *watchProxyStream) recvLoop() error { filters: v3rpc.FiltersFromRequest(cr), } if !w.wr.valid() { - w.post(&pb.WatchResponse{WatchId: -1, Created: true, Canceled: true}) + w.post(&pb.WatchResponse{WatchId: clientv3.InvalidWatchID, Created: true, Canceled: true}) continue } wps.nextWatcherID++