diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java index c394f510bb..9146352dac 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.reader; -import java.io.IOException; import java.util.Objects; +import java.util.function.Function; import scala.Product2; import scala.collection.AbstractIterator; @@ -29,11 +29,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.uniffle.client.api.ShuffleManagerClient; -import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; -import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; @@ -52,8 +49,8 @@ public static class Builder { private int shuffleId; private int partitionId; private int stageAttemptId; - private String reportServerHost; - private int reportServerPort; + private Function + sendReportFunc; private Builder() {} @@ -77,19 +74,15 @@ Builder stageAttemptId(int stageAttemptId) { return this; } - Builder reportServerHost(String host) { - this.reportServerHost = host; - return this; - } - - Builder port(int port) { - this.reportServerPort = port; + Builder doReportFun( + Function + doReportFun) { + this.sendReportFunc = doReportFun; return this; } RssFetchFailedIterator build(Iterator> iter) { Objects.requireNonNull(this.appId); - Objects.requireNonNull(this.reportServerHost); return new RssFetchFailedIterator<>(this, iter); } } @@ -98,37 +91,25 @@ static Builder newBuilder() { return new Builder(); } - private static ShuffleManagerClient createShuffleManagerClient(String host, int port) - throws IOException { - ClientType grpc = ClientType.GRPC; - // host is passed from spark.driver.bindAddress, which would be set when SparkContext is - // constructed. - return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port); - } - - private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) { - String driver = builder.reportServerHost; - int port = builder.reportServerPort; - // todo: reuse this manager client if this is a bottleneck. - try (ShuffleManagerClient client = createShuffleManagerClient(driver, port)) { - RssReportShuffleFetchFailureRequest req = - new RssReportShuffleFetchFailureRequest( - builder.appId, - builder.shuffleId, - builder.stageAttemptId, - builder.partitionId, - e.getMessage()); - RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); - if (response.getReSubmitWholeStage()) { - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is - // provided. - FetchFailedException ffe = - RssSparkShuffleUtils.createFetchFailedException( - builder.shuffleId, -1, builder.partitionId, e); - return new RssException(ffe); - } - } catch (IOException ioe) { - LOG.info("Error closing shuffle manager client with error:", ioe); + private RssException generateFetchFailedIfNecessary( + RssFetchFailedException e, + Function + doReportFun) { + RssReportShuffleFetchFailureRequest req = + new RssReportShuffleFetchFailureRequest( + builder.appId, + builder.shuffleId, + builder.stageAttemptId, + builder.partitionId, + e.getMessage()); + RssReportShuffleFetchFailureResponse response = doReportFun.apply(req); + if (response.getReSubmitWholeStage()) { + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is + // provided. + FetchFailedException ffe = + RssSparkShuffleUtils.createFetchFailedException( + builder.shuffleId, -1, builder.partitionId, e); + return new RssException(ffe); } return e; } @@ -138,7 +119,7 @@ public boolean hasNext() { try { return this.iter.hasNext(); } catch (RssFetchFailedException e) { - throw generateFetchFailedIfNecessary(e); + throw generateFetchFailedIfNecessary(e, builder.sendReportFunc); } } @@ -147,7 +128,7 @@ public Product2 next() { try { return this.iter.next(); } catch (RssFetchFailedException e) { - throw generateFetchFailedIfNecessary(e); + throw generateFetchFailedIfNecessary(e, builder.sendReportFunc); } } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 1d50507904..2ecb514db1 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -534,6 +534,7 @@ public ShuffleWriter getWriter( this, sparkConf, shuffleWriteClient, + shuffleManagerClient, rssHandle, this::markFailedTask, context, @@ -721,6 +722,7 @@ public ShuffleReader getReaderImpl( blockIdBitmap, startPartition, endPartition, blockIdLayout), taskIdBitmap, readMetrics, + shuffleManagerClient, RssSparkConfig.toRssConf(sparkConf), dataDistributionType, shuffleHandleInfo.getAllPartitionServersForReader()); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index bf47ced6be..8df82eda14 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -49,6 +49,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleReadClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.util.RssClientConfig; @@ -58,7 +59,6 @@ import org.apache.uniffle.common.config.RssConf; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; -import static org.apache.uniffle.common.util.Constants.DRIVER_HOST; public class RssShuffleReader implements ShuffleReader { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleReader.class); @@ -83,6 +83,7 @@ public class RssShuffleReader implements ShuffleReader { private ShuffleReadMetrics readMetrics; private RssConf rssConf; private ShuffleDataDistributionType dataDistributionType; + private ShuffleManagerClient shuffleManagerClient; public RssShuffleReader( int startPartition, @@ -97,6 +98,7 @@ public RssShuffleReader( Map partitionToExpectBlocks, Roaring64NavigableMap taskIdBitmap, ShuffleReadMetrics readMetrics, + ShuffleManagerClient shuffleManagerClient, RssConf rssConf, ShuffleDataDistributionType dataDistributionType, Map> allPartitionToServers) { @@ -120,6 +122,7 @@ public RssShuffleReader( this.partitionToShuffleServers = allPartitionToServers; this.rssConf = rssConf; this.dataDistributionType = dataDistributionType; + this.shuffleManagerClient = shuffleManagerClient; } @Override @@ -193,16 +196,13 @@ public Void apply(TaskContext context) { // resubmit stage and shuffle manager server port are both set if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED) && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) { - String driver = rssConf.getString(DRIVER_HOST, ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); resultIter = RssFetchFailedIterator.newBuilder() .appId(appId) .shuffleId(shuffleId) .partitionId(startPartition) .stageAttemptId(context.stageAttemptNumber()) - .reportServerHost(driver) - .port(port) + .doReportFun(shuffleManagerClient::reportShuffleFetchFailure) .build(resultIter); } return resultIter; diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 3dfc2fd620..a4b68c7f5b 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -138,6 +138,8 @@ public class RssShuffleWriter extends ShuffleWriter { private static final Set STATUS_CODE_WITHOUT_BLOCK_RESEND = Sets.newHashSet(StatusCode.NO_REGISTER); + private ShuffleManagerClient shuffleManagerClient; + // Only for tests @VisibleForTesting public RssShuffleWriter( @@ -150,6 +152,7 @@ public RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + ShuffleManagerClient shuffleManagerClient, RssShuffleHandle rssHandle, ShuffleHandleInfo shuffleHandleInfo, TaskContext context) { @@ -162,6 +165,7 @@ public RssShuffleWriter( shuffleManager, sparkConf, shuffleWriteClient, + shuffleManagerClient, rssHandle, (tid) -> true, shuffleHandleInfo, @@ -179,6 +183,7 @@ private RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + ShuffleManagerClient shuffleManagerClient, RssShuffleHandle rssHandle, Function taskFailureCallback, ShuffleHandleInfo shuffleHandleInfo, @@ -211,6 +216,7 @@ private RssShuffleWriter( this.taskFailureCallback = taskFailureCallback; this.taskContext = context; this.sparkConf = sparkConf; + this.shuffleManagerClient = shuffleManagerClient; this.blockFailSentRetryEnabled = sparkConf.getBoolean( RssSparkConfig.SPARK_RSS_CONFIG_PREFIX @@ -229,6 +235,7 @@ public RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + ShuffleManagerClient shuffleManagerClient, RssShuffleHandle rssHandle, Function taskFailureCallback, TaskContext context, @@ -242,6 +249,7 @@ public RssShuffleWriter( shuffleManager, sparkConf, shuffleWriteClient, + shuffleManagerClient, rssHandle, taskFailureCallback, shuffleHandleInfo, @@ -835,33 +843,26 @@ private void throwFetchFailedIfNecessary(Exception e) { taskContext.stageAttemptNumber(), shuffleServerInfos, e.getMessage()); - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - String driver = rssConf.getString("driver.host", ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) { - RssReportShuffleWriteFailureResponse response = - shuffleManagerClient.reportShuffleWriteFailure(req); - if (response.getReSubmitWholeStage()) { - RssReassignServersRequest rssReassignServersRequest = - new RssReassignServersRequest( - taskContext.stageId(), - taskContext.stageAttemptNumber(), - shuffleId, - partitioner.numPartitions()); - RssReassignServersResponse rssReassignServersResponse = - shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest); - LOG.info( - "Whether the reassignment is successful: {}", - rssReassignServersResponse.isNeedReassign()); - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is - // provided. - FetchFailedException ffe = - RssSparkShuffleUtils.createFetchFailedException( - shuffleId, -1, taskContext.stageAttemptNumber(), e); - throw new RssException(ffe); - } - } catch (IOException ioe) { - LOG.info("Error closing shuffle manager client with error:", ioe); + RssReportShuffleWriteFailureResponse response = + shuffleManagerClient.reportShuffleWriteFailure(req); + if (response.getReSubmitWholeStage()) { + RssReassignServersRequest rssReassignServersRequest = + new RssReassignServersRequest( + taskContext.stageId(), + taskContext.stageAttemptNumber(), + shuffleId, + partitioner.numPartitions()); + RssReassignServersResponse rssReassignServersResponse = + shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest); + LOG.info( + "Whether the reassignment is successful: {}", + rssReassignServersResponse.isNeedReassign()); + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is + // provided. + FetchFailedException ffe = + RssSparkShuffleUtils.createFetchFailedException( + shuffleId, -1, taskContext.stageAttemptNumber(), e); + throw new RssException(ffe); } } throw new RssException(e); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java index aaff4cb8e0..bda3a94b6c 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java @@ -36,6 +36,7 @@ import org.junit.jupiter.api.Test; import org.roaringbitmap.longlong.Roaring64NavigableMap; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; @@ -93,6 +94,7 @@ public void readTest() throws Exception { rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name()); rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000); rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000"); + ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); RssShuffleReader rssShuffleReaderSpy = spy( new RssShuffleReader<>( @@ -108,6 +110,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), + mockShuffleManagerClient, rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -131,6 +134,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), + mockShuffleManagerClient, rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -151,6 +155,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, Roaring64NavigableMap.bitmapOf(), new ShuffleReadMetrics(), + mockShuffleManagerClient, rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 53a8e71437..2ed13eb240 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -55,6 +55,7 @@ import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.common.RemoteStorageInfo; @@ -133,6 +134,7 @@ private RssShuffleWriter createMockWriter(MutableShuffleHandleInfo shuffleHandle Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency mockDependency = mock(ShuffleDependency.class); RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -179,6 +181,7 @@ private RssShuffleWriter createMockWriter(MutableShuffleHandleInfo shuffleHandle manager, conf, mockShuffleWriteClient, + mockShuffleManagerClient, mockHandle, shuffleHandle, contextMock); @@ -385,6 +388,7 @@ public void blockFailureResendTest() { Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency mockDependency = mock(ShuffleDependency.class); RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -450,6 +454,7 @@ public void blockFailureResendTest() { manager, conf, mockShuffleWriteClient, + mockShuffleManagerClient, mockHandle, shuffleHandleInfo, contextMock); @@ -552,6 +557,7 @@ public void checkBlockSendResultTest() { conf, false, null, successBlocks, taskToFailedBlockSendTracker); ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); Partitioner mockPartitioner = mock(Partitioner.class); RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); ShuffleDependency mockDependency = mock(ShuffleDependency.class); @@ -587,6 +593,7 @@ public void checkBlockSendResultTest() { manager, conf, mockShuffleWriteClient, + mockShuffleManagerClient, mockHandle, mockShuffleHandleInfo, contextMock); @@ -714,6 +721,7 @@ public void dataConsistencyWhenSpillTriggeredTest() throws Exception { Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency mockDependency = mock(ShuffleDependency.class); RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -734,6 +742,7 @@ public void dataConsistencyWhenSpillTriggeredTest() throws Exception { manager, conf, mockShuffleWriteClient, + mockShuffleManagerClient, mockHandle, mockShuffleHandleInfo, contextMock); @@ -794,6 +803,7 @@ public void writeTest() throws Exception { Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency mockDependency = mock(ShuffleDependency.class); RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -857,6 +867,7 @@ public void writeTest() throws Exception { manager, conf, mockShuffleWriteClient, + mockShuffleManagerClient, mockHandle, mockShuffleHandleInfo, contextMock); @@ -958,6 +969,7 @@ public void postBlockEventTest() throws Exception { TaskContext contextMock = mock(TaskContext.class); SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class); ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class); + ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); List shuffleBlockInfoList = createShuffleBlockList(1, 31); RssShuffleWriter writer = @@ -971,6 +983,7 @@ public void postBlockEventTest() throws Exception { mockShuffleManager, conf, mockWriteClient, + mockShuffleManagerClient, mockHandle, mockShuffleHandleInfo, contextMock);