Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#731] feat(spark): Make blockid layout configurable for Spark clients #1528

Merged
merged 29 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b76ee2d
Move blockId field bit lengths from Constants into BlockIdLayout
EnricoMi Feb 12, 2024
7179361
Create BlockIdLayout from config, use everywhere
EnricoMi Feb 12, 2024
9b5cd42
Use 'Bits' rather than 'Length'
EnricoMi Feb 20, 2024
4c8f5f2
Add block id config to docs, add spark specific considerations
EnricoMi Feb 20, 2024
dac4d97
Remove map config type, simplify block id layout config
EnricoMi Feb 20, 2024
e984f58
Fix after rebase
EnricoMi Feb 20, 2024
fc5820c
Use layout to log block id in ShuffleReadClientImpl
EnricoMi Feb 23, 2024
7c137d4
Add tests
EnricoMi Feb 26, 2024
757ef79
Make DefaultIdHelper not fall back to default BlockIdLayout
EnricoMi Feb 28, 2024
5ba16c2
Remove BlockIdLayout from ShuffleReadClient builder
EnricoMi Feb 29, 2024
59f077a
Fix test after rebase
EnricoMi Feb 28, 2024
0a55688
Remove leftovers, have tez sort WriteBufferManager use layout from conf
EnricoMi Feb 29, 2024
7e3c86a
Fix code style
EnricoMi Feb 29, 2024
c1c2fbb
Move block id bits docs from general client to spark guide
EnricoMi Feb 29, 2024
81a33a3
Fix ShuffleReadClientImplTest
EnricoMi Feb 29, 2024
6cf1458
Rename getTaskAttemptId to getTaskAttemptIdForBlockId, fix compile error
EnricoMi Mar 1, 2024
d360754
Fix code styles
EnricoMi Mar 1, 2024
6583ce4
Better use of parametrized tests
EnricoMi Mar 1, 2024
3703c4f
Test coordinator dynamic client conf
EnricoMi Mar 1, 2024
6700f36
Log block id where partition used to be logged
EnricoMi Mar 4, 2024
ed7b3b2
Remove default int values, use DEFAULT instance instead
EnricoMi Mar 4, 2024
a419387
Handle and test BlockIdLayout not given in getShuffleResult* GRPC req…
EnricoMi Mar 4, 2024
f3bae27
Test RssShuffleDataIterator with various layouts
EnricoMi Mar 4, 2024
24b153b
Revert removed empty line
EnricoMi Mar 4, 2024
f1fc15d
Add more details on carefully configuring block id layout for Spark
EnricoMi Mar 4, 2024
2ab119a
Fail fast when registering shuffle with unsupported number of partitions
EnricoMi Mar 5, 2024
fe3156b
Add more errors to spark client guide blockid section
EnricoMi Mar 5, 2024
727129e
Move fail-fast test to Spark3
EnricoMi Mar 5, 2024
bf0902d
Remove default int values (complements ed7b3b2f)
EnricoMi Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,24 @@
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.Constants;

public class RssMRUtils {

private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class);
private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
private static final int MAX_ATTEMPT_LENGTH = 6;
private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
private static final int MAX_SEQUENCE_NO =
(1 << (Constants.ATOMIC_INT_MAX_LENGTH - MAX_ATTEMPT_LENGTH)) - 1;
(1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1;

// Class TaskAttemptId have two field id and mapId, rss taskAttemptID have 21 bits,
// mapId is 19 bits, id is 2 bits. MR have a trick logic, taskAttemptId will increase
// 1000 * (appAttemptId - 1), so we will decrease it.
public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int appAttemptId) {
int lowBytes = taskAttemptID.getTaskID().getId();
if (lowBytes > Constants.MAX_TASK_ATTEMPT_ID) {
if (lowBytes > LAYOUT.maxTaskAttemptId) {
throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed");
}
if (appAttemptId < 1) {
Expand All @@ -64,16 +65,16 @@ public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int a
throw new RssException(
"TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed");
}
return BlockId.getBlockId(highBytes, 0, lowBytes);
return LAYOUT.getBlockId(highBytes, 0, lowBytes);
}

public static TaskAttemptID createMRTaskAttemptId(
JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId) {
if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
TaskID taskID = new TaskID(jobID, taskType, BlockId.getTaskAttemptId(rssTaskAttemptId));
int id = BlockId.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId - 1);
TaskID taskID = new TaskID(jobID, taskType, LAYOUT.getTaskAttemptId(rssTaskAttemptId));
int id = LAYOUT.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId - 1);
return new TaskAttemptID(taskID, id);
}

Expand Down Expand Up @@ -227,8 +228,7 @@ public static String getString(Configuration rssJobConf, String key, String defa
}

public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) {
long attemptId =
taskAttemptId >> (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits);
if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) {
throw new RssException(
"Can't support attemptId [" + attemptId + "], the max value should be " + MAX_ATTEMPT_ID);
Expand All @@ -240,17 +240,15 @@ public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo

int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId);
long taskId =
taskAttemptId
- (attemptId
<< (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH));
taskAttemptId - (attemptId << (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits));

return BlockId.getBlockId(atomicInt, partitionId, taskId);
return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
}

public static long getTaskAttemptId(long blockId) {
int mapId = BlockId.getTaskAttemptId(blockId);
int attemptId = BlockId.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
return BlockId.getBlockId(attemptId, 0, mapId);
int mapId = LAYOUT.getTaskAttemptId(blockId);
int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
return LAYOUT.getBlockId(attemptId, 0, mapId);
}

public static int estimateTaskConcurrency(JobConf jobConf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.util.StorageType;

Expand Down Expand Up @@ -80,16 +81,16 @@ public void blockConvertTest() {

@Test
public void partitionIdConvertBlockTest() {
BlockIdLayout layout = BlockIdLayout.DEFAULT;
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1);
long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1;
long mask = (1L << layout.partitionIdBits) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
long blockId = RssMRUtils.getBlockId(partitionId, taskAttemptId, seqNo);
int newPartitionId =
Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask);
int newPartitionId = Math.toIntExact((blockId >> layout.taskAttemptIdBits) & mask);
assertEquals(partitionId, newPartitionId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;

public class WriteBufferManager extends MemoryConsumer {
Expand Down Expand Up @@ -95,6 +95,7 @@ public class WriteBufferManager extends MemoryConsumer {
private boolean memorySpillEnabled;
private int memorySpillTimeoutSec;
private boolean isRowBased;
private BlockIdLayout blockIdLayout;

public WriteBufferManager(
int shuffleId,
Expand Down Expand Up @@ -160,6 +161,7 @@ public WriteBufferManager(
this.sendSizeLimit = rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
this.memorySpillEnabled = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
this.blockIdLayout = BlockIdLayout.from(rssConf);
}

/** add serialized columnar data directly when integrate with gluten */
Expand Down Expand Up @@ -325,7 +327,8 @@ protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb)
compressTime += System.currentTimeMillis() - start;
}
final long crc32 = ChecksumUtils.getCrc32(compressed);
final long blockId = BlockId.getBlockId(getNextSeqNo(partitionId), partitionId, taskAttemptId);
final long blockId =
blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId, taskAttemptId);
uncompressedDataLen += data.length;
shuffleWriteMetrics.incBytesWritten(compressed.length);
// add memory to indicate bytes which will be sent to shuffle server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
Expand Down Expand Up @@ -51,8 +50,11 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;

/** See static overload of this method. */
public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);

/**
* Provides a task attempt id that is unique for a shuffle stage.
* Provides a task attempt id to be used in the block id, that is unique for a shuffle stage.
*
* <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
* that is unique across the entire Spark app which can reach very large numbers, which can
Expand All @@ -64,8 +66,7 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
*
* @return a task attempt id unique for a shuffle stage
*/
@VisibleForTesting
protected static long getTaskAttemptId(
protected static long getTaskAttemptIdForBlockId(
int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
// attempt number is zero based: 0, 1, …, maxFailures-1
// max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.storage.HadoopTestBase;
import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
Expand Down Expand Up @@ -95,6 +95,7 @@ protected void writeTestData(
int partitionID,
boolean compress)
throws Exception {
BlockIdLayout layout = BlockIdLayout.DEFAULT;
List<ShufflePartitionedBlock> blocks = Lists.newArrayList();
SerializerInstance serializerInstance = serializer.newInstance();
for (int i = 0; i < blockNum; i++) {
Expand All @@ -106,7 +107,7 @@ protected void writeTestData(
expectedData.put(key, value);
writeData(serializeStream, key, value);
}
long blockId = BlockId.getBlockId(atomicInteger.getAndIncrement(), partitionID, 0);
long blockId = layout.getBlockId(atomicInteger.getAndIncrement(), partitionID, 0);
blockIdBitmap.add(blockId);
blocks.add(createShuffleBlock(output.toBytes(), blockId, compress));
serializeStream.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;

Expand Down Expand Up @@ -81,7 +80,8 @@ public void readTest1() throws Exception {

validateResult(rssShuffleDataIterator, expectedData, 10);

blockIdBitmap.add(BlockId.getBlockId(Constants.MAX_SEQUENCE_NO, 0, 0));
BlockIdLayout layout = BlockIdLayout.DEFAULT;
blockIdBitmap.add(layout.getBlockId(layout.maxSequenceNo, 0, 0));
rssShuffleDataIterator =
getDataIterator(basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(ssi1));
int recNum = 0;
Expand Down Expand Up @@ -270,7 +270,9 @@ public void readTest7() throws Exception {
}
fail(EXPECTED_EXCEPTION_MESSAGE);
} catch (Exception e) {
assertTrue(e.getMessage().startsWith("Unexpected crc value"));
assertTrue(
e.getMessage()
.startsWith("Unexpected crc value for blockId[0 (seq: 0, part: 0, task: 0)]"));
}

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Stream;

import com.google.common.collect.Maps;
import org.apache.commons.lang3.reflect.FieldUtils;
Expand All @@ -38,11 +39,16 @@
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.BlockIdLayout;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -84,18 +90,34 @@ private SparkConf getConf() {
return conf;
}

@Test
public void addRecordCompressedTest() throws Exception {
addRecord(true);
public static Stream<Arguments> testBlockIdLayouts() {
return Stream.of(
Arguments.of(BlockIdLayout.DEFAULT), Arguments.of(BlockIdLayout.from(20, 21, 22)));
}

@Test
public void addRecordUnCompressedTest() throws Exception {
addRecord(false);
@ParameterizedTest
@MethodSource("testBlockIdLayouts")
public void addRecordCompressedTest(BlockIdLayout layout) throws Exception {
addRecord(true, layout);
}

@ParameterizedTest
@MethodSource("testBlockIdLayouts")
public void addRecordUnCompressedTest(BlockIdLayout layout) throws Exception {
addRecord(false, layout);
}

private void addRecord(boolean compress) throws IllegalAccessException {
private void addRecord(boolean compress, BlockIdLayout layout) throws IllegalAccessException {
SparkConf conf = getConf();
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
String.valueOf(layout.sequenceNoBits));
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(),
String.valueOf(layout.partitionIdBits));
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(),
String.valueOf(layout.taskAttemptIdBits));
if (!compress) {
conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false));
}
Expand All @@ -122,6 +144,7 @@ private void addRecord(boolean compress) throws IllegalAccessException {
result = wbm.addRecord(0, testKey, testValue);
// single buffer is full
assertEquals(1, result.size());
assertEquals(layout.asBlockId(0, 0, 0), layout.asBlockId(result.get(0).getBlockId()));
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(96, wbm.getUsedBytes());
assertEquals(96, wbm.getInSendListBytes());
Expand All @@ -139,6 +162,12 @@ private void addRecord(boolean compress) throws IllegalAccessException {
wbm.addRecord(4, testKey, testValue);
result = wbm.addRecord(5, testKey, testValue);
assertEquals(6, result.size());
assertEquals(layout.asBlockId(1, 0, 0), layout.asBlockId(result.get(0).getBlockId()));
assertEquals(layout.asBlockId(0, 1, 0), layout.asBlockId(result.get(1).getBlockId()));
assertEquals(layout.asBlockId(0, 2, 0), layout.asBlockId(result.get(2).getBlockId()));
assertEquals(layout.asBlockId(0, 3, 0), layout.asBlockId(result.get(3).getBlockId()));
assertEquals(layout.asBlockId(0, 4, 0), layout.asBlockId(result.get(4).getBlockId()));
assertEquals(layout.asBlockId(0, 5, 0), layout.asBlockId(result.get(5).getBlockId()));
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(288, wbm.getUsedBytes());
assertEquals(288, wbm.getInSendListBytes());
Expand Down Expand Up @@ -221,28 +250,40 @@ public void addPartitionDataTest() {
assertEquals(0, spyManager.getShuffleWriteMetrics().recordsWritten());
}

@Test
public void createBlockIdTest() {
@ParameterizedTest
@MethodSource("testBlockIdLayouts")
public void createBlockIdTest(BlockIdLayout layout) {
SparkConf conf = getConf();
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
String.valueOf(layout.sequenceNoBits));
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(),
String.valueOf(layout.partitionIdBits));
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(),
String.valueOf(layout.taskAttemptIdBits));

WriteBufferManager wbm = createManager(conf);
WriterBuffer mockWriterBuffer = mock(WriterBuffer.class);
when(mockWriterBuffer.getData()).thenReturn(new byte[] {});
when(mockWriterBuffer.getMemoryUsed()).thenReturn(0);
ShuffleBlockInfo sbi = wbm.createShuffleBlock(0, mockWriterBuffer);

// seqNo = 0, partitionId = 0, taskId = 0
assertEquals(0L, sbi.getBlockId());
assertEquals(layout.asBlockId(0, 0, 0), layout.asBlockId(sbi.getBlockId()));

// seqNo = 1, partitionId = 0, taskId = 0
sbi = wbm.createShuffleBlock(0, mockWriterBuffer);
assertEquals(35184372088832L, sbi.getBlockId());
assertEquals(layout.asBlockId(1, 0, 0), layout.asBlockId(sbi.getBlockId()));

// seqNo = 0, partitionId = 1, taskId = 0
sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
assertEquals(2097152L, sbi.getBlockId());
assertEquals(layout.asBlockId(0, 1, 0), layout.asBlockId(sbi.getBlockId()));

// seqNo = 1, partitionId = 1, taskId = 0
sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
assertEquals(35184374185984L, sbi.getBlockId());
assertEquals(layout.asBlockId(1, 1, 0), layout.asBlockId(sbi.getBlockId()));
}

@Test
Expand Down
Loading
Loading