From ac97ba1fa2cb461be754febc7391f138b9cdbd9d Mon Sep 17 00:00:00 2001 From: Ankit Kala Date: Fri, 26 Aug 2022 14:06:14 +0530 Subject: [PATCH] Fix for missing ShardReplicationTasks on new nodes (#497) Signed-off-by: Ankit Kala Signed-off-by: Ankit Kala (cherry picked from commit 805f686a34393cf2cf26cf7011b71396c26f4fe3) --- .../task/index/IndexReplicationTask.kt | 53 ++++++++-------- .../task/index/IndexReplicationTaskTests.kt | 61 ++++++++++++++++++- 2 files changed, 85 insertions(+), 29 deletions(-) diff --git a/src/main/kotlin/org/opensearch/replication/task/index/IndexReplicationTask.kt b/src/main/kotlin/org/opensearch/replication/task/index/IndexReplicationTask.kt index e499813d9..7e3d3e562 100644 --- a/src/main/kotlin/org/opensearch/replication/task/index/IndexReplicationTask.kt +++ b/src/main/kotlin/org/opensearch/replication/task/index/IndexReplicationTask.kt @@ -183,7 +183,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript ReplicationState.INIT_FOLLOW -> { log.info("Starting shard tasks") addIndexBlockForReplication() - startShardFollowTasks(emptyMap()) + FollowingState(startNewOrMissingShardTasks()) + } ReplicationState.FOLLOWING -> { if (currentTaskState is FollowingState) { @@ -206,8 +207,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript // Tasks need to be started state } else { - state = pollShardTaskStatus((followingTaskState as FollowingState).shardReplicationTasks) - followingTaskState = startMissingShardTasks((followingTaskState as FollowingState).shardReplicationTasks) + state = pollShardTaskStatus() + followingTaskState = FollowingState(startNewOrMissingShardTasks()) when (state) { is MonitoringState -> { updateMetadata() @@ -284,24 +285,7 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript clusterService.addListener(this) } - private suspend fun startMissingShardTasks(shardTasks: Map>): IndexReplicationState { - val persistentTasks = clusterService.state().metadata.custom(PersistentTasksCustomMetadata.TYPE) - - val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream() - .map { task -> task.params as ShardReplicationParams } - .collect(Collectors.toList()) - - val runningTasksForCurrentIndex = shardTasks.filter { entry -> runningShardTasks.find { task -> task.followerShardId == entry.key } != null} - - val numMissingTasks = shardTasks.size - runningTasksForCurrentIndex.size - if (numMissingTasks > 0) { - log.info("Starting $numMissingTasks missing shard task(s)") - return startShardFollowTasks(runningTasksForCurrentIndex) - } - return FollowingState(shardTasks) - } - - private suspend fun pollShardTaskStatus(shardTasks: Map>): IndexReplicationState { + private suspend fun pollShardTaskStatus(): IndexReplicationState { val failedShardTasks = findAllReplicationFailedShardTasks(followerIndexName, clusterService.state()) if (failedShardTasks.isNotEmpty()) { log.info("Failed shard tasks - ", failedShardTasks) @@ -342,11 +326,16 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript registerCloseListeners() val clusterState = clusterService.state() val persistentTasks = clusterState.metadata.custom(PersistentTasksCustomMetadata.TYPE) - val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream() + + val followerShardIds = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards() + .map { shard -> shard.value.shardId } + .stream().collect(Collectors.toSet()) + val runningShardTasksForIndex = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream() .map { task -> task.params as ShardReplicationParams } + .filter {taskParam -> followerShardIds.contains(taskParam.followerShardId) } .collect(Collectors.toList()) - if (runningShardTasks.size == 0) { + if (runningShardTasksForIndex.size != followerShardIds.size) { return InitFollowState } @@ -690,19 +679,27 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript } } - private suspend fun - startShardFollowTasks(tasks: Map>): FollowingState { + suspend fun startNewOrMissingShardTasks(): Map> { assert(clusterService.state().routingTable.hasIndex(followerIndexName)) { "Can't find index $followerIndexName" } val shards = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards() - val newTasks = shards.map { + val persistentTasks = clusterService.state().metadata.custom(PersistentTasksCustomMetadata.TYPE) + val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream() + .map { task -> task as PersistentTask } + .filter { task -> task.params!!.followerShardId.indexName == followerIndexName} + .collect(Collectors.toMap( + {t: PersistentTask -> t.params!!.followerShardId}, + {t: PersistentTask -> t})) + + val tasks = shards.map { it.value.shardId }.associate { shardId -> - val task = tasks.getOrElse(shardId) { + val task = runningShardTasks.getOrElse(shardId) { startReplicationTask(ShardReplicationParams(leaderAlias, ShardId(leaderIndex, shardId.id), shardId)) } return@associate shardId to task } - return FollowingState(newTasks) + + return tasks } private suspend fun cancelRestore() { diff --git a/src/test/kotlin/org/opensearch/replication/task/index/IndexReplicationTaskTests.kt b/src/test/kotlin/org/opensearch/replication/task/index/IndexReplicationTaskTests.kt index 95b3f6f27..569eadf8a 100644 --- a/src/test/kotlin/org/opensearch/replication/task/index/IndexReplicationTaskTests.kt +++ b/src/test/kotlin/org/opensearch/replication/task/index/IndexReplicationTaskTests.kt @@ -55,7 +55,6 @@ import org.opensearch.tasks.TaskManager import org.opensearch.test.ClusterServiceUtils import org.opensearch.test.ClusterServiceUtils.setState import org.opensearch.test.OpenSearchTestCase -import org.opensearch.test.OpenSearchTestCase.assertBusy import org.opensearch.threadpool.TestThreadPool import java.util.* import java.util.concurrent.TimeUnit @@ -150,6 +149,66 @@ class IndexReplicationTaskTests : OpenSearchTestCase() { } + fun testStartNewShardTasks() = runBlocking { + val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask()) + var taskManager = Mockito.mock(TaskManager::class.java) + replicationTask.setPersistent(taskManager) + var rc = ReplicationContext(followerIndex) + var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY) + replicationTask.setReplicationMetadata(rm) + + // Build cluster state + val indices: MutableList = ArrayList() + indices.add(followerIndex) + var metadata = Metadata.builder() + .put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0)) + .put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0)) + .build() + var routingTableBuilder = RoutingTable.builder() + .addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX)) + .addAsNew(metadata.index(followerIndex)) + var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build() + setState(clusterService, newClusterState) + + // Try starting shard tasks + val shardTasks = replicationTask.startNewOrMissingShardTasks() + assertThat(shardTasks.size == 2).isTrue + } + + + fun testStartMissingShardTasks() = runBlocking { + val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask()) + var taskManager = Mockito.mock(TaskManager::class.java) + replicationTask.setPersistent(taskManager) + var rc = ReplicationContext(followerIndex) + var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY) + replicationTask.setReplicationMetadata(rm) + + // Build cluster state + val indices: MutableList = ArrayList() + indices.add(followerIndex) + + val tasks = PersistentTasksCustomMetadata.builder() + var sId = ShardId(Index(followerIndex, "_na_"), 0) + tasks.addTask( "replication:0", ShardReplicationExecutor.TASK_NAME, ShardReplicationParams("remoteCluster", sId, sId), + PersistentTasksCustomMetadata.Assignment("other_node_", "test assignment on other node")) + + var metadata = Metadata.builder() + .put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0)) + .put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0)) + .putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build()) + .build() + var routingTableBuilder = RoutingTable.builder() + .addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX)) + .addAsNew(metadata.index(followerIndex)) + var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build() + setState(clusterService, newClusterState) + + // Try starting shard tasks + val shardTasks = replicationTask.startNewOrMissingShardTasks() + assertThat(shardTasks.size == 2).isTrue + } + private fun createIndexReplicationTask() : IndexReplicationTask { var threadPool = TestThreadPool("IndexReplicationTask") //Hack Alert : Though it is meant to force rejection , this is to make overallTaskScope not null