Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for context.Context #893

Merged
merged 2 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cobra

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -143,9 +144,11 @@ type Command struct {
// TraverseChildren parses flags on all parents before executing child command.
TraverseChildren bool

//FParseErrWhitelist flag parse errors to be ignored
// FParseErrWhitelist flag parse errors to be ignored
FParseErrWhitelist FParseErrWhitelist

ctx context.Context

// commands is the list of commands supported by this program.
commands []*Command
// parent is a parent command for this command.
Expand Down Expand Up @@ -205,6 +208,12 @@ type Command struct {
errWriter io.Writer
}

// Context returns underlying command context. If command wasn't
// executed with ExecuteContext Context returns Background context.
func (c *Command) Context() context.Context {
return c.ctx
}

// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
// particularly useful when testing.
func (c *Command) SetArgs(a []string) {
Expand Down Expand Up @@ -860,6 +869,13 @@ func (c *Command) preRun() {
}
}

// ExecuteContext is the same as Execute(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
func (c *Command) ExecuteContext(ctx context.Context) error {
c.ctx = ctx
return c.Execute()
}

// Execute uses the args (os.Args[1:] by default)
// and run through the command tree finding appropriate matches
// for commands and then corresponding flags.
Expand All @@ -870,6 +886,10 @@ func (c *Command) Execute() error {

// ExecuteC executes the command.
func (c *Command) ExecuteC() (cmd *Command, err error) {
if c.ctx == nil {
c.ctx = context.Background()
}

// Regardless of what command execute is called on, run on Root only
if c.HasParent() {
return c.Root().ExecuteC()
Expand Down Expand Up @@ -914,6 +934,12 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
cmd.commandCalledAs.name = cmd.Name()
}

// We have to pass global context to children command
// if context is present on the parent command.
if cmd.ctx == nil {
cmd.ctx = c.ctx
}

err = cmd.execute(flags)
if err != nil {
// Always show help if requested, even if SilenceErrors is in
Expand Down Expand Up @@ -1558,7 +1584,7 @@ func (c *Command) ParseFlags(args []string) error {
beforeErrorBufLen := c.flagErrorBuf.Len()
c.mergePersistentFlags()

//do it here after merging all flags and just before parse
// do it here after merging all flags and just before parse
c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)

err := c.Flags().Parse(args)
Expand Down
67 changes: 67 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cobra

import (
"bytes"
"context"
"fmt"
"os"
"reflect"
Expand All @@ -18,6 +19,16 @@ func executeCommand(root *Command, args ...string) (output string, err error) {
return output, err
}

func executeCommandWithContext(ctx context.Context, root *Command, args ...string) (output string, err error) {
buf := new(bytes.Buffer)
root.SetOutput(buf)
root.SetArgs(args)

err = root.ExecuteContext(ctx)

return buf.String(), err
}

func executeCommandC(root *Command, args ...string) (c *Command, output string, err error) {
buf := new(bytes.Buffer)
root.SetOutput(buf)
Expand Down Expand Up @@ -135,6 +146,62 @@ func TestSubcommandExecuteC(t *testing.T) {
}
}

func TestExecuteContext(t *testing.T) {
ctx := context.TODO()

ctxRun := func(cmd *Command, args []string) {
if cmd.Context() != ctx {
t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use)
}
}

rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun}
childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun}
granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun}

childCmd.AddCommand(granchildCmd)
rootCmd.AddCommand(childCmd)

if _, err := executeCommandWithContext(ctx, rootCmd, ""); err != nil {
t.Errorf("Root command must not fail: %+v", err)
}

if _, err := executeCommandWithContext(ctx, rootCmd, "child"); err != nil {
t.Errorf("Subcommand must not fail: %+v", err)
}

if _, err := executeCommandWithContext(ctx, rootCmd, "child", "grandchild"); err != nil {
t.Errorf("Command child must not fail: %+v", err)
}
}

func TestExecute_NoContext(t *testing.T) {
run := func(cmd *Command, args []string) {
if cmd.Context() != context.Background() {
t.Errorf("Command %s must have background context", cmd.Use)
}
}

rootCmd := &Command{Use: "root", Run: run, PreRun: run}
childCmd := &Command{Use: "child", Run: run, PreRun: run}
granchildCmd := &Command{Use: "grandchild", Run: run, PreRun: run}

childCmd.AddCommand(granchildCmd)
rootCmd.AddCommand(childCmd)

if _, err := executeCommand(rootCmd, ""); err != nil {
t.Errorf("Root command must not fail: %+v", err)
}

if _, err := executeCommand(rootCmd, "child"); err != nil {
t.Errorf("Subcommand must not fail: %+v", err)
}

if _, err := executeCommand(rootCmd, "child", "grandchild"); err != nil {
t.Errorf("Command child must not fail: %+v", err)
}
}

func TestRootUnknownCommandSilenced(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
rootCmd.SilenceErrors = true
Expand Down