From 29918dd695564760e09f9c83bf82e47ba351c780 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 14 Jun 2024 16:25:09 +0900 Subject: [PATCH] Check allowed candidates in Dial/Listen methods --- client/internal/dns/server_test.go | 6 +-- client/internal/engine.go | 3 +- client/internal/engine_stdnet.go | 2 +- client/internal/peer/conn.go | 39 ++------------- client/internal/peer/stdnet.go | 5 +- client/internal/peer/stdnet_android.go | 7 ++- client/internal/relay/relay.go | 4 +- client/internal/stdnet/dialer.go | 18 +++++++ client/internal/stdnet/listener.go | 67 +++++++++++++++++++++++++- client/internal/stdnet/stdnet.go | 41 +++++++++++++++- util/net/listener_nonios.go | 16 +++--- 11 files changed, 151 insertions(+), 57 deletions(-) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 3709c32ce48..d69f511a303 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -261,7 +261,7 @@ func TestUpdateDNSServer(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { privKey, _ := wgtypes.GenerateKey() - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet(nil, nil) if err != nil { t.Fatal(err) } @@ -336,7 +336,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet(nil, nil) if err != nil { t.Errorf("create stdnet: %v", err) return @@ -794,7 +794,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet(nil, nil) if err != nil { t.Fatalf("create stdnet: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index b62af05582d..3624da7c976 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -932,7 +932,8 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) { conn.UpdateStunTurn(append(e.STUNs, e.TURNs...)) e.syncMsgMux.Unlock() - err := conn.Open(e.ctx) + routes := e.GetClientRoutes() + err := conn.Open(e.ctx, routes) if err != nil { log.Debugf("connection to peer %s failed: %v", peerKey, err) var connectionClosedError *peer.ConnectionClosedError diff --git a/client/internal/engine_stdnet.go b/client/internal/engine_stdnet.go index 9e171b0b24b..02913593b32 100644 --- a/client/internal/engine_stdnet.go +++ b/client/internal/engine_stdnet.go @@ -7,5 +7,5 @@ import ( ) func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(e.config.IFaceBlackList) + return stdnet.NewNet(e.config.IFaceBlackList, nil) } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 213bbed0bcb..c02fe5950b5 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -4,19 +4,16 @@ import ( "context" "fmt" "net" - "net/netip" "runtime" "strings" "sync" "time" - "github.com/davecgh/go-spew/spew" "github.com/pion/ice/v3" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/iface" @@ -171,14 +168,14 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy. }, nil } -func (conn *Conn) reCreateAgent() error { +func (conn *Conn) reCreateAgent(routes route.HAMap) error { conn.mu.Lock() defer conn.mu.Unlock() failedTimeout := 6 * time.Second var err error - transportNet, err := conn.newStdNet() + transportNet, err := conn.newStdNet(routes) if err != nil { log.Errorf("failed to create pion's stdnet: %s", err) } @@ -255,7 +252,7 @@ func (conn *Conn) candidateTypes() []ice.CandidateType { // Open opens connection to the remote peer starting ICE candidate gathering process. // Blocks until connection has been closed or connection timeout. // ConnStatus will be set accordingly -func (conn *Conn) Open(ctx context.Context) error { +func (conn *Conn) Open(ctx context.Context, routes route.HAMap) error { log.Debugf("trying to connect to peer %s", conn.config.Key) peerState := State{ @@ -278,7 +275,7 @@ func (conn *Conn) Open(ctx context.Context) error { } }() - err = conn.reCreateAgent() + err = conn.reCreateAgent(routes) if err != nil { return err } @@ -764,10 +761,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa return } - if candidateViaRoutes(candidate, haRoutes) { - return - } - err := conn.agent.AddRemoteCandidate(candidate) if err != nil { log.Errorf("error while handling remote candidate from peer %s", conn.config.Key) @@ -798,27 +791,3 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive RelPort: relatedAdd.Port, }) } - -func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { - var vpnRoutes []netip.Prefix - log.Tracef("ICE: Client routes: %s", spew.Sdump(clientRoutes)) - log.Tracef("ICE: Candidate: %v", candidate) - for _, routes := range clientRoutes { - if len(routes) > 0 && routes[0] != nil { - vpnRoutes = append(vpnRoutes, routes[0].Network) - } - } - - addr, err := netip.ParseAddr(candidate.Address()) - if err != nil { - log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err) - return false - } - - if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn { - log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix) - return true - } - - return false -} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go index 13f5886f5f2..c59937d09cf 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/stdnet.go @@ -4,8 +4,9 @@ package peer import ( "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/route" ) -func (conn *Conn) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(conn.config.InterfaceBlackList) +func (conn *Conn) newStdNet(routes route.HAMap) (*stdnet.Net, error) { + return stdnet.NewNet(conn.config.InterfaceBlackList, routes) } diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go index 8a245437138..c002387db52 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/stdnet_android.go @@ -1,7 +1,10 @@ package peer -import "github.com/netbirdio/netbird/client/internal/stdnet" +import ( + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/route" +) -func (conn *Conn) newStdNet() (*stdnet.Net, error) { +func (conn *Conn) newStdNet(haMap route.HAMap) (*stdnet.Net, error) { return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList) } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 4542a37febb..0ca791b3bc6 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -30,7 +30,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(nil, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return @@ -119,7 +119,7 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(nil, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go index e80adb42b20..96c96c6d851 100644 --- a/client/internal/stdnet/dialer.go +++ b/client/internal/stdnet/dialer.go @@ -1,20 +1,38 @@ package stdnet import ( + "fmt" "net" "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" nbnet "github.com/netbirdio/netbird/util/net" ) // Dial connects to the address on the named network. func (n *Net) Dial(network, address string) (net.Conn, error) { + log.Tracef("ICE: Checking if address %s is routed", address) + isRouted, prefix, err := addrViaRoutes(address, n.routes) + + if err != nil { + log.Errorf("Failed to check if address %s is routed: %v", address, err) + } else if isRouted { + return nil, fmt.Errorf("[Dial] IP %s is part of routed network %s, refusing to dial", address, prefix) + } return nbnet.NewDialer().Dial(network, address) } // DialUDP connects to the address on the named UDP network. func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + log.Tracef("ICE: Checking if address %s is routed", raddr) + isRouted, prefix, err := addrViaRoutes(raddr.IP.String(), n.routes) + + if err != nil { + log.Errorf("Failed to check if address %s is routed: %v", raddr, err) + } else if isRouted { + return nil, fmt.Errorf("[Dial] IP %s is part of routed network %s, refusing to dial", raddr, prefix) + } return nbnet.DialUDP(network, laddr, raddr) } diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go index 9ce0a555610..77d3824fda0 100644 --- a/client/internal/stdnet/listener.go +++ b/client/internal/stdnet/listener.go @@ -2,19 +2,82 @@ package stdnet import ( "context" + "fmt" "net" + "sync" "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" ) // ListenPacket listens for incoming packets on the given network and address. func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { - return nbnet.NewListener().ListenPacket(context.Background(), network, address) + listener := nbnet.NewListener() + pc, err := listener.ListenPacket(context.Background(), network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + return &PacketConn{PacketConn: pc, routes: n.routes}, nil } // ListenUDP acts like ListenPacket for UDP networks. func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { - return nbnet.ListenUDP(network, locAddr) + udpConn, err := nbnet.ListenUDP(network, locAddr) + if err != nil { + return nil, fmt.Errorf("listen udp: %w", err) + } + + return &UDPConn{UDPConn: udpConn, routes: n.routes}, nil +} + +type PacketConn struct { + net.PacketConn + routes route.HAMap + seenAddrs sync.Map +} + +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + isRouted, err := isRouted(addr, &c.seenAddrs, c.routes) + if err != nil { + log.Errorf("Failed to check if address %s is routed: %v", addr, err) + } else if isRouted { + return 0, fmt.Errorf("[PacketConn] IP %s is part of routed network, refusing to write", addr) + } + + return c.PacketConn.WriteTo(b, addr) +} + +type UDPConn struct { + transport.UDPConn + routes route.HAMap + seenAddrs sync.Map +} + +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + isRouted, err := isRouted(addr, &c.seenAddrs, c.routes) + if err != nil { + log.Errorf("Failed to check if address %s is routed: %v", addr, err) + } else if isRouted { + return 0, fmt.Errorf("[UDPConn] IP %s is part of routed network, refusing to write", addr) + } + + return c.UDPConn.WriteTo(b, addr) +} + +func isRouted(addr net.Addr, seenAddrs *sync.Map, routes route.HAMap) (bool, error) { + log.Tracef("ICE: Checking if address %s is routed", addr.String()) + if isRouted, ok := seenAddrs.Load(addr.String()); ok { + return isRouted.(bool), nil + } + + isRouted, _, err := addrViaRoutes(addr.String(), routes) + if err != nil { + return false, err + } + + seenAddrs.Store(addr.String(), isRouted) + return isRouted, nil } diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 2e87475a53c..063125a630b 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -5,9 +5,16 @@ package stdnet import ( "fmt" + "net" + "net/netip" + "github.com/davecgh/go-spew/spew" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/route" ) // Net is an implementation of the net.Net interface @@ -18,6 +25,7 @@ type Net struct { iFaceDiscover iFaceDiscover // interfaceFilter should return true if the given interfaceName is allowed interfaceFilter func(interfaceName string) bool + routes route.HAMap } // NewNetWithDiscover creates a new StdNet instance. @@ -30,10 +38,11 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri } // NewNet creates a new StdNet instance. -func NewNet(disallowList []string) (*Net, error) { +func NewNet(disallowList []string, routes route.HAMap) (*Net, error) { n := &Net{ iFaceDiscover: pionDiscover{}, interfaceFilter: InterfaceFilter(disallowList), + routes: routes, } return n, n.UpdateInterfaces() } @@ -95,3 +104,33 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I } return result } + +func addrViaRoutes(address string, routes route.HAMap) (bool, netip.Prefix, error) { + log.Tracef("ICE: Client routes: %s", spew.Sdump(routes)) + log.Tracef("ICE: addr %v", address) + + // TODO: resolve domain names + + addrStr, _, err := net.SplitHostPort(address) + if err != nil { + return false, netip.Prefix{}, fmt.Errorf("split host and port: %w", err) + } + + ipAddr, err := netip.ParseAddr(addrStr) + if err != nil { + return false, netip.Prefix{}, fmt.Errorf("parse address: %w", err) + } + + var vpnRoutes []netip.Prefix + for _, routes := range routes { + if len(routes) > 0 && routes[0] != nil { + vpnRoutes = append(vpnRoutes, routes[0].Network) + } + } + + if isVpn, prefix := systemops.IsAddrRouted(ipAddr, vpnRoutes); isVpn { + return true, prefix, nil + } + + return false, netip.Prefix{}, nil +} diff --git a/util/net/listener_nonios.go b/util/net/listener_nonios.go index ae4be34949b..060c80bbf74 100644 --- a/util/net/listener_nonios.go +++ b/util/net/listener_nonios.go @@ -62,25 +62,25 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri return nil, fmt.Errorf("listen packet: %w", err) } connID := GenerateConnID() - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil + return &PacketConn{PacketConn: pc, ID: connID}, nil } // PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. type PacketConn struct { net.PacketConn ID ConnectionID - seenAddrs *sync.Map + seenAddrs sync.Map } // WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) + callWriteHooks(c.ID, &c.seenAddrs, b, addr) return c.PacketConn.WriteTo(b, addr) } // Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} + c.seenAddrs = sync.Map{} return closeConn(c.ID, c.PacketConn) } @@ -88,18 +88,18 @@ func (c *PacketConn) Close() error { type UDPConn struct { *net.UDPConn ID ConnectionID - seenAddrs *sync.Map + seenAddrs sync.Map } // WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) + callWriteHooks(c.ID, &c.seenAddrs, b, addr) return c.UDPConn.WriteTo(b, addr) } // Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} + c.seenAddrs = sync.Map{} return closeConn(c.ID, c.UDPConn) } @@ -168,5 +168,5 @@ func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) } - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID}, nil }