Skip to content

Commit

Permalink
Check allowed candidates in Dial/Listen methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Jun 14, 2024
1 parent e80df55 commit 3cee7b3
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 57 deletions.
6 changes: 3 additions & 3 deletions client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(nil)
newNet, err := stdnet.NewNet(nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -336,7 +336,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)

t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
newNet, err := stdnet.NewNet(nil, nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
return
Expand Down Expand Up @@ -794,7 +794,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)

t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
newNet, err := stdnet.NewNet(nil, nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,8 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
e.syncMsgMux.Unlock()

err := conn.Open(e.ctx)
routes := e.GetClientRoutes()
err := conn.Open(e.ctx, routes)
if err != nil {
log.Debugf("connection to peer %s failed: %v", peerKey, err)
var connectionClosedError *peer.ConnectionClosedError
Expand Down
2 changes: 1 addition & 1 deletion client/internal/engine_stdnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import (
)

func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.config.IFaceBlackList)
return stdnet.NewNet(e.config.IFaceBlackList, nil)
}
39 changes: 4 additions & 35 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@ import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"time"

"github.com/davecgh/go-spew/spew"
"github.com/pion/ice/v3"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
Expand Down Expand Up @@ -171,14 +168,14 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
}, nil
}

func (conn *Conn) reCreateAgent() error {
func (conn *Conn) reCreateAgent(routes route.HAMap) error {
conn.mu.Lock()
defer conn.mu.Unlock()

failedTimeout := 6 * time.Second

var err error
transportNet, err := conn.newStdNet()
transportNet, err := conn.newStdNet(routes)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
Expand Down Expand Up @@ -255,7 +252,7 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
// Open opens connection to the remote peer starting ICE candidate gathering process.
// Blocks until connection has been closed or connection timeout.
// ConnStatus will be set accordingly
func (conn *Conn) Open(ctx context.Context) error {
func (conn *Conn) Open(ctx context.Context, routes route.HAMap) error {
log.Debugf("trying to connect to peer %s", conn.config.Key)

peerState := State{
Expand All @@ -278,7 +275,7 @@ func (conn *Conn) Open(ctx context.Context) error {
}
}()

err = conn.reCreateAgent()
err = conn.reCreateAgent(routes)
if err != nil {
return err
}
Expand Down Expand Up @@ -764,10 +761,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
return
}

if candidateViaRoutes(candidate, haRoutes) {
return
}

err := conn.agent.AddRemoteCandidate(candidate)
if err != nil {
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
Expand Down Expand Up @@ -798,27 +791,3 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
RelPort: relatedAdd.Port,
})
}

func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var vpnRoutes []netip.Prefix
log.Tracef("ICE: Client routes: %s", spew.Sdump(clientRoutes))
log.Tracef("ICE: Candidate: %v", candidate)
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}

addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}

if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix)
return true
}

return false
}
5 changes: 3 additions & 2 deletions client/internal/peer/stdnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ package peer

import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)

func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(conn.config.InterfaceBlackList)
func (conn *Conn) newStdNet(routes route.HAMap) (*stdnet.Net, error) {
return stdnet.NewNet(conn.config.InterfaceBlackList, routes)
}
7 changes: 5 additions & 2 deletions client/internal/peer/stdnet_android.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package peer

import "github.com/netbirdio/netbird/client/internal/stdnet"
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)

func (conn *Conn) newStdNet() (*stdnet.Net, error) {
func (conn *Conn) newStdNet(haMap route.HAMap) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
}
4 changes: 2 additions & 2 deletions client/internal/relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
}()

net, err := stdnet.NewNet(nil)
net, err := stdnet.NewNet(nil, nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
Expand Down Expand Up @@ -119,7 +119,7 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
}()

net, err := stdnet.NewNet(nil)
net, err := stdnet.NewNet(nil, nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
Expand Down
18 changes: 18 additions & 0 deletions client/internal/stdnet/dialer.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
package stdnet

import (
"fmt"
"net"

"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"

nbnet "github.com/netbirdio/netbird/util/net"
)

// Dial connects to the address on the named network.
func (n *Net) Dial(network, address string) (net.Conn, error) {
log.Tracef("ICE: Checking if address %s is routed", address)
isRouted, prefix, err := addrViaRoutes(address, n.routes)

if err != nil {
log.Errorf("Failed to check if address %s is routed: %v", address, err)
} else if isRouted {
return nil, fmt.Errorf("[Dial] IP %s is part of routed network %s, refusing to dial", address, prefix)
}
return nbnet.NewDialer().Dial(network, address)
}

// DialUDP connects to the address on the named UDP network.
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
log.Tracef("ICE: Checking if address %s is routed", raddr)
isRouted, prefix, err := addrViaRoutes(raddr.IP.String(), n.routes)

if err != nil {
log.Errorf("Failed to check if address %s is routed: %v", raddr, err)
} else if isRouted {
return nil, fmt.Errorf("[Dial] IP %s is part of routed network %s, refusing to dial", raddr, prefix)
}
return nbnet.DialUDP(network, laddr, raddr)
}

Expand Down
67 changes: 65 additions & 2 deletions client/internal/stdnet/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,82 @@ package stdnet

import (
"context"
"fmt"
"net"
"sync"

"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
)

// ListenPacket listens for incoming packets on the given network and address.
func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) {
return nbnet.NewListener().ListenPacket(context.Background(), network, address)
listener := nbnet.NewListener()
pc, err := listener.ListenPacket(context.Background(), network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
}
return &PacketConn{PacketConn: pc, routes: n.routes}, nil
}

// ListenUDP acts like ListenPacket for UDP networks.
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) {
return nbnet.ListenUDP(network, locAddr)
udpConn, err := nbnet.ListenUDP(network, locAddr)
if err != nil {
return nil, fmt.Errorf("listen udp: %w", err)
}

return &UDPConn{UDPConn: udpConn, routes: n.routes}, nil
}

type PacketConn struct {
net.PacketConn
routes route.HAMap
seenAddrs sync.Map
}

func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
isRouted, err := isRouted(addr, &c.seenAddrs, c.routes)
if err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else if isRouted {
return 0, fmt.Errorf("[PacketConn] IP %s is part of routed network, refusing to write", addr)
}

return c.PacketConn.WriteTo(b, addr)
}

type UDPConn struct {
transport.UDPConn
routes route.HAMap
seenAddrs sync.Map
}

func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
isRouted, err := isRouted(addr, &c.seenAddrs, c.routes)
if err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else if isRouted {
return 0, fmt.Errorf("[UDPConn] IP %s is part of routed network, refusing to write", addr)
}

return c.UDPConn.WriteTo(b, addr)
}

func isRouted(addr net.Addr, seenAddrs *sync.Map, routes route.HAMap) (bool, error) {
log.Tracef("ICE: Checking if address %s is routed", addr.String())
if isRouted, ok := seenAddrs.Load(addr.String()); ok {
return isRouted.(bool), nil
}

isRouted, _, err := addrViaRoutes(addr.String(), routes)
if err != nil {
return false, err
}

seenAddrs.Store(addr.String(), isRouted)
return isRouted, nil
}
41 changes: 40 additions & 1 deletion client/internal/stdnet/stdnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@ package stdnet

import (
"fmt"
"net"
"net/netip"

"github.com/davecgh/go-spew/spew"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/route"
)

// Net is an implementation of the net.Net interface
Expand All @@ -18,6 +25,7 @@ type Net struct {
iFaceDiscover iFaceDiscover
// interfaceFilter should return true if the given interfaceName is allowed
interfaceFilter func(interfaceName string) bool
routes route.HAMap
}

// NewNetWithDiscover creates a new StdNet instance.
Expand All @@ -30,10 +38,11 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
}

// NewNet creates a new StdNet instance.
func NewNet(disallowList []string) (*Net, error) {
func NewNet(disallowList []string, routes route.HAMap) (*Net, error) {
n := &Net{
iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList),
routes: routes,
}
return n, n.UpdateInterfaces()
}
Expand Down Expand Up @@ -95,3 +104,33 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
}
return result
}

func addrViaRoutes(address string, routes route.HAMap) (bool, netip.Prefix, error) {
log.Tracef("ICE: Client routes: %s", spew.Sdump(routes))
log.Tracef("ICE: addr %v", address)

// TODO: resolve domain names

addrStr, _, err := net.SplitHostPort(address)
if err != nil {
return false, netip.Prefix{}, fmt.Errorf("split host and port: %w", err)
}

ipAddr, err := netip.ParseAddr(addrStr)
if err != nil {
return false, netip.Prefix{}, fmt.Errorf("parse address: %w", err)
}

var vpnRoutes []netip.Prefix
for _, routes := range routes {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}

if isVpn, prefix := systemops.IsAddrRouted(ipAddr, vpnRoutes); isVpn {
return true, prefix, nil
}

return false, netip.Prefix{}, nil
}
Loading

0 comments on commit 3cee7b3

Please sign in to comment.