-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
xmain: Add TestingState for testing CLIs fully virtually
Updates terrastruct/d2#903
- Loading branch information
Showing
5 changed files
with
377 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,10 @@ | ||
package xcontext | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
) | ||
|
||
type Mutex struct { | ||
ch chan struct{} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package xmain | ||
|
||
import ( | ||
"bytes" | ||
"strconv" | ||
) | ||
|
||
// Code here was copied from src/os/exec/exec.go. | ||
|
||
// prefixSuffixSaver is an io.Writer which retains the first N bytes | ||
// and the last N bytes written to it. The Bytes() methods reconstructs | ||
// it with a pretty error message. | ||
type prefixSuffixSaver struct { | ||
N int // max size of prefix or suffix | ||
prefix []byte | ||
suffix []byte // ring buffer once len(suffix) == N | ||
suffixOff int // offset to write into suffix | ||
skipped int64 | ||
} | ||
|
||
func (w *prefixSuffixSaver) Write(p []byte) (n int, err error) { | ||
lenp := len(p) | ||
p = w.fill(&w.prefix, p) | ||
|
||
// Only keep the last w.N bytes of suffix data. | ||
if overage := len(p) - w.N; overage > 0 { | ||
p = p[overage:] | ||
w.skipped += int64(overage) | ||
} | ||
p = w.fill(&w.suffix, p) | ||
|
||
// w.suffix is full now if p is non-empty. Overwrite it in a circle. | ||
for len(p) > 0 { // 0, 1, or 2 iterations. | ||
n := copy(w.suffix[w.suffixOff:], p) | ||
p = p[n:] | ||
w.skipped += int64(n) | ||
w.suffixOff += n | ||
if w.suffixOff == w.N { | ||
w.suffixOff = 0 | ||
} | ||
} | ||
return lenp, nil | ||
} | ||
|
||
// fill appends up to len(p) bytes of p to *dst, such that *dst does not | ||
// grow larger than w.N. It returns the un-appended suffix of p. | ||
func (w *prefixSuffixSaver) fill(dst *[]byte, p []byte) (pRemain []byte) { | ||
if remain := w.N - len(*dst); remain > 0 { | ||
add := minInt(len(p), remain) | ||
*dst = append(*dst, p[:add]...) | ||
p = p[add:] | ||
} | ||
return p | ||
} | ||
|
||
func (w *prefixSuffixSaver) Bytes() []byte { | ||
if w.suffix == nil { | ||
return w.prefix | ||
} | ||
if w.skipped == 0 { | ||
return append(w.prefix, w.suffix...) | ||
} | ||
var buf bytes.Buffer | ||
buf.Grow(len(w.prefix) + len(w.suffix) + 50) | ||
buf.Write(w.prefix) | ||
buf.WriteString("\n... omitting ") | ||
buf.WriteString(strconv.FormatInt(w.skipped, 10)) | ||
buf.WriteString(" bytes ...\n") | ||
buf.Write(w.suffix[w.suffixOff:]) | ||
buf.Write(w.suffix[:w.suffixOff]) | ||
return buf.Bytes() | ||
} | ||
|
||
func minInt(a, b int) int { | ||
if a < b { | ||
return a | ||
} | ||
return b | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,107 +1,215 @@ | ||
package xmain | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"io/fs" | ||
"os" | ||
"testing" | ||
"time" | ||
"context" | ||
|
||
"oss.terrastruct.com/util-go/assert" | ||
"oss.terrastruct.com/util-go/cmdlog" | ||
"oss.terrastruct.com/util-go/xcontext" | ||
"oss.terrastruct.com/util-go/xdefer" | ||
"oss.terrastruct.com/util-go/xos" | ||
) | ||
|
||
type TestingState struct { | ||
State *State | ||
Stdin io.Writer | ||
Stdout io.Reader | ||
Stderr io.Reader | ||
Run func(context.Context, *State) error | ||
Env *xos.Env | ||
Args []string | ||
Dir string | ||
FS fs.FS | ||
|
||
Stdin io.Reader | ||
Stdout io.WriteCloser | ||
Stderr io.WriteCloser | ||
|
||
mu *xcontext.Mutex | ||
ms *State | ||
sigs chan os.Signal | ||
done chan error | ||
doneErr *error | ||
} | ||
|
||
sigs chan os.Signal | ||
done chan error | ||
func (ts *TestingState) StdinPipe() (pw io.WriteCloser) { | ||
ts.Stdin, pw = io.Pipe() | ||
return pw | ||
} | ||
|
||
func (ts *TestingState) Signal(ctx context.Context, sig os.Signal) error { | ||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case ts.sigs <- sig: | ||
return nil | ||
func (ts *TestingState) StdoutPipe() (pr io.Reader) { | ||
pr, ts.Stdout = io.Pipe() | ||
return pr | ||
} | ||
|
||
func (ts *TestingState) StderrPipe() (pr io.Reader) { | ||
pr, ts.Stderr = io.Pipe() | ||
return pr | ||
} | ||
|
||
func (ts *TestingState) Start(tb testing.TB, ctx context.Context) { | ||
tb.Helper() | ||
|
||
if ts.mu != nil { | ||
tb.Fatal("xmain.TestingState.Start cannot be called twice") | ||
} | ||
if ts.Env == nil { | ||
ts.Env = xos.NewEnv(nil) | ||
} | ||
|
||
ts.mu = xcontext.NewMutex() | ||
ts.sigs = make(chan os.Signal, 1) | ||
ts.done = make(chan error, 1) | ||
|
||
name := "" | ||
args := []string(nil) | ||
if len(args) > 0 { | ||
name = os.Args[0] | ||
args = os.Args[1:] | ||
} | ||
log := cmdlog.NewTB(ts.Env, tb) | ||
ts.ms = &State{ | ||
Name: name, | ||
|
||
Log: log, | ||
Env: ts.Env, | ||
Opts: NewOpts(ts.Env, log, args), | ||
Dir: ts.Dir, | ||
FS: ts.FS, | ||
} | ||
|
||
ts.ms.Stdin = ts.Stdin | ||
if ts.Stdin == nil { | ||
ts.ms.Stdin = io.LimitReader(nil, 0) | ||
} | ||
ts.ms.Stdout = ts.Stdout | ||
if ts.Stdout == nil { | ||
ts.ms.Stdout = nopWriterCloser{io.Discard} | ||
} | ||
ts.ms.Stderr = ts.Stderr | ||
if ts.Stderr == nil { | ||
ts.ms.Stderr = nopWriterCloser{&prefixSuffixSaver{N: 1 << 25}} | ||
} | ||
|
||
go func() { | ||
defer ts.Cleanup(tb) | ||
err := ts.ms.Main(ctx, ts.sigs, ts.Run) | ||
if err != nil { | ||
if ts.Stderr == nil { | ||
err = fmt.Errorf("%w; stderr: %s", err, ts.ms.Stderr.(nopWriterCloser).Writer.(*prefixSuffixSaver).Bytes()) | ||
} | ||
} | ||
ts.done <- err | ||
}() | ||
} | ||
|
||
func (ts *TestingState) Cleanup(tb testing.TB) { | ||
tb.Helper() | ||
|
||
if rc, ok := ts.Stdin.(io.ReadCloser); ok { | ||
err := rc.Close() | ||
if err != nil { | ||
tb.Errorf("failed to close stdin: %v", err) | ||
} | ||
} | ||
|
||
err, ok := ts.ExitError() | ||
if ok { | ||
// Already exited. | ||
return | ||
} | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) | ||
defer cancel() | ||
err = ts.Signal(ctx, os.Interrupt) | ||
if err != nil { | ||
tb.Errorf("failed to os.Interrupt testing xmain: %v", err) | ||
} | ||
err = ts.Wait(ctx) | ||
if errors.Is(err, context.DeadlineExceeded) { | ||
err = ts.Signal(ctx, os.Kill) | ||
if err != nil { | ||
tb.Errorf("failed to kill testing xmain: %v", err) | ||
} | ||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) | ||
defer cancel() | ||
err = ts.Wait(ctx) | ||
} | ||
assert.Success(tb, err) | ||
} | ||
|
||
func (ts *TestingState) Wait(ctx context.Context) error { | ||
func (ts *TestingState) Signal(ctx context.Context, sig os.Signal) (err error) { | ||
defer xdefer.Errorf(&err, "failed to signal testing xmain: %v", ts.ms.Name) | ||
|
||
err = ts.mu.Lock(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
defer ts.mu.Unlock() | ||
|
||
if ts.doneErr != nil { | ||
return fmt.Errorf("testing xmain done: %w", *ts.doneErr) | ||
} | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case err := <-ts.done: | ||
ts.doneErr = &err | ||
return err | ||
case ts.sigs <- sig: | ||
return nil | ||
} | ||
} | ||
|
||
func Testing(tb testing.TB, ctx context.Context, env *xos.Env, run func(context.Context, *State) error, name string, args ...string) (ts *TestingState, cleanup func()) { | ||
stdinr, stdinw, err := os.Pipe() | ||
assert.Success(tb, err) | ||
stdoutr, stdoutw, err := os.Pipe() | ||
assert.Success(tb, err) | ||
stderrr, stderrw, err := os.Pipe() | ||
assert.Success(tb, err) | ||
|
||
ms := &State{ | ||
Name: name, | ||
func (ts *TestingState) Wait(ctx context.Context) (err error) { | ||
defer xdefer.Errorf(&err, "failed to wait testing xmain: %v", ts.ms.Name) | ||
|
||
Stdin: stdinr, | ||
Stdout: stdoutw, | ||
Stderr: stderrw, | ||
err = ts.mu.Lock(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
defer ts.mu.Unlock() | ||
|
||
Env: env, | ||
Log: cmdlog.NewTB(env, tb), | ||
if ts.doneErr != nil { | ||
if *ts.doneErr == nil { | ||
return nil | ||
} | ||
return fmt.Errorf("testing xmain done: %w", *ts.doneErr) | ||
} | ||
ms.Opts = NewOpts(ms.Env, ms.Log, args) | ||
|
||
ts = &TestingState{ | ||
State: ms, | ||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case err := <-ts.done: | ||
ts.doneErr = &err | ||
return err | ||
} | ||
} | ||
|
||
Stdin: stdinw, | ||
Stdout: stdoutr, | ||
Stderr: stderrr, | ||
func (ts *TestingState) ExitError() (error, bool) { | ||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) | ||
defer cancel() | ||
|
||
sigs: make(chan os.Signal, 1), | ||
done: make(chan error, 1), | ||
err := ts.mu.Lock(ctx) | ||
if err != nil { | ||
return nil, false | ||
} | ||
defer ts.mu.Unlock() | ||
|
||
cleanup = func() { | ||
stdinr.Close() | ||
stdinw.Close() | ||
stdoutr.Close() | ||
stdoutw.Close() | ||
stderrr.Close() | ||
stderrw.Close() | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) | ||
defer cancel() | ||
err = ts.Signal(ctx, os.Interrupt) | ||
if err != nil { | ||
tb.Errorf("failed to os.Interrupt testing xmain: %v", err) | ||
} | ||
err := ts.Wait(ctx) | ||
if errors.Is(err, context.DeadlineExceeded) { | ||
err = ts.Signal(ctx, os.Kill) | ||
if err != nil { | ||
tb.Errorf("failed to kill testing xmain: %v", err) | ||
} | ||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) | ||
defer cancel() | ||
err = ts.Wait(ctx) | ||
} | ||
assert.Success(tb, err) | ||
if ts.doneErr != nil { | ||
return *ts.doneErr, true | ||
} | ||
return nil, false | ||
} | ||
|
||
go func() { | ||
ts.done <- ms.Main(ctx, ts.sigs, run) | ||
}() | ||
type nopWriterCloser struct { | ||
io.Writer | ||
} | ||
|
||
return ts, cleanup | ||
func (c nopWriterCloser) Close() error { | ||
return nil | ||
} |
Oops, something went wrong.