diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 6587ddd7ba..93fece20b5 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -9,7 +9,6 @@ import ( "io" "net" "os" - "runtime" "strings" "sync" "sync/atomic" @@ -865,18 +864,26 @@ func TestMaxInFlightRequests(t *testing.T) { } func TestManyConnections(t *testing.T) { - const N = 200 - errCh := make(chan error, 200) - successCh := make(chan struct{}, 1) - - tr, lp := getTransport(t) - ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")) - require.NoError(t, err) - defer ln.Close() + var listeners []tpt.Listener + var listenerPeerIDs []peer.ID + + const numListeners = 5 + const dialersPerListener = 5 + const connsPerDialer = 10 + errCh := make(chan error, 10*numListeners*dialersPerListener*connsPerDialer) + successCh := make(chan struct{}, 10*numListeners*dialersPerListener*connsPerDialer) + + for i := 0; i < numListeners; i++ { + tr, lp := getTransport(t) + listenerPeerIDs = append(listenerPeerIDs, lp) + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")) + require.NoError(t, err) + defer ln.Close() + listeners = append(listeners, ln) + } runListenConn := func(conn tpt.CapableConn) { defer conn.Close() - s, err := conn.AcceptStream() if err != nil { t.Errorf("accept stream failed for listener: %s", err) @@ -892,64 +899,81 @@ func TestManyConnections(t *testing.T) { s.Write(b[:]) _, err = s.Read(b[:]) // peer will close the connection after read if !assert.Error(t, err) { - errCh <- errors.New("expected peer to close connection") + err = errors.New("invalid read: expected conn to close") + errCh <- err return } + successCh <- struct{}{} } - runDialConn := func(conn tpt.CapableConn) error { + runDialConn := func(conn tpt.CapableConn) { defer conn.Close() s, err := conn.OpenStream(context.Background()) if err != nil { t.Errorf("accept stream failed for listener: %s", err) - return err + errCh <- err + return } var b [4]byte if _, err := s.Write(b[:]); err != nil { t.Errorf("write stream failed for dialer: %s", err) - return err + errCh <- err + return } if _, err := s.Read(b[:]); err != nil { t.Errorf("read stream failed for dialer: %s", err) - return err + errCh <- err + return } - return nil + s.Close() } - go func() { - for i := 0; i < N; i++ { + runListener := func(ln tpt.Listener) { + for i := 0; i < dialersPerListener*connsPerDialer; i++ { conn, err := ln.Accept() if err != nil { - t.Errorf("listener failed to accept conneciton: %s %d", err, runtime.NumGoroutine()) + t.Errorf("listener failed to accept conneciton: %s", err) return } - runListenConn(conn) - successCh <- struct{}{} + go runListenConn(conn) } - }() + } - tp, _ := getTransport(t) - for i := 0; i < N; i++ { - // This test aims to check for deadlocks. So keep a high timeout - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - conn, err := tp.Dial(ctx, ln.Multiaddr(), lp) - if err != nil { - t.Errorf("dial failed: %s %d", err, runtime.NumGoroutine()) + runDialer := func(ln tpt.Listener, lp peer.ID) { + tp, _ := getTransport(t) + for i := 0; i < connsPerDialer; i++ { + // We want to test for deadlocks, set a high timeout + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + conn, err := tp.Dial(ctx, ln.Multiaddr(), lp) + if err != nil { + t.Errorf("dial failed: %s", err) + errCh <- err + cancel() + return + } + runDialConn(conn) cancel() - return } - err = runDialConn(conn) - require.NoError(t, err) - cancel() + } + + for i := 0; i < numListeners; i++ { + go runListener(listeners[i]) + } + for i := 0; i < numListeners; i++ { + for j := 0; j < dialersPerListener; j++ { + go runDialer(listeners[i], listenerPeerIDs[i]) + } + } + + for i := 0; i < numListeners*dialersPerListener*connsPerDialer; i++ { select { - case <-time.After(120 * time.Second): - t.Fatalf("timed out %d", runtime.NumGoroutine()) - case <-errCh: - t.Fatal("listener error:", err, runtime.NumGoroutine()) case <-successCh: + t.Log("completed conn: ", i) + case err := <-errCh: + t.Fatalf("failed: %s", err) + case <-time.After(300 * time.Second): + t.Fatalf("timed out") } - t.Log("completed conn:", i, runtime.NumGoroutine()) } - }