From 79443c162860bd3e9f4be68bc79b16634ab68a81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Wed, 27 Mar 2024 11:55:14 +0100 Subject: [PATCH] changes after review --- pkg/resources/task.go | 36 ++++-- pkg/sdk/tasks_gen.go | 3 +- pkg/sdk/tasks_impl_gen.go | 35 +++--- pkg/sdk/testint/tasks_gen_integration_test.go | 114 +++++++++++++++++- 4 files changed, 149 insertions(+), 39 deletions(-) diff --git a/pkg/resources/task.go b/pkg/resources/task.go index 8a11de4b2b6..6199ae00ef5 100644 --- a/pkg/resources/task.go +++ b/pkg/resources/task.go @@ -2,7 +2,6 @@ package resources import ( "context" - "errors" "fmt" "log" "slices" @@ -285,7 +284,7 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error { } // CreateTask implements schema.CreateFunc. -func CreateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) { +func CreateTask(d *schema.ResourceData, meta interface{}) error { client := meta.(*provider.Context).Client ctx := context.Background() @@ -349,11 +348,16 @@ func CreateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) { precedingTasks := make([]sdk.SchemaObjectIdentifier, 0) for _, dep := range after { precedingTaskId := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, dep) - resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, precedingTaskId, taskId) - defer func() { returnedErr = errors.Join(returnedErr, resumeSuspended()) }() + tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, precedingTaskId, taskId) + defer func() { + if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil { + log.Printf("[WARN] failed to resume tasks: %s", err) + } + }() if err != nil { return err } + precedingTasks = append(precedingTasks, precedingTaskId) } createRequest.WithAfter(precedingTasks) @@ -397,14 +401,18 @@ func waitForTaskStart(ctx context.Context, client *sdk.Client, id sdk.SchemaObje } // UpdateTask implements schema.UpdateFunc. -func UpdateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) { +func UpdateTask(d *schema.ResourceData, meta interface{}) error { client := meta.(*provider.Context).Client ctx := context.Background() taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, taskId, taskId) - defer func() { returnedErr = errors.Join(returnedErr, resumeSuspended()) }() + tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId) + defer func() { + if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil { + log.Printf("[WARN] failed to resume tasks: %s", err) + } + }() if err != nil { return err } @@ -495,8 +503,12 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) { } if len(toAdd) > 0 { for _, depId := range toAdd { - resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, depId, taskId) - defer func() { returnedErr = errors.Join(returnedErr, resumeSuspended()) }() + tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, depId, taskId) + defer func() { + if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil { + log.Printf("[WARN] failed to resume tasks: %s", err) + } + }() if err != nil { return err } @@ -664,10 +676,10 @@ func DeleteTask(d *schema.ResourceData, meta interface{}) error { taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, taskId, taskId) + tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId) defer func() { - if err := resumeSuspended(); err != nil { - log.Printf("[WARN] failed to resume suspended task: %s", err) + if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil { + log.Printf("[WARN] failed to resume tasks: %s", err) } }() if err != nil { diff --git a/pkg/sdk/tasks_gen.go b/pkg/sdk/tasks_gen.go index 2134fe94824..41ab5d271ab 100644 --- a/pkg/sdk/tasks_gen.go +++ b/pkg/sdk/tasks_gen.go @@ -14,7 +14,8 @@ type Tasks interface { ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Task, error) Describe(ctx context.Context, id SchemaObjectIdentifier) (*Task, error) Execute(ctx context.Context, request *ExecuteTaskRequest) error - TemporarilySuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, id SchemaObjectIdentifier) (func() error, error) + SuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error) + ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error } // CreateTaskOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-task. diff --git a/pkg/sdk/tasks_impl_gen.go b/pkg/sdk/tasks_impl_gen.go index 798fbcfed46..13169d1fa7c 100644 --- a/pkg/sdk/tasks_impl_gen.go +++ b/pkg/sdk/tasks_impl_gen.go @@ -67,17 +67,10 @@ func (v *tasks) Execute(ctx context.Context, request *ExecuteTaskRequest) error return validateAndExec(v.client, ctx, opts) } -// TemporarilySuspendRootTasks takes in the depId for which root tasks will be searched. Then, for all root tasks, -// check if the task is started. If it is, then suspend it and add to the list of tasks to resume only if the root task name -// is not that same as taskId name. -// -// Returns: -// A callback function to resume all the suspended root tasks. -// An error joined from all the suspend calls, nil if no error was returned during by task suspending calls. -func (v *tasks) TemporarilySuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, taskId SchemaObjectIdentifier) (func() error, error) { +func (v *tasks) SuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error) { rootTasks, err := GetRootTasks(v.client.Tasks, ctx, depId) if err != nil { - return func() error { return nil }, err + return nil, err } tasksToResume := make([]SchemaObjectIdentifier, 0) @@ -92,24 +85,26 @@ func (v *tasks) TemporarilySuspendRootTasks(ctx context.Context, depId SchemaObj } // Resume the task after modifications are complete as long as it is not a standalone task - if rootTask.Name != taskId.Name() { + if rootTask.Name != id.Name() { tasksToResume = append(tasksToResume, rootTask.ID()) } suspendErrs = append(suspendErrs, err) } } - return func() error { - resumeErrs := make([]error, 0) - for _, taskId := range tasksToResume { - err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(taskId).WithResume(Bool(true))) - if err != nil { - log.Printf("[WARN] failed to resume task %s", taskId.FullyQualifiedName()) - } - resumeErrs = append(resumeErrs, err) + return tasksToResume, errors.Join(suspendErrs...) +} + +func (v *tasks) ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error { + resumeErrs := make([]error, 0) + for _, id := range ids { + err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(id).WithResume(Bool(true))) + if err != nil { + log.Printf("[WARN] failed to resume task %s", id.FullyQualifiedName()) } - return errors.Join(resumeErrs...) - }, errors.Join(suspendErrs...) + resumeErrs = append(resumeErrs, err) + } + return errors.Join(resumeErrs...) } // GetRootTasks is a way to get all root tasks for the given tasks. diff --git a/pkg/sdk/testint/tasks_gen_integration_test.go b/pkg/sdk/testint/tasks_gen_integration_test.go index 2302afabf23..434dff5b3e0 100644 --- a/pkg/sdk/testint/tasks_gen_integration_test.go +++ b/pkg/sdk/testint/tasks_gen_integration_test.go @@ -605,17 +605,119 @@ func TestInt_Tasks(t *testing.T) { require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true)))) }) - resumeRoots, err := client.Tasks.TemporarilySuspendRootTasks(ctx, task.ID(), task.ID()) + tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID()) require.NoError(t, err) + require.NotEmpty(t, tasksToResume) - rt, err := client.Tasks.ShowByID(ctx, rootTask.ID()) + rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID()) require.NoError(t, err) - require.Equal(t, sdk.TaskStateSuspended, rt.State) + require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State) - require.NoError(t, resumeRoots()) + require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume)) - rt, err = client.Tasks.ShowByID(ctx, rootTask.ID()) + rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID()) require.NoError(t, err) - require.Equal(t, sdk.TaskStateStarted, rt.State) + require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State) + }) + + t.Run("resume root tasks within a graph containing more than one root task", func(t *testing.T) { + rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes"))) + + secondRootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + secondRootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(secondRootTaskId, sql).WithSchedule(sdk.String("60 minutes"))) + + id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + _ = createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID(), secondRootTask.ID()})) + + require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)") + require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(secondRootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)") + }) + + t.Run("suspend root tasks temporarily with three sequentially connected tasks - last in DAG", func(t *testing.T) { + rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes"))) + + middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()})) + + id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + task := createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()})) + + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithResume(sdk.Bool(true)))) + t.Cleanup(func() { + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithSuspend(sdk.Bool(true)))) + }) + + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true)))) + t.Cleanup(func() { + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true)))) + }) + + tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID()) + require.NoError(t, err) + require.NotEmpty(t, tasksToResume) + require.Contains(t, tasksToResume, rootTask.ID()) + + rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State) + + middleTaskStatus, err := client.Tasks.ShowByID(ctx, middleTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State) + + require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume)) + + rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State) + + middleTaskStatus, err = client.Tasks.ShowByID(ctx, middleTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State) + }) + + t.Run("suspend root tasks temporarily with three sequentially connected tasks - middle in DAG", func(t *testing.T) { + rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes"))) + + middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()})) + + childTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String()) + childTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(childTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()})) + + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithResume(sdk.Bool(true)))) + t.Cleanup(func() { + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithSuspend(sdk.Bool(true)))) + }) + + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true)))) + t.Cleanup(func() { + require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true)))) + }) + + tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, middleTask.ID(), middleTask.ID()) + require.NoError(t, err) + require.NotEmpty(t, tasksToResume) + + rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State) + + childTaskStatus, err := client.Tasks.ShowByID(ctx, childTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State) + + require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume)) + + rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State) + + childTaskStatus, err = client.Tasks.ShowByID(ctx, childTask.ID()) + require.NoError(t, err) + require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State) }) }