Skip to content

Commit

Permalink
[#808] improvement(spark): Verify the number of written records to en…
Browse files Browse the repository at this point in the history
…sure data correctness (#1558)

### What changes were proposed in this pull request?

Verify the number of written records to enhance data accuracy.
Make sure all data records are sent by clients.
Make sure bugs like #714 will never be introduced into the code.

### Why are the changes needed?

A follow-up PR for #848.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing UTs.
  • Loading branch information
rickyma authored Mar 7, 2024
1 parent b36d461 commit ec4251d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public class WriteBufferManager extends MemoryConsumer {
private AtomicLong usedBytes = new AtomicLong(0);
// bytes of shuffle data which is in send list
private AtomicLong inSendListBytes = new AtomicLong(0);
/** An atomic counter used to keep track of the number of records */
private AtomicLong recordCounter = new AtomicLong(0);
// it's part of blockId
private Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
private long askExecutorMemory;
Expand Down Expand Up @@ -236,6 +238,7 @@ private List<ShuffleBlockInfo> insertIntoBuffer(
if (wb.getMemoryUsed() > bufferSize) {
List<ShuffleBlockInfo> sentBlocks = new ArrayList<>(1);
sentBlocks.add(createShuffleBlock(partitionId, wb));
recordCounter.addAndGet(wb.getRecordCount());
copyTime += wb.getCopyTime();
buffers.remove(partitionId);
if (LOG.isDebugEnabled()) {
Expand Down Expand Up @@ -298,6 +301,7 @@ public synchronized List<ShuffleBlockInfo> clear() {
dataSize += wb.getDataLength();
memoryUsed += wb.getMemoryUsed();
result.add(createShuffleBlock(entry.getKey(), wb));
recordCounter.addAndGet(wb.getRecordCount());
iterator.remove();
copyTime += wb.getCopyTime();
}
Expand Down Expand Up @@ -509,6 +513,10 @@ protected long getInSendListBytes() {
return inSendListBytes.get();
}

protected long getRecordCount() {
return recordCounter.get();
}

public void freeAllocatedMemory(long freeMemory) {
freeMemory(freeMemory);
allocatedBytes.addAndGet(-freeMemory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class WriterBuffer {
private List<WrappedBuffer> buffers = Lists.newArrayList();
private int dataLength = 0;
private int memoryUsed = 0;
private long recordCount = 0;

public WriterBuffer(int bufferSize) {
this.bufferSize = bufferSize;
Expand Down Expand Up @@ -66,6 +67,7 @@ public void addRecord(byte[] recordBuffer, int length) {

nextOffset += length;
dataLength += length;
recordCount++;
}

public boolean askForMemory(long length) {
Expand Down Expand Up @@ -98,6 +100,10 @@ public int getMemoryUsed() {
return memoryUsed;
}

public long getRecordCount() {
return recordCount;
}

private static final class WrappedBuffer {

byte[] buffer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ public void write(Iterator<Product2<K, V>> records) {

private void writeImpl(Iterator<Product2<K, V>> records) {
List<ShuffleBlockInfo> shuffleBlockInfos;
long recordCount = 0;
while (records.hasNext()) {
recordCount++;
Product2<K, V> record = records.next();
int partition = getPartition(record._1());
if (shuffleDependency.mapSideCombine()) {
Expand All @@ -264,6 +266,7 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
shuffleBlockInfos = bufferManager.clear();
processShuffleBlockInfos(shuffleBlockInfos);
long s = System.currentTimeMillis();
checkSentRecordCount(recordCount);
checkBlockSendResult(blockIds);
final long checkDuration = System.currentTimeMillis() - s;
long commitDuration = 0;
Expand Down Expand Up @@ -291,6 +294,16 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
+ bufferManager.getManagerCostInfo());
}

private void checkSentRecordCount(long recordCount) {
if (recordCount != bufferManager.getRecordCount()) {
String errorMsg =
"Potential record loss may have occurred while preparing to send blocks for task["
+ taskId
+ "]";
throw new RssSendFailedException(errorMsg);
}
}

/**
* ShuffleBlock will be added to queue and send to shuffle server maintenance the following
* information: 1. add blockId to set, check if it is send later 2. update shuffle server info,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
if (isCombine) {
createCombiner = shuffleDependency.aggregator().get().createCombiner();
}
long recordCount = 0;
while (records.hasNext()) {
recordCount++;
// Task should fast fail when sending data failed
checkIfBlocksFailed();

Expand All @@ -285,6 +287,7 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
processShuffleBlockInfos(shuffleBlockInfos);
}
long checkStartTs = System.currentTimeMillis();
checkSentRecordCount(recordCount);
checkBlockSendResult(blockIds);
long commitStartTs = System.currentTimeMillis();
long checkDuration = commitStartTs - checkStartTs;
Expand All @@ -310,6 +313,16 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
+ bufferManager.getManagerCostInfo());
}

private void checkSentRecordCount(long recordCount) {
if (recordCount != bufferManager.getRecordCount()) {
String errorMsg =
"Potential record loss may have occurred while preparing to send blocks for task["
+ taskId
+ "]";
throw new RssSendFailedException(errorMsg);
}
}

// only push-based shuffle use this interface, but rss won't be used when push-based shuffle is
// enabled.
public long[] getPartitionLengths() {
Expand Down

0 comments on commit ec4251d

Please sign in to comment.