diff --git a/internal/imports/imports_linux.go b/internal/imports/imports_linux.go index 02eb8e2..c105508 100644 --- a/internal/imports/imports_linux.go +++ b/internal/imports/imports_linux.go @@ -13,6 +13,7 @@ import ( _ "github.com/networkservicemesh/sdk/pkg/registry/common/authorize" _ "github.com/networkservicemesh/sdk/pkg/tools/debug" _ "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" + _ "github.com/networkservicemesh/sdk/pkg/tools/listenonurl" _ "github.com/networkservicemesh/sdk/pkg/tools/log" _ "github.com/networkservicemesh/sdk/pkg/tools/log/logruslogger" _ "github.com/networkservicemesh/sdk/pkg/tools/opentelemetry" diff --git a/main.go b/main.go index 2b4fc5e..f88eece 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,6 @@ package main import ( "context" "crypto/tls" - "fmt" "net" "net/url" "os" @@ -50,6 +49,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/tools/debug" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" + "github.com/networkservicemesh/sdk/pkg/tools/listenonurl" "github.com/networkservicemesh/sdk/pkg/tools/log" "github.com/networkservicemesh/sdk/pkg/tools/log/logruslogger" ) @@ -159,7 +159,7 @@ func main() { grpcfd.WithChainUnaryInterceptor(), ) - listenURL := getPublicURL(defaultURL(config)) + listenURL := getPublishableURL(config.ListenOn, log.FromContext(ctx)) log.FromContext(ctx).Infof("Listening url: %v", listenURL) @@ -200,32 +200,21 @@ func exitOnErr(ctx context.Context, cancel context.CancelFunc, errCh <-chan erro }(ctx, errCh) } -func defaultURL(c *Config) *url.URL { - for i := 0; i < len(c.ListenOn); i++ { - u := &c.ListenOn[i] - if u.Scheme == "tcp" { - return u - } - } - return &c.ListenOn[0] -} - -func getPublicURL(u *url.URL) *url.URL { - if u.Port() == "" || len(u.Host) != len(":")+len(u.Port()) { - return u - } +func getPublishableURL(listenOn []url.URL, logger log.Logger) *url.URL { + u := defaultURL(listenOn) addrs, err := net.InterfaceAddrs() if err != nil { - logrus.Warn(err.Error()) + logger.Warn(err.Error()) return u } - for _, a := range addrs { - if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - r, _ := url.Parse(fmt.Sprintf("tcp://%v:%v", ipnet.IP.String(), u.Port())) - return r - } + return listenonurl.GetPublicURL(addrs, u) +} +func defaultURL(listenOn []url.URL) *url.URL { + for i := 0; i < len(listenOn); i++ { + u := &listenOn[i] + if u.Scheme == "tcp" { + return u } } - return u + return &listenOn[0] }