diff --git a/disttask/framework/BUILD.bazel b/disttask/framework/BUILD.bazel index eb442e93f8d94..604119b4b2322 100644 --- a/disttask/framework/BUILD.bazel +++ b/disttask/framework/BUILD.bazel @@ -11,7 +11,7 @@ go_test( ], flaky = True, race = "off", - shard_count = 22, + shard_count = 23, deps = [ "//disttask/framework/dispatcher", "//disttask/framework/mock", diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go index c32a8953cddbd..5f2c0cb731922 100644 --- a/disttask/framework/framework_test.go +++ b/disttask/framework/framework_test.go @@ -17,6 +17,7 @@ package framework_test import ( "context" "errors" + "fmt" "sync" "testing" "time" @@ -131,6 +132,42 @@ func (e *testSubtaskExecutor1) Run(_ context.Context) error { return nil } +type testSubtaskExecutor2 struct { + m *sync.Map +} + +func (e *testSubtaskExecutor2) Run(_ context.Context) error { + e.m.Store("2", "2") + return nil +} + +type testSubtaskExecutor3 struct { + m *sync.Map +} + +func (e *testSubtaskExecutor3) Run(_ context.Context) error { + e.m.Store("3", "3") + return nil +} + +type testSubtaskExecutor4 struct { + m *sync.Map +} + +func (e *testSubtaskExecutor4) Run(_ context.Context) error { + e.m.Store("4", "4") + return nil +} + +type testSubtaskExecutor5 struct { + m *sync.Map +} + +func (e *testSubtaskExecutor5) Run(_ context.Context) error { + e.m.Store("5", "5") + return nil +} + func RegisterTaskMeta(t *testing.T, ctrl *gomock.Controller, m *sync.Map, dispatcherHandle dispatcher.Extension) { mockExtension := mock.NewMockExtension(ctrl) mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(&testScheduler{}, nil).AnyTimes() @@ -169,6 +206,70 @@ func registerTaskMetaInner(t *testing.T, mockExtension scheduler.Extension, disp ) } +func RegisterTaskMetaForExample2(t *testing.T, ctrl *gomock.Controller, m *sync.Map, dispatcherHandle dispatcher.Extension) { + mockExtension := mock.NewMockExtension(ctrl) + mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(&testScheduler{}, nil).AnyTimes() + mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) { + switch step { + case proto.StepOne: + return &testSubtaskExecutor2{m: m}, nil + case proto.StepTwo: + return &testSubtaskExecutor3{m: m}, nil + } + panic("invalid step") + }).AnyTimes() + RegisterTaskMetaForExample2Inner(t, mockExtension, dispatcherHandle) +} + +func RegisterTaskMetaForExample2Inner(t *testing.T, mockExtension scheduler.Extension, dispatcherHandle dispatcher.Extension) { + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample2, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task) + baseDispatcher.Extension = dispatcherHandle + return baseDispatcher + }) + scheduler.RegisterTaskType(proto.TaskTypeExample2, + func(ctx context.Context, id string, taskID int64, taskTable scheduler.TaskTable, pool scheduler.Pool) scheduler.Scheduler { + s := scheduler.NewBaseScheduler(ctx, id, taskID, taskTable, pool) + s.Extension = mockExtension + return s + }, + ) +} + +func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync.Map, dispatcherHandle dispatcher.Extension) { + mockExtension := mock.NewMockExtension(ctrl) + mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(&testScheduler{}, nil).AnyTimes() + mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) { + switch step { + case proto.StepOne: + return &testSubtaskExecutor4{m: m}, nil + case proto.StepTwo: + return &testSubtaskExecutor5{m: m}, nil + } + panic("invalid step") + }).AnyTimes() + RegisterTaskMetaForExample3Inner(t, mockExtension, dispatcherHandle) +} + +func RegisterTaskMetaForExample3Inner(t *testing.T, mockExtension scheduler.Extension, dispatcherHandle dispatcher.Extension) { + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample3, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task) + baseDispatcher.Extension = dispatcherHandle + return baseDispatcher + }) + scheduler.RegisterTaskType(proto.TaskTypeExample3, + func(ctx context.Context, id string, taskID int64, taskTable scheduler.TaskTable, pool scheduler.Pool) scheduler.Scheduler { + s := scheduler.NewBaseScheduler(ctx, id, taskID, taskTable, pool) + s.Extension = mockExtension + return s + }, + ) +} + func DispatchTask(taskKey string, t *testing.T) *proto.Task { mgr, err := storage.GetTaskManager() require.NoError(t, err) @@ -227,6 +328,66 @@ func DispatchTaskAndCheckState(taskKey string, t *testing.T, m *sync.Map, state return true }) } +func DispatchMultiTasksAndOneFail(t *testing.T, num int, m []sync.Map) []*proto.Task { + var tasks []*proto.Task + var taskID []int64 + var start []time.Time + mgr, err := storage.GetTaskManager() + require.NoError(t, err) + taskID = make([]int64, num) + start = make([]time.Time, num) + tasks = make([]*proto.Task, num) + + for i := 0; i < num; i++ { + if i == 0 { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockExecutorRunErr", "1*return(true)")) + taskID[0], err = mgr.AddNewGlobalTask("key0", "Example", 8, nil) + require.NoError(t, err) + start[0] = time.Now() + var task *proto.Task + for { + if time.Since(start[0]) > 2*time.Minute { + require.FailNow(t, "timeout") + } + time.Sleep(time.Second) + task, err = mgr.GetGlobalTaskByID(taskID[0]) + tasks[0] = task + require.NoError(t, err) + require.NotNil(t, task) + if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning && task.State != proto.TaskStateCancelling && task.State != proto.TaskStateReverting { + break + } + } + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockExecutorRunErr")) + } else { + taskID[i], err = mgr.AddNewGlobalTask(fmt.Sprintf("key%d", i), proto.Int2Type(i+2), 8, nil) + require.NoError(t, err) + start[i] = time.Now() + } + } + + for i := 1; i < num; i++ { + var task *proto.Task + for { + if time.Since(start[i]) > 2*time.Minute { + require.FailNow(t, "timeout") + } + time.Sleep(time.Second) + task, err = mgr.GetGlobalTaskByID(taskID[i]) + tasks[i] = task + require.NoError(t, err) + require.NotNil(t, task) + if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning && task.State != proto.TaskStateCancelling && task.State != proto.TaskStateReverting { + break + } + } + } + m[0].Range(func(key, value interface{}) bool { + m[0].Delete(key) + return true + }) + return tasks +} func TestFrameworkBasic(t *testing.T) { var m sync.Map @@ -424,3 +585,41 @@ func TestSchedulerDownManyNodes(t *testing.T) { distContext.Close() } + +func TestMultiTasks(t *testing.T) { + defer dispatcher.ClearDispatcherFactory() + defer scheduler.ClearSchedulers() + num := 3 + + m := make([]sync.Map, num) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + RegisterTaskMeta(t, ctrl, &(m[0]), &testDispatcherExt{}) + RegisterTaskMetaForExample2(t, ctrl, &(m[1]), &testDispatcherExt{}) + RegisterTaskMetaForExample3(t, ctrl, &(m[2]), &testDispatcherExt{}) + + distContext := testkit.NewDistExecutionContext(t, 3) + tasks := DispatchMultiTasksAndOneFail(t, num, m) + require.Equal(t, proto.TaskStateReverted, tasks[0].State) + v, ok := m[0].Load("0") + require.Equal(t, false, ok) + require.Equal(t, nil, v) + v, ok = m[0].Load("1") + require.Equal(t, false, ok) + require.Equal(t, nil, v) + require.Equal(t, proto.TaskStateSucceed, tasks[1].State) + v, ok = m[1].Load("2") + require.Equal(t, true, ok) + require.Equal(t, "2", v) + v, ok = m[1].Load("3") + require.Equal(t, true, ok) + require.Equal(t, "3", v) + require.Equal(t, proto.TaskStateSucceed, tasks[2].State) + v, ok = m[2].Load("4") + require.Equal(t, true, ok) + require.Equal(t, "4", v) + v, ok = m[2].Load("5") + require.Equal(t, true, ok) + require.Equal(t, "5", v) + distContext.Close() +} diff --git a/disttask/framework/proto/task.go b/disttask/framework/proto/task.go index 12172a25f555c..307176c424cf3 100644 --- a/disttask/framework/proto/task.go +++ b/disttask/framework/proto/task.go @@ -118,6 +118,10 @@ type MinimalTask interface { const ( // TaskTypeExample is TaskType of Example. TaskTypeExample = "Example" + // TaskTypeExample2 is TaskType of Example. + TaskTypeExample2 = "Example1" + // TaskTypeExample3 is TaskType of Example. + TaskTypeExample3 = "Example2" // ImportInto is TaskType of ImportInto. ImportInto = "ImportInto" ) @@ -129,6 +133,10 @@ func Type2Int(t string) int { return 1 case ImportInto: return 2 + case TaskTypeExample2: + return 3 + case TaskTypeExample3: + return 4 default: return 0 } @@ -141,6 +149,10 @@ func Int2Type(i int) string { return TaskTypeExample case 2: return ImportInto + case 3: + return TaskTypeExample2 + case 4: + return TaskTypeExample3 default: return "" }