From d17087deb30e9dd68fdc1bcf93b681e7aec2a292 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Sun, 29 Oct 2023 19:37:41 +0000 Subject: [PATCH] chore: improve cmd unit testing --- .version | 2 +- cmd/templ/fmtcmd/main.go | 14 +-- cmd/templ/generatecmd/main.go | 47 ++++---- cmd/templ/lspcmd/main.go | 9 +- cmd/templ/main.go | 221 +++++++++++++++++++++++++--------- cmd/templ/main_test.go | 42 ++++++- cmd/templ/migratecmd/main.go | 17 +-- 7 files changed, 249 insertions(+), 103 deletions(-) diff --git a/.version b/.version index 5d54ff2cf..9d84d8d8f 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.2.428 \ No newline at end of file +0.2.431 \ No newline at end of file diff --git a/cmd/templ/fmtcmd/main.go b/cmd/templ/fmtcmd/main.go index 2517eb1bb..c1948c038 100644 --- a/cmd/templ/fmtcmd/main.go +++ b/cmd/templ/fmtcmd/main.go @@ -15,16 +15,16 @@ import ( const workerCount = 4 -func Run(args []string) (err error) { +func Run(w io.Writer, args []string) (err error) { if len(args) > 0 { - return formatDir(args[0]) + return formatDir(w, args[0]) } - return formatStdin() + return formatReader(w, os.Stdin) } -func formatStdin() (err error) { +func formatReader(w io.Writer, r io.Reader) (err error) { var bytes []byte - bytes, err = io.ReadAll(os.Stdin) + bytes, err = io.ReadAll(r) if err != nil { return } @@ -32,14 +32,14 @@ func formatStdin() (err error) { if err != nil { return fmt.Errorf("parsing error: %w", err) } - err = t.Write(os.Stdout) + err = t.Write(w) if err != nil { return fmt.Errorf("formatting error: %w", err) } return nil } -func formatDir(dir string) (err error) { +func formatDir(w io.Writer, dir string) (err error) { start := time.Now() results := make(chan processor.Result) go processor.Process(dir, format, workerCount, results) diff --git a/cmd/templ/generatecmd/main.go b/cmd/templ/generatecmd/main.go index ee96143a8..234c6433d 100644 --- a/cmd/templ/generatecmd/main.go +++ b/cmd/templ/generatecmd/main.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "go/format" + "io" "net/http" "net/url" "os" @@ -48,7 +49,7 @@ type Arguments struct { var defaultWorkerCount = runtime.NumCPU() -func Run(args Arguments) (err error) { +func Run(w io.Writer, args Arguments) (err error) { ctx, cancel := context.WithCancel(context.Background()) signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, os.Interrupt) @@ -64,10 +65,10 @@ func Run(args Arguments) (err error) { go func() { select { case <-signalChan: // First signal, cancel context. - fmt.Println("\nCancelling...") + fmt.Fprintln(w, "\nCancelling...") err = run.Stop() if err != nil { - fmt.Printf("Error killing command: %v\n", err) + fmt.Fprintf(w, "Error killing command: %v\n", err) } cancel() case <-ctx.Done(): @@ -75,14 +76,14 @@ func Run(args Arguments) (err error) { <-signalChan // Second signal, hard exit. os.Exit(2) }() - err = runCmd(ctx, args) + err = runCmd(ctx, w, args) if errors.Is(err, context.Canceled) { return nil } return err } -func runCmd(ctx context.Context, args Arguments) (err error) { +func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) { start := time.Now() if args.Watch && args.FileName != "" { return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag") @@ -95,7 +96,7 @@ func runCmd(ctx context.Context, args Arguments) (err error) { opts = append(opts, generator.WithTimestamp(time.Now())) } if args.FileName != "" { - return processSingleFile(ctx, args.FileName, args.GenerateSourceMapVisualisations, opts) + return processSingleFile(ctx, w, args.FileName, args.GenerateSourceMapVisualisations, opts) } var target *url.URL if args.Proxy != "" { @@ -123,14 +124,14 @@ func runCmd(ctx context.Context, args Arguments) (err error) { p = proxy.New(args.ProxyPort, target) } - fmt.Println("Processing path:", args.Path) + fmt.Fprintln(w, "Processing path:", args.Path) bo := backoff.NewExponentialBackOff() bo.InitialInterval = time.Millisecond * 500 bo.MaxInterval = time.Second * 3 var firstRunComplete bool fileNameToLastModTime := make(map[string]time.Time) for !firstRunComplete || args.Watch { - changesFound, errs := processChanges(ctx, fileNameToLastModTime, args.Path, args.GenerateSourceMapVisualisations, opts, args.WorkerCount) + changesFound, errs := processChanges(ctx, w, fileNameToLastModTime, args.Path, args.GenerateSourceMapVisualisations, opts, args.WorkerCount) if len(errs) > 0 { if errors.Is(errs[0], context.Canceled) { return errs[0] @@ -138,14 +139,14 @@ func runCmd(ctx context.Context, args Arguments) (err error) { if !args.Watch { return fmt.Errorf("failed to process path: %v", errors.Join(errs...)) } - fmt.Printf("Error processing path: %v\n", errors.Join(errs...)) + fmt.Fprintf(w, "Error processing path: %v\n", errors.Join(errs...)) } if changesFound > 0 { - fmt.Printf("Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) + fmt.Fprintf(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) if args.Command != "" { - fmt.Printf("Executing command: %s\n", args.Command) + fmt.Fprintf(w, "Executing command: %s\n", args.Command) if _, err := run.Run(ctx, args.Path, args.Command); err != nil { - fmt.Printf("Error starting command: %v\n", err) + fmt.Fprintf(w, "Error starting command: %v\n", err) } } // Send server-sent event. @@ -155,15 +156,15 @@ func runCmd(ctx context.Context, args Arguments) (err error) { if !firstRunComplete && p != nil { go func() { - fmt.Printf("Proxying from %s to target: %s\n", p.URL, p.Target.String()) + fmt.Fprintf(w, "Proxying from %s to target: %s\n", p.URL, p.Target.String()) if err := http.ListenAndServe(fmt.Sprintf("127.0.0.1:%d", args.ProxyPort), p); err != nil { - fmt.Printf("Error starting proxy: %v\n", err) + fmt.Fprintf(w, "Error starting proxy: %v\n", err) } }() go func() { - fmt.Printf("Opening URL: %s\n", p.Target.String()) - if err := openURL(p.URL); err != nil { - fmt.Printf("Error opening URL: %v\n", err) + fmt.Fprintf(w, "Opening URL: %s\n", p.Target.String()) + if err := openURL(w, p.URL); err != nil { + fmt.Fprintf(w, "Error opening URL: %v\n", err) } }() } @@ -195,7 +196,7 @@ func shouldSkipDir(dir string) bool { return false } -func processChanges(ctx context.Context, fileNameToLastModTime map[string]time.Time, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int) (changesFound int, errs []error) { +func processChanges(ctx context.Context, w io.Writer, fileNameToLastModTime map[string]time.Time, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int) (changesFound int, errs []error) { sem := make(chan struct{}, maxWorkerCount) var wg sync.WaitGroup @@ -227,7 +228,7 @@ func processChanges(ctx context.Context, fileNameToLastModTime map[string]time.T wg.Add(1) go func() { defer wg.Done() - if err := processSingleFile(ctx, path, generateSourceMapVisualisations, opts); err != nil { + if err := processSingleFile(ctx, w, path, generateSourceMapVisualisations, opts); err != nil { errs = append(errs, err) } <-sem @@ -245,7 +246,7 @@ func processChanges(ctx context.Context, fileNameToLastModTime map[string]time.T return changesFound, errs } -func openURL(url string) error { +func openURL(w io.Writer, url string) error { backoff := backoff.NewExponentialBackOff() backoff.InitialInterval = time.Second var client http.Client @@ -255,19 +256,19 @@ func openURL(url string) error { break } d := backoff.NextBackOff() - fmt.Printf("Server not ready. Retrying in %v...\n", d) + fmt.Fprintf(w, "Server not ready. Retrying in %v...\n", d) time.Sleep(d) } return browser.OpenURL(url) } -func processSingleFile(ctx context.Context, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) error { +func processSingleFile(ctx context.Context, w io.Writer, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) error { start := time.Now() err := compile(ctx, fileName, generateSourceMapVisualisations, opts) if err != nil { return err } - fmt.Printf("Generated code for %q in %s\n", fileName, time.Since(start)) + fmt.Fprintf(w, "Generated code for %q in %s\n", fileName, time.Since(start)) return err } diff --git a/cmd/templ/lspcmd/main.go b/cmd/templ/lspcmd/main.go index ea6259b44..a3111bb6e 100644 --- a/cmd/templ/lspcmd/main.go +++ b/cmd/templ/lspcmd/main.go @@ -3,6 +3,7 @@ package lspcmd import ( "context" "fmt" + "io" "net/http" "os" "os/signal" @@ -28,7 +29,7 @@ type Arguments struct { HTTPDebug string } -func Run(args Arguments) error { +func Run(w io.Writer, args Arguments) error { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) signalChan := make(chan os.Signal, 1) @@ -51,10 +52,10 @@ func Run(args Arguments) error { <-signalChan // Second signal, hard exit. os.Exit(2) }() - return run(ctx, args) + return run(ctx, w, args) } -func run(ctx context.Context, args Arguments) (err error) { +func run(ctx context.Context, w io.Writer, args Arguments) (err error) { log := zap.NewNop() if args.Log != "" { cfg := zap.NewProductionConfig() @@ -64,7 +65,7 @@ func run(ctx context.Context, args Arguments) (err error) { } log, err = cfg.Build() if err != nil { - _, _ = fmt.Printf("failed to create logger: %v\n", err) + _, _ = fmt.Fprintf(w, "failed to create logger: %v\n", err) os.Exit(1) } } diff --git a/cmd/templ/main.go b/cmd/templ/main.go index 68841e78a..4bdbcf02b 100644 --- a/cmd/templ/main.go +++ b/cmd/templ/main.go @@ -18,6 +18,20 @@ func main() { run(os.Stdout, os.Args) } +const usageText = `usage: templ [...] + +templ - build HTML UIs with Go + +See docs at https://templ.guide + +commands: + generate Generates Go code from templ files + fmt Formats templ files + lsp Starts a language server for templ files + migrate Migrates v1 templ files to v2 format + version Prints the version +` + func run(w io.Writer, args []string) (code int) { if len(args) < 2 { fmt.Fprint(w, usageText) @@ -25,16 +39,16 @@ func run(w io.Writer, args []string) (code int) { } switch args[1] { case "generate": - generateCmd(args[2:]) + generateCmd(w, args[2:]) return case "migrate": - migrateCmd(args[2:]) + migrateCmd(w, args[2:]) return case "fmt": - fmtCmd(args[2:]) + fmtCmd(w, args[2:]) return case "lsp": - lspCmd(args[2:]) + lspCmd(w, args[2:]) return case "version": fmt.Fprintln(w, templ.Version) @@ -47,37 +61,72 @@ func run(w io.Writer, args []string) (code int) { return 0 } -const usageText = `usage: templ [parameters] -To see help text, you can run: - templ generate --help - templ fmt --help - templ lsp --help - templ migrate --help - templ version -examples: - templ generate +const generateUsageText = `usage: templ generate [...] + +Generates Go code from templ files. + +Args: + -path + Generates code for all files in path. (default .) + -f + Optionally generates code for a single file, e.g. -f header.templ + -sourceMapVisualisations + Set to true to generate HTML files to visualise the templ code and its corresponding Go code. + -include-version + Set to false to skip inclusion of the templ version in the generated code. (default true) + -include-timestamp + Set to true to include the current time in the generated code. + -watch + Set to true to watch the path for changes and regenerate code. + -cmd + Set the command to run after generating code. + -proxy + Set the URL to proxy after generating code and executing the command. + -proxyport + The port the proxy will listen on. (default 7331) + -w + Number of workers to use when generating code. (default runtime.NumCPUs) + -pprof + Port to run the pprof server on. + -help + Print help and exit. + +Examples: + + Generate code for all files in the current directory and subdirectories: + + templ generate + + Generate code for a single file: + + templ generate -f header.templ + + Watch the current directory and subdirectories for changes and regenerate code: + + templ generate -watch ` -func generateCmd(args []string) { +func generateCmd(w io.Writer, args []string) (code int) { cmd := flag.NewFlagSet("generate", flag.ExitOnError) - fileNameFlag := cmd.String("f", "", "Optionally generates code for a single file, e.g. -f header.templ") - pathFlag := cmd.String("path", ".", "Generates code for all files in path.") - sourceMapVisualisations := cmd.Bool("sourceMapVisualisations", false, "Set to true to generate HTML files to visualise the templ code and its corresponding Go code.") - includeVersionFlag := cmd.Bool("include-version", true, "Set to false to skip inclusion of the templ version in the generated code.") - includeTimestampFlag := cmd.Bool("include-timestamp", false, "Set to true to include the current time in the generated code.") - watchFlag := cmd.Bool("watch", false, "Set to true to watch the path for changes and regenerate code.") - cmdFlag := cmd.String("cmd", "", "Set the command to run after generating code.") - proxyFlag := cmd.String("proxy", "", "Set the URL to proxy after generating code and executing the command.") - proxyPortFlag := cmd.Int("proxyport", 7331, "The port the proxy will listen on.") - workerCountFlag := cmd.Int("w", runtime.NumCPU(), "Number of workers to run in parallel.") - pprofPortFlag := cmd.Int("pprof", 0, "Port to start pprof web server on.") - helpFlag := cmd.Bool("help", false, "Print help and exit.") + cmd.SetOutput(w) + fileNameFlag := cmd.String("f", "", "") + pathFlag := cmd.String("path", ".", "") + sourceMapVisualisations := cmd.Bool("sourceMapVisualisations", false, "") + includeVersionFlag := cmd.Bool("include-version", true, "") + includeTimestampFlag := cmd.Bool("include-timestamp", false, "") + watchFlag := cmd.Bool("watch", false, "") + cmdFlag := cmd.String("cmd", "", "") + proxyFlag := cmd.String("proxy", "", "") + proxyPortFlag := cmd.Int("proxyport", 7331, "") + workerCountFlag := cmd.Int("w", runtime.NumCPU(), "") + pprofPortFlag := cmd.Int("pprof", 0, "") + helpFlag := cmd.Bool("help", false, "") err := cmd.Parse(args) if err != nil || *helpFlag { - cmd.PrintDefaults() + fmt.Fprint(w, generateUsageText) return } - err = generatecmd.Run(generatecmd.Arguments{ + err = generatecmd.Run(w, generatecmd.Arguments{ FileName: *fileNameFlag, Path: *pathFlag, Watch: *watchFlag, @@ -91,60 +140,119 @@ func generateCmd(args []string) { PPROFPort: *pprofPortFlag, }) if err != nil { - fmt.Println(err.Error()) - os.Exit(1) + fmt.Fprintln(w, err.Error()) + return 1 } + return 0 } -func migrateCmd(args []string) { +const migrateUsageText = `usage: templ migrate [ ...] + +Migrates v1 templ files to v2 format. + +Args: + -f string + Optionally migrate a single file, e.g. -f header.templ + -help + Print help and exit. + -path string + Migrates code for all files in path. +` + +func migrateCmd(w io.Writer, args []string) (code int) { cmd := flag.NewFlagSet("migrate", flag.ExitOnError) - fileName := cmd.String("f", "", "Optionally migrate a single file, e.g. -f header.templ") - path := cmd.String("path", ".", "Migrates code for all files in path.") - helpFlag := cmd.Bool("help", false, "Print help and exit.") + cmd.SetOutput(w) + fileName := cmd.String("f", "", "") + path := cmd.String("path", "", "") + helpFlag := cmd.Bool("help", false, "") + cmd.Usage = func() { + fmt.Fprint(w, migrateUsageText) + } err := cmd.Parse(args) - if err != nil || *helpFlag { - cmd.PrintDefaults() + if err != nil || *helpFlag || (*path == "" && *fileName == "") { + cmd.Usage() return } - err = migratecmd.Run(migratecmd.Arguments{ + err = migratecmd.Run(w, migratecmd.Arguments{ FileName: *fileName, Path: *path, }) if err != nil { - fmt.Println(err.Error()) - os.Exit(1) + fmt.Fprintln(w, err.Error()) + return 1 } + return 0 } -func fmtCmd(args []string) { +const fmtUsageText = `usage: templ fmt [ ...] + +Format all files in directory: + + templ fmt . + +Format stdin to stdout: + + templ fmt < header.templ + +Args: + -help + Print help and exit. +` + +func fmtCmd(w io.Writer, args []string) (code int) { cmd := flag.NewFlagSet("fmt", flag.ExitOnError) - helpFlag := cmd.Bool("help", false, "Print help and exit.") + cmd.SetOutput(w) + cmd.Usage = func() { + fmt.Fprint(w, fmtUsageText) + } + helpFlag := cmd.Bool("help", false, "") err := cmd.Parse(args) if err != nil || *helpFlag { - cmd.PrintDefaults() + cmd.Usage() return } - err = fmtcmd.Run(args) + err = fmtcmd.Run(w, args) if err != nil { - fmt.Println(err.Error()) - os.Exit(1) + fmt.Fprintln(w, err.Error()) + return 1 } + return 0 } -func lspCmd(args []string) { +const lspUsageText = `usage: templ lsp [ ...] + +Starts a language server for templ. + +Args: + -log string + The file to log templ LSP output to, or leave empty to disable logging. + -goplsLog string + The file to log gopls output, or leave empty to disable logging. + -goplsRPCTrace + Set gopls to log input and output messages. + -help + Print help and exit. + -pprof + Enable pprof web server (default address is localhost:9999) + -http string + Enable http debug server by setting a listen address (e.g. localhost:7474) +` + +func lspCmd(w io.Writer, args []string) (code int) { cmd := flag.NewFlagSet("lsp", flag.ExitOnError) - log := cmd.String("log", "", "The file to log templ LSP output to, or leave empty to disable logging.") - goplsLog := cmd.String("goplsLog", "", "The file to log gopls output, or leave empty to disable logging.") - goplsRPCTrace := cmd.Bool("goplsRPCTrace", false, "Set gopls to log input and output messages.") - helpFlag := cmd.Bool("help", false, "Print help and exit.") - pprofFlag := cmd.Bool("pprof", false, "Enable pprof web server (default address is localhost:9999)") - httpDebugFlag := cmd.String("http", "", "Enable http debug server by setting a listen address (e.g. localhost:7474)") + cmd.SetOutput(w) + log := cmd.String("log", "", "") + goplsLog := cmd.String("goplsLog", "", "") + goplsRPCTrace := cmd.Bool("goplsRPCTrace", false, "") + helpFlag := cmd.Bool("help", false, "") + pprofFlag := cmd.Bool("pprof", false, "") + httpDebugFlag := cmd.String("http", "", "") err := cmd.Parse(args) if err != nil || *helpFlag { - cmd.PrintDefaults() + fmt.Fprint(w, lspUsageText) return } - err = lspcmd.Run(lspcmd.Arguments{ + err = lspcmd.Run(w, lspcmd.Arguments{ Log: *log, GoplsLog: *goplsLog, GoplsRPCTrace: *goplsRPCTrace, @@ -152,7 +260,8 @@ func lspCmd(args []string) { HTTPDebug: *httpDebugFlag, }) if err != nil { - fmt.Println(err.Error()) - os.Exit(1) + fmt.Fprintln(w, err.Error()) + return 1 } + return 0 } diff --git a/cmd/templ/main_test.go b/cmd/templ/main_test.go index 37c88d82f..93e49e1bb 100644 --- a/cmd/templ/main_test.go +++ b/cmd/templ/main_test.go @@ -22,29 +22,59 @@ func TestMain(t *testing.T) { expectedCode: 0, }, { - name: "templ help prints help", + name: `"templ help" prints help`, args: []string{"templ", "help"}, expected: usageText, expectedCode: 0, }, { - name: "templ --help prints help", + name: `"templ --help" prints help`, args: []string{"templ", "--help"}, expected: usageText, expectedCode: 0, }, { - name: "templ version prints version", + name: `"templ version" prints version`, args: []string{"templ", "version"}, expected: templ.Version + "\n", expectedCode: 0, }, { - name: "templ --version prints version", + name: `"templ --version" prints version`, args: []string{"templ", "--version"}, expected: templ.Version + "\n", expectedCode: 0, }, + { + name: `"templ migrate" prints usage`, + args: []string{"templ", "migrate"}, + expected: migrateUsageText, + expectedCode: 0, + }, + { + name: `"templ migrate --help" prints usage`, + args: []string{"templ", "migrate", "--help"}, + expected: migrateUsageText, + expectedCode: 0, + }, + { + name: `"templ fmt --help" prints usage`, + args: []string{"templ", "fmt", "--help"}, + expected: fmtUsageText, + expectedCode: 0, + }, + { + name: `"templ generate --help" prints usage`, + args: []string{"templ", "generate", "--help"}, + expected: generateUsageText, + expectedCode: 0, + }, + { + name: `"templ lsp --help" prints usage`, + args: []string{"templ", "lsp", "--help"}, + expected: lspUsageText, + expectedCode: 0, + }, } for _, test := range tests { @@ -57,6 +87,10 @@ func TestMain(t *testing.T) { } if diff := cmp.Diff(test.expected, actual.String()); diff != "" { t.Error(diff) + t.Error("expected:") + t.Error(test.expected) + t.Error("actual:") + t.Error(actual.String()) } }) } diff --git a/cmd/templ/migratecmd/main.go b/cmd/templ/migratecmd/main.go index 9cfbe4f81..1d8a8b17b 100644 --- a/cmd/templ/migratecmd/main.go +++ b/cmd/templ/migratecmd/main.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io" "reflect" "strings" "time" @@ -21,21 +22,21 @@ type Arguments struct { Path string } -func Run(args Arguments) (err error) { +func Run(w io.Writer, args Arguments) (err error) { if args.FileName != "" { - return processSingleFile(args.FileName) + return processSingleFile(w, args.FileName) } - return processPath(args.Path) + return processPath(w, args.Path) } -func processSingleFile(fileName string) error { +func processSingleFile(w io.Writer, fileName string) error { start := time.Now() err := migrate(fileName) - fmt.Printf("Migrated code for %q in %s\n", fileName, time.Since(start)) + fmt.Fprintf(w, "Migrated code for %q in %s\n", fileName, time.Since(start)) return err } -func processPath(path string) (err error) { +func processPath(w io.Writer, path string) (err error) { start := time.Now() results := make(chan processor.Result) go processor.Process(path, migrate, workerCount, results) @@ -47,9 +48,9 @@ func processPath(path string) (err error) { continue } successCount++ - fmt.Printf("%s complete in %v\n", r.FileName, r.Duration) + fmt.Fprintf(w, "%s complete in %v\n", r.FileName, r.Duration) } - fmt.Printf("Migrated code for %d templates with %d errors in %s\n", successCount+errorCount, errorCount, time.Since(start)) + fmt.Fprintf(w, "Migrated code for %d templates with %d errors in %s\n", successCount+errorCount, errorCount, time.Since(start)) return err }