diff --git a/router/router.go b/router/router.go index 148a64be..a69c403e 100644 --- a/router/router.go +++ b/router/router.go @@ -46,6 +46,11 @@ type Router interface { // time it will inform the caller if the client is a remote client or a // local client via the isRemote return value. GetClient(key string) (client interface{}, isRemote bool, err error) + // GetNClients provides the caller with an ordered slice of clients for a + // given key. Each result is a struct with a reference to the actual client + // and a bool indicating whether or not that client is a remote client or a + // local client. + GetNClients(key string, n int) (clients []ClientResult, err error) } // A ClientFactory is able to provide an implementation of a TChan[Service] @@ -101,6 +106,36 @@ func (r *router) GetClient(key string) (client interface{}, isRemote bool, err e return nil, false, err } + return r.getClientByHost(dest) +} + +// ClientResult is a struct that contains a reference to the actual callable +// client and a bool indicating whether or not that client is local or remote. +type ClientResult struct { + HostPort string + Client interface{} + IsRemote bool +} + +func (r *router) GetNClients(key string, n int) ([]ClientResult, error) { + dests, err := r.ringpop.LookupN(key, n) + if err != nil { + return nil, err + } + + clients := make([]ClientResult, n, n) + + for i, dest := range dests { + client, isRemote, err := r.getClientByHost(dest) + if err != nil { + return nil, err + } + clients[i] = ClientResult{dest, client, isRemote} + } + return clients, nil +} + +func (r *router) getClientByHost(dest string) (client interface{}, isRemote bool, err error) { r.rw.RLock() cachedEntry, ok := r.clientCache[dest] r.rw.RUnlock() diff --git a/router/router_test.go b/router/router_test.go index 00fb6c42..23f2f69a 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -52,6 +52,9 @@ func (s *RouterTestSuite) SetupTest() { s.ringpop.On("Lookup", "remote2").Return("127.0.0.1:3001", nil) s.ringpop.On("Lookup", "error").Return("", errors.New("ringpop not ready")) + s.ringpop.On("LookupN", "localfirst", 2).Return([]string{"127.0.0.1:3000", "127.0.0.1:3001"}, nil) + s.ringpop.On("LookupN", "remotefirst", 2).Return([]string{"127.0.0.1:3001", "127.0.0.1:3000"}, nil) + ch, err := tchannel.NewChannel("remote", nil) s.NoError(err) @@ -67,6 +70,28 @@ func (s *RouterTestSuite) TestRingpopRouterGetLocalClient() { s.clientFactory.AssertCalled(s.T(), "GetLocalClient") } +func (s *RouterTestSuite) TestRingpopRouterGetNClientsLocalFirst() { + clients, err := s.router.GetNClients("localfirst", 2) + s.NoError(err) + + s.Equal("local client", clients[0].Client) + s.Equal("remote client", clients[1].Client) + + s.False(clients[0].IsRemote, "first client for localfirst key should not be a remote client") + s.True(clients[1].IsRemote, "second client for localfirst key should not be a local client") +} + +func (s *RouterTestSuite) TestRingpopRouterGetNClientsRemoteFirst() { + clients, err := s.router.GetNClients("remotefirst", 2) + s.NoError(err) + + s.Equal("remote client", clients[0].Client) + s.Equal("local client", clients[1].Client) + + s.True(clients[0].IsRemote, "first client for localfirst key should not be a local client") + s.False(clients[1].IsRemote, "second client for localfirst key should not be a remote client") +} + func (s *RouterTestSuite) TestRingpopRouterGetLocalClientCached() { client, isRemote, err := s.router.GetClient("local") s.NoError(err) @@ -168,7 +193,7 @@ func (s *RouterTestSuite) TestRingpopRouterRemoveClientOnSwimFaultyEvent() { s.NoError(err) s.internal.HandleEvent(swim.MemberlistChangesReceivedEvent{ Changes: []swim.Change{ - swim.Change{ + { Address: dest, Status: swim.Faulty, }, @@ -191,7 +216,7 @@ func (s *RouterTestSuite) TestRingpopRouterRemoveClientOnSwimLeaveEvent() { s.NoError(err) s.internal.HandleEvent(swim.MemberlistChangesReceivedEvent{ Changes: []swim.Change{ - swim.Change{ + { Address: dest, Status: swim.Leave, }, @@ -214,7 +239,7 @@ func (s *RouterTestSuite) TestRingpopRouterNotRemoveClientOnSwimSuspectEvent() { s.NoError(err) s.internal.HandleEvent(swim.MemberlistChangesReceivedEvent{ Changes: []swim.Change{ - swim.Change{ + { Address: dest, Status: swim.Suspect, },