Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#1594] improvement(client):support generating larger block size during shuffle map task by spill partial partitions data #1670

Merged
merged 15 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ public class RssSparkConfig {
.withDescription(
"The memory spill switch triggered by Spark TaskMemoryManager, default value is false.");

public static final ConfigOption<Double> RSS_MEMORY_SPILL_RATIO =
ConfigOptions.key("rss.client.memory.spill.ratio")
.doubleType()
.defaultValue(1.0d)
.withDescription(
"The buffer size to spill when spill triggered by config spark.rss.writer.buffer.spill.size");
public static final ConfigOption<Integer> RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM =
ConfigOptions.key("rss.client.reassign.maxReassignServerNum")
.intType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class BufferManagerOptions {
private long preAllocatedBufferSize;
private long requireMemoryInterval;
private int requireMemoryRetryMax;
private double bufferSpillPercent;

public BufferManagerOptions(SparkConf sparkConf) {
bufferSize =
Expand All @@ -53,6 +54,10 @@ public BufferManagerOptions(SparkConf sparkConf) {
sparkConf.getSizeAsBytes(
RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(),
RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.defaultValue().get());
bufferSpillPercent =
sparkConf.getDouble(
RssSparkConfig.RSS_MEMORY_SPILL_RATIO.key(),
RssSparkConfig.RSS_MEMORY_SPILL_RATIO.defaultValue());
preAllocatedBufferSize =
sparkConf.getSizeAsBytes(
RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(),
Expand Down Expand Up @@ -119,6 +124,10 @@ public long getBufferSpillThreshold() {
return bufferSpillThreshold;
}

public double getBufferSpillPercent() {
return bufferSpillPercent;
}

public long getRequireMemoryInterval() {
return requireMemoryInterval;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -97,6 +96,7 @@ public class WriteBufferManager extends MemoryConsumer {
private int memorySpillTimeoutSec;
private boolean isRowBased;
private BlockIdLayout blockIdLayout;
private double bufferSpillRatio;
private Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc;

public WriteBufferManager(
Expand Down Expand Up @@ -162,6 +162,7 @@ public WriteBufferManager(
this.sendSizeLimit = rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
this.memorySpillEnabled = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
this.bufferSpillRatio = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_RATIO);
this.blockIdLayout = BlockIdLayout.from(rssConf);
this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
}
Expand Down Expand Up @@ -204,13 +205,12 @@ public List<ShuffleBlockInfo> addPartitionData(
// check buffer size > spill threshold
if (usedBytes.get() - inSendListBytes.get() > spillSize) {
LOG.info(
"ShuffleBufferManager spill for buffer size exceeding spill threshold,"
+ "usedBytes[{}],inSendListBytes[{}],spillSize[{}]",
"ShuffleBufferManager spill for buffer size exceeding spill threshold, "
+ "usedBytes[{}], inSendListBytes[{}], spill size threshold[{}]",
usedBytes.get(),
inSendListBytes.get(),
spillSize);
List<ShuffleBlockInfo> multiSendingBlocks = clear();

List<ShuffleBlockInfo> multiSendingBlocks = clear(bufferSpillRatio);
multiSendingBlocks.addAll(singleOrEmptySendingBlocks);
writeTime += System.currentTimeMillis() - start;
return multiSendingBlocks;
Expand Down Expand Up @@ -323,20 +323,34 @@ public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object valu
}

// transform all [partition, records] to [partition, ShuffleBlockInfo] and clear cache
public synchronized List<ShuffleBlockInfo> clear() {
public synchronized List<ShuffleBlockInfo> clear(double bufferSpillRatio) {
List<ShuffleBlockInfo> result = Lists.newArrayList();
long dataSize = 0;
long memoryUsed = 0;
Iterator<Entry<Integer, WriterBuffer>> iterator = buffers.entrySet().iterator();
while (iterator.hasNext()) {
Entry<Integer, WriterBuffer> entry = iterator.next();
WriterBuffer wb = entry.getValue();
bufferSpillRatio = Math.max(0.1, Math.min(1.0, bufferSpillRatio));
List<Integer> partitionList = new ArrayList(buffers.keySet());
if (Double.compare(bufferSpillRatio, 1.0) < 0) {
partitionList.sort(
Comparator.comparingInt(o -> buffers.get(o) == null ? 0 : buffers.get(o).getMemoryUsed())
.reversed());
}
long targetSpillSize = (long) ((usedBytes.get() - inSendListBytes.get()) * bufferSpillRatio);
for (int partitionId : partitionList) {
WriterBuffer wb = buffers.get(partitionId);
if (wb == null) {
LOG.warn("get partition buffer failed,this should not happen!");
continue;
}
dataSize += wb.getDataLength();
memoryUsed += wb.getMemoryUsed();
result.add(createShuffleBlock(entry.getKey(), wb));
result.add(createShuffleBlock(partitionId, wb));
recordCounter.addAndGet(wb.getRecordCount());
iterator.remove();
copyTime += wb.getCopyTime();
buffers.remove(partitionId);
// got enough buffer to spill
if (memoryUsed >= targetSpillSize) {
break;
}
}
LOG.info(
"Flush total buffer for shuffleId["
Expand All @@ -349,6 +363,8 @@ public synchronized List<ShuffleBlockInfo> clear() {
+ memoryUsed
+ "], number of blocks["
+ result.size()
+ "], flush ratio["
+ bufferSpillRatio
+ "]");
return result;
}
Expand Down Expand Up @@ -491,7 +507,7 @@ public long spill(long size, MemoryConsumer trigger) {
return 0L;
}

List<CompletableFuture<Long>> futures = spillFunc.apply(clear());
List<CompletableFuture<Long>> futures = spillFunc.apply(clear(bufferSpillRatio));
CompletableFuture<Void> allOfFutures =
CompletableFuture.allOf(futures.toArray(new CompletableFuture[futures.size()]));
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ private void addRecord(boolean compress, BlockIdLayout layout) throws IllegalAcc
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(2, testKey, testValue);
result = wbm.clear();
result = wbm.clear(1.0);
assertEquals(3, result.size());
assertEquals(224, wbm.getAllocatedBytes());
assertEquals(96, wbm.getUsedBytes());
Expand Down Expand Up @@ -433,6 +433,56 @@ public void spillByOwnTest() {
Awaitility.await().timeout(5, TimeUnit.SECONDS).until(() -> spyManager.getUsedBytes() == 0);
}

@Test
public void spillPartial() {
SparkConf conf = getConf();
conf.set("spark.rss.client.send.size.limit", "1000");
conf.set("spark.rss.client.memory.spill.ratio", "0.5");
conf.set("spark.rss.client.memory.spill.enabled", "true");
TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);

WriteBufferManager wbm =
new WriteBufferManager(
0,
"taskId_spillPartialTest",
0,
bufferOptions,
new KryoSerializer(conf),
Maps.newHashMap(),
mockTaskMemoryManager,
new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf),
null);

Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
blocks -> {
long sum = 0L;
List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
for (AddBlockEvent event : events) {
event.getProcessedCallbackChain().stream().forEach(x -> x.run());
sum += event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum();
}
return Arrays.asList(CompletableFuture.completedFuture(sum));
};
wbm.setSpillFunc(spillFunc);

when(wbm.acquireMemory(512)).thenReturn(512L);

String testKey = "Key";
String testValue = "Value";
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(1, testKey, testValue);

long releasedSize = wbm.spill(1000, wbm);
assertEquals(64, releasedSize);
assertEquals(96, wbm.getUsedBytes());
assertEquals(0, wbm.getBuffers().keySet().toArray()[0]);
}

public static class FakedTaskMemoryManager extends TaskMemoryManager {
private static final Logger LOGGER = LoggerFactory.getLogger(FakedTaskMemoryManager.class);
private int invokedCnt = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
}

final long start = System.currentTimeMillis();
shuffleBlockInfos = bufferManager.clear();
shuffleBlockInfos = bufferManager.clear(1.0);
processShuffleBlockInfos(shuffleBlockInfos);
long s = System.currentTimeMillis();
checkSentRecordCount(recordCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
}
}
final long start = System.currentTimeMillis();
shuffleBlockInfos = bufferManager.clear();
shuffleBlockInfos = bufferManager.clear(1.0);
if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
processShuffleBlockInfos(shuffleBlockInfos);
}
Expand Down
Loading