diff --git a/CHANGELOG.md b/CHANGELOG.md
index 493e4b9a6c6db..9e3a69d2f1c8c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Add stats for remote publication failure and move download failure stats to remote methods([#16682](https://github.com/opensearch-project/OpenSearch/pull/16682/))
 - Added a precaution to handle extreme date values during sorting to prevent `arithmetic_exception: long overflow` ([#16812](https://github.com/opensearch-project/OpenSearch/pull/16812)).
 - Add `verbose_pipeline` parameter to output each processor's execution details ([#14745](https://github.com/opensearch-project/OpenSearch/pull/14745)).
+- Add search replica stats to segment replication stats API ([#16678](https://github.com/opensearch-project/OpenSearch/pull/16678))
+
 
 ### Dependencies
 - Bump `com.google.cloud:google-cloud-core-http` from 2.23.0 to 2.47.0 ([#16504](https://github.com/opensearch-project/OpenSearch/pull/16504))
@@ -67,6 +69,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Bound the size of cache in deprecation logger ([16702](https://github.com/opensearch-project/OpenSearch/issues/16702))
 - Ensure consistency of system flag on IndexMetadata after diff is applied ([#16644](https://github.com/opensearch-project/OpenSearch/pull/16644))
 - Skip remote-repositories validations for node-joins when RepositoriesService is not in sync with cluster-state ([#16763](https://github.com/opensearch-project/OpenSearch/pull/16763))
+- Fix _list/shards API failing when closed indices are present ([#16606](https://github.com/opensearch-project/OpenSearch/pull/16606))
 
 ### Security
 
diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsActionIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsActionIT.java
index 32d5b3db85629..a7cb4847b45e5 100644
--- a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsActionIT.java
+++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsActionIT.java
@@ -8,9 +8,15 @@
 
 package org.opensearch.action.admin.cluster.shards;
 
+import org.opensearch.action.admin.indices.alias.IndicesAliasesRequest;
+import org.opensearch.action.admin.indices.datastream.DataStreamTestCase;
 import org.opensearch.action.admin.indices.stats.IndicesStatsResponse;
+import org.opensearch.action.admin.indices.stats.ShardStats;
+import org.opensearch.action.pagination.PageParams;
+import org.opensearch.client.Requests;
 import org.opensearch.cluster.metadata.IndexMetadata;
 import org.opensearch.cluster.routing.ShardRouting;
+import org.opensearch.common.action.ActionFuture;
 import org.opensearch.common.settings.Settings;
 import org.opensearch.common.unit.TimeValue;
 import org.opensearch.core.action.ActionListener;
@@ -20,15 +26,19 @@
 import org.opensearch.test.OpenSearchIntegTestCase;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
 
 import static org.opensearch.cluster.routing.UnassignedInfo.INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING;
 import static org.opensearch.common.unit.TimeValue.timeValueMillis;
 import static org.opensearch.search.SearchService.NO_TIMEOUT;
+import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
 
 @OpenSearchIntegTestCase.ClusterScope(numDataNodes = 0, scope = OpenSearchIntegTestCase.Scope.TEST)
-public class TransportCatShardsActionIT extends OpenSearchIntegTestCase {
+public class TransportCatShardsActionIT extends DataStreamTestCase {
 
     public void testCatShardsWithSuccessResponse() throws InterruptedException {
         internalCluster().startClusterManagerOnlyNodes(1);
@@ -125,4 +135,334 @@ public void onFailure(Exception e) {
         latch.await();
     }
 
+    public void testListShardsWithHiddenIndex() throws Exception {
+        final int numShards = 1;
+        final int numReplicas = 1;
+        internalCluster().startClusterManagerOnlyNodes(1);
+        internalCluster().startDataOnlyNodes(2);
+        createIndex(
+            "test-hidden-idx",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .put(IndexMetadata.SETTING_INDEX_HIDDEN, true)
+                .build()
+        );
+        ensureGreen();
+
+        // Verify result for a default query: "_list/shards"
+        CatShardsRequest listShardsRequest = getListShardsTransportRequest(Strings.EMPTY_ARRAY, 100);
+        ActionFuture<CatShardsResponse> listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertSingleIndexResponseShards(listShardsResponse.get(), "test-hidden-idx", 2, true);
+
+        // Verify result when hidden index is explicitly queried: "_list/shards"
+        listShardsRequest = getListShardsTransportRequest(new String[] { "test-hidden-idx" }, 100);
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertSingleIndexResponseShards(listShardsResponse.get(), "test-hidden-idx", 2, true);
+
+        // Verify result when hidden index is queried with wildcard: "_list/shards*"
+        // Since the ClusterStateAction underneath is invoked with lenientExpandOpen IndicesOptions,
+        // Wildcards for hidden indices should not get resolved.
+        listShardsRequest = getListShardsTransportRequest(new String[] { "test-hidden-idx*" }, 100);
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertEquals(0, listShardsResponse.get().getResponseShards().size());
+        assertSingleIndexResponseShards(listShardsResponse.get(), "test-hidden-idx", 0, false);
+    }
+
+    public void testListShardsWithClosedIndex() throws Exception {
+        final int numShards = 1;
+        final int numReplicas = 1;
+        internalCluster().startClusterManagerOnlyNodes(1);
+        internalCluster().startDataOnlyNodes(2);
+        createIndex(
+            "test-closed-idx",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        ensureGreen();
+
+        // close index "test-closed-idx"
+        client().admin().indices().close(Requests.closeIndexRequest("test-closed-idx")).get();
+        ensureGreen();
+
+        // Verify result for a default query: "_list/shards"
+        CatShardsRequest listShardsRequest = getListShardsTransportRequest(Strings.EMPTY_ARRAY, 100);
+        ActionFuture<CatShardsResponse> listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertSingleIndexResponseShards(listShardsResponse.get(), "test-closed-idx", 2, false);
+
+        // Verify result when closed index is explicitly queried: "_list/shards"
+        listShardsRequest = getListShardsTransportRequest(new String[] { "test-closed-idx" }, 100);
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertSingleIndexResponseShards(listShardsResponse.get(), "test-closed-idx", 2, false);
+
+        // Verify result when closed index is queried with wildcard: "_list/shards*"
+        // Since the ClusterStateAction underneath is invoked with lenientExpandOpen IndicesOptions,
+        // Wildcards for closed indices should not get resolved.
+        listShardsRequest = getListShardsTransportRequest(new String[] { "test-closed-idx*" }, 100);
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertSingleIndexResponseShards(listShardsResponse.get(), "test-closed-idx", 0, false);
+    }
+
+    public void testListShardsWithClosedAndHiddenIndices() throws InterruptedException, ExecutionException {
+        final int numIndices = 4;
+        final int numShards = 1;
+        final int numReplicas = 2;
+        final int pageSize = 100;
+        internalCluster().startClusterManagerOnlyNodes(1);
+        internalCluster().startDataOnlyNodes(3);
+        createIndex(
+            "test",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        createIndex(
+            "test-2",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        createIndex(
+            "test-closed-idx",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        createIndex(
+            "test-hidden-idx",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .put(IndexMetadata.SETTING_INDEX_HIDDEN, true)
+                .build()
+        );
+        // close index "test-closed-idx"
+        client().admin().indices().close(Requests.closeIndexRequest("test-closed-idx")).get();
+        ensureGreen();
+
+        // Verifying response for default queries: /_list/shards
+        // all the shards should be part of response, however stats should not be displayed for closed index
+        CatShardsRequest listShardsRequest = getListShardsTransportRequest(Strings.EMPTY_ARRAY, pageSize);
+        ActionFuture<CatShardsResponse> listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertTrue(listShardsResponse.get().getResponseShards().stream().anyMatch(shard -> shard.getIndexName().equals("test-closed-idx")));
+        assertTrue(listShardsResponse.get().getResponseShards().stream().anyMatch(shard -> shard.getIndexName().equals("test-hidden-idx")));
+        assertEquals(numIndices * numShards * (numReplicas + 1), listShardsResponse.get().getResponseShards().size());
+        assertFalse(
+            Arrays.stream(listShardsResponse.get().getIndicesStatsResponse().getShards())
+                .anyMatch(shardStats -> shardStats.getShardRouting().getIndexName().equals("test-closed-idx"))
+        );
+        assertEquals(
+            (numIndices - 1) * numShards * (numReplicas + 1),
+            listShardsResponse.get().getIndicesStatsResponse().getShards().length
+        );
+
+        // Verifying responses when hidden indices are explicitly queried: /_list/shards/test-hidden-idx
+        // Shards for hidden index should appear in response along with stats
+        listShardsRequest.setIndices(List.of("test-hidden-idx").toArray(new String[0]));
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertTrue(listShardsResponse.get().getResponseShards().stream().allMatch(shard -> shard.getIndexName().equals("test-hidden-idx")));
+        assertTrue(
+            Arrays.stream(listShardsResponse.get().getIndicesStatsResponse().getShards())
+                .allMatch(shardStats -> shardStats.getShardRouting().getIndexName().equals("test-hidden-idx"))
+        );
+        assertEquals(
+            listShardsResponse.get().getResponseShards().size(),
+            listShardsResponse.get().getIndicesStatsResponse().getShards().length
+        );
+
+        // Verifying responses when hidden indices are queried with wildcards: /_list/shards/test-hidden-idx*
+        // Shards for hidden index should not appear in response with stats.
+        listShardsRequest.setIndices(List.of("test-hidden-idx*").toArray(new String[0]));
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertEquals(0, listShardsResponse.get().getResponseShards().size());
+        assertEquals(0, listShardsResponse.get().getIndicesStatsResponse().getShards().length);
+
+        // Explicitly querying for closed index: /_list/shards/test-closed-idx
+        // should output closed shards without stats.
+        listShardsRequest.setIndices(List.of("test-closed-idx").toArray(new String[0]));
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertTrue(listShardsResponse.get().getResponseShards().stream().anyMatch(shard -> shard.getIndexName().equals("test-closed-idx")));
+        assertEquals(0, listShardsResponse.get().getIndicesStatsResponse().getShards().length);
+
+        // Querying for closed index with wildcards: /_list/shards/test-closed-idx*
+        // should not output any closed shards.
+        listShardsRequest.setIndices(List.of("test-closed-idx*").toArray(new String[0]));
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertEquals(0, listShardsResponse.get().getResponseShards().size());
+        assertEquals(0, listShardsResponse.get().getIndicesStatsResponse().getShards().length);
+    }
+
+    public void testListShardsWithClosedIndicesAcrossPages() throws InterruptedException, ExecutionException {
+        final int numIndices = 4;
+        final int numShards = 1;
+        final int numReplicas = 2;
+        final int pageSize = numShards * (numReplicas + 1);
+        internalCluster().startClusterManagerOnlyNodes(1);
+        internalCluster().startDataOnlyNodes(3);
+        createIndex(
+            "test-open-idx-1",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        createIndex(
+            "test-closed-idx-1",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        createIndex(
+            "test-open-idx-2",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        createIndex(
+            "test-closed-idx-2",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .put(IndexMetadata.SETTING_INDEX_HIDDEN, true)
+                .build()
+        );
+        // close index "test-closed-idx-1"
+        client().admin().indices().close(Requests.closeIndexRequest("test-closed-idx-1")).get();
+        ensureGreen();
+        // close index "test-closed-idx-2"
+        client().admin().indices().close(Requests.closeIndexRequest("test-closed-idx-2")).get();
+        ensureGreen();
+
+        // Verifying response for default queries: /_list/shards
+        List<ShardRouting> responseShardRouting = new ArrayList<>();
+        List<ShardStats> responseShardStats = new ArrayList<>();
+        String nextToken = null;
+        CatShardsRequest listShardsRequest;
+        ActionFuture<CatShardsResponse> listShardsResponse;
+        do {
+            listShardsRequest = getListShardsTransportRequest(Strings.EMPTY_ARRAY, nextToken, pageSize);
+            listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+            nextToken = listShardsResponse.get().getPageToken().getNextToken();
+            responseShardRouting.addAll(listShardsResponse.get().getResponseShards());
+            responseShardStats.addAll(List.of(listShardsResponse.get().getIndicesStatsResponse().getShards()));
+        } while (nextToken != null);
+
+        assertTrue(responseShardRouting.stream().anyMatch(shard -> shard.getIndexName().equals("test-closed-idx-1")));
+        assertTrue(responseShardRouting.stream().anyMatch(shard -> shard.getIndexName().equals("test-closed-idx-2")));
+        assertEquals(numIndices * numShards * (numReplicas + 1), responseShardRouting.size());
+        // ShardsStats should only appear for 2 open indices
+        assertFalse(
+            responseShardStats.stream().anyMatch(shardStats -> shardStats.getShardRouting().getIndexName().contains("test-closed-idx"))
+        );
+        assertEquals(2 * numShards * (numReplicas + 1), responseShardStats.size());
+    }
+
+    public void testListShardsWithDataStream() throws Exception {
+        final int numDataNodes = 3;
+        String dataStreamName = "logs-test";
+        internalCluster().startClusterManagerOnlyNodes(1);
+        internalCluster().startDataOnlyNodes(numDataNodes);
+        // Create an index template for data streams.
+        createDataStreamIndexTemplate("data-stream-template", List.of("logs-*"));
+        // Create data streams matching the "logs-*" index pattern.
+        createDataStream(dataStreamName);
+        ensureGreen();
+        // Verifying default query's result. Data stream should have created a hidden backing index in the
+        // background and all the corresponding shards should appear in the response along with stats.
+        CatShardsRequest listShardsRequest = getListShardsTransportRequest(Strings.EMPTY_ARRAY, numDataNodes * numDataNodes);
+        ActionFuture<CatShardsResponse> listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertSingleIndexResponseShards(listShardsResponse.get(), dataStreamName, numDataNodes + 1, true);
+        // Verifying result when data stream is directly queried. Again, all the shards with stats should appear
+        listShardsRequest = getListShardsTransportRequest(new String[] { dataStreamName }, numDataNodes * numDataNodes);
+        listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertSingleIndexResponseShards(listShardsResponse.get(), dataStreamName, numDataNodes + 1, true);
+    }
+
+    public void testListShardsWithAliases() throws Exception {
+        final int numShards = 1;
+        final int numReplicas = 1;
+        final String aliasName = "test-alias";
+        internalCluster().startClusterManagerOnlyNodes(1);
+        internalCluster().startDataOnlyNodes(3);
+        createIndex(
+            "test-closed-idx",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .build()
+        );
+        createIndex(
+            "test-hidden-idx",
+            Settings.builder()
+                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numReplicas)
+                .put(IndexMetadata.SETTING_INDEX_HIDDEN, true)
+                .build()
+        );
+        ensureGreen();
+
+        // Point test alias to both the indices (one being hidden while the other is closed)
+        final IndicesAliasesRequest request = new IndicesAliasesRequest().origin("allowed");
+        request.addAliasAction(IndicesAliasesRequest.AliasActions.add().index("test-closed-idx").alias(aliasName));
+        assertAcked(client().admin().indices().aliases(request).actionGet());
+
+        request.addAliasAction(IndicesAliasesRequest.AliasActions.add().index("test-hidden-idx").alias(aliasName));
+        assertAcked(client().admin().indices().aliases(request).actionGet());
+
+        // close index "test-closed-idx"
+        client().admin().indices().close(Requests.closeIndexRequest("test-closed-idx")).get();
+        ensureGreen();
+
+        // Verifying result when an alias is explicitly queried.
+        CatShardsRequest listShardsRequest = getListShardsTransportRequest(new String[] { aliasName }, 100);
+        ActionFuture<CatShardsResponse> listShardsResponse = client().execute(CatShardsAction.INSTANCE, listShardsRequest);
+        assertTrue(
+            listShardsResponse.get()
+                .getResponseShards()
+                .stream()
+                .allMatch(shard -> shard.getIndexName().equals("test-hidden-idx") || shard.getIndexName().equals("test-closed-idx"))
+        );
+        assertTrue(
+            Arrays.stream(listShardsResponse.get().getIndicesStatsResponse().getShards())
+                .allMatch(shardStats -> shardStats.getShardRouting().getIndexName().equals("test-hidden-idx"))
+        );
+        assertEquals(4, listShardsResponse.get().getResponseShards().size());
+        assertEquals(2, listShardsResponse.get().getIndicesStatsResponse().getShards().length);
+    }
+
+    private void assertSingleIndexResponseShards(
+        CatShardsResponse catShardsResponse,
+        String indexNamePattern,
+        final int totalNumShards,
+        boolean shardStatsExist
+    ) {
+        assertTrue(catShardsResponse.getResponseShards().stream().allMatch(shard -> shard.getIndexName().contains(indexNamePattern)));
+        assertEquals(totalNumShards, catShardsResponse.getResponseShards().size());
+        if (shardStatsExist) {
+            assertTrue(
+                Arrays.stream(catShardsResponse.getIndicesStatsResponse().getShards())
+                    .allMatch(shardStats -> shardStats.getShardRouting().getIndexName().contains(indexNamePattern))
+            );
+        }
+        assertEquals(shardStatsExist ? totalNumShards : 0, catShardsResponse.getIndicesStatsResponse().getShards().length);
+    }
+
+    private CatShardsRequest getListShardsTransportRequest(String[] indices, final int pageSize) {
+        return getListShardsTransportRequest(indices, null, pageSize);
+    }
+
+    private CatShardsRequest getListShardsTransportRequest(String[] indices, String nextToken, final int pageSize) {
+        CatShardsRequest listShardsRequest = new CatShardsRequest();
+        listShardsRequest.setCancelAfterTimeInterval(NO_TIMEOUT);
+        listShardsRequest.setIndices(indices);
+        listShardsRequest.setPageParams(new PageParams(nextToken, PageParams.PARAM_ASC_SORT_VALUE, pageSize));
+        return listShardsRequest;
+    }
 }
diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SearchReplicaReplicationIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SearchReplicaReplicationIT.java
index a1b512c326ac5..f660695af9965 100644
--- a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SearchReplicaReplicationIT.java
+++ b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SearchReplicaReplicationIT.java
@@ -8,14 +8,20 @@
 
 package org.opensearch.indices.replication;
 
+import org.opensearch.action.admin.indices.replication.SegmentReplicationStatsResponse;
 import org.opensearch.cluster.metadata.IndexMetadata;
 import org.opensearch.common.settings.Settings;
 import org.opensearch.common.util.FeatureFlags;
+import org.opensearch.index.SegmentReplicationPerGroupStats;
+import org.opensearch.index.SegmentReplicationShardStats;
+import org.opensearch.indices.replication.common.ReplicationType;
 import org.opensearch.test.OpenSearchIntegTestCase;
 import org.junit.After;
 import org.junit.Before;
 
 import java.nio.file.Path;
+import java.util.List;
+import java.util.Set;
 
 @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0)
 public class SearchReplicaReplicationIT extends SegmentReplicationBaseIT {
@@ -82,4 +88,47 @@ public void testReplication() throws Exception {
         waitForSearchableDocs(docCount, primary, replica);
     }
 
+    public void testSegmentReplicationStatsResponseWithSearchReplica() throws Exception {
+        internalCluster().startClusterManagerOnlyNode();
+        final List<String> nodes = internalCluster().startDataOnlyNodes(2);
+        createIndex(
+            INDEX_NAME,
+            Settings.builder()
+                .put("number_of_shards", 1)
+                .put("number_of_replicas", 0)
+                .put("number_of_search_only_replicas", 1)
+                .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
+                .build()
+        );
+        ensureGreen(INDEX_NAME);
+
+        final int docCount = 5;
+        for (int i = 0; i < docCount; i++) {
+            client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get();
+        }
+        refresh(INDEX_NAME);
+        waitForSearchableDocs(docCount, nodes);
+
+        SegmentReplicationStatsResponse segmentReplicationStatsResponse = dataNodeClient().admin()
+            .indices()
+            .prepareSegmentReplicationStats(INDEX_NAME)
+            .setDetailed(true)
+            .execute()
+            .actionGet();
+
+        // Verify the number of indices
+        assertEquals(1, segmentReplicationStatsResponse.getReplicationStats().size());
+        // Verify total shards
+        assertEquals(2, segmentReplicationStatsResponse.getTotalShards());
+        // Verify the number of primary shards
+        assertEquals(1, segmentReplicationStatsResponse.getReplicationStats().get(INDEX_NAME).size());
+
+        SegmentReplicationPerGroupStats perGroupStats = segmentReplicationStatsResponse.getReplicationStats().get(INDEX_NAME).get(0);
+        Set<SegmentReplicationShardStats> replicaStats = perGroupStats.getReplicaStats();
+        // Verify the number of replica stats
+        assertEquals(1, replicaStats.size());
+        for (SegmentReplicationShardStats replicaStat : replicaStats) {
+            assertNotNull(replicaStat.getCurrentReplicationState());
+        }
+    }
 }
diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsAction.java
index 7b36b7a10f4f2..01efa96a7369e 100644
--- a/server/src/main/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsAction.java
+++ b/server/src/main/java/org/opensearch/action/admin/cluster/shards/TransportCatShardsAction.java
@@ -18,6 +18,8 @@
 import org.opensearch.action.support.HandledTransportAction;
 import org.opensearch.action.support.TimeoutTaskCancellationUtility;
 import org.opensearch.client.node.NodeClient;
+import org.opensearch.cluster.ClusterState;
+import org.opensearch.cluster.metadata.IndexMetadata;
 import org.opensearch.common.breaker.ResponseLimitBreachedException;
 import org.opensearch.common.breaker.ResponseLimitSettings;
 import org.opensearch.common.inject.Inject;
@@ -27,6 +29,7 @@
 import org.opensearch.tasks.Task;
 import org.opensearch.transport.TransportService;
 
+import java.util.List;
 import java.util.Objects;
 
 import static org.opensearch.common.breaker.ResponseLimitSettings.LimitEntity.SHARDS;
@@ -98,9 +101,6 @@ public void onResponse(ClusterStateResponse clusterStateResponse) {
                             shardsRequest.getPageParams(),
                             clusterStateResponse
                         );
-                        String[] indices = Objects.isNull(paginationStrategy)
-                            ? shardsRequest.getIndices()
-                            : paginationStrategy.getRequestedIndices().toArray(new String[0]);
                         catShardsResponse.setNodes(clusterStateResponse.getState().getNodes());
                         catShardsResponse.setResponseShards(
                             Objects.isNull(paginationStrategy)
@@ -108,8 +108,12 @@ public void onResponse(ClusterStateResponse clusterStateResponse) {
                                 : paginationStrategy.getRequestedEntities()
                         );
                         catShardsResponse.setPageToken(Objects.isNull(paginationStrategy) ? null : paginationStrategy.getResponseToken());
+
+                        String[] indices = Objects.isNull(paginationStrategy)
+                            ? shardsRequest.getIndices()
+                            : filterClosedIndices(clusterStateResponse.getState(), paginationStrategy.getRequestedIndices());
                         // For paginated queries, if strategy outputs no shards to be returned, avoid fetching IndicesStats.
-                        if (shouldSkipIndicesStatsRequest(paginationStrategy)) {
+                        if (shouldSkipIndicesStatsRequest(paginationStrategy, indices)) {
                             catShardsResponse.setIndicesStatsResponse(IndicesStatsResponse.getEmptyResponse());
                             cancellableListener.onResponse(catShardsResponse);
                             return;
@@ -166,7 +170,19 @@ private void validateRequestLimit(
         }
     }
 
-    private boolean shouldSkipIndicesStatsRequest(ShardPaginationStrategy paginationStrategy) {
-        return Objects.nonNull(paginationStrategy) && paginationStrategy.getRequestedEntities().isEmpty();
+    private boolean shouldSkipIndicesStatsRequest(ShardPaginationStrategy paginationStrategy, String[] indices) {
+        return Objects.nonNull(paginationStrategy) && (indices == null || indices.length == 0);
+    }
+
+    /**
+     * Will be used by paginated query (_list/shards) to filter out closed indices (only consider OPEN) before fetching
+     * IndicesStats. Since pagination strategy always passes concrete indices to TransportIndicesStatsAction,
+     * the default behaviour of StrictExpandOpenAndForbidClosed leads to errors if closed indices are encountered.
+     */
+    private String[] filterClosedIndices(ClusterState clusterState, List<String> strategyIndices) {
+        return strategyIndices.stream().filter(index -> {
+            IndexMetadata metadata = clusterState.metadata().indices().get(index);
+            return metadata != null && metadata.getState().equals(IndexMetadata.State.CLOSE) == false;
+        }).toArray(String[]::new);
     }
 }
diff --git a/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java b/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java
index fc97d67c6c3af..44408c5043fcf 100644
--- a/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java
+++ b/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java
@@ -21,7 +21,6 @@
 import org.opensearch.core.action.support.DefaultShardOperationFailedException;
 import org.opensearch.core.common.io.stream.StreamInput;
 import org.opensearch.core.index.shard.ShardId;
-import org.opensearch.index.IndexService;
 import org.opensearch.index.SegmentReplicationPerGroupStats;
 import org.opensearch.index.SegmentReplicationPressureService;
 import org.opensearch.index.SegmentReplicationShardStats;
@@ -38,7 +37,9 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * Transport action for shard segment replication operation. This transport action does not actually
@@ -96,11 +97,11 @@ protected SegmentReplicationStatsResponse newResponse(
     ) {
         String[] shards = request.shards();
         final List<Integer> shardsToFetch = Arrays.stream(shards).map(Integer::valueOf).collect(Collectors.toList());
-
         // organize replica responses by allocationId.
         final Map<String, SegmentReplicationState> replicaStats = new HashMap<>();
         // map of index name to list of replication group stats.
         final Map<String, List<SegmentReplicationPerGroupStats>> primaryStats = new HashMap<>();
+
         for (SegmentReplicationShardStatsResponse response : responses) {
             if (response != null) {
                 if (response.getReplicaStats() != null) {
@@ -109,6 +110,7 @@ protected SegmentReplicationStatsResponse newResponse(
                         replicaStats.putIfAbsent(shardRouting.allocationId().getId(), response.getReplicaStats());
                     }
                 }
+
                 if (response.getPrimaryStats() != null) {
                     final ShardId shardId = response.getPrimaryStats().getShardId();
                     if (shardsToFetch.isEmpty() || shardsToFetch.contains(shardId.getId())) {
@@ -126,15 +128,20 @@ protected SegmentReplicationStatsResponse newResponse(
                 }
             }
         }
-        // combine the replica stats to the shard stat entry in each group.
-        for (Map.Entry<String, List<SegmentReplicationPerGroupStats>> entry : primaryStats.entrySet()) {
-            for (SegmentReplicationPerGroupStats group : entry.getValue()) {
-                for (SegmentReplicationShardStats replicaStat : group.getReplicaStats()) {
-                    replicaStat.setCurrentReplicationState(replicaStats.getOrDefault(replicaStat.getAllocationId(), null));
-                }
-            }
-        }
-        return new SegmentReplicationStatsResponse(totalShards, successfulShards, failedShards, primaryStats, shardFailures);
+
+        Map<String, List<SegmentReplicationPerGroupStats>> replicationStats = primaryStats.entrySet()
+            .stream()
+            .collect(
+                Collectors.toMap(
+                    Map.Entry::getKey,
+                    entry -> entry.getValue()
+                        .stream()
+                        .map(groupStats -> updateGroupStats(groupStats, replicaStats))
+                        .collect(Collectors.toList())
+                )
+            );
+
+        return new SegmentReplicationStatsResponse(totalShards, successfulShards, failedShards, replicationStats, shardFailures);
     }
 
     @Override
@@ -144,9 +151,8 @@ protected SegmentReplicationStatsRequest readRequestFrom(StreamInput in) throws
 
     @Override
     protected SegmentReplicationShardStatsResponse shardOperation(SegmentReplicationStatsRequest request, ShardRouting shardRouting) {
-        IndexService indexService = indicesService.indexServiceSafe(shardRouting.shardId().getIndex());
-        IndexShard indexShard = indexService.getShard(shardRouting.shardId().id());
         ShardId shardId = shardRouting.shardId();
+        IndexShard indexShard = indicesService.indexServiceSafe(shardId.getIndex()).getShard(shardId.id());
 
         if (indexShard.indexSettings().isSegRepEnabledOrRemoteNode() == false) {
             return null;
@@ -156,11 +162,7 @@ protected SegmentReplicationShardStatsResponse shardOperation(SegmentReplication
             return new SegmentReplicationShardStatsResponse(pressureService.getStatsForShard(indexShard));
         }
 
-        // return information about only on-going segment replication events.
-        if (request.activeOnly()) {
-            return new SegmentReplicationShardStatsResponse(targetService.getOngoingEventSegmentReplicationState(shardId));
-        }
-        return new SegmentReplicationShardStatsResponse(targetService.getSegmentReplicationState(shardId));
+        return new SegmentReplicationShardStatsResponse(getSegmentReplicationState(shardId, request.activeOnly()));
     }
 
     @Override
@@ -181,4 +183,83 @@ protected ClusterBlockException checkRequestBlock(
     ) {
         return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_READ, concreteIndices);
     }
+
+    private SegmentReplicationPerGroupStats updateGroupStats(
+        SegmentReplicationPerGroupStats groupStats,
+        Map<String, SegmentReplicationState> replicaStats
+    ) {
+        // Update the SegmentReplicationState for each of the replicas
+        Set<SegmentReplicationShardStats> updatedReplicaStats = groupStats.getReplicaStats()
+            .stream()
+            .peek(replicaStat -> replicaStat.setCurrentReplicationState(replicaStats.getOrDefault(replicaStat.getAllocationId(), null)))
+            .collect(Collectors.toSet());
+
+        // Compute search replica stats
+        Set<SegmentReplicationShardStats> searchReplicaStats = computeSearchReplicaStats(groupStats.getShardId(), replicaStats);
+
+        // Combine ReplicaStats and SearchReplicaStats
+        Set<SegmentReplicationShardStats> combinedStats = Stream.concat(updatedReplicaStats.stream(), searchReplicaStats.stream())
+            .collect(Collectors.toSet());
+
+        return new SegmentReplicationPerGroupStats(groupStats.getShardId(), combinedStats, groupStats.getRejectedRequestCount());
+    }
+
+    private Set<SegmentReplicationShardStats> computeSearchReplicaStats(
+        ShardId shardId,
+        Map<String, SegmentReplicationState> replicaStats
+    ) {
+        return replicaStats.values()
+            .stream()
+            .filter(segmentReplicationState -> segmentReplicationState.getShardRouting().shardId().equals(shardId))
+            .filter(segmentReplicationState -> segmentReplicationState.getShardRouting().isSearchOnly())
+            .map(segmentReplicationState -> {
+                ShardRouting shardRouting = segmentReplicationState.getShardRouting();
+                SegmentReplicationShardStats segmentReplicationStats = computeSegmentReplicationShardStats(shardRouting);
+                segmentReplicationStats.setCurrentReplicationState(segmentReplicationState);
+                return segmentReplicationStats;
+            })
+            .collect(Collectors.toSet());
+    }
+
+    SegmentReplicationShardStats computeSegmentReplicationShardStats(ShardRouting shardRouting) {
+        ShardId shardId = shardRouting.shardId();
+        SegmentReplicationState completedSegmentReplicationState = targetService.getlatestCompletedEventSegmentReplicationState(shardId);
+        SegmentReplicationState ongoingSegmentReplicationState = targetService.getOngoingEventSegmentReplicationState(shardId);
+
+        return new SegmentReplicationShardStats(
+            shardRouting.allocationId().getId(),
+            0,
+            calculateBytesRemainingToReplicate(ongoingSegmentReplicationState),
+            0,
+            getCurrentReplicationLag(ongoingSegmentReplicationState),
+            getLastCompletedReplicationLag(completedSegmentReplicationState)
+        );
+    }
+
+    private SegmentReplicationState getSegmentReplicationState(ShardId shardId, boolean isActiveOnly) {
+        if (isActiveOnly) {
+            return targetService.getOngoingEventSegmentReplicationState(shardId);
+        } else {
+            return targetService.getSegmentReplicationState(shardId);
+        }
+    }
+
+    private long calculateBytesRemainingToReplicate(SegmentReplicationState ongoingSegmentReplicationState) {
+        if (ongoingSegmentReplicationState == null) {
+            return 0;
+        }
+        return ongoingSegmentReplicationState.getIndex()
+            .fileDetails()
+            .stream()
+            .mapToLong(index -> index.length() - index.recovered())
+            .sum();
+    }
+
+    private long getCurrentReplicationLag(SegmentReplicationState ongoingSegmentReplicationState) {
+        return ongoingSegmentReplicationState != null ? ongoingSegmentReplicationState.getTimer().time() : 0;
+    }
+
+    private long getLastCompletedReplicationLag(SegmentReplicationState completedSegmentReplicationState) {
+        return completedSegmentReplicationState != null ? completedSegmentReplicationState.getTimer().time() : 0;
+    }
 }
diff --git a/server/src/test/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsActionTests.java b/server/src/test/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsActionTests.java
new file mode 100644
index 0000000000000..ea455d607f058
--- /dev/null
+++ b/server/src/test/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsActionTests.java
@@ -0,0 +1,595 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.action.admin.indices.replication;
+
+import org.opensearch.Version;
+import org.opensearch.action.support.ActionFilters;
+import org.opensearch.cluster.ClusterState;
+import org.opensearch.cluster.block.ClusterBlock;
+import org.opensearch.cluster.block.ClusterBlockLevel;
+import org.opensearch.cluster.block.ClusterBlocks;
+import org.opensearch.cluster.metadata.IndexMetadata;
+import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
+import org.opensearch.cluster.routing.AllocationId;
+import org.opensearch.cluster.routing.RoutingTable;
+import org.opensearch.cluster.routing.ShardIterator;
+import org.opensearch.cluster.routing.ShardRouting;
+import org.opensearch.cluster.routing.ShardsIterator;
+import org.opensearch.cluster.service.ClusterService;
+import org.opensearch.common.settings.Settings;
+import org.opensearch.core.action.support.DefaultShardOperationFailedException;
+import org.opensearch.core.index.Index;
+import org.opensearch.core.index.shard.ShardId;
+import org.opensearch.core.rest.RestStatus;
+import org.opensearch.index.IndexService;
+import org.opensearch.index.IndexSettings;
+import org.opensearch.index.SegmentReplicationPerGroupStats;
+import org.opensearch.index.SegmentReplicationPressureService;
+import org.opensearch.index.SegmentReplicationShardStats;
+import org.opensearch.index.shard.IndexShard;
+import org.opensearch.indices.IndicesService;
+import org.opensearch.indices.replication.SegmentReplicationState;
+import org.opensearch.indices.replication.SegmentReplicationTargetService;
+import org.opensearch.indices.replication.common.ReplicationLuceneIndex;
+import org.opensearch.indices.replication.common.ReplicationTimer;
+import org.opensearch.indices.replication.common.ReplicationType;
+import org.opensearch.test.OpenSearchTestCase;
+import org.opensearch.transport.TransportService;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.EnumSet;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TransportSegmentReplicationStatsActionTests extends OpenSearchTestCase {
+    @Mock
+    private ClusterService clusterService;
+    @Mock
+    private TransportService transportService;
+    @Mock
+    private IndicesService indicesService;
+    @Mock
+    private SegmentReplicationTargetService targetService;
+    @Mock
+    private ActionFilters actionFilters;
+    @Mock
+    private IndexNameExpressionResolver indexNameExpressionResolver;
+    @Mock
+    private SegmentReplicationPressureService pressureService;
+    @Mock
+    private IndexShard indexShard;
+    @Mock
+    private IndexService indexService;
+
+    private TransportSegmentReplicationStatsAction action;
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.openMocks(this);
+        super.setUp();
+        action = new TransportSegmentReplicationStatsAction(
+            clusterService,
+            transportService,
+            indicesService,
+            targetService,
+            actionFilters,
+            indexNameExpressionResolver,
+            pressureService
+        );
+    }
+
+    public void testShardReturnsAllTheShardsForTheIndex() {
+        SegmentReplicationStatsRequest segmentReplicationStatsRequest = mock(SegmentReplicationStatsRequest.class);
+        String[] concreteIndices = new String[] { "test-index" };
+        ClusterState clusterState = mock(ClusterState.class);
+        RoutingTable routingTables = mock(RoutingTable.class);
+        ShardsIterator shardsIterator = mock(ShardIterator.class);
+
+        when(clusterState.routingTable()).thenReturn(routingTables);
+        when(routingTables.allShardsIncludingRelocationTargets(any())).thenReturn(shardsIterator);
+        assertEquals(shardsIterator, action.shards(clusterState, segmentReplicationStatsRequest, concreteIndices));
+    }
+
+    public void testShardOperationWithPrimaryShard() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest();
+
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(shardRouting.primary()).thenReturn(true);
+        when(indicesService.indexServiceSafe(shardId.getIndex())).thenReturn(indexService);
+        when(indexService.getShard(shardId.id())).thenReturn(indexShard);
+        when(indexShard.indexSettings()).thenReturn(createIndexSettingsWithSegRepEnabled());
+
+        SegmentReplicationShardStatsResponse response = action.shardOperation(request, shardRouting);
+
+        assertNotNull(response);
+        verify(pressureService).getStatsForShard(any());
+    }
+
+    public void testShardOperationWithReplicaShard() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest();
+        request.activeOnly(false);
+        SegmentReplicationState completedSegmentReplicationState = mock(SegmentReplicationState.class);
+
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(shardRouting.primary()).thenReturn(false);
+        when(indicesService.indexServiceSafe(shardId.getIndex())).thenReturn(indexService);
+        when(indexService.getShard(shardId.id())).thenReturn(indexShard);
+        when(indexShard.indexSettings()).thenReturn(createIndexSettingsWithSegRepEnabled());
+        when(targetService.getSegmentReplicationState(shardId)).thenReturn(completedSegmentReplicationState);
+
+        SegmentReplicationShardStatsResponse response = action.shardOperation(request, shardRouting);
+
+        assertNotNull(response);
+        assertNull(response.getPrimaryStats());
+        assertNotNull(response.getReplicaStats());
+        verify(targetService).getSegmentReplicationState(shardId);
+    }
+
+    public void testShardOperationWithReplicaShardActiveOnly() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest();
+        request.activeOnly(true);
+        SegmentReplicationState onGoingSegmentReplicationState = mock(SegmentReplicationState.class);
+
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(shardRouting.primary()).thenReturn(false);
+        when(indicesService.indexServiceSafe(shardId.getIndex())).thenReturn(indexService);
+        when(indexService.getShard(shardId.id())).thenReturn(indexShard);
+        when(indexShard.indexSettings()).thenReturn(createIndexSettingsWithSegRepEnabled());
+        when(targetService.getOngoingEventSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState);
+
+        SegmentReplicationShardStatsResponse response = action.shardOperation(request, shardRouting);
+
+        assertNotNull(response);
+        assertNull(response.getPrimaryStats());
+        assertNotNull(response.getReplicaStats());
+        verify(targetService).getOngoingEventSegmentReplicationState(shardId);
+    }
+
+    public void testComputeBytesRemainingToReplicateWhenCompletedAndOngoingStateNotNull() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        SegmentReplicationState completedSegmentReplicationState = mock(SegmentReplicationState.class);
+        SegmentReplicationState onGoingSegmentReplicationState = mock(SegmentReplicationState.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        AllocationId allocationId = AllocationId.newInitializing();
+        ReplicationTimer replicationTimerCompleted = mock(ReplicationTimer.class);
+        ReplicationTimer replicationTimerOngoing = mock(ReplicationTimer.class);
+        long time1 = 10;
+        long time2 = 15;
+        ReplicationLuceneIndex replicationLuceneIndex = new ReplicationLuceneIndex();
+        replicationLuceneIndex.addFileDetail("name1", 10, false);
+        replicationLuceneIndex.addFileDetail("name2", 15, false);
+
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(shardRouting.allocationId()).thenReturn(allocationId);
+        when(targetService.getlatestCompletedEventSegmentReplicationState(shardId)).thenReturn(completedSegmentReplicationState);
+        when(targetService.getOngoingEventSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState);
+        when(completedSegmentReplicationState.getTimer()).thenReturn(replicationTimerCompleted);
+        when(onGoingSegmentReplicationState.getTimer()).thenReturn(replicationTimerOngoing);
+        when(replicationTimerOngoing.time()).thenReturn(time1);
+        when(replicationTimerCompleted.time()).thenReturn(time2);
+        when(onGoingSegmentReplicationState.getIndex()).thenReturn(replicationLuceneIndex);
+
+        SegmentReplicationShardStats segmentReplicationShardStats = action.computeSegmentReplicationShardStats(shardRouting);
+
+        assertNotNull(segmentReplicationShardStats);
+        assertEquals(25, segmentReplicationShardStats.getBytesBehindCount());
+        assertEquals(10, segmentReplicationShardStats.getCurrentReplicationLagMillis());
+        assertEquals(15, segmentReplicationShardStats.getLastCompletedReplicationTimeMillis());
+
+        verify(targetService).getlatestCompletedEventSegmentReplicationState(shardId);
+        verify(targetService).getOngoingEventSegmentReplicationState(shardId);
+    }
+
+    public void testCalculateBytesRemainingToReplicateWhenNoCompletedState() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        SegmentReplicationState onGoingSegmentReplicationState = mock(SegmentReplicationState.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        AllocationId allocationId = AllocationId.newInitializing();
+        ReplicationTimer replicationTimerOngoing = mock(ReplicationTimer.class);
+        long time1 = 10;
+        ReplicationLuceneIndex replicationLuceneIndex = new ReplicationLuceneIndex();
+        replicationLuceneIndex.addFileDetail("name1", 10, false);
+        replicationLuceneIndex.addFileDetail("name2", 15, false);
+
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(shardRouting.allocationId()).thenReturn(allocationId);
+        when(targetService.getOngoingEventSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState);
+        when(onGoingSegmentReplicationState.getTimer()).thenReturn(replicationTimerOngoing);
+        when(replicationTimerOngoing.time()).thenReturn(time1);
+        when(onGoingSegmentReplicationState.getIndex()).thenReturn(replicationLuceneIndex);
+
+        SegmentReplicationShardStats segmentReplicationShardStats = action.computeSegmentReplicationShardStats(shardRouting);
+
+        assertNotNull(segmentReplicationShardStats);
+        assertEquals(25, segmentReplicationShardStats.getBytesBehindCount());
+        assertEquals(10, segmentReplicationShardStats.getCurrentReplicationLagMillis());
+        assertEquals(0, segmentReplicationShardStats.getLastCompletedReplicationTimeMillis());
+
+        verify(targetService).getlatestCompletedEventSegmentReplicationState(shardId);
+        verify(targetService).getOngoingEventSegmentReplicationState(shardId);
+    }
+
+    public void testCalculateBytesRemainingToReplicateWhenNoOnGoingState() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        SegmentReplicationState completedSegmentReplicationState = mock(SegmentReplicationState.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        AllocationId allocationId = AllocationId.newInitializing();
+        ReplicationTimer replicationTimerCompleted = mock(ReplicationTimer.class);
+        long time2 = 15;
+
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(shardRouting.allocationId()).thenReturn(allocationId);
+        when(targetService.getlatestCompletedEventSegmentReplicationState(shardId)).thenReturn(completedSegmentReplicationState);
+        when(completedSegmentReplicationState.getTimer()).thenReturn(replicationTimerCompleted);
+        when(replicationTimerCompleted.time()).thenReturn(time2);
+
+        SegmentReplicationShardStats segmentReplicationShardStats = action.computeSegmentReplicationShardStats(shardRouting);
+
+        assertNotNull(segmentReplicationShardStats);
+        assertEquals(0, segmentReplicationShardStats.getBytesBehindCount());
+        assertEquals(0, segmentReplicationShardStats.getCurrentReplicationLagMillis());
+        assertEquals(15, segmentReplicationShardStats.getLastCompletedReplicationTimeMillis());
+
+        verify(targetService).getlatestCompletedEventSegmentReplicationState(shardId);
+        verify(targetService).getOngoingEventSegmentReplicationState(shardId);
+    }
+
+    public void testCalculateBytesRemainingToReplicateWhenNoCompletedAndOngoingState() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        AllocationId allocationId = AllocationId.newInitializing();
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(shardRouting.allocationId()).thenReturn(allocationId);
+
+        SegmentReplicationShardStats segmentReplicationShardStats = action.computeSegmentReplicationShardStats(shardRouting);
+
+        assertNotNull(segmentReplicationShardStats);
+        assertEquals(0, segmentReplicationShardStats.getBytesBehindCount());
+        assertEquals(0, segmentReplicationShardStats.getCurrentReplicationLagMillis());
+        assertEquals(0, segmentReplicationShardStats.getLastCompletedReplicationTimeMillis());
+
+        verify(targetService).getlatestCompletedEventSegmentReplicationState(shardId);
+        verify(targetService).getOngoingEventSegmentReplicationState(shardId);
+    }
+
+    public void testNewResponseWhenAllReplicasReturnResponseCombinesTheResults() {
+        SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest();
+        List<DefaultShardOperationFailedException> shardFailures = new ArrayList<>();
+        String[] shards = { "0", "1" };
+        request.shards(shards);
+
+        int totalShards = 6;
+        int successfulShards = 6;
+        int failedShard = 0;
+        String allocIdOne = "allocIdOne";
+        String allocIdTwo = "allocIdTwo";
+        String allocIdThree = "allocIdThree";
+        String allocIdFour = "allocIdFour";
+        String allocIdFive = "allocIdFive";
+        String allocIdSix = "allocIdSix";
+
+        ShardId shardId0 = mock(ShardId.class);
+        ShardRouting primary0 = mock(ShardRouting.class);
+        ShardRouting replica0 = mock(ShardRouting.class);
+        ShardRouting searchReplica0 = mock(ShardRouting.class);
+
+        ShardId shardId1 = mock(ShardId.class);
+        ShardRouting primary1 = mock(ShardRouting.class);
+        ShardRouting replica1 = mock(ShardRouting.class);
+        ShardRouting searchReplica1 = mock(ShardRouting.class);
+
+        when(shardId0.getId()).thenReturn(0);
+        when(shardId0.getIndexName()).thenReturn("test-index-1");
+        when(primary0.shardId()).thenReturn(shardId0);
+        when(replica0.shardId()).thenReturn(shardId0);
+        when(searchReplica0.shardId()).thenReturn(shardId0);
+
+        when(shardId1.getId()).thenReturn(1);
+        when(shardId1.getIndexName()).thenReturn("test-index-1");
+        when(primary1.shardId()).thenReturn(shardId1);
+        when(replica1.shardId()).thenReturn(shardId1);
+        when(searchReplica1.shardId()).thenReturn(shardId1);
+
+        AllocationId allocationIdOne = mock(AllocationId.class);
+        AllocationId allocationIdTwo = mock(AllocationId.class);
+        AllocationId allocationIdThree = mock(AllocationId.class);
+        AllocationId allocationIdFour = mock(AllocationId.class);
+        AllocationId allocationIdFive = mock(AllocationId.class);
+        AllocationId allocationIdSix = mock(AllocationId.class);
+
+        when(allocationIdOne.getId()).thenReturn(allocIdOne);
+        when(allocationIdTwo.getId()).thenReturn(allocIdTwo);
+        when(allocationIdThree.getId()).thenReturn(allocIdThree);
+        when(allocationIdFour.getId()).thenReturn(allocIdFour);
+        when(allocationIdFive.getId()).thenReturn(allocIdFive);
+        when(allocationIdSix.getId()).thenReturn(allocIdSix);
+        when(primary0.allocationId()).thenReturn(allocationIdOne);
+        when(replica0.allocationId()).thenReturn(allocationIdTwo);
+        when(searchReplica0.allocationId()).thenReturn(allocationIdThree);
+        when(primary1.allocationId()).thenReturn(allocationIdFour);
+        when(replica1.allocationId()).thenReturn(allocationIdFive);
+        when(searchReplica1.allocationId()).thenReturn(allocationIdSix);
+
+        when(primary0.isSearchOnly()).thenReturn(false);
+        when(replica0.isSearchOnly()).thenReturn(false);
+        when(searchReplica0.isSearchOnly()).thenReturn(true);
+        when(primary1.isSearchOnly()).thenReturn(false);
+        when(replica1.isSearchOnly()).thenReturn(false);
+        when(searchReplica1.isSearchOnly()).thenReturn(true);
+
+        Set<SegmentReplicationShardStats> segmentReplicationShardStats0 = new HashSet<>();
+        SegmentReplicationShardStats segmentReplicationShardStatsOfReplica0 = new SegmentReplicationShardStats(allocIdTwo, 0, 0, 0, 0, 0);
+        segmentReplicationShardStats0.add(segmentReplicationShardStatsOfReplica0);
+
+        Set<SegmentReplicationShardStats> segmentReplicationShardStats1 = new HashSet<>();
+        SegmentReplicationShardStats segmentReplicationShardStatsOfReplica1 = new SegmentReplicationShardStats(allocIdFive, 0, 0, 0, 0, 0);
+        segmentReplicationShardStats1.add(segmentReplicationShardStatsOfReplica1);
+
+        SegmentReplicationPerGroupStats segmentReplicationPerGroupStats0 = new SegmentReplicationPerGroupStats(
+            shardId0,
+            segmentReplicationShardStats0,
+            0
+        );
+
+        SegmentReplicationPerGroupStats segmentReplicationPerGroupStats1 = new SegmentReplicationPerGroupStats(
+            shardId1,
+            segmentReplicationShardStats1,
+            0
+        );
+
+        SegmentReplicationState segmentReplicationState0 = mock(SegmentReplicationState.class);
+        SegmentReplicationState searchReplicaSegmentReplicationState0 = mock(SegmentReplicationState.class);
+        SegmentReplicationState segmentReplicationState1 = mock(SegmentReplicationState.class);
+        SegmentReplicationState searchReplicaSegmentReplicationState1 = mock(SegmentReplicationState.class);
+
+        when(segmentReplicationState0.getShardRouting()).thenReturn(replica0);
+        when(searchReplicaSegmentReplicationState0.getShardRouting()).thenReturn(searchReplica0);
+        when(segmentReplicationState1.getShardRouting()).thenReturn(replica1);
+        when(searchReplicaSegmentReplicationState1.getShardRouting()).thenReturn(searchReplica1);
+
+        List<SegmentReplicationShardStatsResponse> responses = List.of(
+            new SegmentReplicationShardStatsResponse(segmentReplicationPerGroupStats0),
+            new SegmentReplicationShardStatsResponse(segmentReplicationState0),
+            new SegmentReplicationShardStatsResponse(searchReplicaSegmentReplicationState0),
+            new SegmentReplicationShardStatsResponse(segmentReplicationPerGroupStats1),
+            new SegmentReplicationShardStatsResponse(segmentReplicationState1),
+            new SegmentReplicationShardStatsResponse(searchReplicaSegmentReplicationState1)
+        );
+
+        SegmentReplicationStatsResponse response = action.newResponse(
+            request,
+            totalShards,
+            successfulShards,
+            failedShard,
+            responses,
+            shardFailures,
+            ClusterState.EMPTY_STATE
+        );
+
+        List<SegmentReplicationPerGroupStats> responseStats = response.getReplicationStats().get("test-index-1");
+        SegmentReplicationPerGroupStats primStats0 = responseStats.get(0);
+        Set<SegmentReplicationShardStats> replicaStats0 = primStats0.getReplicaStats();
+        assertEquals(2, replicaStats0.size());
+        for (SegmentReplicationShardStats replicaStat : replicaStats0) {
+            if (replicaStat.getAllocationId().equals(allocIdTwo)) {
+                assertEquals(segmentReplicationState0, replicaStat.getCurrentReplicationState());
+            }
+
+            if (replicaStat.getAllocationId().equals(allocIdThree)) {
+                assertEquals(searchReplicaSegmentReplicationState0, replicaStat.getCurrentReplicationState());
+            }
+        }
+
+        SegmentReplicationPerGroupStats primStats1 = responseStats.get(1);
+        Set<SegmentReplicationShardStats> replicaStats1 = primStats1.getReplicaStats();
+        assertEquals(2, replicaStats1.size());
+        for (SegmentReplicationShardStats replicaStat : replicaStats1) {
+            if (replicaStat.getAllocationId().equals(allocIdFive)) {
+                assertEquals(segmentReplicationState1, replicaStat.getCurrentReplicationState());
+            }
+
+            if (replicaStat.getAllocationId().equals(allocIdSix)) {
+                assertEquals(searchReplicaSegmentReplicationState1, replicaStat.getCurrentReplicationState());
+            }
+        }
+    }
+
+    public void testNewResponseWhenShardsToFetchEmptyAndResponsesContainsNull() {
+        SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest();
+        List<DefaultShardOperationFailedException> shardFailures = new ArrayList<>();
+        String[] shards = {};
+        request.shards(shards);
+
+        int totalShards = 3;
+        int successfulShards = 3;
+        int failedShard = 0;
+        String allocIdOne = "allocIdOne";
+        String allocIdTwo = "allocIdTwo";
+        ShardId shardIdOne = mock(ShardId.class);
+        ShardId shardIdTwo = mock(ShardId.class);
+        ShardId shardIdThree = mock(ShardId.class);
+        ShardRouting shardRoutingOne = mock(ShardRouting.class);
+        ShardRouting shardRoutingTwo = mock(ShardRouting.class);
+        ShardRouting shardRoutingThree = mock(ShardRouting.class);
+        when(shardIdOne.getId()).thenReturn(1);
+        when(shardIdTwo.getId()).thenReturn(2);
+        when(shardIdThree.getId()).thenReturn(3);
+        when(shardRoutingOne.shardId()).thenReturn(shardIdOne);
+        when(shardRoutingTwo.shardId()).thenReturn(shardIdTwo);
+        when(shardRoutingThree.shardId()).thenReturn(shardIdThree);
+        AllocationId allocationId = mock(AllocationId.class);
+        when(allocationId.getId()).thenReturn(allocIdOne);
+        when(shardRoutingTwo.allocationId()).thenReturn(allocationId);
+        when(shardIdOne.getIndexName()).thenReturn("test-index");
+
+        Set<SegmentReplicationShardStats> segmentReplicationShardStats = new HashSet<>();
+        SegmentReplicationShardStats segmentReplicationShardStatsOfReplica = new SegmentReplicationShardStats(allocIdOne, 0, 0, 0, 0, 0);
+        segmentReplicationShardStats.add(segmentReplicationShardStatsOfReplica);
+        SegmentReplicationPerGroupStats segmentReplicationPerGroupStats = new SegmentReplicationPerGroupStats(
+            shardIdOne,
+            segmentReplicationShardStats,
+            0
+        );
+
+        SegmentReplicationState segmentReplicationState = mock(SegmentReplicationState.class);
+        SegmentReplicationShardStats segmentReplicationShardStatsFromSearchReplica = mock(SegmentReplicationShardStats.class);
+        when(segmentReplicationShardStatsFromSearchReplica.getAllocationId()).thenReturn("alloc2");
+        when(segmentReplicationState.getShardRouting()).thenReturn(shardRoutingTwo);
+
+        List<SegmentReplicationShardStatsResponse> responses = new ArrayList<>();
+        responses.add(null);
+        responses.add(new SegmentReplicationShardStatsResponse(segmentReplicationPerGroupStats));
+        responses.add(new SegmentReplicationShardStatsResponse(segmentReplicationState));
+
+        SegmentReplicationStatsResponse response = action.newResponse(
+            request,
+            totalShards,
+            successfulShards,
+            failedShard,
+            responses,
+            shardFailures,
+            ClusterState.EMPTY_STATE
+        );
+
+        List<SegmentReplicationPerGroupStats> responseStats = response.getReplicationStats().get("test-index");
+        SegmentReplicationPerGroupStats primStats = responseStats.get(0);
+        Set<SegmentReplicationShardStats> segRpShardStatsSet = primStats.getReplicaStats();
+
+        for (SegmentReplicationShardStats segRpShardStats : segRpShardStatsSet) {
+            if (segRpShardStats.getAllocationId().equals(allocIdOne)) {
+                assertEquals(segmentReplicationState, segRpShardStats.getCurrentReplicationState());
+            }
+
+            if (segRpShardStats.getAllocationId().equals(allocIdTwo)) {
+                assertEquals(segmentReplicationShardStatsFromSearchReplica, segRpShardStats);
+            }
+        }
+    }
+
+    public void testShardOperationWithSegRepDisabled() {
+        ShardRouting shardRouting = mock(ShardRouting.class);
+        ShardId shardId = new ShardId(new Index("test-index", "test-uuid"), 0);
+        SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest();
+
+        when(shardRouting.shardId()).thenReturn(shardId);
+        when(indicesService.indexServiceSafe(shardId.getIndex())).thenReturn(indexService);
+        when(indexService.getShard(shardId.id())).thenReturn(indexShard);
+        when(indexShard.indexSettings()).thenReturn(createIndexSettingsWithSegRepDisabled());
+
+        SegmentReplicationShardStatsResponse response = action.shardOperation(request, shardRouting);
+
+        assertNull(response);
+    }
+
+    public void testGlobalBlockCheck() {
+        ClusterBlock writeClusterBlock = new ClusterBlock(
+            1,
+            "uuid",
+            "",
+            true,
+            true,
+            true,
+            RestStatus.OK,
+            EnumSet.of(ClusterBlockLevel.METADATA_WRITE)
+        );
+
+        ClusterBlock readClusterBlock = new ClusterBlock(
+            1,
+            "uuid",
+            "",
+            true,
+            true,
+            true,
+            RestStatus.OK,
+            EnumSet.of(ClusterBlockLevel.METADATA_READ)
+        );
+
+        ClusterBlocks.Builder builder = ClusterBlocks.builder();
+        builder.addGlobalBlock(writeClusterBlock);
+        ClusterState metadataWriteBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build();
+        assertNull(action.checkGlobalBlock(metadataWriteBlockedState, new SegmentReplicationStatsRequest()));
+
+        builder = ClusterBlocks.builder();
+        builder.addGlobalBlock(readClusterBlock);
+        ClusterState metadataReadBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build();
+        assertNotNull(action.checkGlobalBlock(metadataReadBlockedState, new SegmentReplicationStatsRequest()));
+    }
+
+    public void testIndexBlockCheck() {
+        ClusterBlock writeClusterBlock = new ClusterBlock(
+            1,
+            "uuid",
+            "",
+            true,
+            true,
+            true,
+            RestStatus.OK,
+            EnumSet.of(ClusterBlockLevel.METADATA_WRITE)
+        );
+
+        ClusterBlock readClusterBlock = new ClusterBlock(
+            1,
+            "uuid",
+            "",
+            true,
+            true,
+            true,
+            RestStatus.OK,
+            EnumSet.of(ClusterBlockLevel.METADATA_READ)
+        );
+
+        String indexName = "test";
+        ClusterBlocks.Builder builder = ClusterBlocks.builder();
+        builder.addIndexBlock(indexName, writeClusterBlock);
+        ClusterState metadataWriteBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build();
+        assertNull(action.checkRequestBlock(metadataWriteBlockedState, new SegmentReplicationStatsRequest(), new String[] { indexName }));
+
+        builder = ClusterBlocks.builder();
+        builder.addIndexBlock(indexName, readClusterBlock);
+        ClusterState metadataReadBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build();
+        assertNotNull(action.checkRequestBlock(metadataReadBlockedState, new SegmentReplicationStatsRequest(), new String[] { indexName }));
+    }
+
+    private IndexSettings createIndexSettingsWithSegRepEnabled() {
+        Settings settings = Settings.builder()
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2)
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 2)
+            .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
+            .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
+            .build();
+
+        return new IndexSettings(IndexMetadata.builder("test").settings(settings).build(), settings);
+    }
+
+    private IndexSettings createIndexSettingsWithSegRepDisabled() {
+        Settings settings = Settings.builder()
+            .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.DOCUMENT)
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2)
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 2)
+            .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
+            .build();
+        return new IndexSettings(IndexMetadata.builder("test").settings(settings).build(), settings);
+    }
+}