Skip to content

Commit

Permalink
Tests for memcache hosts via SRV (envoyproxy#298)
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Marsh <pete.d.marsh@gmail.com>
  • Loading branch information
petedmarsh authored and timcovar committed Jan 16, 2024
1 parent c96394f commit 8844fe8
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 12 deletions.
16 changes: 8 additions & 8 deletions src/memcached/cache_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,13 @@ func (this *rateLimitMemcacheImpl) Flush() {
this.waitGroup.Wait()
}

func refreshServersPeriodically(serverList *memcache.ServerList, srv string, d time.Duration, finish <-chan struct{}) {
func refreshServersPeriodically(serverList *memcache.ServerList, srv string, d time.Duration, resolver srv.SrvResolver, finish <-chan struct{}) {
t := time.NewTicker(d)
defer t.Stop()
for {
select {
case <-t.C:
err := refreshServers(serverList, srv)
err := refreshServers(serverList, srv, resolver)
if err != nil {
logger.Warn("failed to refresh memcahce hosts")
} else {
Expand All @@ -195,8 +195,8 @@ func refreshServersPeriodically(serverList *memcache.ServerList, srv string, d t
}
}

func refreshServers(serverList *memcache.ServerList, srv_ string) error {
servers, err := srv.ServerStringsFromSrv(srv_)
func refreshServers(serverList *memcache.ServerList, srv string, resolver srv.SrvResolver) error {
servers, err := resolver.ServerStringsFromSrv(srv)
if err != nil {
return err
}
Expand All @@ -207,9 +207,9 @@ func refreshServers(serverList *memcache.ServerList, srv_ string) error {
return nil
}

func newMemcachedFromSrv(srv_ string, d time.Duration) Client {
func newMemcachedFromSrv(srv string, d time.Duration, resolver srv.SrvResolver) Client {
serverList := new(memcache.ServerList)
err := refreshServers(serverList, srv_)
err := refreshServers(serverList, srv, resolver)
if err != nil {
errorText := "Unable to fetch servers from SRV"
logger.Errorf(errorText)
Expand All @@ -219,7 +219,7 @@ func newMemcachedFromSrv(srv_ string, d time.Duration) Client {
if d > 0 {
logger.Infof("refreshing memcache hosts every: %v milliseconds", d.Milliseconds())
finish := make(chan struct{})
go refreshServersPeriodically(serverList, srv_, d, finish)
go refreshServersPeriodically(serverList, srv, d, resolver, finish)
} else {
logger.Debugf("not periodically refreshing memcached hosts")
}
Expand All @@ -233,7 +233,7 @@ func newMemcacheFromSettings(s settings.Settings) Client {
}
if s.MemcacheSrv != "" {
logger.Debugf("Using MEMCACHE_SRV: %v", s.MemcacheSrv)
return newMemcachedFromSrv(s.MemcacheSrv, s.MemcacheSrvRefresh)
return newMemcachedFromSrv(s.MemcacheSrv, s.MemcacheSrvRefresh, new(srv.DnsSrvResolver))
}
logger.Debugf("Usng MEMCACHE_HOST_PORT:: %v", s.MemcacheHostPort)
client := memcache.New(s.MemcacheHostPort...)
Expand Down
94 changes: 94 additions & 0 deletions src/memcached/cache_impl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package memcached

import (
"errors"
"net"
"testing"

"github.com/bradfitz/gomemcache/memcache"
"github.com/stretchr/testify/assert"

"github.com/golang/mock/gomock"

mock_srv "github.com/envoyproxy/ratelimit/test/mocks/srv"
)

func TestRefreshServersSetsServersOnEmptyServerList(t *testing.T) {
assert := assert.New(t)

mockSrv := "_memcache._tcp.example.org"
mockMemcacheHostPort := "127.0.0.1:11211"
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockSrvResolver := mock_srv.NewMockSrvResolver(ctrl)

mockSrvResolver.EXPECT().ServerStringsFromSrv(gomock.Eq(mockSrv)).Return([]string{mockMemcacheHostPort}, nil)

serverList := new(memcache.ServerList)

refreshServers(serverList, mockSrv, mockSrvResolver)

actualMemcacheHosts := []string{}

serverList.Each(func(addr net.Addr) error {
actualMemcacheHosts = append(actualMemcacheHosts, addr.String())
return nil
})

assert.Equal([]string{mockMemcacheHostPort}, actualMemcacheHosts)
}

func TestRefreshServersOverridesServersOnNonEmptyServerList(t *testing.T) {
assert := assert.New(t)

mockSrv := "_memcache._tcp.example.org"
mockMemcacheHostPort := "127.0.0.1:11211"
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockSrvResolver := mock_srv.NewMockSrvResolver(ctrl)

mockSrvResolver.EXPECT().ServerStringsFromSrv(gomock.Eq(mockSrv)).Return([]string{mockMemcacheHostPort}, nil)

serverList := new(memcache.ServerList)
serverList.SetServers("127.0.0.2:11211", "127.0.0.3:11211")

refreshServers(serverList, mockSrv, mockSrvResolver)

actualMemcacheHosts := []string{}

serverList.Each(func(addr net.Addr) error {
actualMemcacheHosts = append(actualMemcacheHosts, addr.String())
return nil
})

assert.Equal([]string{mockMemcacheHostPort}, actualMemcacheHosts)
}

func TestRefreshServerSetsServersDoesNotChangeAnythingIfThereIsAnError(t *testing.T) {
assert := assert.New(t)

mockSrv := "_memcache._tcp.example.org"
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockSrvResolver := mock_srv.NewMockSrvResolver(ctrl)

mockSrvResolver.EXPECT().ServerStringsFromSrv(gomock.Eq(mockSrv)).Return(nil, errors.New("some error"))

originalServers := []string{"127.0.0.2:11211", "127.0.0.3:11211"}
serverList := new(memcache.ServerList)
serverList.SetServers(originalServers...)

refreshServers(serverList, mockSrv, mockSrvResolver)

actualMemcacheHosts := []string{}

serverList.Each(func(addr net.Addr) error {
actualMemcacheHosts = append(actualMemcacheHosts, addr.String())
return nil
})

assert.Equal(originalServers, actualMemcacheHosts)
}
8 changes: 7 additions & 1 deletion src/srv/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ import (

var srvRegex = regexp.MustCompile(`^_(.+?)\._(.+?)\.(.+)$`)

type SrvResolver interface {
ServerStringsFromSrv(srv string) ([]string, error)
}

type DnsSrvResolver struct{}

func ParseSrv(srv string) (string, string, string, error) {
matches := srvRegex.FindStringSubmatch(srv)
if matches == nil {
Expand All @@ -21,7 +27,7 @@ func ParseSrv(srv string) (string, string, string, error) {
return matches[1], matches[2], matches[3], nil
}

func ServerStringsFromSrv(srv string) ([]string, error) {
func (dnsSrvResolver DnsSrvResolver) ServerStringsFromSrv(srv string) ([]string, error) {
service, proto, name, err := ParseSrv(srv)
if err != nil {
logger.Errorf("failed to parse SRV: %s", err)
Expand Down
1 change: 1 addition & 0 deletions test/mocks/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ package mocks
//go:generate go run github.com/golang/mock/mockgen -destination ./utils/utils.go github.com/envoyproxy/ratelimit/src/utils TimeSource,JitterRandSource
//go:generate go run github.com/golang/mock/mockgen -destination ./memcached/client.go github.com/envoyproxy/ratelimit/src/memcached Client
//go:generate go run github.com/golang/mock/mockgen -destination ./rls/rls.go github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3 RateLimitServiceServer
//go:generate go run github.com/golang/mock/mockgen -destination ./srv/srv.go github.com/envoyproxy/ratelimit/src/srv SrvResolver
49 changes: 49 additions & 0 deletions test/mocks/srv/srv.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions test/srv/srv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ func TestParseSrv(t *testing.T) {
}

func TestServerStringsFromSrvWhenSrvIsNotWellFormed(t *testing.T) {
_, err := srv.ServerStringsFromSrv("example.org")
srvResolver := srv.DnsSrvResolver{}
_, err := srvResolver.ServerStringsFromSrv("example.org")
assert.Equal(t, err, errors.New("could not parse example.org to SRV parts"))
}

func TestServerStringsFromSevWhenSrvIsWellFormedButNotLookupable(t *testing.T) {
_, err := srv.ServerStringsFromSrv("_something._tcp.example.invalid")
srvResolver := srv.DnsSrvResolver{}
_, err := srvResolver.ServerStringsFromSrv("_something._tcp.example.invalid")
var e *net.DNSError
if errors.As(err, &e) {
assert.Equal(t, e.Err, "no such host")
Expand All @@ -48,7 +50,8 @@ func TestServerStringsFromSevWhenSrvIsWellFormedButNotLookupable(t *testing.T) {

func TestServerStrings(t *testing.T) {
// it seems reasonable to think _xmpp-server._tcp.gmail.com will be available for a long time!
servers, err := srv.ServerStringsFromSrv("_xmpp-server._tcp.gmail.com.")
srvResolver := srv.DnsSrvResolver{}
servers, err := srvResolver.ServerStringsFromSrv("_xmpp-server._tcp.gmail.com.")
assert.True(t, len(servers) > 0)
for _, s := range servers {
assert.Regexp(t, `^.*xmpp-server.*google.com.:\d+$`, s)
Expand Down

0 comments on commit 8844fe8

Please sign in to comment.