diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index b993243016..9e64b2fd5a 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -247,7 +247,7 @@ public void write(Iterator> records) { private void writeImpl(Iterator> records) { List shuffleBlockInfos; - int recordCount = 0; + long recordCount = 0; while (records.hasNext()) { recordCount++; Product2 record = records.next(); @@ -266,7 +266,7 @@ private void writeImpl(Iterator> records) { shuffleBlockInfos = bufferManager.clear(); processShuffleBlockInfos(shuffleBlockInfos); long s = System.currentTimeMillis(); - assert recordCount == bufferManager.getRecordCount(); + checkSentRecordCount(recordCount); checkBlockSendResult(blockIds); final long checkDuration = System.currentTimeMillis() - s; long commitDuration = 0; @@ -294,6 +294,16 @@ private void writeImpl(Iterator> records) { + bufferManager.getManagerCostInfo()); } + private void checkSentRecordCount(long recordCount) { + if (recordCount != bufferManager.getRecordCount()) { + String errorMsg = + "Potential record loss may have occurred while preparing to send blocks for task[" + + taskId + + "]"; + throw new RssSendFailedException(errorMsg); + } + } + /** * ShuffleBlock will be added to queue and send to shuffle server maintenance the following * information: 1. add blockId to set, check if it is send later 2. update shuffle server info, diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 093fde95ce..2fc0340510 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -262,7 +262,7 @@ private void writeImpl(Iterator> records) { if (isCombine) { createCombiner = shuffleDependency.aggregator().get().createCombiner(); } - int recordCount = 0; + long recordCount = 0; while (records.hasNext()) { recordCount++; // Task should fast fail when sending data failed @@ -287,7 +287,7 @@ private void writeImpl(Iterator> records) { processShuffleBlockInfos(shuffleBlockInfos); } long checkStartTs = System.currentTimeMillis(); - assert recordCount == bufferManager.getRecordCount(); + checkSentRecordCount(recordCount); checkBlockSendResult(blockIds); long commitStartTs = System.currentTimeMillis(); long checkDuration = commitStartTs - checkStartTs; @@ -313,6 +313,16 @@ private void writeImpl(Iterator> records) { + bufferManager.getManagerCostInfo()); } + private void checkSentRecordCount(long recordCount) { + if (recordCount != bufferManager.getRecordCount()) { + String errorMsg = + "Potential record loss may have occurred while preparing to send blocks for task[" + + taskId + + "]"; + throw new RssSendFailedException(errorMsg); + } + } + // only push-based shuffle use this interface, but rss won't be used when push-based shuffle is // enabled. public long[] getPartitionLengths() {