diff --git a/disttask/framework/dispatcher/BUILD.bazel b/disttask/framework/dispatcher/BUILD.bazel index 2cdc2c61945f7..54fc2033b7c1d 100644 --- a/disttask/framework/dispatcher/BUILD.bazel +++ b/disttask/framework/dispatcher/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "//resourcemanager/util", "//sessionctx/variable", "//util", + "//util/disttask", "//util/logutil", "//util/syncutil", "@com_github_pingcap_errors//:errors", diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 8078dab8c5cfd..eb6de9c96e854 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/resourcemanager/util" "github.com/pingcap/tidb/sessionctx/variable" tidbutil "github.com/pingcap/tidb/util" + disttaskutil "github.com/pingcap/tidb/util/disttask" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/syncutil" "go.uber.org/zap" @@ -438,7 +439,8 @@ func GetEligibleInstance(serverNodes []*infosync.ServerInfo, pos int) (string, e return "", errors.New("no available TiDB node") } pos = pos % len(serverNodes) - return serverNodes[pos].ID, nil + serverID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port) + return serverID, nil } // GenerateSchedulerNodes generate a eligible TiDB nodes. @@ -474,9 +476,19 @@ func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]s } ids := make([]string, 0, len(schedulerIDs)) for _, id := range schedulerIDs { - if _, ok := serverInfos[id]; ok { + if ok := matchServerInfo(serverInfos, id); ok { ids = append(ids, id) } } return ids, nil } + +func matchServerInfo(serverInfos map[string]*infosync.ServerInfo, schedulerID string) bool { + for _, serverInfo := range serverInfos { + serverID := disttaskutil.GenerateExecID(serverInfo.IP, serverInfo.Port) + if serverID == schedulerID { + return true + } + } + return false +} diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index bff2d4825fcaa..882c3b5d5b7e3 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -81,12 +81,17 @@ func TestGetInstance(t *testing.T) { // server ids: uuid0, uuid1 // subtask instance ids: nil uuids := []string{"ddl_id_1", "ddl_id_2"} + serverIDs := []string{"10.123.124.10:32457", "[ABCD:EF01:2345:6789:ABCD:EF01:2345:6789]:65535"} mockedAllServerInfos = map[string]*infosync.ServerInfo{ uuids[0]: { - ID: uuids[0], + ID: uuids[0], + IP: "10.123.124.10", + Port: 32457, }, uuids[1]: { - ID: uuids[1], + ID: uuids[1], + IP: "ABCD:EF01:2345:6789:ABCD:EF01:2345:6789", + Port: 65535, }, } require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", makeFailpointRes(mockedAllServerInfos))) @@ -94,7 +99,10 @@ func TestGetInstance(t *testing.T) { require.NoError(t, err) instanceID, err = dispatcher.GetEligibleInstance(serverNodes, 0) require.NoError(t, err) - if instanceID != uuids[0] && instanceID != uuids[1] { + require.Equal(t, serverIDs[0], instanceID) + instanceID, err = dispatcher.GetEligibleInstance(serverNodes, 1) + require.NoError(t, err) + if instanceID != serverIDs[0] && instanceID != serverIDs[1] { require.FailNowf(t, "expected uuids:%d,%d, actual uuid:%d", uuids[0], uuids[1], instanceID) } instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, 1) @@ -107,26 +115,26 @@ func TestGetInstance(t *testing.T) { subtask := &proto.Subtask{ Type: proto.TaskTypeExample, TaskID: gTaskID, - SchedulerID: uuids[1], + SchedulerID: serverIDs[1], } err = mgr.AddNewSubTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) require.NoError(t, err) instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID) require.NoError(t, err) - require.Equal(t, []string{uuids[1]}, instanceIDs) + require.Equal(t, []string{serverIDs[1]}, instanceIDs) // server ids: uuid0, uuid1 // subtask instance ids: uuid0, uuid1 subtask = &proto.Subtask{ Type: proto.TaskTypeExample, TaskID: gTaskID, - SchedulerID: uuids[0], + SchedulerID: serverIDs[0], } err = mgr.AddNewSubTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) require.NoError(t, err) instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID) require.NoError(t, err) - require.Len(t, instanceIDs, len(uuids)) - require.ElementsMatch(t, instanceIDs, uuids) + require.Len(t, instanceIDs, len(serverIDs)) + require.ElementsMatch(t, instanceIDs, serverIDs) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo")) } diff --git a/domain/BUILD.bazel b/domain/BUILD.bazel index 5d246e0ff46fa..68b076f87ff66 100644 --- a/domain/BUILD.bazel +++ b/domain/BUILD.bazel @@ -58,6 +58,7 @@ go_library( "//util", "//util/chunk", "//util/dbterror", + "//util/disttask", "//util/domainutil", "//util/engine", "//util/etcd", diff --git a/domain/domain.go b/domain/domain.go index cfaedaa50ec03..b679b9b918b57 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -69,6 +69,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/dbterror" + disttaskutil "github.com/pingcap/tidb/util/disttask" "github.com/pingcap/tidb/util/domainutil" "github.com/pingcap/tidb/util/engine" "github.com/pingcap/tidb/util/etcd" @@ -1363,7 +1364,12 @@ func (do *Domain) InitDistTaskLoop(ctx context.Context) error { }) taskManager := storage.NewTaskManager(ctx, do.resourcePool) - schedulerManager, err := scheduler.NewManagerBuilder().BuildManager(ctx, do.ddl.GetID(), taskManager) + serverID := generateSubtaskExecID(ctx, do.ddl.GetID()) + if serverID == "" { + errMsg := fmt.Sprintf("TiDB node ID( = %s ) not found in available TiDB nodes list", do.ddl.GetID()) + return errors.New(errMsg) + } + schedulerManager, err := scheduler.NewManagerBuilder().BuildManager(ctx, serverID, taskManager) if err != nil { return err } @@ -1378,6 +1384,17 @@ func (do *Domain) InitDistTaskLoop(ctx context.Context) error { return nil } +func generateSubtaskExecID(ctx context.Context, ID string) string { + serverInfos, err := infosync.GetAllServerInfo(ctx) + if err != nil || len(serverInfos) == 0 { + return "" + } + if serverNode, ok := serverInfos[ID]; ok { + return disttaskutil.GenerateExecID(serverNode.IP, serverNode.Port) + } + return "" +} + func (do *Domain) distTaskFrameworkLoop(ctx context.Context, taskManager *storage.TaskManager, schedulerManager *scheduler.Manager) { schedulerManager.Start() logutil.BgLogger().Info("dist task scheduler started") diff --git a/util/disttask/BUILD.bazel b/util/disttask/BUILD.bazel new file mode 100644 index 0000000000000..35b9b91141034 --- /dev/null +++ b/util/disttask/BUILD.bazel @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "disttask", + srcs = ["idservice.go"], + importpath = "github.com/pingcap/tidb/util/disttask", + visibility = ["//visibility:public"], +) + +go_test( + name = "disttask_test", + timeout = "short", + srcs = ["idservice_test.go"], + embed = [":disttask"], + flaky = True, + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/util/disttask/idservice.go b/util/disttask/idservice.go new file mode 100644 index 0000000000000..39a44fe2b05c1 --- /dev/null +++ b/util/disttask/idservice.go @@ -0,0 +1,28 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package disttaskutiltest + +import ( + "fmt" + "net" +) + +// GenerateExecID used to generate IP:port as exec_id value +// This function is used by distributed task execution to generate serverID string to +// correlated one subtask to on TiDB node to be executed. +func GenerateExecID(ip string, port uint) string { + portstring := fmt.Sprintf("%d", port) + return net.JoinHostPort(ip, portstring) +} diff --git a/util/disttask/idservice_test.go b/util/disttask/idservice_test.go new file mode 100644 index 0000000000000..3b6aa0024c639 --- /dev/null +++ b/util/disttask/idservice_test.go @@ -0,0 +1,37 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package disttaskutiltest + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// This testCase show GenerateExecID only generate string by input parametas +func TestGenServerID(t *testing.T) { + var str string + serverIO := GenerateExecID("", 0) + require.Equal(t, serverIO, ":0") + serverIO = GenerateExecID("10.124.122.25", 3456) + require.Equal(t, serverIO, "10.124.122.25:3456") + serverIO = GenerateExecID("10.124", 3456) + require.Equal(t, serverIO, "10.124:3456") + serverIO = GenerateExecID(str, 65537) + require.Equal(t, serverIO, ":65537") + // IPv6 testcase + serverIO = GenerateExecID("ABCD:EF01:2345:6789:ABCD:EF01:2345:6789", 65537) + require.Equal(t, serverIO, "[ABCD:EF01:2345:6789:ABCD:EF01:2345:6789]:65537") +} diff --git a/util/stringutil/string_util.go b/util/stringutil/string_util.go index 2f74c51976821..f7e23876dcdfe 100644 --- a/util/stringutil/string_util.go +++ b/util/stringutil/string_util.go @@ -391,7 +391,7 @@ func TrimUtf8String(str *string, trimmedNum int64) int64 { totalLenTrimmed := int64(0) for ; trimmedNum > 0; trimmedNum-- { length := Utf8Len((*str)[0]) // character length - (*str) = (*str)[length:] + *str = (*str)[length:] totalLenTrimmed += int64(length) } return totalLenTrimmed