From 9fafb5dc8fc302746405d0b5b1586113ea5a29a6 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Mon, 26 Feb 2024 20:37:41 +0100 Subject: [PATCH] Add tests --- .../manager/RssShuffleManagerBase.java | 9 + .../writer/WriteBufferManagerTest.java | 213 +++++++++++------- .../spark/shuffle/RssShuffleManager.java | 8 +- .../spark/shuffle/RssShuffleManager.java | 26 ++- .../client/impl/ShuffleReadClientImpl.java | 2 +- .../impl/ShuffleReadClientImplTest.java | 83 ++++++- .../impl/ShuffleWriteClientImplTest.java | 64 ++++++ .../apache/uniffle/common/util/BlockId.java | 19 ++ .../uniffle/common/util/BlockIdLayout.java | 25 +- .../common/util/BlockIdLayoutTest.java | 22 +- .../uniffle/common/util/BlockIdTest.java | 35 ++- .../uniffle/common/util/RssUtilsTest.java | 53 +++-- .../uniffle/test/ShuffleServerGrpcTest.java | 144 ++++++------ .../uniffle/test/RssShuffleManagerTest.java | 51 ++++- 14 files changed, 543 insertions(+), 211 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index b70ab6933a..69f553ce01 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -23,11 +23,13 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; import org.apache.hadoop.conf.Configuration; import org.apache.spark.MapOutputTracker; import org.apache.spark.MapOutputTrackerMaster; import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; import org.apache.spark.SparkEnv; import org.apache.spark.SparkException; import org.apache.spark.shuffle.RssSparkConfig; @@ -50,6 +52,13 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac private Method unregisterAllMapOutputMethod; private Method registerShuffleMethod; + /** + * Provides a task attempt id that is unique for a shuffle stage. + * + * @return a task attempt id unique for a shuffle stage + */ + public abstract long getTaskAttemptId(int mapIndex, int attemptNo, long taskAttemptId); + @Override public void unregisterAllMapOutput(int shuffleId) throws SparkException { if (!RssSparkShuffleUtils.isStageResubmitSupported()) { diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index c0fb191dcf..5c0eb7d87f 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -35,6 +35,10 @@ import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.serializer.Serializer; import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.uniffle.client.util.RssClientConfig; +import org.apache.uniffle.common.config.RssClientConf; +import org.apache.uniffle.common.util.BlockId; +import org.apache.uniffle.common.util.BlockIdLayout; import org.awaitility.Awaitility; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -44,8 +48,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.config.RssConf; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -53,6 +56,7 @@ import static org.mockito.Mockito.when; public class WriteBufferManagerTest { + // test blockid config is considered private WriteBufferManager createManager(SparkConf conf) { Serializer kryoSerializer = new KryoSerializer(conf); @@ -95,72 +99,91 @@ public void addRecordUnCompressedTest() throws Exception { } private void addRecord(boolean compress) throws IllegalAccessException { - SparkConf conf = getConf(); - if (!compress) { - conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false)); - } - WriteBufferManager wbm = createManager(conf); - Object codec = FieldUtils.readField(wbm, "codec", true); - if (compress) { - Assertions.assertNotNull(codec); - } else { - Assertions.assertNull(codec); + // test with different block id layouts + for (BlockIdLayout layout: new BlockIdLayout[] { + BlockIdLayout.DEFAULT, BlockIdLayout.from(20, 21, 22) + }) { + // we should also test layouts that are different to the default + if (layout != BlockIdLayout.DEFAULT) { + assertNotEquals(layout, BlockIdLayout.DEFAULT); + } + SparkConf conf = getConf(); + conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(layout.sequenceNoBits)); + conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(layout.partitionIdBits)); + conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(layout.taskAttemptIdBits)); + if (!compress) { + conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false)); + } + WriteBufferManager wbm = createManager(conf); + Object codec = FieldUtils.readField(wbm, "codec", true); + if (compress) { + Assertions.assertNotNull(codec); + } else { + Assertions.assertNull(codec); + } + wbm.setShuffleWriteMetrics(new ShuffleWriteMetrics()); + String testKey = "Key"; + String testValue = "Value"; + List result = wbm.addRecord(0, testKey, testValue); + // single buffer is not full, there is no data return + assertEquals(0, result.size()); + assertEquals(512, wbm.getAllocatedBytes()); + assertEquals(32, wbm.getUsedBytes()); + assertEquals(0, wbm.getInSendListBytes()); + assertEquals(1, wbm.getBuffers().size()); + wbm.addRecord(0, testKey, testValue); + wbm.addRecord(0, testKey, testValue); + wbm.addRecord(0, testKey, testValue); + result = wbm.addRecord(0, testKey, testValue); + // single buffer is full + assertEquals(1, result.size()); + assertEquals(layout.asBlockId(0, 0, 0), layout.asBlockId(result.get(0).getBlockId())); + assertEquals(512, wbm.getAllocatedBytes()); + assertEquals(96, wbm.getUsedBytes()); + assertEquals(96, wbm.getInSendListBytes()); + assertEquals(0, wbm.getBuffers().size()); + wbm.addRecord(0, testKey, testValue); + wbm.addRecord(1, testKey, testValue); + wbm.addRecord(2, testKey, testValue); + // single buffer is not full, and less than spill size + assertEquals(512, wbm.getAllocatedBytes()); + assertEquals(192, wbm.getUsedBytes()); + assertEquals(96, wbm.getInSendListBytes()); + assertEquals(3, wbm.getBuffers().size()); + // all buffer size > spill size + wbm.addRecord(3, testKey, testValue); + wbm.addRecord(4, testKey, testValue); + result = wbm.addRecord(5, testKey, testValue); + assertEquals(6, result.size()); + assertEquals(layout.asBlockId(1, 0, 0), layout.asBlockId(result.get(0).getBlockId())); + assertEquals(layout.asBlockId(0, 1, 0), layout.asBlockId(result.get(1).getBlockId())); + assertEquals(layout.asBlockId(0, 2, 0), layout.asBlockId(result.get(2).getBlockId())); + assertEquals(layout.asBlockId(0, 3, 0), layout.asBlockId(result.get(3).getBlockId())); + assertEquals(layout.asBlockId(0, 4, 0), layout.asBlockId(result.get(4).getBlockId())); + assertEquals(layout.asBlockId(0, 5, 0), layout.asBlockId(result.get(5).getBlockId())); + assertEquals(512, wbm.getAllocatedBytes()); + assertEquals(288, wbm.getUsedBytes()); + assertEquals(288, wbm.getInSendListBytes()); + assertEquals(0, wbm.getBuffers().size()); + // free memory + wbm.freeAllocatedMemory(96); + assertEquals(416, wbm.getAllocatedBytes()); + assertEquals(192, wbm.getUsedBytes()); + assertEquals(192, wbm.getInSendListBytes()); + + assertEquals(11, wbm.getShuffleWriteMetrics().recordsWritten()); + assertTrue(wbm.getShuffleWriteMetrics().bytesWritten() > 0); + + wbm.freeAllocatedMemory(192); + wbm.addRecord(0, testKey, testValue); + wbm.addRecord(1, testKey, testValue); + wbm.addRecord(2, testKey, testValue); + result = wbm.clear(); + assertEquals(3, result.size()); + assertEquals(224, wbm.getAllocatedBytes()); + assertEquals(96, wbm.getUsedBytes()); + assertEquals(96, wbm.getInSendListBytes()); } - wbm.setShuffleWriteMetrics(new ShuffleWriteMetrics()); - String testKey = "Key"; - String testValue = "Value"; - List result = wbm.addRecord(0, testKey, testValue); - // single buffer is not full, there is no data return - assertEquals(0, result.size()); - assertEquals(512, wbm.getAllocatedBytes()); - assertEquals(32, wbm.getUsedBytes()); - assertEquals(0, wbm.getInSendListBytes()); - assertEquals(1, wbm.getBuffers().size()); - wbm.addRecord(0, testKey, testValue); - wbm.addRecord(0, testKey, testValue); - wbm.addRecord(0, testKey, testValue); - result = wbm.addRecord(0, testKey, testValue); - // single buffer is full - assertEquals(1, result.size()); - assertEquals(512, wbm.getAllocatedBytes()); - assertEquals(96, wbm.getUsedBytes()); - assertEquals(96, wbm.getInSendListBytes()); - assertEquals(0, wbm.getBuffers().size()); - wbm.addRecord(0, testKey, testValue); - wbm.addRecord(1, testKey, testValue); - wbm.addRecord(2, testKey, testValue); - // single buffer is not full, and less than spill size - assertEquals(512, wbm.getAllocatedBytes()); - assertEquals(192, wbm.getUsedBytes()); - assertEquals(96, wbm.getInSendListBytes()); - assertEquals(3, wbm.getBuffers().size()); - // all buffer size > spill size - wbm.addRecord(3, testKey, testValue); - wbm.addRecord(4, testKey, testValue); - result = wbm.addRecord(5, testKey, testValue); - assertEquals(6, result.size()); - assertEquals(512, wbm.getAllocatedBytes()); - assertEquals(288, wbm.getUsedBytes()); - assertEquals(288, wbm.getInSendListBytes()); - assertEquals(0, wbm.getBuffers().size()); - // free memory - wbm.freeAllocatedMemory(96); - assertEquals(416, wbm.getAllocatedBytes()); - assertEquals(192, wbm.getUsedBytes()); - assertEquals(192, wbm.getInSendListBytes()); - - assertEquals(11, wbm.getShuffleWriteMetrics().recordsWritten()); - assertTrue(wbm.getShuffleWriteMetrics().bytesWritten() > 0); - - wbm.freeAllocatedMemory(192); - wbm.addRecord(0, testKey, testValue); - wbm.addRecord(1, testKey, testValue); - wbm.addRecord(2, testKey, testValue); - result = wbm.clear(); - assertEquals(3, result.size()); - assertEquals(224, wbm.getAllocatedBytes()); - assertEquals(96, wbm.getUsedBytes()); - assertEquals(96, wbm.getInSendListBytes()); } @Test @@ -223,26 +246,42 @@ public void addPartitionDataTest() { @Test public void createBlockIdTest() { - SparkConf conf = getConf(); - WriteBufferManager wbm = createManager(conf); - WriterBuffer mockWriterBuffer = mock(WriterBuffer.class); - when(mockWriterBuffer.getData()).thenReturn(new byte[] {}); - when(mockWriterBuffer.getMemoryUsed()).thenReturn(0); - ShuffleBlockInfo sbi = wbm.createShuffleBlock(0, mockWriterBuffer); - // seqNo = 0, partitionId = 0, taskId = 0 - assertEquals(0L, sbi.getBlockId()); - - // seqNo = 1, partitionId = 0, taskId = 0 - sbi = wbm.createShuffleBlock(0, mockWriterBuffer); - assertEquals(35184372088832L, sbi.getBlockId()); - - // seqNo = 0, partitionId = 1, taskId = 0 - sbi = wbm.createShuffleBlock(1, mockWriterBuffer); - assertEquals(2097152L, sbi.getBlockId()); - - // seqNo = 1, partitionId = 1, taskId = 0 - sbi = wbm.createShuffleBlock(1, mockWriterBuffer); - assertEquals(35184374185984L, sbi.getBlockId()); + // test with different block id layouts + for (BlockIdLayout layout: new BlockIdLayout[] { + BlockIdLayout.DEFAULT, BlockIdLayout.from(20, 21, 22) + }) { + // we should also test layouts that are different to the default + if (layout != BlockIdLayout.DEFAULT) { + assertNotEquals(layout, BlockIdLayout.DEFAULT); + } + SparkConf conf = getConf(); + conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(layout.sequenceNoBits)); + conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(layout.partitionIdBits)); + conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(layout.taskAttemptIdBits)); + + WriteBufferManager wbm = createManager(conf); + WriterBuffer mockWriterBuffer = mock(WriterBuffer.class); + when(mockWriterBuffer.getData()).thenReturn(new byte[]{}); + when(mockWriterBuffer.getMemoryUsed()).thenReturn(0); + ShuffleBlockInfo sbi = wbm.createShuffleBlock(0, mockWriterBuffer); + + String layoutString = layout.toString(); + + // seqNo = 0, partitionId = 0, taskId = 0 + assertEquals(layout.asBlockId(0, 0, 0), layout.asBlockId(sbi.getBlockId()), layoutString); + + // seqNo = 1, partitionId = 0, taskId = 0 + sbi = wbm.createShuffleBlock(0, mockWriterBuffer); + assertEquals(layout.asBlockId(1, 0, 0), layout.asBlockId(sbi.getBlockId()), layoutString); + + // seqNo = 0, partitionId = 1, taskId = 0 + sbi = wbm.createShuffleBlock(1, mockWriterBuffer); + assertEquals(layout.asBlockId(0, 1, 0), layout.asBlockId(sbi.getBlockId()), layoutString); + + // seqNo = 1, partitionId = 1, taskId = 0 + sbi = wbm.createShuffleBlock(1, mockWriterBuffer); + assertEquals(layout.asBlockId(1, 1, 0), layout.asBlockId(sbi.getBlockId()), layoutString); + } } @Test diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index ed3f340618..9c87db54cc 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -465,7 +465,7 @@ public ShuffleWriter getWriter( rssHandle.getAppId(), shuffleId, taskId, - context.taskAttemptId(), + getTaskAttemptId(context.partitionId(), context.attemptNumber(), context.taskAttemptId()), writeMetrics, this, sparkConf, @@ -479,6 +479,12 @@ public ShuffleWriter getWriter( } } + @Override + @VisibleForTesting + public long getTaskAttemptId(int mapIndex, int attemptNo, long taskAttemptId) { + return taskAttemptId; + } + // This method is called in Spark executor, // getting information from Spark driver via the ShuffleHandle. @Override diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index d643a8174c..4578fd43ed 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -510,18 +510,11 @@ public ShuffleWriter getWriter( } String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId()); - long taskAttemptId = - getTaskAttemptId( - context.partitionId(), - context.attemptNumber(), - maxFailures, - speculation, - blockIdLayout.taskAttemptIdBits); return new RssShuffleWriter<>( rssHandle.getAppId(), shuffleId, taskId, - taskAttemptId, + getTaskAttemptId(context.partitionId(), context.attemptNumber(), context.taskAttemptId()), writeMetrics, this, sparkConf, @@ -532,6 +525,23 @@ public ShuffleWriter getWriter( shuffleHandleInfo); } + /** + * Provides a task attempt id that is unique for a shuffle stage. + * + * For details see overloaded method getTaskAttemptId. + * + * @return a task attempt id unique for a shuffle stage + */ + @VisibleForTesting + public long getTaskAttemptId(int mapIndex, int attemptNo, long taskAttemptId) { + return getTaskAttemptId( + mapIndex, + attemptNo, + maxFailures, + speculation, + blockIdLayout.taskAttemptIdBits); + } + /** * Provides a task attempt id that is unique for a shuffle stage. * diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java index 0aee06ec22..062cca4385 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java @@ -26,7 +26,6 @@ import com.google.common.collect.Lists; import com.google.common.collect.Queues; import org.apache.hadoop.conf.Configuration; -import org.apache.uniffle.common.util.BlockIdLayout; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,6 +41,7 @@ import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssFetchFailedException; +import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; import org.apache.uniffle.common.util.IdHelper; import org.apache.uniffle.common.util.RssUtils; diff --git a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java index 56d4933146..377d9a4ec1 100644 --- a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java +++ b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java @@ -46,6 +46,7 @@ import org.apache.uniffle.storage.util.StorageType; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -323,6 +324,7 @@ public void readTest8() throws Exception { .basePath(basePath) .blockIdBitmap(blockIdBitmap) .taskIdBitmap(taskIdBitmap) + .blockIdLayout(BlockIdLayout.DEFAULT) .build(); ShuffleReadClientImpl readClient2 = baseReadBuilder() @@ -330,6 +332,7 @@ public void readTest8() throws Exception { .basePath(basePath) .blockIdBitmap(blockIdBitmap) .taskIdBitmap(taskIdBitmap) + .blockIdLayout(BlockIdLayout.DEFAULT) .shuffleServerInfoList(Lists.newArrayList(ssi1, ssi2)) .build(); // crc32 is incorrect @@ -342,7 +345,11 @@ public void readTest8() throws Exception { } fail(EXPECTED_EXCEPTION_MESSAGE); } catch (Exception e) { - assertTrue(e.getMessage().startsWith("Unexpected crc value")); + assertTrue( + e.getMessage() + .startsWith( + "Unexpected crc value for blockId[5800000000000 (seq: 44, part: 0, task: 0)]"), + e.getMessage()); } CompressedShuffleBlock block = readClient2.readShuffleBlockData(); @@ -497,31 +504,69 @@ public void readTest12() throws Exception { @Test public void readTest13() throws Exception { - String basePath = HDFS_URI + "clientReadTest13"; + doReadTest13(BlockIdLayout.DEFAULT); + } + + @Test + public void readTest13b() throws Exception { + // This test is identical to readTest13, except that it does not use the default BlockIdLayout + // the layout is only used by IdHelper that extracts the task attempt id from the block id + // the partition id has to be larger than 0, so that it can leak into the task attempt id + // if the default layout is being used + BlockIdLayout layout = BlockIdLayout.from(22, 21, 20); + assertNotEquals(layout, BlockIdLayout.DEFAULT); + doReadTest13(layout); + } + + public void doReadTest13(BlockIdLayout layout) throws Exception { + String basePath = HDFS_URI + "clientReadTest13-" + layout.hashCode(); HadoopShuffleWriteHandler writeHandler = new HadoopShuffleWriteHandler("appId", 0, 1, 1, basePath, ssi1.getId(), conf); Map expectedData = Maps.newHashMap(); final Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); final Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0, 3); - writeTestData(writeHandler, 5, 30, 0, expectedData, blockIdBitmap); + writeTestData(writeHandler, 5, 30, 1, 0, expectedData, blockIdBitmap, layout); // test case: data generated by speculation task without report result - writeTestData(writeHandler, 5, 30, 1, Maps.newHashMap(), Roaring64NavigableMap.bitmapOf()); + writeTestData( + writeHandler, 5, 30, 1, 1, Maps.newHashMap(), Roaring64NavigableMap.bitmapOf(), layout); // test case: data generated by speculation task with report result - writeTestData(writeHandler, 5, 30, 2, Maps.newHashMap(), blockIdBitmap); - writeTestData(writeHandler, 5, 30, 3, expectedData, blockIdBitmap); + writeTestData(writeHandler, 5, 30, 1, 2, Maps.newHashMap(), blockIdBitmap, layout); + writeTestData(writeHandler, 5, 30, 1, 3, expectedData, blockIdBitmap, layout); // unexpected taskAttemptId should be filtered + assertEquals(15, blockIdBitmap.getIntCardinality()); ShuffleReadClientImpl readClient = baseReadBuilder() .basePath(basePath) .blockIdBitmap(blockIdBitmap) + .partitionId(1) .taskIdBitmap(taskIdBitmap) + .blockIdLayout(layout) .build(); + // note that skipped block ids in blockIdBitmap will be removed by `build()` + assertEquals(10, blockIdBitmap.getIntCardinality()); TestUtils.validateResult(readClient, expectedData); assertEquals(20, readClient.getProcessedBlockIds().getLongCardinality()); readClient.checkProcessedBlockIds(); readClient.close(); + + if (!layout.equals(BlockIdLayout.DEFAULT)) { + // creating a reader with a wrong block id layout will skip all blocks where task attempt id + // is not in taskIdBitmap + // the particular layout that created the block ids is incompatible with default layout, so + // all block ids will be skipped + // note that skipped block ids in blockIdBitmap will be removed by `build()` + baseReadBuilder() + .basePath(basePath) + .blockIdBitmap(blockIdBitmap) + .partitionId(1) + .taskIdBitmap(taskIdBitmap) + .blockIdLayout(BlockIdLayout.DEFAULT) + .build(); + // note that skipped block ids in blockIdBitmap will be removed by `build()` + assertEquals(0, blockIdBitmap.getIntCardinality()); + } } @Test @@ -582,16 +627,17 @@ private void writeTestData( HadoopShuffleWriteHandler writeHandler, int num, int length, + int partitionId, long taskAttemptId, Map expectedData, - Roaring64NavigableMap blockIdBitmap) + Roaring64NavigableMap blockIdBitmap, + BlockIdLayout layout) throws Exception { - BlockIdLayout layout = BlockIdLayout.DEFAULT; List blocks = Lists.newArrayList(); for (int i = 0; i < num; i++) { byte[] buf = new byte[length]; new Random().nextBytes(buf); - long blockId = layout.getBlockId(ATOMIC_INT.getAndIncrement(), 0, taskAttemptId); + long blockId = layout.getBlockId(ATOMIC_INT.getAndIncrement(), partitionId, taskAttemptId); blocks.add( new ShufflePartitionedBlock( length, length, ChecksumUtils.getCrc32(buf), blockId, taskAttemptId, buf)); @@ -601,6 +647,25 @@ private void writeTestData( writeHandler.write(blocks); } + private void writeTestData( + HadoopShuffleWriteHandler writeHandler, + int num, + int length, + long taskAttemptId, + Map expectedData, + Roaring64NavigableMap blockIdBitmap) + throws Exception { + writeTestData( + writeHandler, + num, + length, + 0, + taskAttemptId, + expectedData, + blockIdBitmap, + BlockIdLayout.DEFAULT); + } + private void writeDuplicatedData( HadoopShuffleWriteHandler writeHandler, int num, diff --git a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java index f3008e9fe7..09efd8b315 100644 --- a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java +++ b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java @@ -19,17 +19,21 @@ import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.concurrent.TimeUnit; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.mockito.Mockito; import org.mockito.stubbing.Answer; +import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.apache.uniffle.client.api.ShuffleServerClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.client.response.RssGetShuffleResultResponse; import org.apache.uniffle.client.response.RssSendShuffleDataResponse; import org.apache.uniffle.client.response.SendShuffleDataResult; import org.apache.uniffle.common.ClientType; @@ -37,14 +41,21 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.netty.IOMode; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.BlockIdLayout; +import org.apache.uniffle.common.util.RssUtils; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class ShuffleWriteClientImplTest { @@ -317,4 +328,57 @@ public void testSettingRssClientConfigs() { client.close(); assertEquals(IOMode.EPOLL, ioMode); } + + @Test + public void testGetShuffleResult() { + // test with different block id layouts + for (BlockIdLayout layout : + new BlockIdLayout[] {BlockIdLayout.DEFAULT, BlockIdLayout.from(20, 21, 22)}) { + // we should also test layouts that are different to the default + if (layout != BlockIdLayout.DEFAULT) { + assertNotEquals(layout, BlockIdLayout.DEFAULT); + } + RssConf rssConf = new RssConf(); + rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, layout.sequenceNoBits); + rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, layout.partitionIdBits); + rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, layout.taskAttemptIdBits); + ShuffleWriteClientImpl shuffleWriteClient = + ShuffleClientFactory.newWriteBuilder() + .clientType(ClientType.GRPC.name()) + .retryMax(3) + .retryIntervalMax(2000) + .heartBeatThreadNum(4) + .replica(1) + .replicaWrite(1) + .replicaRead(1) + .replicaSkipEnabled(true) + .dataTransferPoolSize(1) + .dataCommitPoolSize(1) + .unregisterThreadPoolSize(10) + .unregisterRequestTimeSec(10) + .rssConf(rssConf) + .build(); + ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class); + ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient); + doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any()); + RssGetShuffleResultResponse response; + try { + Roaring64NavigableMap res = Roaring64NavigableMap.bitmapOf(1L, 2L, 5L); + response = + new RssGetShuffleResultResponse(StatusCode.SUCCESS, RssUtils.serializeBitMap(res)); + } catch (Exception e) { + throw new RssException(e); + } + when(mockShuffleServerClient.getShuffleResult(any())).thenReturn(response); + + Set shuffleServerInfoSet = + Sets.newHashSet(new ShuffleServerInfo("id", "host", 0)); + Roaring64NavigableMap result = + spyClient.getShuffleResult("GRPC", shuffleServerInfoSet, "appId", 1, 2); + + verify(mockShuffleServerClient) + .getShuffleResult(argThat(request -> request.getBlockIdLayout().equals(layout))); + assertArrayEquals(result.stream().sorted().toArray(), new long[] {1L, 2L, 5L}); + } + } } diff --git a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java index 0ca3a18432..36025f66bd 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java +++ b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java @@ -17,6 +17,8 @@ package org.apache.uniffle.common.util; +import java.util.Objects; + /** * This represents a block id and all its constituents. This is particularly useful for logging and * debugging block ids. @@ -53,4 +55,21 @@ public String toString() { + taskAttemptId + ")]"; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BlockId blockId1 = (BlockId) o; + return blockId == blockId1.blockId && Objects.equals(layout, blockId1.layout); + } + + @Override + public int hashCode() { + return Objects.hash(blockId, layout); + } } diff --git a/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java b/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java index ef5ee94ad4..066f3d969a 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java +++ b/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java @@ -17,6 +17,8 @@ package org.apache.uniffle.common.util; +import java.util.Objects; + import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; @@ -121,6 +123,25 @@ public String toString() { + " bits]"; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BlockIdLayout that = (BlockIdLayout) o; + return sequenceNoBits == that.sequenceNoBits + && partitionIdBits == that.partitionIdBits + && taskAttemptIdBits == that.taskAttemptIdBits; + } + + @Override + public int hashCode() { + return Objects.hash(sequenceNoBits, partitionIdBits, taskAttemptIdBits); + } + public long getBlockId(int sequenceNo, int partitionId, long taskAttemptId) { if (sequenceNo < 0 || sequenceNo > maxSequenceNo) { throw new IllegalArgumentException( @@ -163,13 +184,13 @@ public BlockId asBlockId(long blockId) { blockId, this, getSequenceNo(blockId), getPartitionId(blockId), getTaskAttemptId(blockId)); } - public BlockId asBlockId(int sequenceNo, int partitionId, int taskAttemptId) { + public BlockId asBlockId(int sequenceNo, int partitionId, long taskAttemptId) { return new BlockId( getBlockId(sequenceNo, partitionId, taskAttemptId), this, sequenceNo, partitionId, - taskAttemptId); + (int) taskAttemptId); } public static BlockIdLayout from(RssConf rssConf) { diff --git a/common/src/test/java/org/apache/uniffle/common/util/BlockIdLayoutTest.java b/common/src/test/java/org/apache/uniffle/common/util/BlockIdLayoutTest.java index cc67328c64..7a82239daa 100644 --- a/common/src/test/java/org/apache/uniffle/common/util/BlockIdLayoutTest.java +++ b/common/src/test/java/org/apache/uniffle/common/util/BlockIdLayoutTest.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; public class BlockIdLayoutTest { @@ -92,7 +93,7 @@ public void testFromLengthsErrors() { @Test public void testLayoutGetBlockId() { - BlockIdLayout[] layouts = + for (BlockIdLayout layout : new BlockIdLayout[] { BlockIdLayout.DEFAULT, BlockIdLayout.from(21, 21, 21), @@ -100,9 +101,7 @@ public void testLayoutGetBlockId() { BlockIdLayout.from(1, 31, 31), BlockIdLayout.from(31, 1, 31), BlockIdLayout.from(31, 31, 1), - }; - - for (BlockIdLayout layout : layouts) { + }) { // max value of blockId assertEquals( (long) layout.maxSequenceNo << layout.sequenceNoOffset @@ -172,4 +171,19 @@ public void testLayoutGetBlockId() { e3.getMessage()); } } + + @Test + public void testEquals() { + BlockIdLayout layout1 = BlockIdLayout.from(20, 21, 22); + BlockIdLayout layout2 = BlockIdLayout.from(20, 21, 22); + BlockIdLayout layout3 = BlockIdLayout.from(18, 22, 23); + + assertEquals(layout1, layout1); + assertEquals(layout1, layout2); + assertNotEquals(layout1, layout3); + + BlockIdLayout layout4 = BlockIdLayout.from(18, 24, 21); + assertNotEquals(layout1, layout4); + assertEquals(layout4, BlockIdLayout.DEFAULT); + } } diff --git a/common/src/test/java/org/apache/uniffle/common/util/BlockIdTest.java b/common/src/test/java/org/apache/uniffle/common/util/BlockIdTest.java index 857cae9d4a..50dc43f3e9 100644 --- a/common/src/test/java/org/apache/uniffle/common/util/BlockIdTest.java +++ b/common/src/test/java/org/apache/uniffle/common/util/BlockIdTest.java @@ -20,20 +20,39 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; public class BlockIdTest { + BlockIdLayout layout1 = BlockIdLayout.DEFAULT; + BlockId blockId1 = layout1.asBlockId(1, 2, 3); + BlockId blockId2 = layout1.asBlockId(15, 30, 63); + + BlockIdLayout layout2 = BlockIdLayout.from(31, 16, 16); + BlockId blockId3 = layout2.asBlockId(1, 2, 3); + BlockId blockId4 = layout2.asBlockId(15, 30, 63); + @Test - public void toStringTest() { - BlockIdLayout layout1 = BlockIdLayout.DEFAULT; - BlockId blockId1 = layout1.asBlockId(1, 2, 3); + public void testToString() { assertEquals("blockId[200000400003 (seq: 1, part: 2, task: 3)]", blockId1.toString()); - BlockId blockId2 = layout1.asBlockId(15, 30, 63); assertEquals("blockId[1e00003c0003f (seq: 15, part: 30, task: 63)]", blockId2.toString()); - - BlockIdLayout layout2 = BlockIdLayout.from(31, 16, 16); - BlockId blockId3 = layout2.asBlockId(1, 2, 3); assertEquals("blockId[100020003 (seq: 1, part: 2, task: 3)]", blockId3.toString()); - BlockId blockId4 = layout2.asBlockId(15, 30, 63); assertEquals("blockId[f001e003f (seq: 15, part: 30, task: 63)]", blockId4.toString()); } + + @Test + public void testEquals() { + assertEquals(blockId1, blockId1); + assertSame(blockId1, blockId1); + + assertNotEquals(blockId1, blockId2); + assertNotEquals(blockId1, blockId3); + assertNotEquals(blockId1, blockId4); + assertNotEquals(blockId2, blockId4); + + BlockId blockId4 = layout1.asBlockId(1, 2, 3); + assertEquals(blockId1, blockId4); + assertNotSame(blockId1, blockId4); + } } diff --git a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java index 1cb6123501..d4a3feec7b 100644 --- a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java +++ b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java @@ -45,6 +45,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertNull; @@ -227,28 +228,36 @@ public void testLoadExtentions() { @Test public void testShuffleBitmapToPartitionBitmap() { - BlockIdLayout layout = BlockIdLayout.DEFAULT; - Roaring64NavigableMap partition1Bitmap = - Roaring64NavigableMap.bitmapOf( - layout.getBlockId(0, 0, 0), - layout.getBlockId(1, 0, 0), - layout.getBlockId(0, 0, 1), - layout.getBlockId(1, 0, 1)); - Roaring64NavigableMap partition2Bitmap = - Roaring64NavigableMap.bitmapOf( - layout.getBlockId(0, 1, 0), - layout.getBlockId(1, 1, 0), - layout.getBlockId(0, 1, 1), - layout.getBlockId(1, 1, 1)); - Roaring64NavigableMap shuffleBitmap = Roaring64NavigableMap.bitmapOf(); - shuffleBitmap.or(partition1Bitmap); - shuffleBitmap.or(partition2Bitmap); - assertEquals(8, shuffleBitmap.getLongCardinality()); - Map toPartitionBitmap = - RssUtils.generatePartitionToBitmap(shuffleBitmap, 0, 2, layout); - assertEquals(2, toPartitionBitmap.size()); - assertEquals(partition1Bitmap, toPartitionBitmap.get(0)); - assertEquals(partition2Bitmap, toPartitionBitmap.get(1)); + // test with different block id layouts + for (BlockIdLayout layout : + new BlockIdLayout[] {BlockIdLayout.DEFAULT, BlockIdLayout.from(20, 21, 22)}) { + // we should also test layouts that are different to the default + if (layout != BlockIdLayout.DEFAULT) { + assertNotEquals(layout, BlockIdLayout.DEFAULT); + } + + Roaring64NavigableMap partition1Bitmap = + Roaring64NavigableMap.bitmapOf( + layout.getBlockId(0, 0, 0), + layout.getBlockId(1, 0, 0), + layout.getBlockId(0, 0, 1), + layout.getBlockId(1, 0, 1)); + Roaring64NavigableMap partition2Bitmap = + Roaring64NavigableMap.bitmapOf( + layout.getBlockId(0, 1, 0), + layout.getBlockId(1, 1, 0), + layout.getBlockId(0, 1, 1), + layout.getBlockId(1, 1, 1)); + Roaring64NavigableMap shuffleBitmap = Roaring64NavigableMap.bitmapOf(); + shuffleBitmap.or(partition1Bitmap); + shuffleBitmap.or(partition2Bitmap); + assertEquals(8, shuffleBitmap.getLongCardinality()); + Map toPartitionBitmap = + RssUtils.generatePartitionToBitmap(shuffleBitmap, 0, 2, layout); + assertEquals(2, toPartitionBitmap.size()); + assertEquals(partition1Bitmap, toPartitionBitmap.get(0)); + assertEquals(partition2Bitmap, toPartitionBitmap.get(1)); + } } @Test diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java index 53e3e73720..fce5902b94 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java @@ -671,74 +671,83 @@ public void sendDataWithoutRequirePreAllocation() { @Test public void multipleShuffleResultTest() throws Exception { - Set expectedBlockIds = Sets.newConcurrentHashSet(); - RssRegisterShuffleRequest rrsr = - new RssRegisterShuffleRequest( - "multipleShuffleResultTest", 100, Lists.newArrayList(new PartitionRange(0, 1)), ""); - grpcShuffleServerClient.registerShuffle(rrsr); + // test with different block id layouts + for (BlockIdLayout layout : + new BlockIdLayout[] {BlockIdLayout.DEFAULT, BlockIdLayout.from(20, 21, 22)}) { + // we should also test layouts that are different to the default + if (layout != BlockIdLayout.DEFAULT) { + assertNotEquals(layout, BlockIdLayout.DEFAULT); + } - Runnable r1 = - () -> { - for (int i = 0; i < 100; i++) { - Map> ptbs = Maps.newHashMap(); - List blockIds = Lists.newArrayList(); - Long blockId = layout.getBlockId(i, 1, 0); - expectedBlockIds.add(blockId); - blockIds.add(blockId); - ptbs.put(1, blockIds); - RssReportShuffleResultRequest req1 = - new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 0, ptbs, 1); - grpcShuffleServerClient.reportShuffleResult(req1); - } - }; - Runnable r2 = - () -> { - for (int i = 100; i < 200; i++) { - Map> ptbs = Maps.newHashMap(); - List blockIds = Lists.newArrayList(); - Long blockId = layout.getBlockId(i, 1, 1); - expectedBlockIds.add(blockId); - blockIds.add(blockId); - ptbs.put(1, blockIds); - RssReportShuffleResultRequest req1 = - new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 1, ptbs, 1); - grpcShuffleServerClient.reportShuffleResult(req1); - } - }; - Runnable r3 = - () -> { - for (int i = 200; i < 300; i++) { - Map> ptbs = Maps.newHashMap(); - List blockIds = Lists.newArrayList(); - Long blockId = layout.getBlockId(i, 1, 2); - expectedBlockIds.add(blockId); - blockIds.add(blockId); - ptbs.put(1, blockIds); - RssReportShuffleResultRequest req1 = - new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 2, ptbs, 1); - grpcShuffleServerClient.reportShuffleResult(req1); - } - }; - Thread t1 = new Thread(r1); - Thread t2 = new Thread(r2); - Thread t3 = new Thread(r3); - t1.start(); - t2.start(); - t3.start(); - t1.join(); - t2.join(); - t3.join(); - - Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); - for (Long blockId : expectedBlockIds) { - blockIdBitmap.addLong(blockId); - } + Set expectedBlockIds = Sets.newConcurrentHashSet(); + RssRegisterShuffleRequest rrsr = + new RssRegisterShuffleRequest( + "multipleShuffleResultTest", 100, Lists.newArrayList(new PartitionRange(0, 1)), ""); + grpcShuffleServerClient.registerShuffle(rrsr); + + Runnable r1 = + () -> { + for (int i = 0; i < 100; i++) { + Map> ptbs = Maps.newHashMap(); + List blockIds = Lists.newArrayList(); + Long blockId = layout.getBlockId(i, 1, 0); + expectedBlockIds.add(blockId); + blockIds.add(blockId); + ptbs.put(1, blockIds); + RssReportShuffleResultRequest req1 = + new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 0, ptbs, 1); + grpcShuffleServerClient.reportShuffleResult(req1); + } + }; + Runnable r2 = + () -> { + for (int i = 100; i < 200; i++) { + Map> ptbs = Maps.newHashMap(); + List blockIds = Lists.newArrayList(); + Long blockId = layout.getBlockId(i, 1, 1); + expectedBlockIds.add(blockId); + blockIds.add(blockId); + ptbs.put(1, blockIds); + RssReportShuffleResultRequest req1 = + new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 1, ptbs, 1); + grpcShuffleServerClient.reportShuffleResult(req1); + } + }; + Runnable r3 = + () -> { + for (int i = 200; i < 300; i++) { + Map> ptbs = Maps.newHashMap(); + List blockIds = Lists.newArrayList(); + Long blockId = layout.getBlockId(i, 1, 2); + expectedBlockIds.add(blockId); + blockIds.add(blockId); + ptbs.put(1, blockIds); + RssReportShuffleResultRequest req1 = + new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 2, ptbs, 1); + grpcShuffleServerClient.reportShuffleResult(req1); + } + }; + Thread t1 = new Thread(r1); + Thread t2 = new Thread(r2); + Thread t3 = new Thread(r3); + t1.start(); + t2.start(); + t3.start(); + t1.join(); + t2.join(); + t3.join(); + + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : expectedBlockIds) { + blockIdBitmap.addLong(blockId); + } - RssGetShuffleResultRequest req = - new RssGetShuffleResultRequest("multipleShuffleResultTest", 1, 1, layout); - RssGetShuffleResultResponse result = grpcShuffleServerClient.getShuffleResult(req); - Roaring64NavigableMap actualBlockIdBitmap = result.getBlockIdBitmap(); - assertEquals(blockIdBitmap, actualBlockIdBitmap); + RssGetShuffleResultRequest req = + new RssGetShuffleResultRequest("multipleShuffleResultTest", 1, 1, layout); + RssGetShuffleResultResponse result = grpcShuffleServerClient.getShuffleResult(req); + Roaring64NavigableMap actualBlockIdBitmap = result.getBlockIdBitmap(); + assertEquals(blockIdBitmap, actualBlockIdBitmap, layout.toString()); + } } @Disabled("flaky test") @@ -958,7 +967,8 @@ public void rpcMetricsTest() throws Exception { .getCounterMap() .get(ShuffleServerGrpcMetrics.GET_SHUFFLE_RESULT_METHOD) .get(); - grpcShuffleServerClient.getShuffleResult(new RssGetShuffleResultRequest(appId, shuffleId, 1, layout)); + grpcShuffleServerClient.getShuffleResult( + new RssGetShuffleResultRequest(appId, shuffleId, 1, layout)); newValue = grpcShuffleServers .get(0) diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java index 30aed50392..88f53c6c25 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java @@ -19,8 +19,22 @@ import java.util.Arrays; import java.util.HashMap; +import java.util.List; +import java.util.Collection; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.spark.shuffle.ShuffleHandleInfo; +import org.apache.uniffle.client.api.ShuffleWriteClient; +import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.util.BlockId; +import org.apache.uniffle.common.util.BlockIdLayout; +import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase; +import org.roaringbitmap.longlong.Roaring64NavigableMap; import scala.Option; import com.google.common.collect.Maps; @@ -28,7 +42,6 @@ import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.shuffle.RssShuffleManager; import org.apache.spark.shuffle.RssSparkConfig; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.BeforeAll; @@ -66,6 +79,8 @@ Map runTest(SparkSession spark, String fileName) throws Exception { @Test public void testRssShuffleManager() throws Exception { + BlockIdLayout layout = BlockIdLayout.DEFAULT; + SparkConf conf = createSparkConf(); updateSparkConfWithRss(conf); // enable stage recompute @@ -84,7 +99,39 @@ public void testRssShuffleManager() throws Exception { // create a rdd that triggers shuffle registration long count = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).groupBy(x -> x).count(); assertEquals(5, count); - RssShuffleManager shuffleManager = (RssShuffleManager) SparkEnv.get().shuffleManager(); + RssShuffleManagerBase shuffleManager = (RssShuffleManagerBase) SparkEnv.get().shuffleManager(); + + // get written block ids (we know there is one shuffle where two task attempts wrote two partitions) + RssConf rssConf = RssSparkConfig.toRssConf(conf); + ShuffleWriteClient shuffleWriteClient = + ShuffleClientFactory.newWriteBuilder() + .clientType(ClientType.GRPC.name()) + .retryMax(3) + .retryIntervalMax(2000) + .heartBeatThreadNum(4) + .replica(1) + .replicaWrite(1) + .replicaRead(1) + .replicaSkipEnabled(true) + .dataTransferPoolSize(1) + .dataCommitPoolSize(1) + .unregisterThreadPoolSize(10) + .unregisterRequestTimeSec(10) + .rssConf(rssConf) + .build(); + ShuffleHandleInfo handle = shuffleManager.getShuffleHandleInfoByShuffleId(0); + Set servers = handle.getPartitionToServers().values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); + + for (int partitionId: new int[]{0, 1}) { + Roaring64NavigableMap blockIdLongs = shuffleWriteClient.getShuffleResult(ClientType.GRPC.name(), servers, shuffleManager.getAppId(), 0, partitionId); + List blockIds = blockIdLongs.stream().sorted().mapToObj(layout::asBlockId).collect(Collectors.toList()); + assertEquals(2, blockIds.size()); + long taskAttemptId0 = shuffleManager.getTaskAttemptId(0, 0, 0); + long taskAttemptId1 = shuffleManager.getTaskAttemptId(1, 0, 1); + assertEquals(layout.asBlockId(0, partitionId, taskAttemptId0), blockIds.get(0), layout.toString()); + assertEquals(layout.asBlockId(0, partitionId, taskAttemptId1), blockIds.get(1), layout.toString()); + } + shuffleManager.unregisterAllMapOutput(0); MapOutputTrackerMaster master = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); assertTrue(master.containsShuffle(0));