Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

correctly handle WebTransport addresses without certhashes #2239

Merged
merged 15 commits into from
Apr 6, 2023
39 changes: 39 additions & 0 deletions libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"

ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -289,3 +290,41 @@ func TestSecurityConstructor(t *testing.T) {
require.Contains(t, err.Error(), "failed to negotiate security protocol")
require.NoError(t, h2.Connect(context.Background(), ai))
}

func TestTransportConstructorWebTransport(t *testing.T) {
h, err := New(
Transport(webtransport.New),
DisableRelay(),
)
require.NoError(t, err)
defer h.Close()
require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")))
err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/"))
require.Error(t, err)
require.Contains(t, err.Error(), swarm.ErrNoTransport.Error())
}

func TestTransportCustomAddressWebTransport(t *testing.T) {
customAddr, err := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1/webtransport")
if err != nil {
t.Fatal(err)
}
h, err := New(
Transport(webtransport.New),
DisableRelay(),
AddrsFactory(func(multiaddrs []ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{customAddr}
}),
)
require.NoError(t, err)
defer h.Close()
require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")))
addrs := h.Addrs()
require.Len(t, addrs, 1)
require.NotEqual(t, addrs[0], customAddr)
restOfAddr, lastComp := ma.SplitLast(addrs[0])
restOfAddr, secondToLastComp := ma.SplitLast(restOfAddr)
require.Equal(t, lastComp.Protocol().Code, ma.P_CERTHASH)
require.Equal(t, secondToLastComp.Protocol().Code, ma.P_CERTHASH)
require.True(t, restOfAddr.Equal(customAddr))
}
1 change: 1 addition & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ func AddrsFactory(factory config.AddrsFactory) Option {
if cfg.AddrsFactory != nil {
return fmt.Errorf("cannot specify multiple address factories")
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

cfg.AddrsFactory = factory
return nil
}
Expand Down
46 changes: 45 additions & 1 deletion p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/record"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/host/autonat"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/host/pstoremanager"
Expand All @@ -27,11 +28,13 @@ import (
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/prometheus/client_golang/prometheus"

"github.com/libp2p/go-netroute"

logging "github.com/ipfs/go-log/v2"
"github.com/multiformats/go-multiaddr"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
Expand Down Expand Up @@ -760,7 +763,48 @@ func (h *BasicHost) ConnManager() connmgr.ConnManager {
// Addrs returns listening addresses that are safe to announce to the network.
// The output is the same as AllAddrs, but processed by AddrsFactory.
func (h *BasicHost) Addrs() []ma.Multiaddr {
return h.AddrsFactory(h.AllAddrs())
type transportForListeninger interface {
MarcoPolo marked this conversation as resolved.
Show resolved Hide resolved
TransportForListening(a ma.Multiaddr) transport.Transport
}
type addCertHasher interface {
AddCertHashes(m ma.Multiaddr) ma.Multiaddr
}

addrs := h.AddrsFactory(h.AllAddrs())

s, ok := h.Network().(transportForListeninger)
if !ok {
return addrs
}

for i, addr := range addrs {
if ok, n := libp2pwebtransport.IsWebtransportMultiaddr(addr); ok && n == 0 {
t := s.TransportForListening(addr)
tpt, ok := t.(addCertHasher)
if !ok {
continue
}
addrs[i] = tpt.AddCertHashes(addr)
}
}
return addrs
}

// NormalizeMultiaddr returns a multiaddr suitable for equality checks.
// If the multiaddr is a webtransport component, it removes the certhashes.
func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this would move into the address pipeline, once we have it, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Anywhere that isn't adding another method to BasicHost

if ok, n := libp2pwebtransport.IsWebtransportMultiaddr(addr); ok && n > 0 {
var firstCerthash ma.Multiaddr
MarcoPolo marked this conversation as resolved.
Show resolved Hide resolved
multiaddr.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
firstCerthash = &c
return false
}
return true
})
return addr.Decapsulate(firstCerthash)
}
return addr
}

// mergeAddrs merges input address lists, leave only unique addresses
Expand Down
32 changes: 30 additions & 2 deletions p2p/protocol/identify/obsaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ func (oas *ObservedAddrManager) removeConn(conn network.Conn) {
oas.activeConnsMu.Unlock()
}

type normalizeMultiaddrer interface {
NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr
}

func (oas *ObservedAddrManager) maybeRecordObservation(conn network.Conn, observed ma.Multiaddr) {
// First, determine if this observation is even worth keeping...

Expand All @@ -375,16 +379,40 @@ func (oas *ObservedAddrManager) maybeRecordObservation(conn network.Conn, observ
return
}

normalizer, canNormalize := oas.host.(normalizeMultiaddrer)

if canNormalize {
for i, a := range ifaceaddrs {
ifaceaddrs[i] = normalizer.NormalizeMultiaddr(a)
}
}

local := conn.LocalMultiaddr()
if canNormalize {
local = normalizer.NormalizeMultiaddr(local)
}
if !ma.Contains(ifaceaddrs, local) && !ma.Contains(oas.host.Network().ListenAddresses(), local) {
// not in our list
return
}

hostAddrs := oas.host.Addrs()
if canNormalize {
for i, a := range hostAddrs {
hostAddrs[i] = normalizer.NormalizeMultiaddr(a)
}
}
listenAddrs := oas.host.Network().ListenAddresses()
if canNormalize {
for i, a := range listenAddrs {
listenAddrs[i] = normalizer.NormalizeMultiaddr(a)
}
}

// We should reject the connection if the observation doesn't match the
// transports of one of our advertised addresses.
if !HasConsistentTransport(observed, oas.host.Addrs()) &&
!HasConsistentTransport(observed, oas.host.Network().ListenAddresses()) {
if !HasConsistentTransport(observed, hostAddrs) &&
!HasConsistentTransport(observed, listenAddrs) {
log.Debugw(
"observed multiaddr doesn't match the transports of any announced addresses",
"from", conn.RemoteMultiaddr(),
Expand Down
28 changes: 27 additions & 1 deletion p2p/protocol/identify/obsaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ func (h *harness) observeInbound(observed ma.Multiaddr, observer peer.ID) networ
}

func newHarness(t *testing.T) harness {
return newHarnessWithMa(t, ma.StringCast("/ip4/127.0.0.1/tcp/10086"))
}

func newHarnessWithMa(t *testing.T, listenAddr ma.Multiaddr) harness {
mn := mocknet.New()
sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
require.NoError(t, err)
h, err := mn.AddPeer(sk, ma.StringCast("/ip4/127.0.0.1/tcp/10086"))
h, err := mn.AddPeer(sk, listenAddr)
require.NoError(t, err)
oas, err := identify.NewObservedAddrManager(h)
require.NoError(t, err)
Expand Down Expand Up @@ -408,3 +412,25 @@ func TestEmitNATDeviceTypeCone(t *testing.T) {
t.Fatal("did not get Cone NAT event")
}
}

func TestObserveWebtransport(t *testing.T) {
listenAddr := ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28")
observedAddr := ma.StringCast("/ip4/1.2.3.4/udp/1231/quic-v1/webtransport")

harness := newHarnessWithMa(t, listenAddr)

pb1 := harness.add(ma.StringCast("/ip4/1.2.3.6/udp/1236/quic-v1/webtransport"))
pb2 := harness.add(ma.StringCast("/ip4/1.2.3.7/udp/1237/quic-v1/webtransport"))
pb3 := harness.add(ma.StringCast("/ip4/1.2.3.8/udp/1237/quic-v1/webtransport"))
pb4 := harness.add(ma.StringCast("/ip4/1.2.3.9/udp/1237/quic-v1/webtransport"))
pb5 := harness.add(ma.StringCast("/ip4/1.2.3.10/udp/1237/quic-v1/webtransport"))

harness.observe(observedAddr, pb1)
harness.observe(observedAddr, pb2)
harness.observe(observedAddr, pb3)
harness.observe(observedAddr, pb4)
harness.observe(observedAddr, pb5)

require.Equal(t, 1, len(harness.oas.Addrs()))
require.Equal(t, "/ip4/1.2.3.4/udp/1231/quic-v1/webtransport", harness.oas.Addrs()[0].String())
}
40 changes: 37 additions & 3 deletions p2p/transport/webtransport/multiaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ import (
"strconv"

ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
)

var webtransportMA = ma.StringCast("/quic-v1/webtransport")

var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC_V1), mafmt.Base(ma.P_WEBTRANSPORT))

func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) {
addr, err := manet.FromNetAddr(na)
if err != nil {
Expand Down Expand Up @@ -78,3 +75,40 @@ func addrComponentForCert(hash []byte) (ma.Multiaddr, error) {
}
return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
}

// IsWebtransportMultiaddrWithCerthash returns true if the given multiaddr is a
// well formed webtransport multiaddr. Returns the number of certhashes found.
func IsWebtransportMultiaddr(multiaddr ma.Multiaddr) (bool, int) {
const (
init = iota
foundUDP
foundQuicV1
foundWebTransport
)
state := init
certhashCount := 0

ma.ForEach(multiaddr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_QUIC_V1 && state == init {
state = foundUDP
}
if c.Protocol().Code == ma.P_QUIC_V1 && state == foundUDP {
state = foundQuicV1
}
if c.Protocol().Code == ma.P_WEBTRANSPORT && state == foundQuicV1 {
state = foundWebTransport
}
if c.Protocol().Code == ma.P_CERTHASH && state == foundWebTransport {
certhashCount++
}
return true
})
return state == foundWebTransport, certhashCount
}

// IsWebtransportMultiaddrWithCerthash returns true if the given multiaddr is a
// well formed webtransport multiaddr with a certificate hash.
func IsWebtransportMultiaddrWithCerthash(multiaddr ma.Multiaddr) bool {
ok, n := IsWebtransportMultiaddr(multiaddr)
return ok && n > 0
}
26 changes: 26 additions & 0 deletions p2p/transport/webtransport/multiaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,29 @@ func TestWebtransportResolve(t *testing.T) {
require.Error(t, err)
})
}

func TestIsWebtransportMultiaddrWithCerthash(t *testing.T) {
fooHash := encodeCertHash(t, []byte("foo"), multihash.SHA2_256, multibase.Base58BTC)
barHash := encodeCertHash(t, []byte("bar"), multihash.SHA2_256, multibase.Base58BTC)

testCases := []struct {
addr string
want bool
certhashCount int
}{
{addr: "/ip4/1.2.3.4/udp/60042/quic-v1/webtransport", want: true},
{addr: "/ip4/1.2.3.4/udp/60042/quic-v1/webtransport/certhash/" + fooHash, want: true, certhashCount: 1},
{addr: "/ip4/1.2.3.4/udp/60042/quic-v1/webtransport/certhash/" + fooHash + "/certhash/" + barHash, want: true, certhashCount: 2},
{addr: "/dns4/example.com/udp/60042/quic-v1/webtransport/certhash/" + fooHash, want: true, certhashCount: 1},
{addr: "/dns4/example.com/udp/60042/webrtc/certhash/" + fooHash, want: false},
}

for _, tc := range testCases {
t.Run(tc.addr, func(t *testing.T) {
got, n := IsWebtransportMultiaddr(ma.StringCast(tc.addr))
require.Equal(t, tc.want, got)
require.Equal(t, tc.certhashCount, n)
require.Equal(t, tc.want && n > 0, IsWebtransportMultiaddrWithCerthash(ma.StringCast(tc.addr)))
})
}
}
25 changes: 11 additions & 14 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,13 @@ func decodeCertHashesFromProtobuf(b [][]byte) ([]multihash.DecodedMultihash, err
}

func (t *transport) CanDial(addr ma.Multiaddr) bool {
var numHashes int
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
numHashes++
}
return true
})
// Remove the /certhash components from the multiaddr.
// If the multiaddr doesn't contain any certhashes, the node might have a CA-signed certificate.
for i := 0; i < numHashes; i++ {
addr, _ = ma.SplitLast(addr)
}
return webtransportMatcher.Matches(addr)
ok, _ := IsWebtransportMultiaddr(addr)
return ok
}

func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
if !webtransportMatcher.Matches(laddr) {
isWebTransport, _ := IsWebtransportMultiaddr(laddr)
if !isWebTransport {
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr)
}
if t.staticTLSConf == nil {
Expand Down Expand Up @@ -410,3 +400,10 @@ func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiad
}
return []ma.Multiaddr{beforeQuicMA.Encapsulate(quicComponent).Encapsulate(sniComponent).Encapsulate(afterQuicMA)}, nil
}

func (t *transport) AddCertHashes(m ma.Multiaddr) ma.Multiaddr {
if t.certManager == nil {
return m
}
return m.Encapsulate(t.certManager.AddrComponent())
}