Skip to content

Commit

Permalink
Allow a comma-separated list of local prefixes, like goimports (#33)
Browse files Browse the repository at this point in the history
Signed-off-by: Luke Shumaker <lukeshu@datawire.io>
  • Loading branch information
LukeShu authored Jun 24, 2021
1 parent 9b479ee commit 59beec6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
6 changes: 4 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var (
doWrite = flag.Bool("w", false, "doWrite result to (source) file instead of stdout")
doDiff = flag.Bool("d", false, "display diffs instead of rewriting files")

localFlag string
localFlag []string

exitCode = 0
)
Expand All @@ -27,9 +27,11 @@ func report(err error) {
}

func parseFlags() []string {
flag.StringVar(&localFlag, "local", "", "put imports beginning with this string after 3rd-party packages, only support one string")
var localFlagStr string
flag.StringVar(&localFlagStr, "local", "", "put imports beginning with this string after 3rd-party packages; comma-separated list")

flag.Parse()
localFlag = gci.ParseLocalFlag(localFlagStr)
return flag.Args()
}

Expand Down
20 changes: 15 additions & 5 deletions pkg/gci/gci.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
)

type FlagSet struct {
LocalFlag string
LocalFlag []string
DoWrite, DoDiff *bool
}

Expand All @@ -42,7 +42,15 @@ type pkg struct {
alias map[string]string
}

func newPkg(data [][]byte, localFlag string) *pkg {
// ParseLocalFlag takes a comma-separated list of
// package-name-prefixes (as passed to the "-local" flag), and splits
// it in to a list. This is different than strings.Split in that it
// handles the empty string and empty entries in the list.
func ParseLocalFlag(str string) []string {
return strings.FieldsFunc(str, func(c rune) bool { return c == ',' })
}

func newPkg(data [][]byte, localFlag []string) *pkg {
listMap := make(map[int][]string)
commentMap := make(map[string]string)
aliasMap := make(map[string]string)
Expand Down Expand Up @@ -156,11 +164,13 @@ func getPkgInfo(line string, comment bool) (string, string, string) {
}
}

func getPkgType(line, localFlag string) int {
func getPkgType(line string, localFlag []string) int {
pkgName := strings.Trim(line, "\"\\`")

if localFlag != "" && strings.HasPrefix(pkgName, localFlag) {
return local
for _, localPkg := range localFlag {
if strings.HasPrefix(pkgName, localPkg) {
return local
}
}

if isStandardPackage(pkgName) {
Expand Down
18 changes: 17 additions & 1 deletion pkg/gci/gci_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,47 @@ func TestGetPkgType(t *testing.T) {
{Line: `"foo/pkg/bar"`, LocalFlag: "foo", ExpectedResult: local},
{Line: `"foo/pkg/bar"`, LocalFlag: "bar", ExpectedResult: remote},
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: remote},
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo", ExpectedResult: remote},
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/bar", ExpectedResult: remote},
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: remote},
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: remote},

{Line: `"github.com/foo/bar"`, LocalFlag: "", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "foo", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "bar", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: local},
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo", ExpectedResult: local},
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/bar", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: local},
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: local},

{Line: `"context"`, LocalFlag: "", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "context", ExpectedResult: local},
{Line: `"context"`, LocalFlag: "foo", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "bar", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "github.com/foo", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "github.com/bar", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: standard},

{Line: `"os/signal"`, LocalFlag: "", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "os/signal", ExpectedResult: local},
{Line: `"os/signal"`, LocalFlag: "foo", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "bar", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "github.com/foo", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "github.com/bar", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: standard},
}

for _, tc := range testCases {
tc := tc
t.Run(fmt.Sprintf("%s:%s", tc.Line, tc.LocalFlag), func(t *testing.T) {
t.Parallel()

result := getPkgType(tc.Line, tc.LocalFlag)
result := getPkgType(tc.Line, ParseLocalFlag(tc.LocalFlag))
if got, want := result, tc.ExpectedResult; got != want {
t.Errorf("bad result: %d, expected: %d", got, want)
}
Expand Down

0 comments on commit 59beec6

Please sign in to comment.