diff --git a/comptplus.go b/comptplus.go index 6940b83..d6eba46 100644 --- a/comptplus.go +++ b/comptplus.go @@ -55,10 +55,10 @@ type CobraPrompt struct { OnErrorFunc func(err error) // HookAfter is a hook that will be executed every time after a command has been executed - HookAfter func(cmd *cobra.Command, input string) + HookAfter func(cmd *cobra.Command, input string) error // HookBefore is a hook that will be executed every time before a command is executed - HookBefore func(cmd *cobra.Command, input string) + HookBefore func(cmd *cobra.Command, input string) error // InArgsParser adds a custom parser for the command line arguments (default: strings.Fields) InArgsParser func(args string) []string @@ -83,11 +83,11 @@ func (co *CobraPrompt) RunContext(ctx context.Context) { } if co.HookBefore == nil { - co.HookBefore = func(_ *cobra.Command, _ string) {} + co.HookBefore = func(_ *cobra.Command, _ string) error { return nil } } if co.HookAfter == nil { - co.HookAfter = func(_ *cobra.Command, _ string) {} + co.HookAfter = func(_ *cobra.Command, _ string) error { return nil } } if co.CustomFlagResetBehaviour == nil { @@ -149,20 +149,34 @@ func (co *CobraPrompt) executeCommand(ctx context.Context) func(string) { os.Args = append([]string{os.Args[0]}, args...) executedCmd, _, _ := co.RootCmd.Find(os.Args[1:]) - co.HookBefore(executedCmd, input) + if err := co.HookBefore(executedCmd, input); err != nil { + co.handleUserError(err) + return + } if err := co.RootCmd.ExecuteContext(ctx); err != nil { - if co.OnErrorFunc != nil { - co.OnErrorFunc(err) - } else { - co.RootCmd.PrintErrln(err) - os.Exit(1) - } + co.handleUserError(err) + return } + if !co.PersistFlagValues { co.resetFlagsToDefault(executedCmd) } - co.HookAfter(executedCmd, input) + + if err := co.HookAfter(executedCmd, input); err != nil { + co.handleUserError(err) + return + } + } +} + +// handleUserError is a utility function to handle errors. +func (co *CobraPrompt) handleUserError(err error) { + if co.OnErrorFunc != nil { + co.OnErrorFunc(err) + } else { + co.RootCmd.PrintErrln(err) + os.Exit(1) } }