diff --git a/libp2p_test.go b/libp2p_test.go index c5f667610c..0b81e33bb5 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -17,6 +17,7 @@ import ( tls "github.com/libp2p/go-libp2p/p2p/security/tls" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" "github.com/libp2p/go-libp2p/p2p/transport/tcp" + webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" @@ -289,3 +290,42 @@ func TestSecurityConstructor(t *testing.T) { require.Contains(t, err.Error(), "failed to negotiate security protocol") require.NoError(t, h2.Connect(context.Background(), ai)) } + +func TestTransportConstructorWebTransport(t *testing.T) { + h, err := New( + Transport(webtransport.New), + DisableRelay(), + ) + require.NoError(t, err) + defer h.Close() + require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))) + err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/")) + require.Error(t, err) + require.Contains(t, err.Error(), swarm.ErrNoTransport.Error()) +} + +func TestTransportCustomAddressWebTransport(t *testing.T) { + customAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1/webtransport") + if err != nil { + t.Fatal(err) + } + h, err := New( + Transport(webtransport.New), + ListenAddrs(customAddr), + DisableRelay(), + AddrsFactory(func(multiaddrs []ma.Multiaddr) []ma.Multiaddr { + return []ma.Multiaddr{customAddr} + }), + ) + require.NoError(t, err) + defer h.Close() + require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))) + addrs := h.Addrs() + require.Len(t, addrs, 1) + require.NotEqual(t, addrs[0], customAddr) + restOfAddr, lastComp := ma.SplitLast(addrs[0]) + restOfAddr, secondToLastComp := ma.SplitLast(restOfAddr) + require.Equal(t, lastComp.Protocol().Code, ma.P_CERTHASH) + require.Equal(t, secondToLastComp.Protocol().Code, ma.P_CERTHASH) + require.True(t, restOfAddr.Equal(customAddr)) +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index bee068fd9c..6db431e81e 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -18,6 +18,7 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/record" + "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" @@ -27,6 +28,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/protocol/ping" + libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/prometheus/client_golang/prometheus" "github.com/libp2p/go-netroute" @@ -760,7 +762,51 @@ func (h *BasicHost) ConnManager() connmgr.ConnManager { // Addrs returns listening addresses that are safe to announce to the network. // The output is the same as AllAddrs, but processed by AddrsFactory. func (h *BasicHost) Addrs() []ma.Multiaddr { - return h.AddrsFactory(h.AllAddrs()) + // This is a temporary workaround/hack that fixes #2233. Once we have a + // proper address pipeline, rework this. See the issue for more context. + type transportForListeninger interface { + TransportForListening(a ma.Multiaddr) transport.Transport + } + type addCertHasher interface { + AddCertHashes(m ma.Multiaddr) ma.Multiaddr + } + + addrs := h.AddrsFactory(h.AllAddrs()) + + s, ok := h.Network().(transportForListeninger) + if !ok { + return addrs + } + + // Copy addrs slice since we'll be modifying it. + addrsOld := addrs + addrs = make([]ma.Multiaddr, len(addrsOld)) + copy(addrs, addrsOld) + + for i, addr := range addrs { + if ok, n := libp2pwebtransport.IsWebtransportMultiaddr(addr); ok && n == 0 { + t := s.TransportForListening(addr) + tpt, ok := t.(addCertHasher) + if !ok { + continue + } + addrs[i] = tpt.AddCertHashes(addr) + } + } + return addrs +} + +// NormalizeMultiaddr returns a multiaddr suitable for equality checks. +// If the multiaddr is a webtransport component, it removes the certhashes. +func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { + if ok, n := libp2pwebtransport.IsWebtransportMultiaddr(addr); ok && n > 0 { + out := addr + for i := 0; i < n; i++ { + out, _ = ma.SplitLast(out) + } + return out + } + return addr } // mergeAddrs merges input address lists, leave only unique addresses diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 12a997c3b4..ac49aad007 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -816,3 +816,11 @@ func peerRecordFromEnvelope(t *testing.T, ev *record.Envelope) *peer.PeerRecord } return peerRec } + +func TestNormalizeMultiaddr(t *testing.T) { + h1, err := NewHost(swarmt.GenSwarm(t), nil) + require.NoError(t, err) + defer h1.Close() + + require.Equal(t, "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport", h1.NormalizeMultiaddr(ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28")).String()) +} diff --git a/p2p/protocol/identify/obsaddr.go b/p2p/protocol/identify/obsaddr.go index 9b20ee4f61..deec772bd8 100644 --- a/p2p/protocol/identify/obsaddr.go +++ b/p2p/protocol/identify/obsaddr.go @@ -356,6 +356,10 @@ func (oas *ObservedAddrManager) removeConn(conn network.Conn) { oas.activeConnsMu.Unlock() } +type normalizeMultiaddrer interface { + NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr +} + func (oas *ObservedAddrManager) maybeRecordObservation(conn network.Conn, observed ma.Multiaddr) { // First, determine if this observation is even worth keeping... @@ -375,16 +379,40 @@ func (oas *ObservedAddrManager) maybeRecordObservation(conn network.Conn, observ return } + normalizer, canNormalize := oas.host.(normalizeMultiaddrer) + + if canNormalize { + for i, a := range ifaceaddrs { + ifaceaddrs[i] = normalizer.NormalizeMultiaddr(a) + } + } + local := conn.LocalMultiaddr() + if canNormalize { + local = normalizer.NormalizeMultiaddr(local) + } if !ma.Contains(ifaceaddrs, local) && !ma.Contains(oas.host.Network().ListenAddresses(), local) { // not in our list return } + hostAddrs := oas.host.Addrs() + if canNormalize { + for i, a := range hostAddrs { + hostAddrs[i] = normalizer.NormalizeMultiaddr(a) + } + } + listenAddrs := oas.host.Network().ListenAddresses() + if canNormalize { + for i, a := range listenAddrs { + listenAddrs[i] = normalizer.NormalizeMultiaddr(a) + } + } + // We should reject the connection if the observation doesn't match the // transports of one of our advertised addresses. - if !HasConsistentTransport(observed, oas.host.Addrs()) && - !HasConsistentTransport(observed, oas.host.Network().ListenAddresses()) { + if !HasConsistentTransport(observed, hostAddrs) && + !HasConsistentTransport(observed, listenAddrs) { log.Debugw( "observed multiaddr doesn't match the transports of any announced addresses", "from", conn.RemoteMultiaddr(), diff --git a/p2p/protocol/identify/obsaddr_test.go b/p2p/protocol/identify/obsaddr_test.go index e807dd08be..2987b01bd8 100644 --- a/p2p/protocol/identify/obsaddr_test.go +++ b/p2p/protocol/identify/obsaddr_test.go @@ -83,10 +83,14 @@ func (h *harness) observeInbound(observed ma.Multiaddr, observer peer.ID) networ } func newHarness(t *testing.T) harness { + return newHarnessWithMa(t, ma.StringCast("/ip4/127.0.0.1/tcp/10086")) +} + +func newHarnessWithMa(t *testing.T, listenAddr ma.Multiaddr) harness { mn := mocknet.New() sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader) require.NoError(t, err) - h, err := mn.AddPeer(sk, ma.StringCast("/ip4/127.0.0.1/tcp/10086")) + h, err := mn.AddPeer(sk, listenAddr) require.NoError(t, err) oas, err := identify.NewObservedAddrManager(h) require.NoError(t, err) @@ -408,3 +412,25 @@ func TestEmitNATDeviceTypeCone(t *testing.T) { t.Fatal("did not get Cone NAT event") } } + +func TestObserveWebtransport(t *testing.T) { + listenAddr := ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28") + observedAddr := ma.StringCast("/ip4/1.2.3.4/udp/1231/quic-v1/webtransport") + + harness := newHarnessWithMa(t, listenAddr) + + pb1 := harness.add(ma.StringCast("/ip4/1.2.3.6/udp/1236/quic-v1/webtransport")) + pb2 := harness.add(ma.StringCast("/ip4/1.2.3.7/udp/1237/quic-v1/webtransport")) + pb3 := harness.add(ma.StringCast("/ip4/1.2.3.8/udp/1237/quic-v1/webtransport")) + pb4 := harness.add(ma.StringCast("/ip4/1.2.3.9/udp/1237/quic-v1/webtransport")) + pb5 := harness.add(ma.StringCast("/ip4/1.2.3.10/udp/1237/quic-v1/webtransport")) + + harness.observe(observedAddr, pb1) + harness.observe(observedAddr, pb2) + harness.observe(observedAddr, pb3) + harness.observe(observedAddr, pb4) + harness.observe(observedAddr, pb5) + + require.Equal(t, 1, len(harness.oas.Addrs())) + require.Equal(t, "/ip4/1.2.3.4/udp/1231/quic-v1/webtransport", harness.oas.Addrs()[0].String()) +} diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go index b6b79336c8..d6930af363 100644 --- a/p2p/transport/webtransport/multiaddr.go +++ b/p2p/transport/webtransport/multiaddr.go @@ -7,7 +7,6 @@ import ( "strconv" ma "github.com/multiformats/go-multiaddr" - mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" @@ -15,8 +14,6 @@ import ( var webtransportMA = ma.StringCast("/quic-v1/webtransport") -var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC_V1), mafmt.Base(ma.P_WEBTRANSPORT)) - func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { addr, err := manet.FromNetAddr(na) if err != nil { @@ -78,3 +75,33 @@ func addrComponentForCert(hash []byte) (ma.Multiaddr, error) { } return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) } + +// IsWebtransportMultiaddr returns true if the given multiaddr is a well formed +// webtransport multiaddr. Returns the number of certhashes found. +func IsWebtransportMultiaddr(multiaddr ma.Multiaddr) (bool, int) { + const ( + init = iota + foundUDP + foundQuicV1 + foundWebTransport + ) + state := init + certhashCount := 0 + + ma.ForEach(multiaddr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_QUIC_V1 && state == init { + state = foundUDP + } + if c.Protocol().Code == ma.P_QUIC_V1 && state == foundUDP { + state = foundQuicV1 + } + if c.Protocol().Code == ma.P_WEBTRANSPORT && state == foundQuicV1 { + state = foundWebTransport + } + if c.Protocol().Code == ma.P_CERTHASH && state == foundWebTransport { + certhashCount++ + } + return true + }) + return state == foundWebTransport, certhashCount +} diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go index ae3ebc4a3e..4e9d7c336c 100644 --- a/p2p/transport/webtransport/multiaddr_test.go +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -103,3 +103,28 @@ func TestWebtransportResolve(t *testing.T) { require.Error(t, err) }) } + +func TestIsWebtransportMultiaddr(t *testing.T) { + fooHash := encodeCertHash(t, []byte("foo"), multihash.SHA2_256, multibase.Base58BTC) + barHash := encodeCertHash(t, []byte("bar"), multihash.SHA2_256, multibase.Base58BTC) + + testCases := []struct { + addr string + want bool + certhashCount int + }{ + {addr: "/ip4/1.2.3.4/udp/60042/quic-v1/webtransport", want: true}, + {addr: "/ip4/1.2.3.4/udp/60042/quic-v1/webtransport/certhash/" + fooHash, want: true, certhashCount: 1}, + {addr: "/ip4/1.2.3.4/udp/60042/quic-v1/webtransport/certhash/" + fooHash + "/certhash/" + barHash, want: true, certhashCount: 2}, + {addr: "/dns4/example.com/udp/60042/quic-v1/webtransport/certhash/" + fooHash, want: true, certhashCount: 1}, + {addr: "/dns4/example.com/udp/60042/webrtc/certhash/" + fooHash, want: false}, + } + + for _, tc := range testCases { + t.Run(tc.addr, func(t *testing.T) { + got, n := IsWebtransportMultiaddr(ma.StringCast(tc.addr)) + require.Equal(t, tc.want, got) + require.Equal(t, tc.certhashCount, n) + }) + } +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index aef58a28c5..e44ac33252 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -68,11 +68,12 @@ type transport struct { rcmgr network.ResourceManager gater connmgr.ConnectionGater - listenOnce sync.Once - listenOnceErr error - certManager *certManager - staticTLSConf *tls.Config - tlsClientConf *tls.Config + listenOnce sync.Once + listenOnceErr error + certManager *certManager + certManagerReady chan struct{} // Closed when the certManager has been instantiated. + staticTLSConf *tls.Config + tlsClientConf *tls.Config noise *noise.Transport @@ -97,13 +98,14 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater return nil, err } t := &transport{ - pid: id, - privKey: key, - rcmgr: rcmgr, - gater: gater, - clock: clock.New(), - connManager: connManager, - conns: map[uint64]*conn{}, + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + connManager: connManager, + conns: map[uint64]*conn{}, + certManagerReady: make(chan struct{}), } for _, opt := range opts { if err := opt(t); err != nil { @@ -286,33 +288,25 @@ func decodeCertHashesFromProtobuf(b [][]byte) ([]multihash.DecodedMultihash, err } func (t *transport) CanDial(addr ma.Multiaddr) bool { - var numHashes int - ma.ForEach(addr, func(c ma.Component) bool { - if c.Protocol().Code == ma.P_CERTHASH { - numHashes++ - } - return true - }) - // Remove the /certhash components from the multiaddr. - // If the multiaddr doesn't contain any certhashes, the node might have a CA-signed certificate. - for i := 0; i < numHashes; i++ { - addr, _ = ma.SplitLast(addr) - } - return webtransportMatcher.Matches(addr) + ok, _ := IsWebtransportMultiaddr(addr) + return ok } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { - if !webtransportMatcher.Matches(laddr) { + isWebTransport, _ := IsWebtransportMultiaddr(laddr) + if !isWebTransport { return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) } if t.staticTLSConf == nil { t.listenOnce.Do(func() { t.certManager, t.listenOnceErr = newCertManager(t.privKey, t.clock) + close(t.certManagerReady) }) if t.listenOnceErr != nil { return nil, t.listenOnceErr } } else { + close(t.certManagerReady) return nil, errors.New("static TLS config not supported on WebTransport") } tlsConf := t.staticTLSConf.Clone() @@ -410,3 +404,11 @@ func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiad } return []ma.Multiaddr{beforeQuicMA.Encapsulate(quicComponent).Encapsulate(sniComponent).Encapsulate(afterQuicMA)}, nil } + +func (t *transport) AddCertHashes(m ma.Multiaddr) ma.Multiaddr { + <-t.certManagerReady + if t.certManager == nil { + return m + } + return m.Encapsulate(t.certManager.AddrComponent()) +}