From 071c0c156c50a9dcba3cf2392ea3903f07f3b801 Mon Sep 17 00:00:00 2001 From: Sergey Kacheev Date: Mon, 19 Jul 2021 01:31:21 +0700 Subject: [PATCH] netutil: add url comparison without resolver to URLStringsEqual If one of the nodes in the cluster has lost a dns record, restarting the second node will break it. This PR makes an attempt to add a comparison without using a resolver, which allows to protect cluster from dns errors and does not break the current logic of comparing urls in the URLStringsEqual function. You can read more in the issue #7798 Fixes #7798 --- pkg/netutil/netutil.go | 42 ++++++++++++++++++++++++------------- pkg/netutil/netutil_test.go | 33 ++++++++++++++++++++++++----- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/pkg/netutil/netutil.go b/pkg/netutil/netutil.go index faef6466eeb..f6a87317b65 100644 --- a/pkg/netutil/netutil.go +++ b/pkg/netutil/netutil.go @@ -174,21 +174,13 @@ func URLStringsEqual(ctx context.Context, lg *zap.Logger, a []string, b []string if len(a) != len(b) { return false, fmt.Errorf("len(%q) != len(%q)", a, b) } - urlsA := make([]url.URL, 0) - for _, str := range a { - u, err := url.Parse(str) - if err != nil { - return false, fmt.Errorf("failed to parse %q", str) - } - urlsA = append(urlsA, *u) + urlsA, err := stringsToURLs(a) + if err != nil { + return false, err } - urlsB := make([]url.URL, 0) - for _, str := range b { - u, err := url.Parse(str) - if err != nil { - return false, fmt.Errorf("failed to parse %q", str) - } - urlsB = append(urlsB, *u) + urlsB, err := stringsToURLs(b) + if err != nil { + return false, err } if lg == nil { lg, _ = zap.NewProduction() @@ -196,7 +188,15 @@ func URLStringsEqual(ctx context.Context, lg *zap.Logger, a []string, b []string lg = zap.NewExample() } } - return urlsEqual(ctx, lg, urlsA, urlsB) + sort.Sort(types.URLs(urlsA)) + sort.Sort(types.URLs(urlsB)) + for i := range urlsA { + if !reflect.DeepEqual(urlsA[i], urlsB[i]) { + // If urls are not equal, try to resolve it and compare again. + return urlsEqual(ctx, lg, urlsA, urlsB) + } + } + return true, nil } func urlsToStrings(us []url.URL) []string { @@ -207,6 +207,18 @@ func urlsToStrings(us []url.URL) []string { return rs } +func stringsToURLs(us []string) ([]url.URL, error) { + urls := make([]url.URL, 0, len(us)) + for _, str := range us { + u, err := url.Parse(str) + if err != nil { + return nil, fmt.Errorf("failed to parse %q", str) + } + urls = append(urls, *u) + } + return urls, nil +} + func IsNetworkTimeoutError(err error) bool { nerr, ok := err.(net.Error) return ok && nerr.Timeout() diff --git a/pkg/netutil/netutil_test.go b/pkg/netutil/netutil_test.go index 42b05ca295a..7d1d17aa269 100644 --- a/pkg/netutil/netutil_test.go +++ b/pkg/netutil/netutil_test.go @@ -17,6 +17,7 @@ package netutil import ( "context" "errors" + "fmt" "net" "net/url" "reflect" @@ -292,11 +293,33 @@ func TestURLsEqual(t *testing.T) { } } func TestURLStringsEqual(t *testing.T) { - result, err := URLStringsEqual(context.TODO(), zap.NewExample(), []string{"http://127.0.0.1:8080"}, []string{"http://127.0.0.1:8080"}) - if !result { - t.Errorf("unexpected result %v", result) + defer func() { resolveTCPAddr = resolveTCPAddrDefault }() + errOnResolve := func(ctx context.Context, addr string) (*net.TCPAddr, error) { + return nil, fmt.Errorf("unexpected attempt to resolve: %q", addr) + } + cases := []struct { + urlsA []string + urlsB []string + resolver func(ctx context.Context, addr string) (*net.TCPAddr, error) + }{ + {[]string{"http://127.0.0.1:8080"}, []string{"http://127.0.0.1:8080"}, resolveTCPAddrDefault}, + {[]string{ + "http://host1:8080", + "http://host2:8080", + }, []string{ + "http://host1:8080", + "http://host2:8080", + }, errOnResolve}, } - if err != nil { - t.Errorf("unexpected error %v", err) + for idx, c := range cases { + t.Logf("TestURLStringsEqual, case #%d", idx) + resolveTCPAddr = c.resolver + result, err := URLStringsEqual(context.TODO(), zap.NewExample(), c.urlsA, c.urlsB) + if !result { + t.Errorf("unexpected result %v", result) + } + if err != nil { + t.Errorf("unexpected error %v", err) + } } }