diff --git a/debug_runner.go b/debug_runner.go index 42f68f2..e7cbb7a 100644 --- a/debug_runner.go +++ b/debug_runner.go @@ -52,12 +52,12 @@ func (r *DebugRunner) Run(state StateBag) { pauseFn = DebugPauseDefault } - // Rebuild the steps so that we insert the pause step after each - steps := make([]Step, len(r.Steps)*2) + // Wrap steps to call PauseFn after each run and before each cleanup + steps := make([]Step, len(r.Steps)) for i, step := range r.Steps { - steps[i*2] = step - steps[(i*2)+1] = &debugStepPause{ + steps[i] = &debugStepPause{ reflect.Indirect(reflect.ValueOf(step)).Type().Name(), + step, pauseFn, } } @@ -97,14 +97,17 @@ func DebugPauseDefault(loc DebugLocation, name string, state StateBag) { type debugStepPause struct { StepName string + Step Step PauseFn DebugPauseFn } func (s *debugStepPause) Run(state StateBag) StepAction { + action := s.Step.Run(state) s.PauseFn(DebugLocationAfterRun, s.StepName, state) - return ActionContinue + return action } func (s *debugStepPause) Cleanup(state StateBag) { s.PauseFn(DebugLocationBeforeCleanup, s.StepName, state) + s.Step.Cleanup(state) } diff --git a/debug_runner_test.go b/debug_runner_test.go index a071da6..57476e3 100644 --- a/debug_runner_test.go +++ b/debug_runner_test.go @@ -172,3 +172,40 @@ func TestDebugPauseDefault(t *testing.T) { t.Fatal("didn't complete") } } + +// confirm that a halting step is debuggable before cleanup +func TestDebugRunner_Halt(t *testing.T) { + data := new(BasicStateBag) + stepA := &TestStepAcc{Data: "a"} + stepB := &TestStepAcc{Data: "b", Halt: true} + stepC := &TestStepAcc{Data: "c"} + + key := "data" + pauseFn := func(loc DebugLocation, name string, state StateBag) { + direction := "Run" + if loc == DebugLocationBeforeCleanup { + direction = "Cleanup" + } + + if _, ok := state.GetOk(key); !ok { + state.Put(key, []string{}) + } + + data := state.Get(key).([]string) + state.Put(key, append(data, direction)) + } + + r := &DebugRunner{ + Steps: []Step{stepA, stepB, stepC}, + PauseFn: pauseFn, + } + + r.Run(data) + + // Test data + expected := []string{"a", "Run", "b", "Run", "Cleanup", "Cleanup"} + results := data.Get("data").([]string) + if !reflect.DeepEqual(results, expected) { + t.Errorf("unexpected results: %#v", results) + } +}