diff --git a/bitswap.go b/bitswap.go index 73ca266e..100ce859 100644 --- a/bitswap.go +++ b/bitswap.go @@ -303,7 +303,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, bs.engine.SetSendDontHaves(bs.engineSetSendDontHaves) bs.pqm.Startup() - network.SetDelegate(bs) + network.Start(bs) // Start up bitswaps async worker routines bs.startWorkers(ctx, px) @@ -316,6 +316,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, sm.Shutdown() cancelFunc() notif.Shutdown() + network.Stop() }() procctx.CloseAfterContext(px, ctx) // parent cancelled first diff --git a/network/connecteventmanager.go b/network/connecteventmanager.go index bbde7af2..a9053ba6 100644 --- a/network/connecteventmanager.go +++ b/network/connecteventmanager.go @@ -11,96 +11,203 @@ type ConnectionListener interface { PeerDisconnected(peer.ID) } +type state byte + +const ( + stateDisconnected = iota + stateResponsive + stateUnresponsive +) + type connectEventManager struct { connListener ConnectionListener lk sync.RWMutex - conns map[peer.ID]*connState + cond sync.Cond + peers map[peer.ID]*peerState + + changeQueue []peer.ID + stop bool + done chan struct{} } -type connState struct { - refs int - responsive bool +type peerState struct { + newState, curState state + pending bool } func newConnectEventManager(connListener ConnectionListener) *connectEventManager { - return &connectEventManager{ + evtManager := &connectEventManager{ connListener: connListener, - conns: make(map[peer.ID]*connState), + peers: make(map[peer.ID]*peerState), + done: make(chan struct{}), } + evtManager.cond = sync.Cond{L: &evtManager.lk} + return evtManager } -func (c *connectEventManager) Connected(p peer.ID) { +func (c *connectEventManager) Start() { + go c.worker() +} + +func (c *connectEventManager) Stop() { c.lk.Lock() - defer c.lk.Unlock() + c.stop = true + c.lk.Unlock() + c.cond.Broadcast() - state, ok := c.conns[p] + <-c.done +} + +func (c *connectEventManager) getState(p peer.ID) state { + if state, ok := c.peers[p]; ok { + return state.newState + } else { + return stateDisconnected + } +} + +func (c *connectEventManager) setState(p peer.ID, newState state) { + state, ok := c.peers[p] if !ok { - state = &connState{responsive: true} - c.conns[p] = state + state = new(peerState) + c.peers[p] = state + } + state.newState = newState + if !state.pending && state.newState != state.curState { + state.pending = true + c.changeQueue = append(c.changeQueue, p) + c.cond.Broadcast() } - state.refs++ +} - if state.refs == 1 && state.responsive { - c.connListener.PeerConnected(p) +// Waits for a change to be enqueued, or for the event manager to be stopped. Returns false if the +// connect event manager has been stopped. +func (c *connectEventManager) waitChange() bool { + for !c.stop && len(c.changeQueue) == 0 { + c.cond.Wait() } + return !c.stop } -func (c *connectEventManager) Disconnected(p peer.ID) { +func (c *connectEventManager) worker() { c.lk.Lock() defer c.lk.Unlock() + defer close(c.done) + + for c.waitChange() { + pid := c.changeQueue[0] + c.changeQueue[0] = peer.ID("") // free the peer ID (slicing won't do that) + c.changeQueue = c.changeQueue[1:] + + state, ok := c.peers[pid] + // If we've disconnected and forgotten, continue. + if !ok { + // This shouldn't be possible because _this_ thread is responsible for + // removing peers from this map, and we shouldn't get duplicate entries in + // the change queue. + log.Error("a change was enqueued for a peer we're not tracking") + continue + } - state, ok := c.conns[p] - if !ok { - // Should never happen + // Record the fact that this "state" is no longer in the queue. + state.pending = false + + // Then, if there's nothing to do, continue. + if state.curState == state.newState { + continue + } + + // Or record the state update, then apply it. + oldState := state.curState + state.curState = state.newState + + switch state.newState { + case stateDisconnected: + delete(c.peers, pid) + fallthrough + case stateUnresponsive: + // Only trigger a disconnect event if the peer was responsive. + // We could be transitioning from unresponsive to disconnected. + if oldState == stateResponsive { + c.lk.Unlock() + c.connListener.PeerDisconnected(pid) + c.lk.Lock() + } + case stateResponsive: + c.lk.Unlock() + c.connListener.PeerConnected(pid) + c.lk.Lock() + } + } +} + +// Called whenever we receive a new connection. May be called many times. +func (c *connectEventManager) Connected(p peer.ID) { + c.lk.Lock() + defer c.lk.Unlock() + + // !responsive -> responsive + + if c.getState(p) == stateResponsive { return } - state.refs-- + c.setState(p, stateResponsive) +} - if state.refs == 0 { - if state.responsive { - c.connListener.PeerDisconnected(p) - } - delete(c.conns, p) +// Called when we drop the final connection to a peer. +func (c *connectEventManager) Disconnected(p peer.ID) { + c.lk.Lock() + defer c.lk.Unlock() + + // !disconnected -> disconnected + + if c.getState(p) == stateDisconnected { + return } + + c.setState(p, stateDisconnected) } +// Called whenever a peer is unresponsive. func (c *connectEventManager) MarkUnresponsive(p peer.ID) { c.lk.Lock() defer c.lk.Unlock() - state, ok := c.conns[p] - if !ok || !state.responsive { + // responsive -> unresponsive + + if c.getState(p) != stateResponsive { return } - state.responsive = false - c.connListener.PeerDisconnected(p) + c.setState(p, stateUnresponsive) } +// Called whenever we receive a message from a peer. +// +// - When we're connected to the peer, this will mark the peer as responsive (from unresponsive). +// - When not connected, we ignore this call. Unfortunately, a peer may disconnect before we process +// the "on message" event, so we can't treat this as evidence of a connection. func (c *connectEventManager) OnMessage(p peer.ID) { - // This is a frequent operation so to avoid different message arrivals - // getting blocked by a write lock, first take a read lock to check if - // we need to modify state c.lk.RLock() - state, ok := c.conns[p] - responsive := ok && state.responsive + unresponsive := c.getState(p) == stateUnresponsive c.lk.RUnlock() - if !ok || responsive { + // Only continue if both connected, and unresponsive. + if !unresponsive { return } + // unresponsive -> responsive + // We need to make a modification so now take a write lock c.lk.Lock() defer c.lk.Unlock() // Note: state may have changed in the time between when read lock // was released and write lock taken, so check again - state, ok = c.conns[p] - if !ok || state.responsive { + if c.getState(p) != stateUnresponsive { return } - state.responsive = true - c.connListener.PeerConnected(p) + c.setState(p, stateResponsive) } diff --git a/network/connecteventmanager_test.go b/network/connecteventmanager_test.go index fb81abee..4ed7edd7 100644 --- a/network/connecteventmanager_test.go +++ b/network/connecteventmanager_test.go @@ -1,144 +1,168 @@ package network import ( + "sync" "testing" + "time" "github.com/ipfs/go-bitswap/internal/testutil" "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" ) +type mockConnEvent struct { + connected bool + peer peer.ID +} + type mockConnListener struct { - conns map[peer.ID]int + sync.Mutex + events []mockConnEvent } func newMockConnListener() *mockConnListener { - return &mockConnListener{ - conns: make(map[peer.ID]int), - } + return new(mockConnListener) } func (cl *mockConnListener) PeerConnected(p peer.ID) { - cl.conns[p]++ + cl.Lock() + defer cl.Unlock() + cl.events = append(cl.events, mockConnEvent{connected: true, peer: p}) } func (cl *mockConnListener) PeerDisconnected(p peer.ID) { - cl.conns[p]-- + cl.Lock() + defer cl.Unlock() + cl.events = append(cl.events, mockConnEvent{connected: false, peer: p}) +} + +func wait(t *testing.T, c *connectEventManager) { + require.Eventually(t, func() bool { + c.lk.RLock() + defer c.lk.RUnlock() + return len(c.changeQueue) == 0 + }, time.Second, time.Millisecond, "connection event manager never processed events") } -func TestConnectEventManagerConnectionCount(t *testing.T) { +func TestConnectEventManagerConnectDisconnect(t *testing.T) { connListener := newMockConnListener() peers := testutil.GeneratePeers(2) cem := newConnectEventManager(connListener) + cem.Start() + t.Cleanup(cem.Stop) - // Peer A: 1 Connection - cem.Connected(peers[0]) - if connListener.conns[peers[0]] != 1 { - t.Fatal("Expected Connected event") - } + var expectedEvents []mockConnEvent - // Peer A: 2 Connections + // Connect A twice, should only see one event + cem.Connected(peers[0]) cem.Connected(peers[0]) - if connListener.conns[peers[0]] != 1 { - t.Fatal("Unexpected no Connected event for the same peer") - } + expectedEvents = append(expectedEvents, mockConnEvent{ + peer: peers[0], + connected: true, + }) - // Peer A: 2 Connections - // Peer B: 1 Connection + // Flush the event queue. + wait(t, cem) + require.Equal(t, expectedEvents, connListener.events) + + // Block up the event loop. + connListener.Lock() cem.Connected(peers[1]) - if connListener.conns[peers[1]] != 1 { - t.Fatal("Expected Connected event") - } - - // Peer A: 2 Connections - // Peer B: 0 Connections - cem.Disconnected(peers[1]) - if connListener.conns[peers[1]] != 0 { - t.Fatal("Expected Disconnected event") - } - - // Peer A: 1 Connection - // Peer B: 0 Connections - cem.Disconnected(peers[0]) - if connListener.conns[peers[0]] != 1 { - t.Fatal("Expected no Disconnected event for peer with one remaining conn") - } + expectedEvents = append(expectedEvents, mockConnEvent{ + peer: peers[1], + connected: true, + }) - // Peer A: 0 Connections - // Peer B: 0 Connections + // We don't expect this to show up. cem.Disconnected(peers[0]) - if connListener.conns[peers[0]] != 0 { - t.Fatal("Expected Disconnected event") - } + cem.Connected(peers[0]) + + connListener.Unlock() + + wait(t, cem) + require.Equal(t, expectedEvents, connListener.events) } func TestConnectEventManagerMarkUnresponsive(t *testing.T) { connListener := newMockConnListener() p := testutil.GeneratePeers(1)[0] cem := newConnectEventManager(connListener) + cem.Start() + t.Cleanup(cem.Stop) - // Peer A: 1 Connection - cem.Connected(p) - if connListener.conns[p] != 1 { - t.Fatal("Expected Connected event") - } + var expectedEvents []mockConnEvent - // Peer A: 1 Connection - cem.MarkUnresponsive(p) - if connListener.conns[p] != 0 { - t.Fatal("Expected Disconnected event") - } + // Don't mark as connected when we receive a message (could have been delayed). + cem.OnMessage(p) + wait(t, cem) + require.Equal(t, expectedEvents, connListener.events) - // Peer A: 2 Connections + // Handle connected event. cem.Connected(p) - if connListener.conns[p] != 0 { - t.Fatal("Expected no Connected event for unresponsive peer") - } + wait(t, cem) - // Peer A: 2 Connections - cem.OnMessage(p) - if connListener.conns[p] != 1 { - t.Fatal("Expected Connected event for newly responsive peer") - } + expectedEvents = append(expectedEvents, mockConnEvent{ + peer: p, + connected: true, + }) + require.Equal(t, expectedEvents, connListener.events) - // Peer A: 2 Connections - cem.OnMessage(p) - if connListener.conns[p] != 1 { - t.Fatal("Expected no further Connected event for subsequent messages") - } + // Becomes unresponsive. + cem.MarkUnresponsive(p) + wait(t, cem) - // Peer A: 1 Connection - cem.Disconnected(p) - if connListener.conns[p] != 1 { - t.Fatal("Expected no Disconnected event for peer with one remaining conn") - } + expectedEvents = append(expectedEvents, mockConnEvent{ + peer: p, + connected: false, + }) + require.Equal(t, expectedEvents, connListener.events) - // Peer A: 0 Connections - cem.Disconnected(p) - if connListener.conns[p] != 0 { - t.Fatal("Expected Disconnected event") - } + // We have a new connection, mark them responsive. + cem.Connected(p) + wait(t, cem) + expectedEvents = append(expectedEvents, mockConnEvent{ + peer: p, + connected: true, + }) + require.Equal(t, expectedEvents, connListener.events) + + // No duplicate event. + cem.OnMessage(p) + wait(t, cem) + require.Equal(t, expectedEvents, connListener.events) } func TestConnectEventManagerDisconnectAfterMarkUnresponsive(t *testing.T) { connListener := newMockConnListener() p := testutil.GeneratePeers(1)[0] cem := newConnectEventManager(connListener) + cem.Start() + t.Cleanup(cem.Stop) - // Peer A: 1 Connection + var expectedEvents []mockConnEvent + + // Handle connected event. cem.Connected(p) - if connListener.conns[p] != 1 { - t.Fatal("Expected Connected event") - } + wait(t, cem) + + expectedEvents = append(expectedEvents, mockConnEvent{ + peer: p, + connected: true, + }) + require.Equal(t, expectedEvents, connListener.events) - // Peer A: 1 Connection + // Becomes unresponsive. cem.MarkUnresponsive(p) - if connListener.conns[p] != 0 { - t.Fatal("Expected Disconnected event") - } + wait(t, cem) + + expectedEvents = append(expectedEvents, mockConnEvent{ + peer: p, + connected: false, + }) + require.Equal(t, expectedEvents, connListener.events) - // Peer A: 0 Connections cem.Disconnected(p) - if connListener.conns[p] != 0 { - t.Fatal("Expected not to receive a second Disconnected event") - } + wait(t, cem) + require.Empty(t, cem.peers) // all disconnected + require.Equal(t, expectedEvents, connListener.events) } diff --git a/network/interface.go b/network/interface.go index a350d525..8648f8dd 100644 --- a/network/interface.go +++ b/network/interface.go @@ -35,9 +35,10 @@ type BitSwapNetwork interface { peer.ID, bsmsg.BitSwapMessage) error - // SetDelegate registers the Reciver to handle messages received from the - // network. - SetDelegate(Receiver) + // Start registers the Reciver and starts handling new messages, connectivity events, etc. + Start(Receiver) + // Stop stops the network service. + Stop() ConnectTo(context.Context, peer.ID) error DisconnectFrom(context.Context, peer.ID) error diff --git a/network/ipfs_impl.go b/network/ipfs_impl.go index 7457aeb8..6f69b26a 100644 --- a/network/ipfs_impl.go +++ b/network/ipfs_impl.go @@ -349,17 +349,22 @@ func (bsnet *impl) newStreamToPeer(ctx context.Context, p peer.ID) (network.Stre return bsnet.host.NewStream(ctx, p, bsnet.supportedProtocols...) } -func (bsnet *impl) SetDelegate(r Receiver) { +func (bsnet *impl) Start(r Receiver) { bsnet.receiver = r bsnet.connectEvtMgr = newConnectEventManager(r) for _, proto := range bsnet.supportedProtocols { bsnet.host.SetStreamHandler(proto, bsnet.handleNewStream) } bsnet.host.Network().Notify((*netNotifiee)(bsnet)) - // TODO: StopNotify. + bsnet.connectEvtMgr.Start() } +func (bsnet *impl) Stop() { + bsnet.connectEvtMgr.Stop() + bsnet.host.Network().StopNotify((*netNotifiee)(bsnet)) +} + func (bsnet *impl) ConnectTo(ctx context.Context, p peer.ID) error { return bsnet.host.Connect(ctx, peer.AddrInfo{ID: p}) } @@ -450,8 +455,8 @@ func (nn *netNotifiee) Connected(n network.Network, v network.Conn) { nn.impl().connectEvtMgr.Connected(v.RemotePeer()) } func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { - // ignore transient connections - if v.Stat().Transient { + // Only record a "disconnect" when we actually disconnect. + if n.Connectedness(v.RemotePeer()) == network.Connected { return } diff --git a/network/ipfs_impl_test.go b/network/ipfs_impl_test.go index 0d7968ec..9e069489 100644 --- a/network/ipfs_impl_test.go +++ b/network/ipfs_impl_test.go @@ -38,7 +38,8 @@ func newReceiver() *receiver { return &receiver{ peers: make(map[peer.ID]struct{}), messageReceived: make(chan struct{}), - connectionEvent: make(chan bool, 1), + // Avoid blocking. 100 is good enough for tests. + connectionEvent: make(chan bool, 100), } } @@ -169,8 +170,10 @@ func TestMessageSendAndReceive(t *testing.T) { bsnet2 := streamNet.Adapter(p2) r1 := newReceiver() r2 := newReceiver() - bsnet1.SetDelegate(r1) - bsnet2.SetDelegate(r2) + bsnet1.Start(r1) + t.Cleanup(bsnet1.Stop) + bsnet2.Start(r2) + t.Cleanup(bsnet2.Stop) err = mn.LinkAll() if err != nil { @@ -268,7 +271,8 @@ func prepareNetwork(t *testing.T, ctx context.Context, p1 tnet.Identity, r1 *rec eh1 := &ErrHost{Host: h1} routing1 := mr.ClientWithDatastore(context.TODO(), p1, ds.NewMapDatastore()) bsnet1 := bsnet.NewFromIpfsHost(eh1, routing1) - bsnet1.SetDelegate(r1) + bsnet1.Start(r1) + t.Cleanup(bsnet1.Stop) if r1.listener != nil { eh1.Network().Notify(r1.listener) } @@ -281,7 +285,8 @@ func prepareNetwork(t *testing.T, ctx context.Context, p1 tnet.Identity, r1 *rec eh2 := &ErrHost{Host: h2} routing2 := mr.ClientWithDatastore(context.TODO(), p2, ds.NewMapDatastore()) bsnet2 := bsnet.NewFromIpfsHost(eh2, routing2) - bsnet2.SetDelegate(r2) + bsnet2.Start(r2) + t.Cleanup(bsnet2.Stop) if r2.listener != nil { eh2.Network().Notify(r2.listener) } @@ -454,28 +459,32 @@ func TestSupportsHave(t *testing.T) { } for _, tc := range testCases { - p1 := tnet.RandIdentityOrFatal(t) - bsnet1 := streamNet.Adapter(p1) - bsnet1.SetDelegate(newReceiver()) - - p2 := tnet.RandIdentityOrFatal(t) - bsnet2 := streamNet.Adapter(p2, bsnet.SupportedProtocols([]protocol.ID{tc.proto})) - bsnet2.SetDelegate(newReceiver()) - - err = mn.LinkAll() - if err != nil { - t.Fatal(err) - } + t.Run(fmt.Sprintf("%s-%v", tc.proto, tc.expSupportsHave), func(t *testing.T) { + p1 := tnet.RandIdentityOrFatal(t) + bsnet1 := streamNet.Adapter(p1) + bsnet1.Start(newReceiver()) + t.Cleanup(bsnet1.Stop) + + p2 := tnet.RandIdentityOrFatal(t) + bsnet2 := streamNet.Adapter(p2, bsnet.SupportedProtocols([]protocol.ID{tc.proto})) + bsnet2.Start(newReceiver()) + t.Cleanup(bsnet2.Stop) + + err = mn.LinkAll() + if err != nil { + t.Fatal(err) + } - senderCurrent, err := bsnet1.NewMessageSender(ctx, p2.ID(), &bsnet.MessageSenderOpts{}) - if err != nil { - t.Fatal(err) - } - defer senderCurrent.Close() + senderCurrent, err := bsnet1.NewMessageSender(ctx, p2.ID(), &bsnet.MessageSenderOpts{}) + if err != nil { + t.Fatal(err) + } + defer senderCurrent.Close() - if senderCurrent.SupportsHave() != tc.expSupportsHave { - t.Fatal("Expected sender HAVE message support", tc.proto, tc.expSupportsHave) - } + if senderCurrent.SupportsHave() != tc.expSupportsHave { + t.Fatal("Expected sender HAVE message support", tc.proto, tc.expSupportsHave) + } + }) } } diff --git a/testnet/network_test.go b/testnet/network_test.go index 89f3d68f..fbd1fa41 100644 --- a/testnet/network_test.go +++ b/testnet/network_test.go @@ -28,7 +28,7 @@ func TestSendMessageAsyncButWaitForResponse(t *testing.T) { expectedStr := "received async" - responder.SetDelegate(lambda(func( + responder.Start(lambda(func( ctx context.Context, fromWaiter peer.ID, msgFromWaiter bsmsg.BitSwapMessage) { @@ -40,8 +40,9 @@ func TestSendMessageAsyncButWaitForResponse(t *testing.T) { t.Error(err) } })) + t.Cleanup(responder.Stop) - waiter.SetDelegate(lambda(func( + waiter.Start(lambda(func( ctx context.Context, fromResponder peer.ID, msgFromResponder bsmsg.BitSwapMessage) { @@ -59,6 +60,7 @@ func TestSendMessageAsyncButWaitForResponse(t *testing.T) { t.Fatal("Message not received from the responder") } })) + t.Cleanup(waiter.Stop) messageSentAsync := bsmsg.New(true) messageSentAsync.AddBlock(blocks.NewBlock([]byte("data"))) diff --git a/testnet/virtual.go b/testnet/virtual.go index 66f5e821..b5405841 100644 --- a/testnet/virtual.go +++ b/testnet/virtual.go @@ -300,10 +300,13 @@ func (nc *networkClient) Provide(ctx context.Context, k cid.Cid) error { return nc.routing.Provide(ctx, k, true) } -func (nc *networkClient) SetDelegate(r bsnet.Receiver) { +func (nc *networkClient) Start(r bsnet.Receiver) { nc.Receiver = r } +func (nc *networkClient) Stop() { +} + func (nc *networkClient) ConnectTo(_ context.Context, p peer.ID) error { nc.network.mu.Lock() otherClient, ok := nc.network.clients[p]