Skip to content

Commit

Permalink
[ISSUE-468] Put unavailable servers to the end of the list when sendi…
Browse files Browse the repository at this point in the history
…ng shuffle data (#470)

### What changes were proposed in this pull request?
Put unavailable shuffle servers to the end of the server list when sending shuffle data if replica=1.

### Why are the changes needed?

If we use multiple replicas and the first shuffle server becomes unavailable, sending data will take a lot of time. Because the client will always send to the first shuffle server firstly. #468 

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
UT
  • Loading branch information
xianjingfeng authored Jan 31, 2023
1 parent 6367b36 commit ebbe2db
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
Expand Down Expand Up @@ -105,6 +106,7 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
private final ExecutorService dataTransferPool;
private final int unregisterThreadPoolSize;
private final int unregisterRequestTimeSec;
private Set<ShuffleServerInfo> defectiveServers;

public ShuffleWriteClientImpl(
String clientType,
Expand Down Expand Up @@ -133,6 +135,9 @@ public ShuffleWriteClientImpl(
this.dataCommitPoolSize = dataCommitPoolSize;
this.unregisterThreadPoolSize = unregisterThreadPoolSize;
this.unregisterRequestTimeSec = unregisterRequestTimeSec;
if (replica > 1) {
defectiveServers = Sets.newConcurrentHashSet();
}
}

private boolean sendShuffleDataAsync(
Expand Down Expand Up @@ -170,12 +175,21 @@ private boolean sendShuffleDataAsync(
if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
// mark a replica of block that has been sent
serverToBlockIds.get(ssi).forEach(block -> blockIdsTracker.get(block).incrementAndGet());
if (defectiveServers != null) {
defectiveServers.remove(ssi);
}
LOG.info("{} successfully.", logMsg);
} else {
if (defectiveServers != null) {
defectiveServers.add(ssi);
}
LOG.warn("{}, it failed wth statusCode[{}]", logMsg, response.getStatusCode());
return false;
}
} catch (Exception e) {
if (defectiveServers != null) {
defectiveServers.add(ssi);
}
LOG.warn("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] failed.", e);
return false;
}
Expand All @@ -192,12 +206,27 @@ private boolean sendShuffleDataAsync(
return result;
}

private void genServerToBlocks(ShuffleBlockInfo sbi, List<ShuffleServerInfo> serverList,
void genServerToBlocks(
ShuffleBlockInfo sbi,
List<ShuffleServerInfo> serverList,
int replicaNum,
List<ShuffleServerInfo> excludeServers,
Map<ShuffleServerInfo, Map<Integer, Map<Integer, List<ShuffleBlockInfo>>>> serverToBlocks,
Map<ShuffleServerInfo, List<Long>> serverToBlockIds) {
Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
boolean excludeDefectiveServers) {
if (replicaNum <= 0) {
return;
}
int partitionId = sbi.getPartitionId();
int shuffleId = sbi.getShuffleId();
int assignedNum = 0;
for (ShuffleServerInfo ssi : serverList) {
if (excludeDefectiveServers && replica > 1 && defectiveServers.contains(ssi)) {
continue;
}
if (CollectionUtils.isNotEmpty(excludeServers) && excludeServers.contains(ssi)) {
continue;
}
if (!serverToBlockIds.containsKey(ssi)) {
serverToBlockIds.put(ssi, Lists.newArrayList());
}
Expand All @@ -216,6 +245,18 @@ private void genServerToBlocks(ShuffleBlockInfo sbi, List<ShuffleServerInfo> ser
partitionToBlocks.put(partitionId, Lists.newArrayList());
}
partitionToBlocks.get(partitionId).add(sbi);
if (excludeServers != null) {
excludeServers.add(ssi);
}
assignedNum++;
if (assignedNum >= replicaNum) {
break;
}
}

if (assignedNum < replicaNum && excludeDefectiveServers) {
genServerToBlocks(sbi, serverList, replicaNum - assignedNum,
excludeServers, serverToBlocks, serverToBlockIds, false);
}
}

Expand Down Expand Up @@ -247,14 +288,15 @@ public SendShuffleDataResult sendShuffleData(String appId, List<ShuffleBlockInfo
for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
List<ShuffleServerInfo> allServers = sbi.getShuffleServerInfos();
if (replicaSkipEnabled) {
genServerToBlocks(sbi, allServers.subList(0, replicaWrite),
primaryServerToBlocks, primaryServerToBlockIds);
genServerToBlocks(sbi, allServers.subList(replicaWrite, replica),
secondaryServerToBlocks, secondaryServerToBlockIds);
List<ShuffleServerInfo> excludeServers = new ArrayList<>();
genServerToBlocks(sbi, allServers, replicaWrite, excludeServers,
primaryServerToBlocks, primaryServerToBlockIds, true);
genServerToBlocks(sbi, allServers,replica - replicaWrite,
excludeServers, secondaryServerToBlocks, secondaryServerToBlockIds, false);
} else {
// When replicaSkip is disabled, we send data to all replicas within one round.
genServerToBlocks(sbi, allServers,
primaryServerToBlocks, primaryServerToBlockIds);
genServerToBlocks(sbi, allServers, allServers.size(),
null, primaryServerToBlocks, primaryServerToBlockIds, false);
}
}

Expand Down Expand Up @@ -756,6 +798,11 @@ public ShuffleServerClient getShuffleServerClient(ShuffleServerInfo shuffleServe
return ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType, shuffleServerInfo);
}

@VisibleForTesting
Set<ShuffleServerInfo> getDefectiveServers() {
return defectiveServers;
}

void addShuffleServer(String appId, int shuffleId, ShuffleServerInfo serverInfo) {
Map<Integer, Set<ShuffleServerInfo>> appServerMap = shuffleServerInfoMap.get(appId);
if (appServerMap == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.uniffle.client.impl;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand Down Expand Up @@ -104,4 +106,80 @@ public void testRegisterAndUnRegisterShuffleServer() {
shuffleWriteClient.unregisterShuffle(appId1, 1);
assertEquals(1, shuffleWriteClient.getAllShuffleServers(appId1).size());
}

@Test
public void testSendDataWithDefectiveServers() {
ShuffleWriteClientImpl shuffleWriteClient =
new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 3, 2, 2, true, 1, 1, 10, 10);
ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
when(mockShuffleServerClient.sendShuffleData(any())).thenReturn(
new RssSendShuffleDataResponse(ResponseStatusCode.NO_BUFFER),
new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS),
new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS));

String appId = "testSendDataWithDefectiveServers_appId";
ShuffleServerInfo ssi1 = new ShuffleServerInfo("127.0.0.1", 0);
ShuffleServerInfo ssi2 = new ShuffleServerInfo("127.0.0.1", 1);
ShuffleServerInfo ssi3 = new ShuffleServerInfo("127.0.0.1", 2);
List<ShuffleServerInfo> shuffleServerInfoList =
Lists.newArrayList(ssi1, ssi2, ssi3);
List<ShuffleBlockInfo> shuffleBlockInfoList = Lists.newArrayList(new ShuffleBlockInfo(
0, 0, 10, 10, 10, new byte[]{1}, shuffleServerInfoList, 10, 100, 0));
SendShuffleDataResult result = spyClient.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
assertEquals(0, result.getFailedBlockIds().size());

// Send data for the second time, the first shuffle server will be moved to the last.
when(mockShuffleServerClient.sendShuffleData(any())).thenReturn(
new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS),
new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS));
List<ShuffleServerInfo> excludeServers = new ArrayList<>();
spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
2, excludeServers, Maps.newHashMap(), Maps.newHashMap(), true);
assertEquals(2, excludeServers.size());
assertEquals(ssi2, excludeServers.get(0));
assertEquals(ssi3, excludeServers.get(1));
spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
1, excludeServers, Maps.newHashMap(), Maps.newHashMap(), false);
assertEquals(3, excludeServers.size());
assertEquals(ssi1, excludeServers.get(2));
result = spyClient.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
assertEquals(0, result.getFailedBlockIds().size());

// Send data for the third time, the first server will be removed from the defectiveServers
// and the second server will be added to the defectiveServers.
when(mockShuffleServerClient.sendShuffleData(any())).thenReturn(
new RssSendShuffleDataResponse(ResponseStatusCode.NO_BUFFER),
new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS),
new RssSendShuffleDataResponse(ResponseStatusCode.SUCCESS));
List<ShuffleServerInfo> shuffleServerInfoList2 = Lists.newArrayList(ssi2, ssi1, ssi3);
List<ShuffleBlockInfo> shuffleBlockInfoList2 = Lists.newArrayList(new ShuffleBlockInfo(0, 0, 10, 10, 10,
new byte[]{1}, shuffleServerInfoList2, 10, 100, 0));
result = spyClient.sendShuffleData(appId, shuffleBlockInfoList2, () -> false);
assertEquals(0, result.getFailedBlockIds().size());
assertEquals(1, spyClient.getDefectiveServers().size());
assertEquals(ssi2, spyClient.getDefectiveServers().toArray()[0]);
excludeServers = new ArrayList<>();
spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
2, excludeServers, Maps.newHashMap(), Maps.newHashMap(), true);
assertEquals(2, excludeServers.size());
assertEquals(ssi1, excludeServers.get(0));
assertEquals(ssi3, excludeServers.get(1));
spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
1, excludeServers, Maps.newHashMap(), Maps.newHashMap(), false);
assertEquals(3, excludeServers.size());
assertEquals(ssi2, excludeServers.get(2));

// Check whether it is normal when two shuffle servers in defectiveServers
spyClient.getDefectiveServers().add(ssi1);
assertEquals(2, spyClient.getDefectiveServers().size());
excludeServers = new ArrayList<>();
spyClient.genServerToBlocks(shuffleBlockInfoList.get(0), shuffleServerInfoList,
2, excludeServers, Maps.newHashMap(), Maps.newHashMap(), true);
assertEquals(2, excludeServers.size());
assertEquals(ssi3, excludeServers.get(0));
assertEquals(ssi1, excludeServers.get(1));
}

}

0 comments on commit ebbe2db

Please sign in to comment.