Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Commit

Permalink
Merge pull request #83 from libp2p/reuse-windows
Browse files Browse the repository at this point in the history
make reuse work on Windows
  • Loading branch information
Stebalien authored Nov 15, 2019
2 parents 0e79de7 + 72b9c53 commit df66b10
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 99 deletions.
2 changes: 1 addition & 1 deletion libp2pquic_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var maxUnusedDurationOrig time.Duration
func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuseBase).runGarbageCollector")
}

var _ = BeforeEach(func() {
Expand Down
3 changes: 2 additions & 1 deletion netlink_other.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// +build !linux
// +build !windows

package libp2pquic

import "github.com/vishvananda/netlink/nl"

// nl.SupportedNlFamilies is the default netlink families used by the netlink package
// SupportedNlFamilies is the default netlink families used by the netlink package
var SupportedNlFamilies = nl.SupportedNlFamilies
68 changes: 9 additions & 59 deletions reuse.go → reuse_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import (
"net"
"sync"
"time"

"github.com/vishvananda/netlink"
)

// Constants. Defined as variables to simplify testing.
// Constant. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
Expand Down Expand Up @@ -48,34 +46,24 @@ func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool {
return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now)
}

type reuse struct {
type reuseBase struct {
mutex sync.Mutex

garbageCollectorRunning bool

handle *netlink.Handle // Only set on Linux. nil on other systems.

unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
// global contains connections that are listening on 0.0.0.0 / ::
global map[int]*reuseConn
}

func newReuse() (*reuse, error) {
// On non-Linux systems, this will return ErrNotImplemented.
handle, err := netlink.NewHandle(SupportedNlFamilies...)
if err == netlink.ErrNotImplemented {
handle = nil
} else if err != nil {
return nil, err
}
return &reuse{
func newReuseBase() reuseBase {
return reuseBase{
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
handle: handle,
}, nil
}
}

func (r *reuse) runGarbageCollector() {
func (r *reuseBase) runGarbageCollector() {
ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop()

Expand Down Expand Up @@ -114,52 +102,14 @@ func (r *reuse) runGarbageCollector() {
}

// must be called while holding the mutex
func (r *reuse) maybeStartGarbageCollector() {
func (r *reuseBase) maybeStartGarbageCollector() {
if !r.garbageCollectorRunning {
r.garbageCollectorRunning = true
go r.runGarbageCollector()
}
}

// Get the source IP that the kernel would use for dialing.
// This only works on Linux.
// On other systems, this returns an empty slice of IP addresses.
func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, error) {
if r.handle == nil {
return nil, nil
}

routes, err := r.handle.RouteGet(raddr.IP)
if err != nil {
return nil, err
}

ips := make([]net.IP, 0, len(routes))
for _, route := range routes {
ips = append(ips, route.Src)
}
return ips, nil
}

func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
ips, err := r.getSourceIPs(network, raddr)
if err != nil {
return nil, err
}

r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, ips)
if err != nil {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}

func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*reuseConn, error) {
func (r *reuseBase) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*reuseConn, error) {
for _, ip := range ips {
// We already have at least one suitable connection...
if conns, ok := r.unicast[ip.String()]; ok {
Expand Down Expand Up @@ -194,7 +144,7 @@ func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*r
return rconn, nil
}

func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
func (r *reuseBase) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
Expand Down
42 changes: 42 additions & 0 deletions reuse_linux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// +build linux

package libp2pquic

import (
"net"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("Reuse (on Linux)", func() {
var reuse *reuse

BeforeEach(func() {
var err error
reuse, err = newReuse()
Expect(err).ToNot(HaveOccurred())
})

Context("creating and reusing connections", func() {
AfterEach(func() { closeAllConns(reuse) })

It("reuses a connection it created for listening on a specific interface", func() {
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
Expect(err).ToNot(HaveOccurred())
ips, err := reuse.getSourceIPs("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(ips).ToNot(BeEmpty())
// listen
addr, err := net.ResolveUDPAddr("udp4", ips[0].String()+":0")
Expect(err).ToNot(HaveOccurred())
lconn, err := reuse.Listen("udp4", addr)
Expect(err).ToNot(HaveOccurred())
Expect(lconn.GetCount()).To(Equal(1))
// dial
conn, err := reuse.Dial("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(conn.GetCount()).To(Equal(2))
})
})
})
66 changes: 66 additions & 0 deletions reuse_not_win.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// +build !windows

package libp2pquic

import (
"net"

"github.com/vishvananda/netlink"
)

type reuse struct {
reuseBase

handle *netlink.Handle // Only set on Linux. nil on other systems.
}

func newReuse() (*reuse, error) {
handle, err := netlink.NewHandle(SupportedNlFamilies...)
if err == netlink.ErrNotImplemented {
handle = nil
} else if err != nil {
return nil, err
}
return &reuse{
reuseBase: newReuseBase(),
handle: handle,
}, nil
}

// Get the source IP that the kernel would use for dialing.
// This only works on Linux.
// On other systems, this returns an empty slice of IP addresses.
func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, error) {
if r.handle == nil {
return nil, nil
}

routes, err := r.handle.RouteGet(raddr.IP)
if err != nil {
return nil, err
}

ips := make([]net.IP, 0, len(routes))
for _, route := range routes {
ips = append(ips, route.Src)
}
return ips, nil
}

func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
ips, err := r.getSourceIPs(network, raddr)
if err != nil {
return nil, err
}

r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, ips)
if err != nil {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}
57 changes: 19 additions & 38 deletions reuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package libp2pquic

import (
"net"
"runtime"
"time"

. "github.com/onsi/ginkgo"
Expand All @@ -15,6 +14,24 @@ func (c *reuseConn) GetCount() int {
return c.refCount
}

func closeAllConns(reuse *reuse) {
reuse.mutex.Lock()
for _, conn := range reuse.global {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
for _, conns := range reuse.unicast {
for _, conn := range conns {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
}
reuse.mutex.Unlock()
Eventually(isGarbageCollectorRunning).Should(BeFalse())
}

var _ = Describe("Reuse", func() {
var reuse *reuse

Expand All @@ -25,23 +42,7 @@ var _ = Describe("Reuse", func() {
})

Context("creating and reusing connections", func() {
AfterEach(func() {
reuse.mutex.Lock()
for _, conn := range reuse.global {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
for _, conns := range reuse.unicast {
for _, conn := range conns {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
}
reuse.mutex.Unlock()
Eventually(isGarbageCollectorRunning).Should(BeFalse())
})
AfterEach(func() { closeAllConns(reuse) })

It("creates a new global connection when listening on 0.0.0.0", func() {
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
Expand Down Expand Up @@ -84,26 +85,6 @@ var _ = Describe("Reuse", func() {
Expect(err).ToNot(HaveOccurred())
Expect(conn.GetCount()).To(Equal(2))
})

if runtime.GOOS == "linux" {
It("reuses a connection it created for listening on a specific interface", func() {
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
Expect(err).ToNot(HaveOccurred())
ips, err := reuse.getSourceIPs("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(ips).ToNot(BeEmpty())
// listen
addr, err := net.ResolveUDPAddr("udp4", ips[0].String()+":0")
Expect(err).ToNot(HaveOccurred())
lconn, err := reuse.Listen("udp4", addr)
Expect(err).ToNot(HaveOccurred())
Expect(lconn.GetCount()).To(Equal(1))
// dial
conn, err := reuse.Dial("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(conn.GetCount()).To(Equal(2))
})
}
})

Context("garbage-collecting connections", func() {
Expand Down
26 changes: 26 additions & 0 deletions reuse_win.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// +build windows

package libp2pquic

import "net"

type reuse struct {
reuseBase
}

func newReuse() (*reuse, error) {
return &reuse{reuseBase: newReuseBase()}, nil
}

func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, nil)
if err != nil {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}

0 comments on commit df66b10

Please sign in to comment.