Skip to content

Commit

Permalink
Fallback ip (#1147)
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Littley <cody@eigenlabs.org>
  • Loading branch information
cody-littley authored Jan 23, 2025
1 parent 34ad649 commit cc009aa
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 94 deletions.
17 changes: 17 additions & 0 deletions common/pubip/mock_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package pubip

import "context"

var _ Provider = (*mockProvider)(nil)

// mockProvider is a mock implementation of the Provider interface.
type mockProvider struct {
}

func (m mockProvider) Name() string {
return "mockip"
}

func (m mockProvider) PublicIPAddress(ctx context.Context) (string, error) {
return "localhost", nil
}
52 changes: 52 additions & 0 deletions common/pubip/multi_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package pubip

import (
"context"
"fmt"
"github.com/Layr-Labs/eigensdk-go/logging"
"strings"
)

var _ Provider = (*multiProvider)(nil)

// An implementation of Provider that uses multiple providers. It attempts each provider in order until one succeeds.
type multiProvider struct {
logger logging.Logger
providers []Provider
}

func (m *multiProvider) Name() string {
sb := strings.Builder{}
sb.WriteString("multiProvider(")
for i, provider := range m.providers {
sb.WriteString(provider.Name())
if i < len(m.providers)-1 {
sb.WriteString(", ")
}
}
sb.WriteString(")")
return sb.String()
}

// NewMultiProvider creates a new multiProvider with the given providers.
func NewMultiProvider(
logger logging.Logger,
providers ...Provider) Provider {

return &multiProvider{
logger: logger,
providers: providers,
}
}

func (m *multiProvider) PublicIPAddress(ctx context.Context) (string, error) {
for _, provider := range m.providers {
ip, err := provider.PublicIPAddress(ctx)
if err == nil {
return ip, nil
}
m.logger.Warnf("failed to get public IP address from %s: %v", provider, err)
}

return "", fmt.Errorf("failed to get public IP address from any provider")
}
123 changes: 53 additions & 70 deletions common/pubip/pubip.go
Original file line number Diff line number Diff line change
@@ -1,97 +1,80 @@
package pubip

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"github.com/Layr-Labs/eigensdk-go/logging"
"strings"
)

const (
SeepIPProvider = "seeip"
IpifyProvider = "ipify"
MockIpProvider = "mockip"
)

var (
SeeIP = &SimpleProvider{Name: "seeip", URL: "https://api.seeip.org"}
Ipify = &SimpleProvider{Name: "ipify", URL: "https://api.ipify.org"}
MockIp = &SimpleProvider{Name: "mockip", URL: ""}
)

type RequestDoer interface {
Do(req *http.Request) (*http.Response, error)
}

type RequestDoerFunc func(req *http.Request) (*http.Response, error)
SeeIPURL = "https://api.seeip.org"

var _ RequestDoer = (RequestDoerFunc)(nil)
IpifyProvider = "ipify"
IpifyURL = "https://api.ipify.org"

func (f RequestDoerFunc) Do(req *http.Request) (*http.Response, error) {
return f(req)
}
MockIpProvider = "mockip"
)

// Provider is an interface for getting a machine's public IP address.
type Provider interface {
// Name returns the name of the provider
Name() string
// PublicIPAddress returns the public IP address of the node
PublicIPAddress(ctx context.Context) (string, error)
}

type SimpleProvider struct {
RequestDoer RequestDoer
Name string
URL string
}

var _ Provider = (*SimpleProvider)(nil)

func (s *SimpleProvider) PublicIPAddress(ctx context.Context) (string, error) {
if s.Name == MockIpProvider {
return "localhost", nil
// buildSimpleProviderByName returns a simple provider with the given name.
// Returns nil if the name is not recognized.
func buildSimpleProviderByName(name string) Provider {
if name == SeepIPProvider {
return NewSimpleProvider(SeepIPProvider, SeeIPURL)
} else if name == IpifyProvider {
return NewSimpleProvider(IpifyProvider, IpifyURL)
} else if name == MockIpProvider {
return &mockProvider{}
}
ip, err := s.doRequest(ctx, s.URL)
if err != nil {
return "", fmt.Errorf("%s: failed to retrieve public ip address: %w", s.Name, err)
}
return ip, nil
return nil
}

func (s *SimpleProvider) doRequest(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", err
}
// buildDefaultProviders returns a default provider.
func buildDefaultProvider(logger logging.Logger) Provider {
return NewMultiProvider(logger, buildSimpleProviderByName(SeepIPProvider), buildSimpleProviderByName(IpifyProvider))
}

if s.RequestDoer == nil {
s.RequestDoer = http.DefaultClient
}
resp, err := s.RequestDoer.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
func providerOrDefault(logger logging.Logger, names ...string) Provider {

if resp.StatusCode >= http.StatusBadRequest {
return "", errors.New(resp.Status)
for i, name := range names {
names[i] = strings.ToLower(strings.TrimSpace(name))
}

var b bytes.Buffer
_, err = io.Copy(&b, resp.Body)
if err != nil {
return "", err
if len(names) == 0 {
return buildDefaultProvider(logger)
} else if len(names) == 1 {
provider := buildSimpleProviderByName(names[0])
if provider == nil {
logger.Warnf("Unknown IP provider '%s'", names[0])
return buildDefaultProvider(logger)
}
return provider
} else {
providers := make([]Provider, len(names))
for i, name := range names {
providers[i] = buildSimpleProviderByName(name)
if providers[i] == nil {
logger.Warnf("Unknown IP provider '%s'", name)
return buildDefaultProvider(logger)
}
}

return NewMultiProvider(logger, providers...)
}
return strings.TrimSpace(b.String()), nil
}

func ProviderOrDefault(name string) Provider {
p := map[string]Provider{
SeepIPProvider: SeeIP,
IpifyProvider: Ipify,
MockIpProvider: MockIp,
}[name]
if p == nil {
p = SeeIP
}
return p
// ProviderOrDefault returns a provider with the provided name, or a default provider if the name is not recognized.
// Provider strings are not case-sensitive.
func ProviderOrDefault(logger logging.Logger, names ...string) Provider {
provider := providerOrDefault(logger, names...)
logger.Infof("Using IP provider '%s'", provider.Name())
return provider
}
147 changes: 136 additions & 11 deletions common/pubip/pubip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,145 @@ package pubip

import (
"context"
"fmt"
"github.com/Layr-Labs/eigenda/common"
"github.com/Layr-Labs/eigenda/common/testutils/random"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net/http"
"net/http/httptest"
"testing"
)

func TestProviderOrDefault(t *testing.T) {
p := ProviderOrDefault(SeepIPProvider)
assert.Equal(t, SeeIP, p)
p = ProviderOrDefault(IpifyProvider)
assert.Equal(t, Ipify, p)
p = ProviderOrDefault("test")
assert.Equal(t, SeeIP, p)
logger, err := common.NewLogger(common.DefaultLoggerConfig())
require.NoError(t, err)

provider := ProviderOrDefault(logger, SeepIPProvider)
require.Equal(t, SeepIPProvider, provider.Name())
seeIPProvider, ok := provider.(*simpleProvider)
require.True(t, ok)
require.Equal(t, SeeIPURL, seeIPProvider.URL)

provider = ProviderOrDefault(logger, IpifyProvider)
require.Equal(t, IpifyProvider, provider.Name())
ipifyProvider, ok := provider.(*simpleProvider)
require.True(t, ok)
require.Equal(t, IpifyURL, ipifyProvider.URL)

provider = ProviderOrDefault(logger, MockIpProvider)
require.Equal(t, MockIpProvider, provider.Name())
_, ok = provider.(*mockProvider)
require.True(t, ok)

// invalid provider, should yield default
provider = ProviderOrDefault(logger, "this is not a supported provider")
require.Equal(t, fmt.Sprintf("multiProvider(%s, %s)", SeepIPProvider, IpifyProvider), provider.Name())
multi, ok := provider.(*multiProvider)
require.True(t, ok)
require.Equal(t, 2, len(multi.providers))
require.Equal(t, SeepIPProvider, multi.providers[0].Name())
require.Equal(t, IpifyProvider, multi.providers[1].Name())

provider = providerOrDefault(logger, SeepIPProvider, IpifyProvider)
require.Equal(t, fmt.Sprintf("multiProvider(%s, %s)", SeepIPProvider, IpifyProvider), provider.Name())
multi, ok = provider.(*multiProvider)
require.True(t, ok)
require.Equal(t, 2, len(multi.providers))
require.Equal(t, SeepIPProvider, multi.providers[0].Name())
require.Equal(t, IpifyProvider, multi.providers[1].Name())

provider = providerOrDefault(logger, IpifyProvider, SeepIPProvider, MockIpProvider)
require.Equal(t, fmt.Sprintf("multiProvider(%s, %s, %s)",
IpifyProvider, SeepIPProvider, MockIpProvider), provider.Name())
multi, ok = provider.(*multiProvider)
require.True(t, ok)
require.Equal(t, 3, len(multi.providers))
require.Equal(t, IpifyProvider, multi.providers[0].Name())
require.Equal(t, SeepIPProvider, multi.providers[1].Name())
require.Equal(t, MockIpProvider, multi.providers[2].Name())

// invalid provider, should yield default
provider = providerOrDefault(logger, IpifyProvider, "not a real provider", MockIpProvider)
require.Equal(t, fmt.Sprintf("multiProvider(%s, %s)", SeepIPProvider, IpifyProvider), provider.Name())
multi, ok = provider.(*multiProvider)
require.True(t, ok)
require.Equal(t, 2, len(multi.providers))
require.Equal(t, SeepIPProvider, multi.providers[0].Name())
require.Equal(t, IpifyProvider, multi.providers[1].Name())
}

var _ Provider = (*testProvider)(nil)

type testProvider struct {
// if true then this PublicIPAddress will return an error
returnErr bool

// number of times PublicIPAddress was called
count int

// ip address to return when PublicIPAddress is called
ip string
}

func (t *testProvider) Name() string {
return "test"
}

func (t *testProvider) PublicIPAddress(ctx context.Context) (string, error) {
t.count++
if t.returnErr {
return "", fmt.Errorf("intentional error")
}
return t.ip, nil
}

func TestMultiProvider(t *testing.T) {
rand := random.NewTestRandom(t)
logger, err := common.NewLogger(common.DefaultLoggerConfig())
require.NoError(t, err)

provider1 := &testProvider{
ip: rand.String(10),
}
provider2 := &testProvider{
ip: rand.String(10),
}
provider3 := &testProvider{
ip: rand.String(10),
}
provider := NewMultiProvider(logger, provider1, provider2, provider3)

ip, err := provider.PublicIPAddress(context.Background())
require.NoError(t, err)
require.Equal(t, 1, provider1.count)
require.Equal(t, 0, provider2.count)
require.Equal(t, 0, provider3.count)
require.Equal(t, provider1.ip, ip)

provider1.returnErr = true
ip, err = provider.PublicIPAddress(context.Background())
require.NoError(t, err)
require.Equal(t, 2, provider1.count)
require.Equal(t, 1, provider2.count)
require.Equal(t, 0, provider3.count)
require.Equal(t, provider2.ip, ip)

provider2.returnErr = true
ip, err = provider.PublicIPAddress(context.Background())
require.NoError(t, err)
require.Equal(t, 3, provider1.count)
require.Equal(t, 2, provider2.count)
require.Equal(t, 1, provider3.count)
require.Equal(t, provider3.ip, ip)

provider3.returnErr = true
ip, err = provider.PublicIPAddress(context.Background())
require.Error(t, err)
require.Equal(t, 4, provider1.count)
require.Equal(t, 3, provider2.count)
require.Equal(t, 2, provider3.count)
require.Equal(t, "", ip)
}

func TestSimpleProvider_PublicIPAddress(t *testing.T) {
Expand Down Expand Up @@ -48,11 +174,10 @@ func TestSimpleProvider_PublicIPAddress(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := SimpleProvider{
RequestDoer: tt.requestDoer,
Name: "test",
URL: "https://api.seeip.org",
}
p := CustomProvider(
tt.requestDoer,
"test",
"https://api.seeip.org")

ip, err := p.PublicIPAddress(context.Background())
assert.Equal(t, tt.expected, ip)
Expand Down
Loading

0 comments on commit cc009aa

Please sign in to comment.