Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Feb 26, 2024
1 parent 4d2415e commit 9fafb5d
Show file tree
Hide file tree
Showing 14 changed files with 543 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.RssSparkConfig;
Expand All @@ -50,6 +52,13 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;

/**
* Provides a task attempt id that is unique for a shuffle stage.
*
* @return a task attempt id unique for a shuffle stage
*/
public abstract long getTaskAttemptId(int mapIndex, int attemptNo, long taskAttemptId);

@Override
public void unregisterAllMapOutput(int shuffleId) throws SparkException {
if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
Expand All @@ -44,15 +48,15 @@
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.config.RssConf;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

public class WriteBufferManagerTest {
// test blockid config is considered

private WriteBufferManager createManager(SparkConf conf) {
Serializer kryoSerializer = new KryoSerializer(conf);
Expand Down Expand Up @@ -95,72 +99,91 @@ public void addRecordUnCompressedTest() throws Exception {
}

private void addRecord(boolean compress) throws IllegalAccessException {
SparkConf conf = getConf();
if (!compress) {
conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false));
}
WriteBufferManager wbm = createManager(conf);
Object codec = FieldUtils.readField(wbm, "codec", true);
if (compress) {
Assertions.assertNotNull(codec);
} else {
Assertions.assertNull(codec);
// test with different block id layouts
for (BlockIdLayout layout: new BlockIdLayout[] {
BlockIdLayout.DEFAULT, BlockIdLayout.from(20, 21, 22)
}) {
// we should also test layouts that are different to the default
if (layout != BlockIdLayout.DEFAULT) {
assertNotEquals(layout, BlockIdLayout.DEFAULT);
}
SparkConf conf = getConf();
conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(layout.sequenceNoBits));
conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(layout.partitionIdBits));
conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(layout.taskAttemptIdBits));
if (!compress) {
conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false));
}
WriteBufferManager wbm = createManager(conf);
Object codec = FieldUtils.readField(wbm, "codec", true);
if (compress) {
Assertions.assertNotNull(codec);
} else {
Assertions.assertNull(codec);
}
wbm.setShuffleWriteMetrics(new ShuffleWriteMetrics());
String testKey = "Key";
String testValue = "Value";
List<ShuffleBlockInfo> result = wbm.addRecord(0, testKey, testValue);
// single buffer is not full, there is no data return
assertEquals(0, result.size());
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(32, wbm.getUsedBytes());
assertEquals(0, wbm.getInSendListBytes());
assertEquals(1, wbm.getBuffers().size());
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(0, testKey, testValue);
result = wbm.addRecord(0, testKey, testValue);
// single buffer is full
assertEquals(1, result.size());
assertEquals(layout.asBlockId(0, 0, 0), layout.asBlockId(result.get(0).getBlockId()));
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(96, wbm.getUsedBytes());
assertEquals(96, wbm.getInSendListBytes());
assertEquals(0, wbm.getBuffers().size());
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(2, testKey, testValue);
// single buffer is not full, and less than spill size
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(192, wbm.getUsedBytes());
assertEquals(96, wbm.getInSendListBytes());
assertEquals(3, wbm.getBuffers().size());
// all buffer size > spill size
wbm.addRecord(3, testKey, testValue);
wbm.addRecord(4, testKey, testValue);
result = wbm.addRecord(5, testKey, testValue);
assertEquals(6, result.size());
assertEquals(layout.asBlockId(1, 0, 0), layout.asBlockId(result.get(0).getBlockId()));
assertEquals(layout.asBlockId(0, 1, 0), layout.asBlockId(result.get(1).getBlockId()));
assertEquals(layout.asBlockId(0, 2, 0), layout.asBlockId(result.get(2).getBlockId()));
assertEquals(layout.asBlockId(0, 3, 0), layout.asBlockId(result.get(3).getBlockId()));
assertEquals(layout.asBlockId(0, 4, 0), layout.asBlockId(result.get(4).getBlockId()));
assertEquals(layout.asBlockId(0, 5, 0), layout.asBlockId(result.get(5).getBlockId()));
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(288, wbm.getUsedBytes());
assertEquals(288, wbm.getInSendListBytes());
assertEquals(0, wbm.getBuffers().size());
// free memory
wbm.freeAllocatedMemory(96);
assertEquals(416, wbm.getAllocatedBytes());
assertEquals(192, wbm.getUsedBytes());
assertEquals(192, wbm.getInSendListBytes());

assertEquals(11, wbm.getShuffleWriteMetrics().recordsWritten());
assertTrue(wbm.getShuffleWriteMetrics().bytesWritten() > 0);

wbm.freeAllocatedMemory(192);
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(2, testKey, testValue);
result = wbm.clear();
assertEquals(3, result.size());
assertEquals(224, wbm.getAllocatedBytes());
assertEquals(96, wbm.getUsedBytes());
assertEquals(96, wbm.getInSendListBytes());
}
wbm.setShuffleWriteMetrics(new ShuffleWriteMetrics());
String testKey = "Key";
String testValue = "Value";
List<ShuffleBlockInfo> result = wbm.addRecord(0, testKey, testValue);
// single buffer is not full, there is no data return
assertEquals(0, result.size());
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(32, wbm.getUsedBytes());
assertEquals(0, wbm.getInSendListBytes());
assertEquals(1, wbm.getBuffers().size());
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(0, testKey, testValue);
result = wbm.addRecord(0, testKey, testValue);
// single buffer is full
assertEquals(1, result.size());
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(96, wbm.getUsedBytes());
assertEquals(96, wbm.getInSendListBytes());
assertEquals(0, wbm.getBuffers().size());
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(2, testKey, testValue);
// single buffer is not full, and less than spill size
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(192, wbm.getUsedBytes());
assertEquals(96, wbm.getInSendListBytes());
assertEquals(3, wbm.getBuffers().size());
// all buffer size > spill size
wbm.addRecord(3, testKey, testValue);
wbm.addRecord(4, testKey, testValue);
result = wbm.addRecord(5, testKey, testValue);
assertEquals(6, result.size());
assertEquals(512, wbm.getAllocatedBytes());
assertEquals(288, wbm.getUsedBytes());
assertEquals(288, wbm.getInSendListBytes());
assertEquals(0, wbm.getBuffers().size());
// free memory
wbm.freeAllocatedMemory(96);
assertEquals(416, wbm.getAllocatedBytes());
assertEquals(192, wbm.getUsedBytes());
assertEquals(192, wbm.getInSendListBytes());

assertEquals(11, wbm.getShuffleWriteMetrics().recordsWritten());
assertTrue(wbm.getShuffleWriteMetrics().bytesWritten() > 0);

wbm.freeAllocatedMemory(192);
wbm.addRecord(0, testKey, testValue);
wbm.addRecord(1, testKey, testValue);
wbm.addRecord(2, testKey, testValue);
result = wbm.clear();
assertEquals(3, result.size());
assertEquals(224, wbm.getAllocatedBytes());
assertEquals(96, wbm.getUsedBytes());
assertEquals(96, wbm.getInSendListBytes());
}

@Test
Expand Down Expand Up @@ -223,26 +246,42 @@ public void addPartitionDataTest() {

@Test
public void createBlockIdTest() {
SparkConf conf = getConf();
WriteBufferManager wbm = createManager(conf);
WriterBuffer mockWriterBuffer = mock(WriterBuffer.class);
when(mockWriterBuffer.getData()).thenReturn(new byte[] {});
when(mockWriterBuffer.getMemoryUsed()).thenReturn(0);
ShuffleBlockInfo sbi = wbm.createShuffleBlock(0, mockWriterBuffer);
// seqNo = 0, partitionId = 0, taskId = 0
assertEquals(0L, sbi.getBlockId());

// seqNo = 1, partitionId = 0, taskId = 0
sbi = wbm.createShuffleBlock(0, mockWriterBuffer);
assertEquals(35184372088832L, sbi.getBlockId());

// seqNo = 0, partitionId = 1, taskId = 0
sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
assertEquals(2097152L, sbi.getBlockId());

// seqNo = 1, partitionId = 1, taskId = 0
sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
assertEquals(35184374185984L, sbi.getBlockId());
// test with different block id layouts
for (BlockIdLayout layout: new BlockIdLayout[] {
BlockIdLayout.DEFAULT, BlockIdLayout.from(20, 21, 22)
}) {
// we should also test layouts that are different to the default
if (layout != BlockIdLayout.DEFAULT) {
assertNotEquals(layout, BlockIdLayout.DEFAULT);
}
SparkConf conf = getConf();
conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(layout.sequenceNoBits));
conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(layout.partitionIdBits));
conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(layout.taskAttemptIdBits));

WriteBufferManager wbm = createManager(conf);
WriterBuffer mockWriterBuffer = mock(WriterBuffer.class);
when(mockWriterBuffer.getData()).thenReturn(new byte[]{});
when(mockWriterBuffer.getMemoryUsed()).thenReturn(0);
ShuffleBlockInfo sbi = wbm.createShuffleBlock(0, mockWriterBuffer);

String layoutString = layout.toString();

// seqNo = 0, partitionId = 0, taskId = 0
assertEquals(layout.asBlockId(0, 0, 0), layout.asBlockId(sbi.getBlockId()), layoutString);

// seqNo = 1, partitionId = 0, taskId = 0
sbi = wbm.createShuffleBlock(0, mockWriterBuffer);
assertEquals(layout.asBlockId(1, 0, 0), layout.asBlockId(sbi.getBlockId()), layoutString);

// seqNo = 0, partitionId = 1, taskId = 0
sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
assertEquals(layout.asBlockId(0, 1, 0), layout.asBlockId(sbi.getBlockId()), layoutString);

// seqNo = 1, partitionId = 1, taskId = 0
sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
assertEquals(layout.asBlockId(1, 1, 0), layout.asBlockId(sbi.getBlockId()), layoutString);
}
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ public <K, V> ShuffleWriter<K, V> getWriter(
rssHandle.getAppId(),
shuffleId,
taskId,
context.taskAttemptId(),
getTaskAttemptId(context.partitionId(), context.attemptNumber(), context.taskAttemptId()),
writeMetrics,
this,
sparkConf,
Expand All @@ -479,6 +479,12 @@ public <K, V> ShuffleWriter<K, V> getWriter(
}
}

@Override
@VisibleForTesting
public long getTaskAttemptId(int mapIndex, int attemptNo, long taskAttemptId) {
return taskAttemptId;
}

// This method is called in Spark executor,
// getting information from Spark driver via the ShuffleHandle.
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,18 +510,11 @@ public <K, V> ShuffleWriter<K, V> getWriter(
}
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId());
long taskAttemptId =
getTaskAttemptId(
context.partitionId(),
context.attemptNumber(),
maxFailures,
speculation,
blockIdLayout.taskAttemptIdBits);
return new RssShuffleWriter<>(
rssHandle.getAppId(),
shuffleId,
taskId,
taskAttemptId,
getTaskAttemptId(context.partitionId(), context.attemptNumber(), context.taskAttemptId()),
writeMetrics,
this,
sparkConf,
Expand All @@ -532,6 +525,23 @@ public <K, V> ShuffleWriter<K, V> getWriter(
shuffleHandleInfo);
}

/**
* Provides a task attempt id that is unique for a shuffle stage.
*
* For details see overloaded method getTaskAttemptId.
*
* @return a task attempt id unique for a shuffle stage
*/
@VisibleForTesting
public long getTaskAttemptId(int mapIndex, int attemptNo, long taskAttemptId) {
return getTaskAttemptId(
mapIndex,
attemptNo,
maxFailures,
speculation,
blockIdLayout.taskAttemptIdBits);
}

/**
* Provides a task attempt id that is unique for a shuffle stage.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Queues;
import org.apache.hadoop.conf.Configuration;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -42,6 +41,7 @@
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.IdHelper;
import org.apache.uniffle.common.util.RssUtils;
Expand Down
Loading

0 comments on commit 9fafb5d

Please sign in to comment.