diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java index 353006cc75..ca03a784bb 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -37,8 +37,13 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.response.SendShuffleDataResult; import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.ThreadUtils; +/** + * A {@link DataPusher} that is responsible for sending data to remote + * shuffle servers asynchronously. + */ public class DataPusher implements Closeable { private static final Logger LOGGER = LoggerFactory.getLogger(DataPusher.class); @@ -49,7 +54,7 @@ public class DataPusher implements Closeable { private final Map> taskToSuccessBlockIds; // Must be thread safe private final Map> taskToFailedBlockIds; - private String appId; + private String rssAppId; // Must be thread safe private final Set failedTaskIds; @@ -74,13 +79,15 @@ public DataPusher(ShuffleWriteClient shuffleWriteClient, } public CompletableFuture send(AddBlockEvent event) { - assert appId != null; + if (rssAppId == null) { + throw new RssException("RssAppId should be set."); + } return CompletableFuture.supplyAsync(() -> { String taskId = event.getTaskId(); List shuffleBlockInfoList = event.getShuffleDataInfoList(); try { SendShuffleDataResult result = shuffleWriteClient.sendShuffleData( - appId, + rssAppId, shuffleBlockInfoList, () -> !isValidTask(taskId) ); @@ -113,8 +120,8 @@ public boolean isValidTask(String taskId) { return !failedTaskIds.contains(taskId); } - public void setAppId(String appId) { - this.appId = appId; + public void setRssAppId(String rssAppId) { + this.rssAppId = rssAppId; } @Override diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 497832c0df..580fce699d 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -299,6 +299,7 @@ public List buildBlockEvents(List shuffleBlockI if (totalSize > sendSizeLimit) { LOG.info("Build event with " + shuffleBlockInfosPerEvent.size() + " blocks and " + totalSize + " bytes"); + // Use final temporary variables for closures final long _memoryUsed = memoryUsed; events.add( new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed)) @@ -311,6 +312,7 @@ public List buildBlockEvents(List shuffleBlockI if (!shuffleBlockInfosPerEvent.isEmpty()) { LOG.info("Build event with " + shuffleBlockInfosPerEvent.size() + " blocks and " + totalSize + " bytes"); + // Use final temporary variables for closures final long _memoryUsed = memoryUsed; events.add( new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed)) @@ -328,7 +330,8 @@ public long spill(long size, MemoryConsumer trigger) { try { allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS); } catch (TimeoutException timeoutException) { - // ignore this. + // A best effort strategy to wait. + // If timeout exception occurs, the underlying tasks won't be cancelled. } finally { long releasedSize = futures.stream().filter(x -> x.isDone()).mapToLong(x -> { try { diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java index d0cf759dd7..20711dc053 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java @@ -94,7 +94,7 @@ public void testSendData() throws ExecutionException, InterruptedException { 1, 2 ); - dataPusher.setAppId("testSendData_appId"); + dataPusher.setRssAppId("testSendData_appId"); // sync send AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList( diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 1e9bc07b04..f7c6de3d74 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -249,10 +249,7 @@ public void spillTest() { // case1. all events are flushed within normal time. long releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class)); - assertEquals( - 64, - releasedSize - ); + assertEquals(64, releasedSize); // case2. partial events are not flushed within normal time. // when calling spill func, 2 events will be spilled. @@ -272,10 +269,7 @@ public void spillTest() { return event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum(); })); releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class)); - assertEquals( - 32, - releasedSize - ); + assertEquals(32, releasedSize); assertEquals(32, spyManager.getUsedBytes()); Awaitility.await().timeout(3, TimeUnit.SECONDS).until(() -> spyManager.getUsedBytes() == 0); } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 1e8c8387cb..472240512b 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -265,7 +265,7 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< if (id.get() == null) { id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid); - dataPusher.setAppId(id.get()); + dataPusher.setRssAppId(id.get()); } LOG.info("Generate application id used in rss: " + id.get());