Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Introduce shared function to suspend task roots #2650

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 33 additions & 82 deletions pkg/resources/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,24 +348,16 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error {
precedingTasks := make([]sdk.SchemaObjectIdentifier, 0)
for _, dep := range after {
precedingTaskId := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, dep)
rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, precedingTaskId)
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, precedingTaskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
if !(rootTask.Name == name) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}
precedingTasks = append(precedingTasks, precedingTaskId)
}
createRequest.WithAfter(precedingTasks)
Expand All @@ -392,7 +384,7 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error {
}

func waitForTaskStart(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error {
err := resumeTask(ctx, client, id)
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithResume(sdk.Bool(true)))
if err != nil {
return fmt.Errorf("error starting task %s err = %w", id.FullyQualifiedName(), err)
}
Expand All @@ -408,47 +400,22 @@ func waitForTaskStart(ctx context.Context, client *sdk.Client, id sdk.SchemaObje
})
}

func suspendTask(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error {
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithSuspend(sdk.Bool(true)))
if err != nil {
log.Printf("[WARN] failed to suspend task %s", id.FullyQualifiedName())
}
return err
}

func resumeTask(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error {
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithResume(sdk.Bool(true)))
if err != nil {
log.Printf("[WARN] failed to resume task %s", id.FullyQualifiedName())
}
return err
}

// UpdateTask implements schema.UpdateFunc.
func UpdateTask(d *schema.ResourceData, meta interface{}) error {
client := meta.(*provider.Context).Client
ctx := context.Background()

taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, taskId)
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == taskId.Name()) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}

if d.HasChange("warehouse") {
newWarehouse := d.Get("warehouse")
Expand Down Expand Up @@ -497,7 +464,9 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {

if d.HasChange("after") {
// making changes to after require suspending the current task
if err := suspendTask(ctx, client, taskId); err != nil {
// (the task will be brought up to the correct running state in the "enabled" check at the bottom of Update function).
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSuspend(sdk.Bool(true)))
if err != nil {
return fmt.Errorf("error suspending task %s, err: %w", taskId.FullyQualifiedName(), err)
}

Expand Down Expand Up @@ -532,29 +501,19 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
toAdd = append(toAdd, sdk.NewSchemaObjectIdentifier(taskId.DatabaseName(), taskId.SchemaName(), dep))
}
}
// TODO [SNOW-1007541]: for now leaving old copy-pasted implementation; extract function for task suspension in following change
if len(toAdd) > 0 {
// need to suspend any new root tasks from dependencies before adding them
for _, dep := range toAdd {
rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, dep)
for _, depId := range toAdd {
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, depId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == taskId.Name()) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}
}

if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithAddAfter(toAdd)); err != nil {
return fmt.Errorf("error adding after dependencies from task %s", taskId.FullyQualifiedName())
}
Expand Down Expand Up @@ -702,10 +661,11 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
log.Printf("[WARN] failed to resume task %s", taskId.FullyQualifiedName())
}
} else {
if suspendTask(ctx, client, taskId) != nil {
return fmt.Errorf("[WARN] failed to suspend task %s", taskId.FullyQualifiedName())
if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSuspend(sdk.Bool(true))); err != nil {
return fmt.Errorf("failed to suspend task %s", taskId.FullyQualifiedName())
}
}

return ReadTask(d, meta)
}

Expand All @@ -716,24 +676,15 @@ func DeleteTask(d *schema.ResourceData, meta interface{}) error {

taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, taskId)
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == taskId.Name()) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}

dropRequest := sdk.NewDropTaskRequest(taskId)
err = client.Tasks.Drop(ctx, dropRequest)
Expand Down
2 changes: 2 additions & 0 deletions pkg/sdk/tasks_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type Tasks interface {
ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Task, error)
Describe(ctx context.Context, id SchemaObjectIdentifier) (*Task, error)
Execute(ctx context.Context, request *ExecuteTaskRequest) error
SuspendRootTasks(ctx context.Context, taskId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error)
ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error
}

// CreateTaskOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-task.
Expand Down
44 changes: 44 additions & 0 deletions pkg/sdk/tasks_impl_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sdk
import (
"context"
"encoding/json"
"errors"
"log"
"slices"
"strings"

Expand Down Expand Up @@ -65,6 +67,48 @@ func (v *tasks) Execute(ctx context.Context, request *ExecuteTaskRequest) error
return validateAndExec(v.client, ctx, opts)
}

// TODO(SNOW-1277135): See if depId is necessary or could be removed
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
func (v *tasks) SuspendRootTasks(ctx context.Context, taskId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error) {
rootTasks, err := GetRootTasks(v.client.Tasks, ctx, taskId)
if err != nil {
return nil, err
}

tasksToResume := make([]SchemaObjectIdentifier, 0)
suspendErrs := make([]error, 0)

for _, rootTask := range rootTasks {
// If a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(rootTask.ID()).WithSuspend(Bool(true)))
if err != nil {
log.Printf("[WARN] failed to suspend task %s", rootTask.ID().FullyQualifiedName())
suspendErrs = append(suspendErrs, err)
}

// Resume the task after modifications are complete as long as it is not a standalone task
// TODO(SNOW-1277135): Document the purpose of this check and why we need different value for GetRootTasks (depId).
if rootTask.Name != id.Name() {
tasksToResume = append(tasksToResume, rootTask.ID())
}
}
}

return tasksToResume, errors.Join(suspendErrs...)
}

func (v *tasks) ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error {
resumeErrs := make([]error, 0)
for _, id := range ids {
err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(id).WithResume(Bool(true)))
if err != nil {
log.Printf("[WARN] failed to resume task %s", id.FullyQualifiedName())
resumeErrs = append(resumeErrs, err)
}
}
return errors.Join(resumeErrs...)
}

// GetRootTasks is a way to get all root tasks for the given tasks.
// Snowflake does not have (yet) a method to do it without traversing the task graph manually.
// Task DAG should have a single root but this is checked when the root task is being resumed; that's why we return here multiple roots.
Expand Down
131 changes: 131 additions & 0 deletions pkg/sdk/testint/tasks_gen_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,4 +592,135 @@ func TestInt_Tasks(t *testing.T) {
err := client.Tasks.Execute(ctx, executeRequest)
require.NoError(t, err)
})

t.Run("temporarily suspend root tasks", func(t *testing.T) {
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
task := createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)
})

t.Run("resume root tasks within a graph containing more than one root task", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

secondRootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
secondRootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(secondRootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
_ = createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID(), secondRootTask.ID()}))

require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)")
require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(secondRootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)")
})

t.Run("suspend root tasks temporarily with three sequentially connected tasks - last in DAG", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
task := createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithSuspend(sdk.Bool(true))))
})

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)
require.Contains(t, tasksToResume, rootTask.ID())

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

middleTaskStatus, err := client.Tasks.ShowByID(ctx, middleTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State)

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)

middleTaskStatus, err = client.Tasks.ShowByID(ctx, middleTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State)
})

t.Run("suspend root tasks temporarily with three sequentially connected tasks - middle in DAG", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

childTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
childTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(childTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithSuspend(sdk.Bool(true))))
})

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, middleTask.ID(), middleTask.ID())
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

childTaskStatus, err := client.Tasks.ShowByID(ctx, childTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State)
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)

childTaskStatus, err = client.Tasks.ShowByID(ctx, childTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State)
})

// TODO(SNOW-1277135): Create more tests with different sets of roots/children and see if the current implementation
// acts correctly in certain situations/edge cases.
}
Loading