diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index fdcc167b53..3c0d739edc 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -24,9 +24,11 @@ import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import com.google.common.collect.Maps; import scala.Option; import scala.Tuple2; import scala.collection.Iterator; @@ -658,6 +660,13 @@ public void addFailedBlockIds(String taskId, Set blockIds) { taskToFailedBlockIds.get(taskId).addAll(blockIds); } + @VisibleForTesting + public void addTaskToFailedBlockIdsAndServer(String taskId, Long blockId, ShuffleServerInfo shuffleServerInfo) { + taskToFailedBlockIdsAndServer.putIfAbsent(taskId, Maps.newHashMap()); + taskToFailedBlockIdsAndServer.get(taskId).putIfAbsent(blockId, new LinkedBlockingDeque<>()); + taskToFailedBlockIdsAndServer.get(taskId).get(blockId).add(shuffleServerInfo); + } + @VisibleForTesting public void addSuccessBlockIds(String taskId, Set blockIds) { if (taskToSuccessBlockIds.get(taskId) == null) { diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 13fb93f7a6..8b150f9ddc 100644 --- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -146,6 +146,8 @@ public void checkBlockSendResultTest() { // case 3: partial blocks are sent failed, Runtime exception will be thrown manager.addSuccessBlockIds(taskId, Sets.newHashSet(1L, 2L)); manager.addFailedBlockIds(taskId, Sets.newHashSet(3L)); + ShuffleServerInfo shuffleServerInfo = new ShuffleServerInfo("127.0.0.1", 20001); + manager.addTaskToFailedBlockIdsAndServer(taskId, 3L, shuffleServerInfo); Throwable e3 = assertThrows( RuntimeException.class, diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 206ceb33f9..9944082814 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -26,6 +26,7 @@ import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; import java.util.function.Function; import java.util.stream.Collectors; @@ -85,10 +86,11 @@ public void checkBlockSendResultTest() { .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346"); Map> failBlocks = JavaUtils.newConcurrentMap(); Map> successBlocks = JavaUtils.newConcurrentMap(); + Map>> taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap(); Serializer kryoSerializer = new KryoSerializer(conf); RssShuffleManager manager = TestUtils.createShuffleManager( - conf, false, null, successBlocks, failBlocks, JavaUtils.newConcurrentMap()); + conf, false, null, successBlocks, failBlocks, taskToFailedBlockIdsAndServer); ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); Partitioner mockPartitioner = mock(Partitioner.class); @@ -149,6 +151,12 @@ public void checkBlockSendResultTest() { // case 3: partial blocks are sent failed, Runtime exception will be thrown successBlocks.put("taskId", Sets.newHashSet(1L, 2L)); failBlocks.put("taskId", Sets.newHashSet(3L)); + Map> blockIdToShuffleServerInfoMap = JavaUtils.newConcurrentMap(); + BlockingQueue blockingQueue = new LinkedBlockingQueue<>(); + ShuffleServerInfo shuffleServerInfo = new ShuffleServerInfo("127.0.0.1", 20001); + blockingQueue.add(shuffleServerInfo); + blockIdToShuffleServerInfoMap.put(3L, blockingQueue); + taskToFailedBlockIdsAndServer.put("taskId", blockIdToShuffleServerInfoMap); Throwable e3 = assertThrows( RuntimeException.class,