Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use child dir for branch taken #6120

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions flytepropeller/pkg/controller/nodes/branch/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import (
"context"
"fmt"
"strconv"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
Expand All @@ -15,6 +16,7 @@
stdErrors "github.com/flyteorg/flyte/flytestdlib/errors"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/storage"
)

type metrics struct {
Expand Down Expand Up @@ -74,8 +76,7 @@
childNodeStatus.SetParentNodeID(&i)

logger.Debugf(ctx, "Recursively executing branchNode's chosen path")
nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID())
return b.recurseDownstream(ctx, nCtx, nodeStatus, finalNode)
return b.recurseDownstream(ctx, nCtx, finalNode)

Check warning on line 79 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L79

Added line #L79 was not covered by tests
}

// If the branchNodestatus was already evaluated i.e, Node is in Running status
Expand All @@ -99,8 +100,7 @@
}

// Recurse downstream
nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID())
return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode)
return b.recurseDownstream(ctx, nCtx, branchTakenNode)

Check warning on line 103 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L103

Added line #L103 was not covered by tests
}

func (b *branchHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) {
Expand All @@ -123,7 +123,7 @@
return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil
}

func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) {
func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) {
// TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time
// There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc
// The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node.
Expand All @@ -134,8 +134,16 @@
}

childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID())
childNodeStatus.SetDataDir(nodeStatus.GetDataDir())
childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir())
childDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), branchTakenNode.GetID())
if err != nil {
return handler.UnknownTransition, err
}

Check warning on line 140 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L139-L140

Added lines #L139 - L140 were not covered by tests
childOutputDir, err := nCtx.DataStore().ConstructReference(ctx, childDataDir, strconv.Itoa(int(childNodeStatus.GetAttempts())))
if err != nil {
return handler.UnknownTransition, err
}

Check warning on line 144 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L143-L144

Added lines #L143 - L144 were not covered by tests
childNodeStatus.SetDataDir(childDataDir)
childNodeStatus.SetOutputDir(childOutputDir)
upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID())
if err != nil {
return handler.UnknownTransition, err
Expand All @@ -151,9 +159,14 @@
}

if downstreamStatus.IsComplete() {
// For branch node we set the output node to be the same as the child nodes output
childOutputsPath := v1alpha1.GetOutputsFile(childOutputDir)
outputsPath := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir())
if err := nCtx.DataStore().CopyRaw(ctx, childOutputsPath, outputsPath, storage.Options{}); err != nil {
errMsg := fmt.Sprintf("Failed to copy child node outputs from [%v] to [%v]", childOutputsPath, outputsPath)
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.OutputsNotFoundError, errMsg, nil)), nil
}

Check warning on line 167 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L165-L167

Added lines #L165 - L167 were not covered by tests
phase := handler.PhaseInfoSuccess(&handler.ExecutionInfo{
OutputInfo: &handler.OutputInfo{OutputURI: v1alpha1.GetOutputsFile(childNodeStatus.GetOutputDir())},
OutputInfo: &handler.OutputInfo{OutputURI: outputsPath},
})

return handler.DoTransition(handler.TransitionTypeEphemeral, phase), nil
Expand Down
31 changes: 16 additions & 15 deletions flytepropeller/pkg/controller/nodes/branch/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package branch

import (
"bytes"
"context"
"fmt"
"testing"
Expand Down Expand Up @@ -110,6 +111,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod

ns := &mocks2.ExecutableNodeStatus{}
ns.OnGetDataDir().Return(storage.DataReference("data-dir"))
ns.OnGetOutputDir().Return(storage.DataReference("output-dir"))
ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted)

ir := &mocks3.InputReader{}
Expand Down Expand Up @@ -162,25 +164,24 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
name string
ns interfaces.NodeStatus
err error
nodeStatus *mocks2.ExecutableNodeStatus
branchTakenNode v1alpha1.ExecutableNode
isErr bool
expectedPhase handler.EPhase
childPhase v1alpha1.NodePhase
upstreamNodeID string
}{
{"upstreamNodeExists", interfaces.NodeStatusPending, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},
{"childNodeError", interfaces.NodeStatusUndefined, fmt.Errorf("err"),
&mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""},
bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""},
{"childPending", interfaces.NodeStatusPending, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""},
{"childStillRunning", interfaces.NodeStatusRunning, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""},
{"childFailure", interfaces.NodeStatusFailed(expectedError), nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""},
bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""},
{"childComplete", interfaces.NodeStatusComplete, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""},
bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand Down Expand Up @@ -222,16 +223,16 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
).Return(test.ns, test.err)

childNodeStatus := &mocks2.ExecutableNodeStatus{}
if mockNodeLookup != nil {
childNodeStatus.OnGetOutputDir().Return("parent-output-dir")
test.nodeStatus.OnGetDataDir().Return("parent-data-dir")
test.nodeStatus.OnGetOutputDir().Return("parent-output-dir")
mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once()
childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once()
childNodeStatus.OnGetAttempts().Return(0)
childNodeStatus.On("SetDataDir", storage.DataReference("/output-dir/child")).Once()
childNodeStatus.On("SetOutputDir", storage.DataReference("/output-dir/child/0")).Once()
mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
if test.childPhase == v1alpha1.NodePhaseSucceeded {
_ = nCtx.DataStore().WriteRaw(ctx, storage.DataReference("/output-dir/child/0/outputs.pb"), 0, storage.Options{}, bytes.NewReader([]byte{}))
}

branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()).(*branchHandler)
h, err := branch.recurseDownstream(ctx, nCtx, test.nodeStatus, test.branchTakenNode)
h, err := branch.recurseDownstream(ctx, nCtx, test.branchTakenNode)
if test.isErr {
assert.Error(t, err)
} else {
Expand Down
Loading