From db41da3b261432f2150b4f6b035b13d9dcb6778c Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 1 Aug 2024 07:36:46 -0700 Subject: [PATCH] feat: WebRTC reuse QUIC conn (#2889) * feat: WebRTC reuse QUIC conn * Fix transport constructor in test * Move provide to where the transports are --- config/config.go | 26 ++++++ libp2p_test.go | 86 +++++++++++++++++++ p2p/net/swarm/swarm_listen.go | 40 ++++++++- p2p/transport/quic/transport.go | 6 ++ p2p/transport/quicreuse/connmgr.go | 22 +++++ p2p/transport/quicreuse/nonquic_packetconn.go | 74 ++++++++++++++++ p2p/transport/webrtc/transport.go | 16 +++- p2p/transport/webrtc/transport_test.go | 8 +- 8 files changed, 269 insertions(+), 9 deletions(-) create mode 100644 p2p/transport/quicreuse/nonquic_packetconn.go diff --git a/config/config.go b/config/config.go index 3743979e22..4b08c076ff 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "errors" "fmt" + "net" "time" "github.com/libp2p/go-libp2p/core/connmgr" @@ -35,10 +36,12 @@ import ( relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/prometheus/client_golang/prometheus" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" + manet "github.com/multiformats/go-multiaddr/net" "github.com/quic-go/quic-go" "go.uber.org/fx" "go.uber.org/fx/fxevent" @@ -284,6 +287,29 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }), + fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn { + hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { + quicAddrPorts := map[string]struct{}{} + for _, addr := range sw.ListenAddresses() { + if _, err := addr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + netw, addr, err := manet.DialArgs(addr) + if err != nil { + return false + } + quicAddrPorts[netw+"_"+addr] = struct{}{} + } + } + _, ok := quicAddrPorts[network+"_"+laddr.String()] + return ok + } + + return func(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + if hasQuicAddrPortFor(network, laddr) { + return cm.SharedNonQUICPacketConn(network, laddr) + } + return net.ListenUDP(network, laddr) + } + }), } fxopts = append(fxopts, cfg.Transports...) if cfg.Insecure { diff --git a/libp2p_test.go b/libp2p_test.go index 54a1b6688a..676af7696e 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "regexp" + "strconv" + "strings" "testing" "time" @@ -24,7 +26,9 @@ import ( "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcp" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "go.uber.org/goleak" @@ -416,6 +420,7 @@ func TestMain(m *testing.M) { m, // This will return eventually (5s timeout) but doesn't take a context. goleak.IgnoreAnyFunction("github.com/koron/go-ssdp.Search"), + goleak.IgnoreAnyFunction("github.com/pion/sctp.(*Stream).SetReadDeadline.func1"), // Logging & Stats goleak.IgnoreTopFunction("github.com/ipfs/go-log/v2/writer.(*MirrorWriter).logRoutine"), goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), @@ -490,3 +495,84 @@ func TestHostAddrsFactoryAddsCerthashes(t *testing.T) { }, 5*time.Second, 50*time.Millisecond) h.Close() } + +func TestWebRTCReuseAddrWithQUIC(t *testing.T) { + order := [][]string{ + {"/ip4/127.0.0.1/udp/54322/quic-v1", "/ip4/127.0.0.1/udp/54322/webrtc-direct"}, + {"/ip4/127.0.0.1/udp/54322/webrtc-direct", "/ip4/127.0.0.1/udp/54322/quic-v1"}, + // We do not support WebRTC automatically reusing QUIC addresses if port is not specified, yet. + // {"/ip4/127.0.0.1/udp/0/webrtc-direct", "/ip4/127.0.0.1/udp/0/quic-v1"}, + } + for i, addrs := range order { + t.Run("Order "+strconv.Itoa(i), func(t *testing.T) { + h1, err := New(ListenAddrStrings(addrs...), Transport(quic.NewTransport), Transport(libp2pwebrtc.New)) + require.NoError(t, err) + defer h1.Close() + + seenPorts := make(map[string]struct{}) + for _, addr := range h1.Addrs() { + s, err := addr.ValueForProtocol(ma.P_UDP) + require.NoError(t, err) + seenPorts[s] = struct{}{} + } + require.Len(t, seenPorts, 1) + + quicClient, err := New(NoListenAddrs, Transport(quic.NewTransport)) + require.NoError(t, err) + defer quicClient.Close() + + webrtcClient, err := New(NoListenAddrs, Transport(libp2pwebrtc.New)) + require.NoError(t, err) + defer webrtcClient.Close() + + for _, client := range []host.Host{quicClient, webrtcClient} { + err := client.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()}) + require.NoError(t, err) + } + + t.Run("quic client can connect", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + p := ping.NewPingService(quicClient) + resCh := p.Ping(ctx, h1.ID()) + res := <-resCh + require.NoError(t, res.Error) + }) + + t.Run("webrtc client can connect", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + p := ping.NewPingService(webrtcClient) + resCh := p.Ping(ctx, h1.ID()) + res := <-resCh + require.NoError(t, res.Error) + }) + }) + } + + swapPort := func(addrStrs []string, newPort string) []string { + out := make([]string, 0, len(addrStrs)) + for _, addrStr := range addrStrs { + out = append(out, strings.Replace(addrStr, "54322", newPort, 1)) + } + return out + } + + t.Run("setup with no reuseport. Should fail", func(t *testing.T) { + h1, err := New(ListenAddrStrings(swapPort(order[0], "54323")...), Transport(quic.NewTransport), Transport(libp2pwebrtc.New), QUICReuse(quicreuse.NewConnManager, quicreuse.DisableReuseport())) + require.NoError(t, err) // It's a bug/feature that swarm.Listen does not error if at least one transport succeeds in listening. + defer h1.Close() + // Check that webrtc did fail to listen + require.Equal(t, 1, len(h1.Addrs())) + require.Contains(t, h1.Addrs()[0].String(), "quic-v1") + }) + + t.Run("setup with autonat", func(t *testing.T) { + h1, err := New(EnableAutoNATv2(), ListenAddrStrings(swapPort(order[0], "54324")...), Transport(quic.NewTransport), Transport(libp2pwebrtc.New), QUICReuse(quicreuse.NewConnManager, quicreuse.DisableReuseport())) + require.NoError(t, err) // It's a bug/feature that swarm.Listen does not error if at least one transport succeeds in listening. + defer h1.Close() + // Check that webrtc did fail to listen + require.Equal(t, 1, len(h1.Addrs())) + require.Contains(t, h1.Addrs()[0].String(), "quic-v1") + }) +} diff --git a/p2p/net/swarm/swarm_listen.go b/p2p/net/swarm/swarm_listen.go index 0905e84513..e94db44a42 100644 --- a/p2p/net/swarm/swarm_listen.go +++ b/p2p/net/swarm/swarm_listen.go @@ -3,6 +3,7 @@ package swarm import ( "errors" "fmt" + "slices" "time" "github.com/libp2p/go-libp2p/core/canonicallog" @@ -12,13 +13,44 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +type OrderedListener interface { + // Transports optionally implement this interface to indicate the relative + // ordering that listeners should be setup. Some transports may optionally + // make use of other listeners if they are setup. e.g. WebRTC may reuse the + // same UDP port as QUIC, but only when QUIC is setup first. + // lower values are setup first. + ListenOrder() int +} + // Listen sets up listeners for all of the given addresses. // It returns as long as we successfully listen on at least *one* address. func (s *Swarm) Listen(addrs ...ma.Multiaddr) error { errs := make([]error, len(addrs)) var succeeded int - for i, a := range addrs { - if err := s.AddListenAddr(a); err != nil { + + type addrAndListener struct { + addr ma.Multiaddr + lTpt transport.Transport + } + sortedAddrsAndTpts := make([]addrAndListener, 0, len(addrs)) + for _, a := range addrs { + t := s.TransportForListening(a) + sortedAddrsAndTpts = append(sortedAddrsAndTpts, addrAndListener{addr: a, lTpt: t}) + } + slices.SortFunc(sortedAddrsAndTpts, func(a, b addrAndListener) int { + aOrder := 0 + bOrder := 0 + if l, ok := a.lTpt.(OrderedListener); ok { + aOrder = l.ListenOrder() + } + if l, ok := b.lTpt.(OrderedListener); ok { + bOrder = l.ListenOrder() + } + return aOrder - bOrder + }) + + for i, a := range sortedAddrsAndTpts { + if err := s.AddListenAddr(a.addr); err != nil { errs[i] = err } else { succeeded++ @@ -27,11 +59,11 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error { for i, e := range errs { if e != nil { - log.Warnw("listening failed", "on", addrs[i], "error", errs[i]) + log.Warnw("listening failed", "on", sortedAddrsAndTpts[i].addr, "error", errs[i]) } } - if succeeded == 0 && len(addrs) > 0 { + if succeeded == 0 && len(sortedAddrsAndTpts) > 0 { return fmt.Errorf("failed to listen on any addresses: %s", errs) } diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 18d198bbea..04b5e4d6fe 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -26,6 +26,8 @@ import ( "github.com/quic-go/quic-go" ) +const ListenOrder = 1 + var log = logging.Logger("quic-transport") var ErrHolePunching = errors.New("hole punching attempted; no active dial") @@ -103,6 +105,10 @@ func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.P }, nil } +func (t *transport) ListenOrder() int { + return ListenOrder +} + // Dial dials a new QUIC connection func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (_c tpt.CapableConn, _err error) { if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient { diff --git a/p2p/transport/quicreuse/connmgr.go b/p2p/transport/quicreuse/connmgr.go index 500a9c9760..01e1bcda5e 100644 --- a/p2p/transport/quicreuse/connmgr.go +++ b/p2p/transport/quicreuse/connmgr.go @@ -154,6 +154,28 @@ func (c *ConnManager) onListenerClosed(key string) { } } +func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + c.quicListenersMu.Lock() + defer c.quicListenersMu.Unlock() + key := laddr.String() + entry, ok := c.quicListeners[key] + if !ok { + return nil, errors.New("expected to be able to share with a QUIC listener, but no QUIC listener found. The QUIC listener should start first") + } + t := entry.ln.transport + if t, ok := t.(*refcountedTransport); ok { + t.IncreaseCount() + ctx, cancel := context.WithCancel(context.Background()) + return &nonQUICPacketConn{ + ctx: ctx, + ctxCancel: cancel, + owningTransport: t, + tr: &t.Transport, + }, nil + } + return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set") +} + func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) { if c.enableReuseport { reuse, err := c.getReuse(network) diff --git a/p2p/transport/quicreuse/nonquic_packetconn.go b/p2p/transport/quicreuse/nonquic_packetconn.go new file mode 100644 index 0000000000..2f950e76a1 --- /dev/null +++ b/p2p/transport/quicreuse/nonquic_packetconn.go @@ -0,0 +1,74 @@ +package quicreuse + +import ( + "context" + "net" + "time" + + "github.com/quic-go/quic-go" +) + +// nonQUICPacketConn is a net.PacketConn that can be used to read and write +// non-QUIC packets on a quic.Transport. This lets us reuse this UDP port for +// other transports like WebRTC. +type nonQUICPacketConn struct { + owningTransport refCountedQuicTransport + tr *quic.Transport + ctx context.Context + ctxCancel context.CancelFunc + readCtx context.Context + readCancel context.CancelFunc +} + +// Close implements net.PacketConn. +func (n *nonQUICPacketConn) Close() error { + n.ctxCancel() + + // Don't actually close the underlying transport since someone else might be using it. + // reuse has it's own gc to close unused transports. + n.owningTransport.DecreaseCount() + return nil +} + +// LocalAddr implements net.PacketConn. +func (n *nonQUICPacketConn) LocalAddr() net.Addr { + return n.tr.Conn.LocalAddr() +} + +// ReadFrom implements net.PacketConn. +func (n *nonQUICPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + ctx := n.readCtx + if ctx == nil { + ctx = n.ctx + } + return n.tr.ReadNonQUICPacket(ctx, p) +} + +// SetDeadline implements net.PacketConn. +func (n *nonQUICPacketConn) SetDeadline(t time.Time) error { + // Only used for reads. + return n.SetReadDeadline(t) +} + +// SetReadDeadline implements net.PacketConn. +func (n *nonQUICPacketConn) SetReadDeadline(t time.Time) error { + if t.IsZero() && n.readCtx != nil { + n.readCancel() + n.readCtx = nil + } + n.readCtx, n.readCancel = context.WithDeadline(n.ctx, t) + return nil +} + +// SetWriteDeadline implements net.PacketConn. +func (n *nonQUICPacketConn) SetWriteDeadline(t time.Time) error { + // Unused. quic-go doesn't support deadlines for writes. + return nil +} + +// WriteTo implements net.PacketConn. +func (n *nonQUICPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + return n.tr.WriteTo(p, addr) +} + +var _ net.PacketConn = &nonQUICPacketConn{} diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index db605bcc20..8aaa93e43a 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -26,6 +26,7 @@ import ( "github.com/libp2p/go-libp2p/core/sec" tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/security/noise" + libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" "github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb" "github.com/libp2p/go-msgio" @@ -78,6 +79,8 @@ type WebRTCTransport struct { noiseTpt *noise.Transport localPeerId peer.ID + listenUDP func(network string, laddr *net.UDPAddr) (net.PacketConn, error) + // timeouts peerConnectionTimeouts iceTimeouts @@ -95,7 +98,9 @@ type iceTimeouts struct { Keepalive time.Duration } -func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (*WebRTCTransport, error) { +type ListenUDPFn func(network string, laddr *net.UDPAddr) (net.PacketConn, error) + +func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, listenUDP ListenUDPFn, opts ...Option) (*WebRTCTransport, error) { if psk != nil { log.Error("WebRTC doesn't support private networks yet.") return nil, fmt.Errorf("WebRTC doesn't support private networks yet") @@ -141,6 +146,7 @@ func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr noiseTpt: noiseTpt, localPeerId: localPeerID, + listenUDP: listenUDP, peerConnectionTimeouts: iceTimeouts{ Disconnect: DefaultDisconnectedTimeout, Failed: DefaultFailedTimeout, @@ -157,6 +163,10 @@ func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr return transport, nil } +func (t *WebRTCTransport) ListenOrder() int { + return libp2pquic.ListenOrder + 1 // We want to listen after QUIC listens so we can possibly reuse the same port. +} + func (t *WebRTCTransport) Protocols() []int { return []int{ma.P_WEBRTC_DIRECT} } @@ -190,7 +200,7 @@ func (t *WebRTCTransport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { return nil, fmt.Errorf("listener could not resolve udp address: %w", err) } - socket, err := net.ListenUDP(nw, udpAddr) + socket, err := t.listenUDP(nw, udpAddr) if err != nil { return nil, fmt.Errorf("listen on udp: %w", err) } @@ -203,7 +213,7 @@ func (t *WebRTCTransport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { return listener, nil } -func (t *WebRTCTransport) listenSocket(socket *net.UDPConn) (tpt.Listener, error) { +func (t *WebRTCTransport) listenSocket(socket net.PacketConn) (tpt.Listener, error) { listenerMultiaddr, err := manet.FromNetAddr(socket.LocalAddr()) if err != nil { return nil, err diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index e457c13fdd..e6a60420be 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -29,12 +29,16 @@ import ( "golang.org/x/crypto/sha3" ) +var netListenUDP ListenUDPFn = func(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + return net.ListenUDP(network, laddr) +} + func getTransport(t *testing.T, opts ...Option) (*WebRTCTransport, peer.ID) { t.Helper() privKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) require.NoError(t, err) rcmgr := &network.NullResourceManager{} - transport, err := New(privKey, nil, nil, rcmgr, opts...) + transport, err := New(privKey, nil, nil, rcmgr, netListenUDP, opts...) require.NoError(t, err) peerID, err := peer.IDFromPrivateKey(privKey) require.NoError(t, err) @@ -45,7 +49,7 @@ func getTransport(t *testing.T, opts ...Option) (*WebRTCTransport, peer.ID) { func TestNullRcmgrTransport(t *testing.T) { privKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) require.NoError(t, err) - transport, err := New(privKey, nil, nil, nil) + transport, err := New(privKey, nil, nil, nil, netListenUDP) require.NoError(t, err) listenTransport, pid := getTransport(t)