Skip to content

Commit

Permalink
add tests for present behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Dec 30, 2024
1 parent fd76100 commit 5fcce02
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 78 deletions.
56 changes: 29 additions & 27 deletions p2p/host/basic/address_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/record"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
Expand All @@ -19,8 +18,6 @@ import (
manet "github.com/multiformats/go-multiaddr/net"
)

type peerRecordFunc func([]ma.Multiaddr) (*record.Envelope, error)

type observedAddrsService interface {
OwnObservedAddrs() []ma.Multiaddr
ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr
Expand All @@ -34,12 +31,13 @@ type addressService struct {
addrsChangeChan chan struct{}
addrsUpdated chan struct{}
autoRelayAddrsSub event.Subscription
autoRelayAddrs func() []ma.Multiaddr
reachability func() network.Reachability
ifaceAddrs *interfaceAddrsCache
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
// There are wrapped in to functions for mocking
autoRelayAddrs func() []ma.Multiaddr
reachability func() network.Reachability
ifaceAddrs *interfaceAddrsCache
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}

func NewAddressService(h *BasicHost, natmgr func(network.Network) NATManager,
Expand Down Expand Up @@ -177,19 +175,22 @@ func (a *addressService) AllAddrs() []ma.Multiaddr {

finalAddrs := make([]ma.Multiaddr, 0, 8)
finalAddrs = a.appendInterfaceAddrs(finalAddrs, listenAddrs)

// use nat mappings if we have them
finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs)
finalAddrs = ma.Unique(finalAddrs)

// Remove /p2p-circuit addresses from the list.
// The p2p-circuit transport listener reports its address as just /p2p-circuit
// This is useless for dialing. Users need to manage their circuit addresses themselves,
// Remove "/p2p-circuit" addresses from the list.
// The p2p-circuit listener reports its address as just /p2p-circuit. This is
// useless for dialing. Users need to manage their circuit addresses themselves,
// or use AutoRelay.
finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool {
return a.Equal(p2pCircuitAddr)
})

// Remove any unspecified address from the list
finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool {
return manet.IsIPUnspecified(a)
})

// Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered
// using identify.
finalAddrs = a.addCertHashes(finalAddrs)
Expand All @@ -208,19 +209,23 @@ func (a *addressService) appendInterfaceAddrs(result []ma.Multiaddr, listenAddrs
return result
}

// appendNATAddrs appends the NAT-ed addrs for the listenAddrs. For unspecified listen addrs it appends the
// public address for all the interfaces.
// This automatically infers addresses from other transport addresses. For example, it'll infer a webtransport
// address from a quic observed address.
//
// TODO: Merge the natmgr and identify.ObservedAddrManager in to one NatMapper module.
func (a *addressService) appendNATAddrs(result []ma.Multiaddr, listenAddrs []ma.Multiaddr) []ma.Multiaddr {
ifaceAddrs := a.ifaceAddrs.All()
// use nat mappings if we have them
if a.natmgr != nil && a.natmgr.HasDiscoveredNAT() {
// we have a NAT device
for _, listen := range listenAddrs {
extMaddr := a.natmgr.GetMapping(listen)
result = appendNATAddrsForListenAddrs(result, listen, extMaddr, a.observedAddrsService.ObservedAddrsFor, ifaceAddrs)
}
} else {
if a.natmgr == nil || !a.natmgr.HasDiscoveredNAT() {
if a.observedAddrsService != nil {
result = append(result, a.observedAddrsService.OwnObservedAddrs()...)
}
return result
}
for _, listen := range listenAddrs {
extMaddr := a.natmgr.GetMapping(listen)
result = appendNATAddrsForListenAddrs(result, listen, extMaddr, a.observedAddrsService.ObservedAddrsFor, ifaceAddrs)
}
return result
}
Expand All @@ -241,11 +246,6 @@ func (a *addressService) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr {
return addrs
}

// Copy addrs slice since we'll be modifying it.
addrsOld := addrs
addrs = make([]ma.Multiaddr, len(addrsOld))
copy(addrs, addrsOld)

for i, addr := range addrs {
wtOK, wtN := libp2pwebtransport.IsWebtransportMultiaddr(addr)
webrtcOK, webrtcN := libp2pwebrtc.IsWebRTCDirectMultiaddr(addr)
Expand Down Expand Up @@ -411,6 +411,8 @@ func (i *interfaceAddrsCache) updateUnlocked() {
}
}

// getAllPossibleLocalAddrs gives all the possible address returned for `conn.LocalAddr` correspoinding
// to the `listenAddr`
func getAllPossibleLocalAddrs(listenAddr ma.Multiaddr, ifaceAddrs []ma.Multiaddr) []ma.Multiaddr {
// If the nat mapping fails, use the observed addrs
resolved, err := manet.ResolveUnspecifiedAddress(listenAddr, ifaceAddrs)
Expand Down
174 changes: 174 additions & 0 deletions p2p/host/basic/address_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package basichost

import (
"testing"
"time"

"github.com/libp2p/go-libp2p/core/network"
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -96,3 +99,174 @@ func TestAppendNATAddrs(t *testing.T) {
})
}
}

type mockNatManager struct {
GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr
HasDiscoveredNATFunc func() bool
}

func (m *mockNatManager) Close() error {
return nil
}

func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr {
return m.GetMappingFunc(addr)
}

func (m *mockNatManager) HasDiscoveredNAT() bool {
return m.HasDiscoveredNATFunc()
}

var _ NATManager = &mockNatManager{}

type mockObservedAddrs struct {
OwnObservedAddrsFunc func() []ma.Multiaddr
ObservedAddrsForFunc func(ma.Multiaddr) []ma.Multiaddr
}

func (m *mockObservedAddrs) OwnObservedAddrs() []ma.Multiaddr {
return m.OwnObservedAddrsFunc()
}

func (m *mockObservedAddrs) ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr {
return m.ObservedAddrsForFunc(local)
}

func TestAddressService(t *testing.T) {
getAddrService := func() *addressService {
h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{DisableIdentifyAddressDiscovery: true})
require.NoError(t, err)
t.Cleanup(func() { h.Close() })

as := h.addressService
return as
}

t.Run("NAT Address", func(t *testing.T) {
as := getAddrService()
as.natmgr = &mockNatManager{
HasDiscoveredNATFunc: func() bool { return true },
GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr {
if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil {
return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")
}
return nil
},
}
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"))
})

t.Run("NAT And Observed Address", func(t *testing.T) {
as := getAddrService()
as.natmgr = &mockNatManager{
HasDiscoveredNATFunc: func() bool { return true },
GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr {
if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil {
return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")
}
return nil
},
}
as.observedAddrsService = &mockObservedAddrs{
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr {
if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil {
return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")}
}
return nil
},
}
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"))
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1"))
})
t.Run("Only Observed Address", func(t *testing.T) {
as := getAddrService()
as.natmgr = nil
as.observedAddrsService = &mockObservedAddrs{
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr {
if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil {
return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")}
}
return nil
},
OwnObservedAddrsFunc: func() []ma.Multiaddr {
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
},
}
require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1"))
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
})
t.Run("Public Addrs Removed When Private", func(t *testing.T) {
as := getAddrService()
as.natmgr = nil
as.observedAddrsService = &mockObservedAddrs{
OwnObservedAddrsFunc: func() []ma.Multiaddr {
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
},
}
as.reachability = func() network.Reachability {
return network.ReachabilityPrivate
}
relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit")
as.autoRelayAddrs = func() []ma.Multiaddr {
return []ma.Multiaddr{relayAddr}
}
require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
require.Contains(t, as.Addrs(), relayAddr)
require.Contains(t, as.AllAddrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
})

t.Run("AddressFactory gets relay addresses", func(t *testing.T) {
as := getAddrService()
as.natmgr = nil
as.observedAddrsService = &mockObservedAddrs{
OwnObservedAddrsFunc: func() []ma.Multiaddr {
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
},
}
as.reachability = func() network.Reachability {
return network.ReachabilityPrivate
}
relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit")
as.autoRelayAddrs = func() []ma.Multiaddr {
return []ma.Multiaddr{relayAddr}
}
as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr {
for _, a := range addrs {
if a.Equal(relayAddr) {
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
}
}
return nil
}
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
require.NotContains(t, as.Addrs(), relayAddr)
})

t.Run("updates addresses on signaling", func(t *testing.T) {
as := getAddrService()
as.natmgr = nil
updateChan := make(chan struct{})
a1 := ma.StringCast("/ip4/1.1.1.1/udp/1/quic-v1")
a2 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr {
select {
case <-updateChan:
return []ma.Multiaddr{a2}
default:
return []ma.Multiaddr{a1}
}
}
as.Start()
require.Contains(t, as.Addrs(), a1)
require.NotContains(t, as.Addrs(), a2)
close(updateChan)
as.SignalAddressChange()
select {
case <-as.AddrsUpdated():
require.Contains(t, as.Addrs(), a2)
require.NotContains(t, as.Addrs(), a1)
case <-time.After(2 * time.Second):
t.Fatal("expected addrs to be updated")
}
})
}
Loading

0 comments on commit 5fcce02

Please sign in to comment.