From 59beec6e2fe1fd299a96d9182071dbb529b936fc Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Wed, 23 Jun 2021 20:15:44 -0600 Subject: [PATCH] Allow a comma-separated list of local prefixes, like goimports (#33) Signed-off-by: Luke Shumaker --- main.go | 6 ++++-- pkg/gci/gci.go | 20 +++++++++++++++----- pkg/gci/gci_test.go | 18 +++++++++++++++++- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/main.go b/main.go index 6884ae5..97df77a 100644 --- a/main.go +++ b/main.go @@ -13,7 +13,7 @@ var ( doWrite = flag.Bool("w", false, "doWrite result to (source) file instead of stdout") doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") - localFlag string + localFlag []string exitCode = 0 ) @@ -27,9 +27,11 @@ func report(err error) { } func parseFlags() []string { - flag.StringVar(&localFlag, "local", "", "put imports beginning with this string after 3rd-party packages, only support one string") + var localFlagStr string + flag.StringVar(&localFlagStr, "local", "", "put imports beginning with this string after 3rd-party packages; comma-separated list") flag.Parse() + localFlag = gci.ParseLocalFlag(localFlagStr) return flag.Args() } diff --git a/pkg/gci/gci.go b/pkg/gci/gci.go index 47b777c..7efa576 100644 --- a/pkg/gci/gci.go +++ b/pkg/gci/gci.go @@ -32,7 +32,7 @@ import ( ) type FlagSet struct { - LocalFlag string + LocalFlag []string DoWrite, DoDiff *bool } @@ -42,7 +42,15 @@ type pkg struct { alias map[string]string } -func newPkg(data [][]byte, localFlag string) *pkg { +// ParseLocalFlag takes a comma-separated list of +// package-name-prefixes (as passed to the "-local" flag), and splits +// it in to a list. This is different than strings.Split in that it +// handles the empty string and empty entries in the list. +func ParseLocalFlag(str string) []string { + return strings.FieldsFunc(str, func(c rune) bool { return c == ',' }) +} + +func newPkg(data [][]byte, localFlag []string) *pkg { listMap := make(map[int][]string) commentMap := make(map[string]string) aliasMap := make(map[string]string) @@ -156,11 +164,13 @@ func getPkgInfo(line string, comment bool) (string, string, string) { } } -func getPkgType(line, localFlag string) int { +func getPkgType(line string, localFlag []string) int { pkgName := strings.Trim(line, "\"\\`") - if localFlag != "" && strings.HasPrefix(pkgName, localFlag) { - return local + for _, localPkg := range localFlag { + if strings.HasPrefix(pkgName, localPkg) { + return local + } } if isStandardPackage(pkgName) { diff --git a/pkg/gci/gci_test.go b/pkg/gci/gci_test.go index 27b2ed3..b8644b8 100644 --- a/pkg/gci/gci_test.go +++ b/pkg/gci/gci_test.go @@ -15,23 +15,39 @@ func TestGetPkgType(t *testing.T) { {Line: `"foo/pkg/bar"`, LocalFlag: "foo", ExpectedResult: local}, {Line: `"foo/pkg/bar"`, LocalFlag: "bar", ExpectedResult: remote}, {Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: remote}, + {Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo", ExpectedResult: remote}, + {Line: `"foo/pkg/bar"`, LocalFlag: "github.com/bar", ExpectedResult: remote}, + {Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: remote}, + {Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: remote}, {Line: `"github.com/foo/bar"`, LocalFlag: "", ExpectedResult: remote}, {Line: `"github.com/foo/bar"`, LocalFlag: "foo", ExpectedResult: remote}, {Line: `"github.com/foo/bar"`, LocalFlag: "bar", ExpectedResult: remote}, {Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: local}, + {Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo", ExpectedResult: local}, + {Line: `"github.com/foo/bar"`, LocalFlag: "github.com/bar", ExpectedResult: remote}, + {Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: local}, + {Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: local}, {Line: `"context"`, LocalFlag: "", ExpectedResult: standard}, {Line: `"context"`, LocalFlag: "context", ExpectedResult: local}, {Line: `"context"`, LocalFlag: "foo", ExpectedResult: standard}, {Line: `"context"`, LocalFlag: "bar", ExpectedResult: standard}, {Line: `"context"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard}, + {Line: `"context"`, LocalFlag: "github.com/foo", ExpectedResult: standard}, + {Line: `"context"`, LocalFlag: "github.com/bar", ExpectedResult: standard}, + {Line: `"context"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: standard}, + {Line: `"context"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: standard}, {Line: `"os/signal"`, LocalFlag: "", ExpectedResult: standard}, {Line: `"os/signal"`, LocalFlag: "os/signal", ExpectedResult: local}, {Line: `"os/signal"`, LocalFlag: "foo", ExpectedResult: standard}, {Line: `"os/signal"`, LocalFlag: "bar", ExpectedResult: standard}, {Line: `"os/signal"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard}, + {Line: `"os/signal"`, LocalFlag: "github.com/foo", ExpectedResult: standard}, + {Line: `"os/signal"`, LocalFlag: "github.com/bar", ExpectedResult: standard}, + {Line: `"os/signal"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: standard}, + {Line: `"os/signal"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: standard}, } for _, tc := range testCases { @@ -39,7 +55,7 @@ func TestGetPkgType(t *testing.T) { t.Run(fmt.Sprintf("%s:%s", tc.Line, tc.LocalFlag), func(t *testing.T) { t.Parallel() - result := getPkgType(tc.Line, tc.LocalFlag) + result := getPkgType(tc.Line, ParseLocalFlag(tc.LocalFlag)) if got, want := result, tc.ExpectedResult; got != want { t.Errorf("bad result: %d, expected: %d", got, want) }