Skip to content

Commit

Permalink
Fixing caching on maptasks when using partials (#4344)
Browse files Browse the repository at this point in the history
* adding non-collection literals to statis input readers for cache key computation

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed codespell and added unit tests

Signed-off-by: Daniel Rammer <daniel@union.ai>

---------

Signed-off-by: Daniel Rammer <daniel@union.ai>
  • Loading branch information
hamersaw authored Nov 6, 2023
1 parent 9fe34db commit eccf993
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 34 deletions.
42 changes: 15 additions & 27 deletions flyteplugins/go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,17 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
return state, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
}

// identify and validate the size of the array job
size := -1
var literalCollection *idlCore.LiteralCollection
literals := make([][]*idlCore.Literal, 0)
discoveredInputNames := make([]string, 0)
for inputName, literal := range inputs.Literals {
for _, literal := range inputs.Literals {
if literalCollection = literal.GetCollection(); literalCollection != nil {
// validate length of input list
if size != -1 && size != len(literalCollection.Literals) {
state = state.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason("all maptask input lists must be the same length")
return state, nil
}

literals = append(literals, literalCollection.Literals)
discoveredInputNames = append(discoveredInputNames, inputName)

size = len(literalCollection.Literals)
}
}
Expand All @@ -110,7 +106,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
arrayJobSize = int64(size)

// build input readers
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames)
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), inputs.Literals, size)
}

if arrayJobSize > maxArrayJobSize {
Expand Down Expand Up @@ -246,18 +242,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
return state, externalResources, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
}

var literalCollection *idlCore.LiteralCollection
literals := make([][]*idlCore.Literal, 0)
discoveredInputNames := make([]string, 0)
for inputName, literal := range inputs.Literals {
if literalCollection = literal.GetCollection(); literalCollection != nil {
literals = append(literals, literalCollection.Literals)
discoveredInputNames = append(discoveredInputNames, inputName)
}
}

// build input readers
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames)
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), inputs.Literals, arrayJobSize)
}

// output reader
Expand Down Expand Up @@ -476,16 +461,19 @@ func ConstructCatalogReaderWorkItems(ctx context.Context, taskReader core.TaskRe

// ConstructStaticInputReaders constructs input readers that comply with the io.InputReader interface but have their
// inputs already populated.
func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs [][]*idlCore.Literal, inputNames []string) []io.InputReader {
inputReaders := make([]io.InputReader, 0, len(inputs))
if len(inputs) == 0 {
return inputReaders
}
func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputLiterals map[string]*idlCore.Literal, arrayJobSize int) []io.InputReader {
var literalCollection *idlCore.LiteralCollection

for i := 0; i < len(inputs[0]); i++ {
inputReaders := make([]io.InputReader, 0, arrayJobSize)
for i := 0; i < arrayJobSize; i++ {
literals := make(map[string]*idlCore.Literal)
for j := 0; j < len(inputNames); j++ {
literals[inputNames[j]] = inputs[j][i]
for inputName, inputLiteral := range inputLiterals {
if literalCollection = inputLiteral.GetCollection(); literalCollection != nil {
// if literal is a collection then we need to retrieve the specific literal for this subtask index
literals[inputName] = literalCollection.Literals[i]
} else {
literals[inputName] = inputLiteral
}
}

inputReaders = append(inputReaders, NewStaticInputReader(inputPaths, &idlCore.LiteralMap{Literals: literals}))
Expand Down
7 changes: 6 additions & 1 deletion flytepropeller/pkg/controller/nodes/array/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,12 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter
taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(subNodeIndex))

// need to initialize the inputReader every time to ensure TaskHandler can access for cache lookups / population
inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), subNodeIndex)
inputs, err := nCtx.InputReader().Get(ctx)
if err != nil {
return nil, nil, nil, nil, nil, nil, err
}

inputLiteralMap, err := constructLiteralMap(inputs, subNodeIndex)
if err != nil {
return nil, nil, nil, nil, nil, nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package array

import (
"context"
"fmt"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io"
Expand All @@ -26,16 +27,16 @@ func newStaticInputReader(inputPaths io.InputFilePaths, input *core.LiteralMap)
}
}

func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int) (*core.LiteralMap, error) {
inputs, err := inputReader.Get(ctx)
if err != nil {
return nil, err
}

func constructLiteralMap(inputs *core.LiteralMap, index int) (*core.LiteralMap, error) {
literals := make(map[string]*core.Literal)
for name, literal := range inputs.Literals {
if literalCollection := literal.GetCollection(); literalCollection != nil {
if index >= len(literalCollection.Literals) {
return nil, fmt.Errorf("index %v out of bounds for literal collection %v", index, name)
}
literals[name] = literalCollection.Literals[index]
} else {
literals[name] = literal
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package array

import (
"testing"

"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

var (
literalOne = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 1,
},
},
},
},
},
}
literalTwo = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 2,
},
},
},
},
},
}
)

func TestConstructLiteralMap(t *testing.T) {
tests := []struct {
name string
inputLiteralMaps *core.LiteralMap
expectedLiteralMaps []*core.LiteralMap
}{
{
"SingleList",
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalOne,
literalTwo,
},
},
},
},
},
},
[]*core.LiteralMap{
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalOne,
},
},
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalTwo,
},
},
},
},
{
"MultiList",
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalOne,
literalTwo,
},
},
},
},
"bar": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalTwo,
literalOne,
},
},
},
},
},
},
[]*core.LiteralMap{
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalOne,
"bar": literalTwo,
},
},
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalTwo,
"bar": literalOne,
},
},
},
},
{
"Partial",
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalOne,
literalTwo,
},
},
},
},
"bar": literalTwo,
},
},
[]*core.LiteralMap{
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalOne,
"bar": literalTwo,
},
},
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalTwo,
"bar": literalTwo,
},
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for i := 0; i < len(test.expectedLiteralMaps); i++ {
outputLiteralMap, err := constructLiteralMap(test.inputLiteralMaps, i)
assert.NoError(t, err)
assert.True(t, proto.Equal(test.expectedLiteralMaps[i], outputLiteralMap))
}
})
}
}

0 comments on commit eccf993

Please sign in to comment.