diff --git a/raft/node.go b/raft/node.go index ebbe23277dc..f3ba250b9af 100644 --- a/raft/node.go +++ b/raft/node.go @@ -325,7 +325,9 @@ func (n *node) run(r *raft) { case cc := <-n.confc: if cc.NodeID == None { select { - case n.confstatec <- pb.ConfState{Nodes: r.nodes()}: + case n.confstatec <- pb.ConfState{ + Nodes: r.nodes(), + Learners: r.learnerNodes()}: case <-n.done: } break @@ -347,7 +349,9 @@ func (n *node) run(r *raft) { panic("unexpected conf type") } select { - case n.confstatec <- pb.ConfState{Nodes: r.nodes()}: + case n.confstatec <- pb.ConfState{ + Nodes: r.nodes(), + Learners: r.learnerNodes()}: case <-n.done: } case <-n.tickc: diff --git a/raft/node_test.go b/raft/node_test.go index ef0c92ad145..0ccceb8a2fd 100644 --- a/raft/node_test.go +++ b/raft/node_test.go @@ -732,3 +732,55 @@ func TestIsHardStateEqual(t *testing.T) { } } } + +func TestNodeProposeAddLearnerNode(t *testing.T) { + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() + n := newNode() + s := NewMemoryStorage() + r := newTestRaft(1, []uint64{1}, 10, 1, s) + go n.run(r) + n.Campaign(context.TODO()) + stop := make(chan struct{}) + done := make(chan struct{}) + applyConfChan := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-stop: + return + case <-ticker.C: + n.Tick() + case rd := <-n.Ready(): + s.Append(rd.Entries) + t.Logf("raft: %v", rd.Entries) + for _, ent := range rd.Entries { + if ent.Type != raftpb.EntryConfChange { + continue + } + var cc raftpb.ConfChange + cc.Unmarshal(ent.Data) + state := n.ApplyConfChange(cc) + if len(state.Learners) == 0 || + state.Learners[0] != cc.NodeID || + cc.NodeID != 2 { + t.Errorf("apply conf change should return new added learner: %v", state.String()) + } + + if len(state.Nodes) != 1 { + t.Errorf("add learner should not change the nodes: %v", state.String()) + } + t.Logf("apply raft conf %v changed to: %v", cc, state.String()) + applyConfChan <- struct{}{} + } + n.Advance() + } + } + }() + cc := raftpb.ConfChange{Type: raftpb.ConfChangeAddLearnerNode, NodeID: 2} + n.ProposeConfChange(context.TODO(), cc) + <-applyConfChan + close(stop) + <-done +} diff --git a/raft/raft.go b/raft/raft.go index b9939fae092..f1fafb1c23d 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -377,10 +377,16 @@ func (r *raft) hardState() pb.HardState { func (r *raft) quorum() int { return len(r.prs)/2 + 1 } func (r *raft) nodes() []uint64 { - nodes := make([]uint64, 0, len(r.prs)+len(r.learnerPrs)) + nodes := make([]uint64, 0, len(r.prs)) for id := range r.prs { nodes = append(nodes, id) } + sort.Sort(uint64Slice(nodes)) + return nodes +} + +func (r *raft) learnerNodes() []uint64 { + nodes := make([]uint64, 0, len(r.learnerPrs)) for id := range r.learnerPrs { nodes = append(nodes, id) } diff --git a/raft/raft_test.go b/raft/raft_test.go index b8ef596ea57..858fb4a1c53 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -2475,8 +2475,12 @@ func TestRestoreWithLearner(t *testing.T) { t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term) } sg := sm.nodes() - if len(sg) != len(s.Metadata.ConfState.Nodes)+len(s.Metadata.ConfState.Learners) { - t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState) + if len(sg) != len(s.Metadata.ConfState.Nodes) { + t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Nodes) + } + lns := sm.learnerNodes() + if len(lns) != len(s.Metadata.ConfState.Learners) { + t.Errorf("sm.LearnerNodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Learners) } for _, n := range s.Metadata.ConfState.Nodes { if sm.prs[n].IsLearner { @@ -2805,8 +2809,8 @@ func TestAddNode(t *testing.T) { func TestAddLearner(t *testing.T) { r := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage()) r.addLearner(2) - nodes := r.nodes() - wnodes := []uint64{1, 2} + nodes := r.learnerNodes() + wnodes := []uint64{2} if !reflect.DeepEqual(nodes, wnodes) { t.Errorf("nodes = %v, want %v", nodes, wnodes) } @@ -2877,9 +2881,13 @@ func TestRemoveLearner(t *testing.T) { t.Errorf("nodes = %v, want %v", g, w) } + w = []uint64{} + if g := r.learnerNodes(); !reflect.DeepEqual(g, w) { + t.Errorf("nodes = %v, want %v", g, w) + } + // remove all nodes from cluster r.removeNode(1) - w = []uint64{} if g := r.nodes(); !reflect.DeepEqual(g, w) { t.Errorf("nodes = %v, want %v", g, w) } diff --git a/raft/rawnode.go b/raft/rawnode.go index b289cd60f27..fbd7a49e854 100644 --- a/raft/rawnode.go +++ b/raft/rawnode.go @@ -169,7 +169,7 @@ func (rn *RawNode) ProposeConfChange(cc pb.ConfChange) error { // ApplyConfChange applies a config change to the local node. func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState { if cc.NodeID == None { - return &pb.ConfState{Nodes: rn.raft.nodes()} + return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()} } switch cc.Type { case pb.ConfChangeAddNode: @@ -182,7 +182,7 @@ func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState { default: panic("unexpected conf type") } - return &pb.ConfState{Nodes: rn.raft.nodes()} + return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()} } // Step advances the state machine using the given message.