From 561622d8e39c44d32764b0f62cbd4021c54a026e Mon Sep 17 00:00:00 2001
From: Himshikha Gupta <gupta.himshikha@gmail.com>
Date: Tue, 1 Oct 2024 14:48:10 +0530
Subject: [PATCH] Optimize checksum creation for remote cluster state (#16046)

* Support parallelisation in remote publication checksum computation

Signed-off-by: Himshikha Gupta <himshikh@amazon.com>
---
 .../gateway/remote/ClusterStateChecksum.java  | 150 ++++++++++++------
 .../remote/RemoteClusterStateService.java     |  14 +-
 .../org/opensearch/threadpool/ThreadPool.java |   7 +
 .../remote/ClusterMetadataManifestTests.java  |  15 +-
 .../remote/ClusterStateChecksumTests.java     |  32 ++--
 .../RemoteClusterStateServiceTests.java       |  28 ++--
 6 files changed, 164 insertions(+), 82 deletions(-)

diff --git a/server/src/main/java/org/opensearch/gateway/remote/ClusterStateChecksum.java b/server/src/main/java/org/opensearch/gateway/remote/ClusterStateChecksum.java
index d6739c4572d1a..aa007f5da15b3 100644
--- a/server/src/main/java/org/opensearch/gateway/remote/ClusterStateChecksum.java
+++ b/server/src/main/java/org/opensearch/gateway/remote/ClusterStateChecksum.java
@@ -12,8 +12,10 @@
 import org.apache.logging.log4j.Logger;
 import org.opensearch.cluster.ClusterState;
 import org.opensearch.cluster.metadata.DiffableStringMap;
+import org.opensearch.common.CheckedFunction;
 import org.opensearch.common.io.stream.BytesStreamOutput;
 import org.opensearch.common.settings.Settings;
+import org.opensearch.common.unit.TimeValue;
 import org.opensearch.core.common.io.stream.BufferedChecksumStreamOutput;
 import org.opensearch.core.common.io.stream.StreamInput;
 import org.opensearch.core.common.io.stream.StreamOutput;
@@ -22,11 +24,15 @@
 import org.opensearch.core.xcontent.XContentBuilder;
 import org.opensearch.core.xcontent.XContentParseException;
 import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.threadpool.ThreadPool;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
 
 import com.jcraft.jzlib.JZlib;
 
@@ -37,6 +43,7 @@
  */
 public class ClusterStateChecksum implements ToXContentFragment, Writeable {
 
+    public static final int COMPONENT_SIZE = 11;
     static final String ROUTING_TABLE_CS = "routing_table";
     static final String NODES_CS = "discovery_nodes";
     static final String BLOCKS_CS = "blocks";
@@ -65,62 +72,103 @@ public class ClusterStateChecksum implements ToXContentFragment, Writeable {
     long indicesChecksum;
     long clusterStateChecksum;
 
-    public ClusterStateChecksum(ClusterState clusterState) {
-        try (
-            BytesStreamOutput out = new BytesStreamOutput();
-            BufferedChecksumStreamOutput checksumOut = new BufferedChecksumStreamOutput(out)
-        ) {
-            clusterState.routingTable().writeVerifiableTo(checksumOut);
-            routingTableChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            clusterState.nodes().writeVerifiableTo(checksumOut);
-            nodesChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            clusterState.coordinationMetadata().writeVerifiableTo(checksumOut);
-            coordinationMetadataChecksum = checksumOut.getChecksum();
-
-            // Settings create sortedMap by default, so no explicit sorting required here.
-            checksumOut.reset();
-            Settings.writeSettingsToStream(clusterState.metadata().persistentSettings(), checksumOut);
-            settingMetadataChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            Settings.writeSettingsToStream(clusterState.metadata().transientSettings(), checksumOut);
-            transientSettingsMetadataChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            clusterState.metadata().templatesMetadata().writeVerifiableTo(checksumOut);
-            templatesMetadataChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            checksumOut.writeStringCollection(clusterState.metadata().customs().keySet());
-            customMetadataMapChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            ((DiffableStringMap) clusterState.metadata().hashesOfConsistentSettings()).writeTo(checksumOut);
-            hashesOfConsistentSettingsChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            checksumOut.writeMapValues(
+    public ClusterStateChecksum(ClusterState clusterState, ThreadPool threadpool) {
+        long start = threadpool.relativeTimeInNanos();
+        ExecutorService executorService = threadpool.executor(ThreadPool.Names.REMOTE_STATE_CHECKSUM);
+        CountDownLatch latch = new CountDownLatch(COMPONENT_SIZE);
+
+        executeChecksumTask((stream) -> {
+            clusterState.routingTable().writeVerifiableTo(stream);
+            return null;
+        }, checksum -> routingTableChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            clusterState.nodes().writeVerifiableTo(stream);
+            return null;
+        }, checksum -> nodesChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            clusterState.coordinationMetadata().writeVerifiableTo(stream);
+            return null;
+        }, checksum -> coordinationMetadataChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            Settings.writeSettingsToStream(clusterState.metadata().persistentSettings(), stream);
+            return null;
+        }, checksum -> settingMetadataChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            Settings.writeSettingsToStream(clusterState.metadata().transientSettings(), stream);
+            return null;
+        }, checksum -> transientSettingsMetadataChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            clusterState.metadata().templatesMetadata().writeVerifiableTo(stream);
+            return null;
+        }, checksum -> templatesMetadataChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            stream.writeStringCollection(clusterState.metadata().customs().keySet());
+            return null;
+        }, checksum -> customMetadataMapChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            ((DiffableStringMap) clusterState.metadata().hashesOfConsistentSettings()).writeTo(stream);
+            return null;
+        }, checksum -> hashesOfConsistentSettingsChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            stream.writeMapValues(
                 clusterState.metadata().indices(),
-                (stream, value) -> value.writeVerifiableTo((BufferedChecksumStreamOutput) stream)
+                (checksumStream, value) -> value.writeVerifiableTo((BufferedChecksumStreamOutput) checksumStream)
             );
-            indicesChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            clusterState.blocks().writeVerifiableTo(checksumOut);
-            blocksChecksum = checksumOut.getChecksum();
-
-            checksumOut.reset();
-            checksumOut.writeStringCollection(clusterState.customs().keySet());
-            clusterStateCustomsChecksum = checksumOut.getChecksum();
-        } catch (IOException e) {
-            logger.error("Failed to create checksum for cluster state.", e);
+            return null;
+        }, checksum -> indicesChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            clusterState.blocks().writeVerifiableTo(stream);
+            return null;
+        }, checksum -> blocksChecksum = checksum, executorService, latch);
+
+        executeChecksumTask((stream) -> {
+            stream.writeStringCollection(clusterState.customs().keySet());
+            return null;
+        }, checksum -> clusterStateCustomsChecksum = checksum, executorService, latch);
+
+        try {
+            latch.await();
+        } catch (InterruptedException e) {
             throw new RemoteStateTransferException("Failed to create checksum for cluster state.", e);
         }
         createClusterStateChecksum();
+        logger.debug("Checksum execution time {}", TimeValue.nsecToMSec(threadpool.relativeTimeInNanos() - start));
+    }
+
+    private void executeChecksumTask(
+        CheckedFunction<BufferedChecksumStreamOutput, Void, IOException> checksumTask,
+        Consumer<Long> checksumConsumer,
+        ExecutorService executorService,
+        CountDownLatch latch
+    ) {
+        executorService.execute(() -> {
+            try {
+                long checksum = createChecksum(checksumTask);
+                checksumConsumer.accept(checksum);
+                latch.countDown();
+            } catch (IOException e) {
+                throw new RemoteStateTransferException("Failed to execute checksum task", e);
+            }
+        });
+    }
+
+    private long createChecksum(CheckedFunction<BufferedChecksumStreamOutput, Void, IOException> task) throws IOException {
+        try (
+            BytesStreamOutput out = new BytesStreamOutput();
+            BufferedChecksumStreamOutput checksumOut = new BufferedChecksumStreamOutput(out)
+        ) {
+            task.apply(checksumOut);
+            return checksumOut.getChecksum();
+        }
     }
 
     private void createClusterStateChecksum() {
diff --git a/server/src/main/java/org/opensearch/gateway/remote/RemoteClusterStateService.java b/server/src/main/java/org/opensearch/gateway/remote/RemoteClusterStateService.java
index ece29180f9cf5..ce5e57b79dadb 100644
--- a/server/src/main/java/org/opensearch/gateway/remote/RemoteClusterStateService.java
+++ b/server/src/main/java/org/opensearch/gateway/remote/RemoteClusterStateService.java
@@ -332,7 +332,9 @@ public RemoteClusterStateManifestInfo writeFullMetadata(ClusterState clusterStat
             uploadedMetadataResults,
             previousClusterUUID,
             clusterStateDiffManifest,
-            !remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE) ? new ClusterStateChecksum(clusterState) : null,
+            !remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE)
+                ? new ClusterStateChecksum(clusterState, threadpool)
+                : null,
             false,
             codecVersion
         );
@@ -539,7 +541,9 @@ public RemoteClusterStateManifestInfo writeIncrementalMetadata(
             uploadedMetadataResults,
             previousManifest.getPreviousClusterUUID(),
             clusterStateDiffManifest,
-            !remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE) ? new ClusterStateChecksum(clusterState) : null,
+            !remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE)
+                ? new ClusterStateChecksum(clusterState, threadpool)
+                : null,
             false,
             previousManifest.getCodecVersion()
         );
@@ -1010,7 +1014,9 @@ public RemoteClusterStateManifestInfo markLastStateAsCommitted(
             uploadedMetadataResults,
             previousManifest.getPreviousClusterUUID(),
             previousManifest.getDiffManifest(),
-            !remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE) ? new ClusterStateChecksum(clusterState) : null,
+            !remoteClusterStateValidationMode.equals(RemoteClusterStateValidationMode.NONE)
+                ? new ClusterStateChecksum(clusterState, threadpool)
+                : null,
             true,
             previousManifest.getCodecVersion()
         );
@@ -1631,7 +1637,7 @@ void validateClusterStateFromChecksum(
         String localNodeId,
         boolean isFullStateDownload
     ) {
-        ClusterStateChecksum newClusterStateChecksum = new ClusterStateChecksum(clusterState);
+        ClusterStateChecksum newClusterStateChecksum = new ClusterStateChecksum(clusterState, threadpool);
         List<String> failedValidation = newClusterStateChecksum.getMismatchEntities(manifest.getClusterStateChecksum());
         if (failedValidation.isEmpty()) {
             return;
diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java
index 81220ab171b34..d795fd252b7fc 100644
--- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java
+++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java
@@ -53,6 +53,7 @@
 import org.opensearch.core.service.ReportingService;
 import org.opensearch.core.xcontent.ToXContentFragment;
 import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.gateway.remote.ClusterStateChecksum;
 import org.opensearch.node.Node;
 
 import java.io.IOException;
@@ -118,6 +119,7 @@ public static class Names {
         public static final String REMOTE_RECOVERY = "remote_recovery";
         public static final String REMOTE_STATE_READ = "remote_state_read";
         public static final String INDEX_SEARCHER = "index_searcher";
+        public static final String REMOTE_STATE_CHECKSUM = "remote_state_checksum";
     }
 
     /**
@@ -191,6 +193,7 @@ public static ThreadPoolType fromType(String type) {
         map.put(Names.REMOTE_RECOVERY, ThreadPoolType.SCALING);
         map.put(Names.REMOTE_STATE_READ, ThreadPoolType.SCALING);
         map.put(Names.INDEX_SEARCHER, ThreadPoolType.RESIZABLE);
+        map.put(Names.REMOTE_STATE_CHECKSUM, ThreadPoolType.FIXED);
         THREAD_POOL_TYPES = Collections.unmodifiableMap(map);
     }
 
@@ -307,6 +310,10 @@ public ThreadPool(
                 runnableTaskListener
             )
         );
+        builders.put(
+            Names.REMOTE_STATE_CHECKSUM,
+            new FixedExecutorBuilder(settings, Names.REMOTE_STATE_CHECKSUM, ClusterStateChecksum.COMPONENT_SIZE, 1000)
+        );
 
         for (final ExecutorBuilder<?> builder : customBuilders) {
             if (builders.containsKey(builder.name())) {
diff --git a/server/src/test/java/org/opensearch/gateway/remote/ClusterMetadataManifestTests.java b/server/src/test/java/org/opensearch/gateway/remote/ClusterMetadataManifestTests.java
index 3f9aa1245cab3..09c2933680be3 100644
--- a/server/src/test/java/org/opensearch/gateway/remote/ClusterMetadataManifestTests.java
+++ b/server/src/test/java/org/opensearch/gateway/remote/ClusterMetadataManifestTests.java
@@ -34,6 +34,9 @@
 import org.opensearch.gateway.remote.ClusterMetadataManifest.UploadedMetadataAttribute;
 import org.opensearch.test.EqualsHashCodeTestUtils;
 import org.opensearch.test.OpenSearchTestCase;
+import org.opensearch.threadpool.TestThreadPool;
+import org.opensearch.threadpool.ThreadPool;
+import org.junit.After;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -64,6 +67,14 @@
 
 public class ClusterMetadataManifestTests extends OpenSearchTestCase {
 
+    private final ThreadPool threadPool = new TestThreadPool(getClass().getName());
+
+    @After
+    public void teardown() throws Exception {
+        super.tearDown();
+        threadPool.shutdown();
+    }
+
     public void testClusterMetadataManifestXContentV0() throws IOException {
         UploadedIndexMetadata uploadedIndexMetadata = new UploadedIndexMetadata("test-index", "test-uuid", "/test/upload/path", CODEC_V0);
         ClusterMetadataManifest originalManifest = ClusterMetadataManifest.builder()
@@ -214,7 +225,7 @@ public void testClusterMetadataManifestSerializationEqualsHashCode() {
                     "indicesRoutingDiffPath"
                 )
             )
-            .checksum(new ClusterStateChecksum(createClusterState()))
+            .checksum(new ClusterStateChecksum(createClusterState(), threadPool))
             .build();
         {  // Mutate Cluster Term
             EqualsHashCodeTestUtils.checkEqualsAndHashCode(
@@ -647,7 +658,7 @@ public void testClusterMetadataManifestXContentV4() throws IOException {
         UploadedIndexMetadata uploadedIndexMetadata = new UploadedIndexMetadata("test-index", "test-uuid", "/test/upload/path");
         UploadedMetadataAttribute uploadedMetadataAttribute = new UploadedMetadataAttribute("attribute_name", "testing_attribute");
         final StringKeyDiffProvider<IndexRoutingTable> routingTableIncrementalDiff = Mockito.mock(StringKeyDiffProvider.class);
-        ClusterStateChecksum checksum = new ClusterStateChecksum(createClusterState());
+        ClusterStateChecksum checksum = new ClusterStateChecksum(createClusterState(), threadPool);
         ClusterMetadataManifest originalManifest = ClusterMetadataManifest.builder()
             .clusterTerm(1L)
             .stateVersion(1L)
diff --git a/server/src/test/java/org/opensearch/gateway/remote/ClusterStateChecksumTests.java b/server/src/test/java/org/opensearch/gateway/remote/ClusterStateChecksumTests.java
index 0203e56dd2d5c..9b98187053a39 100644
--- a/server/src/test/java/org/opensearch/gateway/remote/ClusterStateChecksumTests.java
+++ b/server/src/test/java/org/opensearch/gateway/remote/ClusterStateChecksumTests.java
@@ -34,6 +34,9 @@
 import org.opensearch.core.xcontent.XContentBuilder;
 import org.opensearch.core.xcontent.XContentParser;
 import org.opensearch.test.OpenSearchTestCase;
+import org.opensearch.threadpool.TestThreadPool;
+import org.opensearch.threadpool.ThreadPool;
+import org.junit.After;
 
 import java.io.IOException;
 import java.util.EnumSet;
@@ -41,14 +44,21 @@
 import java.util.Map;
 
 public class ClusterStateChecksumTests extends OpenSearchTestCase {
+    private final ThreadPool threadPool = new TestThreadPool(getClass().getName());
+
+    @After
+    public void teardown() throws Exception {
+        super.tearDown();
+        threadPool.shutdown();
+    }
 
     public void testClusterStateChecksumEmptyClusterState() {
-        ClusterStateChecksum checksum = new ClusterStateChecksum(ClusterState.EMPTY_STATE);
+        ClusterStateChecksum checksum = new ClusterStateChecksum(ClusterState.EMPTY_STATE, threadPool);
         assertNotNull(checksum);
     }
 
     public void testClusterStateChecksum() {
-        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState());
+        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState(), threadPool);
         assertNotNull(checksum);
         assertTrue(checksum.routingTableChecksum != 0);
         assertTrue(checksum.nodesChecksum != 0);
@@ -65,8 +75,8 @@ public void testClusterStateChecksum() {
     }
 
     public void testClusterStateMatchChecksum() {
-        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState());
-        ClusterStateChecksum newChecksum = new ClusterStateChecksum(generateClusterState());
+        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState(), threadPool);
+        ClusterStateChecksum newChecksum = new ClusterStateChecksum(generateClusterState(), threadPool);
         assertNotNull(checksum);
         assertNotNull(newChecksum);
         assertEquals(checksum.routingTableChecksum, newChecksum.routingTableChecksum);
@@ -84,7 +94,7 @@ public void testClusterStateMatchChecksum() {
     }
 
     public void testXContentConversion() throws IOException {
-        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState());
+        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState(), threadPool);
         final XContentBuilder builder = JsonXContent.contentBuilder();
         builder.startObject();
         checksum.toXContent(builder, ToXContent.EMPTY_PARAMS);
@@ -97,7 +107,7 @@ public void testXContentConversion() throws IOException {
     }
 
     public void testSerialization() throws IOException {
-        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState());
+        ClusterStateChecksum checksum = new ClusterStateChecksum(generateClusterState(), threadPool);
         BytesStreamOutput output = new BytesStreamOutput();
         checksum.writeTo(output);
 
@@ -109,10 +119,10 @@ public void testSerialization() throws IOException {
 
     public void testGetMismatchEntities() {
         ClusterState clsState1 = generateClusterState();
-        ClusterStateChecksum checksum = new ClusterStateChecksum(clsState1);
+        ClusterStateChecksum checksum = new ClusterStateChecksum(clsState1, threadPool);
         assertTrue(checksum.getMismatchEntities(checksum).isEmpty());
 
-        ClusterStateChecksum checksum2 = new ClusterStateChecksum(clsState1);
+        ClusterStateChecksum checksum2 = new ClusterStateChecksum(clsState1, threadPool);
         assertTrue(checksum.getMismatchEntities(checksum2).isEmpty());
 
         ClusterState clsState2 = ClusterState.builder(ClusterName.DEFAULT)
@@ -122,7 +132,7 @@ public void testGetMismatchEntities() {
             .customs(Map.of())
             .metadata(Metadata.EMPTY_METADATA)
             .build();
-        ClusterStateChecksum checksum3 = new ClusterStateChecksum(clsState2);
+        ClusterStateChecksum checksum3 = new ClusterStateChecksum(clsState2, threadPool);
         List<String> mismatches = checksum.getMismatchEntities(checksum3);
         assertFalse(mismatches.isEmpty());
         assertEquals(11, mismatches.size());
@@ -151,8 +161,8 @@ public void testGetMismatchEntitiesUnorderedInput() {
         ClusterState state2 = ClusterState.builder(state1).nodes(nodes1).build();
         ClusterState state3 = ClusterState.builder(state1).nodes(nodes2).build();
 
-        ClusterStateChecksum checksum1 = new ClusterStateChecksum(state2);
-        ClusterStateChecksum checksum2 = new ClusterStateChecksum(state3);
+        ClusterStateChecksum checksum1 = new ClusterStateChecksum(state2, threadPool);
+        ClusterStateChecksum checksum2 = new ClusterStateChecksum(state3, threadPool);
         assertEquals(checksum2, checksum1);
     }
 
diff --git a/server/src/test/java/org/opensearch/gateway/remote/RemoteClusterStateServiceTests.java b/server/src/test/java/org/opensearch/gateway/remote/RemoteClusterStateServiceTests.java
index 56857285fa8d3..35a8ae16cacf7 100644
--- a/server/src/test/java/org/opensearch/gateway/remote/RemoteClusterStateServiceTests.java
+++ b/server/src/test/java/org/opensearch/gateway/remote/RemoteClusterStateServiceTests.java
@@ -3123,7 +3123,7 @@ public void testWriteFullMetadataSuccessWithChecksumValidationEnabled() throws I
             .previousClusterUUID("prev-cluster-uuid")
             .routingTableVersion(1L)
             .indicesRouting(List.of(uploadedIndiceRoutingMetadata))
-            .checksum(new ClusterStateChecksum(clusterState))
+            .checksum(new ClusterStateChecksum(clusterState, threadPool))
             .build();
 
         assertThat(manifest.getIndices().size(), is(1));
@@ -3193,7 +3193,7 @@ public void testWriteIncrementalMetadataSuccessWithChecksumValidationEnabled() t
 
         final ClusterMetadataManifest previousManifest = ClusterMetadataManifest.builder()
             .indices(Collections.emptyList())
-            .checksum(new ClusterStateChecksum(clusterState))
+            .checksum(new ClusterStateChecksum(clusterState, threadPool))
             .build();
         when((blobStoreRepository.basePath())).thenReturn(BlobPath.cleanPath().add("base-path"));
 
@@ -3219,7 +3219,7 @@ public void testWriteIncrementalMetadataSuccessWithChecksumValidationEnabled() t
             .previousClusterUUID("prev-cluster-uuid")
             .routingTableVersion(1)
             .indicesRouting(List.of(uploadedIndiceRoutingMetadata))
-            .checksum(new ClusterStateChecksum(clusterState))
+            .checksum(new ClusterStateChecksum(clusterState, threadPool))
             .build();
 
         assertThat(manifest.getIndices().size(), is(1));
@@ -3245,7 +3245,7 @@ public void testWriteIncrementalMetadataSuccessWithChecksumValidationModeNone()
 
         final ClusterMetadataManifest previousManifest = ClusterMetadataManifest.builder()
             .indices(Collections.emptyList())
-            .checksum(new ClusterStateChecksum(clusterState))
+            .checksum(new ClusterStateChecksum(clusterState, threadPool))
             .build();
         when((blobStoreRepository.basePath())).thenReturn(BlobPath.cleanPath().add("base-path"));
 
@@ -3271,7 +3271,7 @@ public void testWriteIncrementalMetadataSuccessWithChecksumValidationModeNone()
             .previousClusterUUID("prev-cluster-uuid")
             .routingTableVersion(1)
             .indicesRouting(List.of(uploadedIndiceRoutingMetadata))
-            .checksum(new ClusterStateChecksum(clusterState))
+            .checksum(new ClusterStateChecksum(clusterState, threadPool))
             .build();
 
         assertThat(manifest.getIndices().size(), is(1));
@@ -3349,7 +3349,7 @@ public void testGetClusterStateForManifestWithChecksumValidationEnabled() throws
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.FAILURE);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).build();
         remoteClusterStateService.start();
         RemoteClusterStateService mockService = spy(remoteClusterStateService);
@@ -3382,7 +3382,7 @@ public void testGetClusterStateForManifestWithChecksumValidationModeNone() throw
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.NONE);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).build();
         remoteClusterStateService.start();
         RemoteClusterStateService mockService = spy(remoteClusterStateService);
@@ -3415,7 +3415,7 @@ public void testGetClusterStateForManifestWithChecksumValidationEnabledWithMisma
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.FAILURE);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).build();
         remoteClusterStateService.start();
         RemoteClusterStateService mockService = spy(remoteClusterStateService);
@@ -3465,7 +3465,7 @@ public void testGetClusterStateForManifestWithChecksumValidationDebugWithMismatc
         );
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).build();
         remoteClusterStateService.start();
         RemoteClusterStateService mockService = spy(remoteClusterStateService);
@@ -3505,7 +3505,7 @@ public void testGetClusterStateUsingDiffWithChecksum() throws IOException {
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.FAILURE);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).diffManifest(ClusterStateDiffManifest.builder().build()).build();
 
         remoteClusterStateService.start();
@@ -3547,7 +3547,7 @@ public void testGetClusterStateUsingDiffWithChecksumModeNone() throws IOExceptio
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.NONE);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).diffManifest(ClusterStateDiffManifest.builder().build()).build();
 
         remoteClusterStateService.start();
@@ -3589,7 +3589,7 @@ public void testGetClusterStateUsingDiffWithChecksumModeDebugMismatch() throws I
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.DEBUG);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).diffManifest(ClusterStateDiffManifest.builder().build()).build();
 
         remoteClusterStateService.start();
@@ -3630,7 +3630,7 @@ public void testGetClusterStateUsingDiffWithChecksumModeTraceMismatch() throws I
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.TRACE);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).diffManifest(ClusterStateDiffManifest.builder().build()).build();
 
         remoteClusterStateService.start();
@@ -3692,7 +3692,7 @@ public void testGetClusterStateUsingDiffWithChecksumMismatch() throws IOExceptio
         initializeWithChecksumEnabled(RemoteClusterStateService.RemoteClusterStateValidationMode.FAILURE);
         ClusterState clusterState = generateClusterStateWithAllAttributes().build();
         ClusterMetadataManifest manifest = generateClusterMetadataManifestWithAllAttributes().checksum(
-            new ClusterStateChecksum(clusterState)
+            new ClusterStateChecksum(clusterState, threadPool)
         ).diffManifest(ClusterStateDiffManifest.builder().build()).build();
 
         remoteClusterStateService.start();