Skip to content

Commit

Permalink
chore: improve cmd unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
a-h committed Oct 29, 2023
1 parent 1d3e08b commit d17087d
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 103 deletions.
2 changes: 1 addition & 1 deletion .version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.428
0.2.431
14 changes: 7 additions & 7 deletions cmd/templ/fmtcmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,31 @@ import (

const workerCount = 4

func Run(args []string) (err error) {
func Run(w io.Writer, args []string) (err error) {
if len(args) > 0 {
return formatDir(args[0])
return formatDir(w, args[0])
}
return formatStdin()
return formatReader(w, os.Stdin)
}

func formatStdin() (err error) {
func formatReader(w io.Writer, r io.Reader) (err error) {
var bytes []byte
bytes, err = io.ReadAll(os.Stdin)
bytes, err = io.ReadAll(r)
if err != nil {
return
}
t, err := parser.ParseString(string(bytes))
if err != nil {
return fmt.Errorf("parsing error: %w", err)
}
err = t.Write(os.Stdout)
err = t.Write(w)
if err != nil {
return fmt.Errorf("formatting error: %w", err)
}
return nil
}

func formatDir(dir string) (err error) {
func formatDir(w io.Writer, dir string) (err error) {
start := time.Now()
results := make(chan processor.Result)
go processor.Process(dir, format, workerCount, results)
Expand Down
47 changes: 24 additions & 23 deletions cmd/templ/generatecmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"go/format"
"io"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -48,7 +49,7 @@ type Arguments struct {

var defaultWorkerCount = runtime.NumCPU()

func Run(args Arguments) (err error) {
func Run(w io.Writer, args Arguments) (err error) {
ctx, cancel := context.WithCancel(context.Background())
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt)
Expand All @@ -64,25 +65,25 @@ func Run(args Arguments) (err error) {
go func() {
select {
case <-signalChan: // First signal, cancel context.
fmt.Println("\nCancelling...")
fmt.Fprintln(w, "\nCancelling...")
err = run.Stop()
if err != nil {
fmt.Printf("Error killing command: %v\n", err)
fmt.Fprintf(w, "Error killing command: %v\n", err)
}
cancel()
case <-ctx.Done():
}
<-signalChan // Second signal, hard exit.
os.Exit(2)
}()
err = runCmd(ctx, args)
err = runCmd(ctx, w, args)
if errors.Is(err, context.Canceled) {
return nil
}
return err
}

func runCmd(ctx context.Context, args Arguments) (err error) {
func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
start := time.Now()
if args.Watch && args.FileName != "" {
return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag")
Expand All @@ -95,7 +96,7 @@ func runCmd(ctx context.Context, args Arguments) (err error) {
opts = append(opts, generator.WithTimestamp(time.Now()))
}
if args.FileName != "" {
return processSingleFile(ctx, args.FileName, args.GenerateSourceMapVisualisations, opts)
return processSingleFile(ctx, w, args.FileName, args.GenerateSourceMapVisualisations, opts)
}
var target *url.URL
if args.Proxy != "" {
Expand Down Expand Up @@ -123,29 +124,29 @@ func runCmd(ctx context.Context, args Arguments) (err error) {
p = proxy.New(args.ProxyPort, target)
}

fmt.Println("Processing path:", args.Path)
fmt.Fprintln(w, "Processing path:", args.Path)
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = time.Millisecond * 500
bo.MaxInterval = time.Second * 3
var firstRunComplete bool
fileNameToLastModTime := make(map[string]time.Time)
for !firstRunComplete || args.Watch {
changesFound, errs := processChanges(ctx, fileNameToLastModTime, args.Path, args.GenerateSourceMapVisualisations, opts, args.WorkerCount)
changesFound, errs := processChanges(ctx, w, fileNameToLastModTime, args.Path, args.GenerateSourceMapVisualisations, opts, args.WorkerCount)
if len(errs) > 0 {
if errors.Is(errs[0], context.Canceled) {
return errs[0]
}
if !args.Watch {
return fmt.Errorf("failed to process path: %v", errors.Join(errs...))
}
fmt.Printf("Error processing path: %v\n", errors.Join(errs...))
fmt.Fprintf(w, "Error processing path: %v\n", errors.Join(errs...))
}
if changesFound > 0 {
fmt.Printf("Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start))
fmt.Fprintf(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start))
if args.Command != "" {
fmt.Printf("Executing command: %s\n", args.Command)
fmt.Fprintf(w, "Executing command: %s\n", args.Command)
if _, err := run.Run(ctx, args.Path, args.Command); err != nil {
fmt.Printf("Error starting command: %v\n", err)
fmt.Fprintf(w, "Error starting command: %v\n", err)
}
}
// Send server-sent event.
Expand All @@ -155,15 +156,15 @@ func runCmd(ctx context.Context, args Arguments) (err error) {

if !firstRunComplete && p != nil {
go func() {
fmt.Printf("Proxying from %s to target: %s\n", p.URL, p.Target.String())
fmt.Fprintf(w, "Proxying from %s to target: %s\n", p.URL, p.Target.String())
if err := http.ListenAndServe(fmt.Sprintf("127.0.0.1:%d", args.ProxyPort), p); err != nil {
fmt.Printf("Error starting proxy: %v\n", err)
fmt.Fprintf(w, "Error starting proxy: %v\n", err)
}
}()
go func() {
fmt.Printf("Opening URL: %s\n", p.Target.String())
if err := openURL(p.URL); err != nil {
fmt.Printf("Error opening URL: %v\n", err)
fmt.Fprintf(w, "Opening URL: %s\n", p.Target.String())
if err := openURL(w, p.URL); err != nil {
fmt.Fprintf(w, "Error opening URL: %v\n", err)
}
}()
}
Expand Down Expand Up @@ -195,7 +196,7 @@ func shouldSkipDir(dir string) bool {
return false
}

func processChanges(ctx context.Context, fileNameToLastModTime map[string]time.Time, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int) (changesFound int, errs []error) {
func processChanges(ctx context.Context, w io.Writer, fileNameToLastModTime map[string]time.Time, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int) (changesFound int, errs []error) {
sem := make(chan struct{}, maxWorkerCount)
var wg sync.WaitGroup

Expand Down Expand Up @@ -227,7 +228,7 @@ func processChanges(ctx context.Context, fileNameToLastModTime map[string]time.T
wg.Add(1)
go func() {
defer wg.Done()
if err := processSingleFile(ctx, path, generateSourceMapVisualisations, opts); err != nil {
if err := processSingleFile(ctx, w, path, generateSourceMapVisualisations, opts); err != nil {
errs = append(errs, err)
}
<-sem
Expand All @@ -245,7 +246,7 @@ func processChanges(ctx context.Context, fileNameToLastModTime map[string]time.T
return changesFound, errs
}

func openURL(url string) error {
func openURL(w io.Writer, url string) error {
backoff := backoff.NewExponentialBackOff()
backoff.InitialInterval = time.Second
var client http.Client
Expand All @@ -255,19 +256,19 @@ func openURL(url string) error {
break
}
d := backoff.NextBackOff()
fmt.Printf("Server not ready. Retrying in %v...\n", d)
fmt.Fprintf(w, "Server not ready. Retrying in %v...\n", d)
time.Sleep(d)
}
return browser.OpenURL(url)
}

func processSingleFile(ctx context.Context, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) error {
func processSingleFile(ctx context.Context, w io.Writer, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) error {
start := time.Now()
err := compile(ctx, fileName, generateSourceMapVisualisations, opts)
if err != nil {
return err
}
fmt.Printf("Generated code for %q in %s\n", fileName, time.Since(start))
fmt.Fprintf(w, "Generated code for %q in %s\n", fileName, time.Since(start))
return err
}

Expand Down
9 changes: 5 additions & 4 deletions cmd/templ/lspcmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package lspcmd
import (
"context"
"fmt"
"io"
"net/http"
"os"
"os/signal"
Expand All @@ -28,7 +29,7 @@ type Arguments struct {
HTTPDebug string
}

func Run(args Arguments) error {
func Run(w io.Writer, args Arguments) error {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
signalChan := make(chan os.Signal, 1)
Expand All @@ -51,10 +52,10 @@ func Run(args Arguments) error {
<-signalChan // Second signal, hard exit.
os.Exit(2)
}()
return run(ctx, args)
return run(ctx, w, args)
}

func run(ctx context.Context, args Arguments) (err error) {
func run(ctx context.Context, w io.Writer, args Arguments) (err error) {
log := zap.NewNop()
if args.Log != "" {
cfg := zap.NewProductionConfig()
Expand All @@ -64,7 +65,7 @@ func run(ctx context.Context, args Arguments) (err error) {
}
log, err = cfg.Build()
if err != nil {
_, _ = fmt.Printf("failed to create logger: %v\n", err)
_, _ = fmt.Fprintf(w, "failed to create logger: %v\n", err)
os.Exit(1)
}
}
Expand Down
Loading

0 comments on commit d17087d

Please sign in to comment.