From 0bb986cd2f8a1fa43cb0029774e4df462e40fd3e Mon Sep 17 00:00:00 2001 From: "jam.xu" Date: Wed, 24 Jul 2024 14:11:40 +0800 Subject: [PATCH] add CloseStateful interface --- .../spark/shuffle/RssSparkShuffleUtils.java | 4 +- .../reader/RssFetchFailedIterator.java | 43 +++---- .../BlockIdSelfManagedShuffleWriteClient.java | 4 +- .../shuffle/RssShuffleClientFactory.java | 10 +- .../manager/RssShuffleManagerBase.java | 19 +-- .../spark/shuffle/RssShuffleManager.java | 2 +- .../shuffle/reader/RssShuffleReader.java | 6 +- .../shuffle/writer/RssShuffleWriter.java | 10 +- .../shuffle/reader/RssShuffleReaderTest.java | 4 +- .../shuffle/writer/RssShuffleWriterTest.java | 8 +- .../spark/shuffle/RssShuffleManager.java | 2 +- .../shuffle/reader/RssShuffleReader.java | 6 +- .../shuffle/writer/RssShuffleWriter.java | 10 +- .../shuffle/reader/RssShuffleReaderTest.java | 8 +- .../shuffle/writer/RssShuffleWriterTest.java | 14 +-- .../uniffle/common/util/CloseStateful.java | 25 ++++ ...er.java => ExpiringCloseableSupplier.java} | 53 +++++--- .../util/ExpireCloseableSupplierTest.java | 82 ------------- .../util/ExpiringCloseableSupplierTest.java | 116 ++++++++++++++++++ .../client/api/ShuffleManagerClient.java | 5 +- .../impl/grpc/ShuffleManagerGrpcClient.java | 5 + 21 files changed, 260 insertions(+), 176 deletions(-) create mode 100644 common/src/main/java/org/apache/uniffle/common/util/CloseStateful.java rename common/src/main/java/org/apache/uniffle/common/util/{ExpireCloseableSupplier.java => ExpiringCloseableSupplier.java} (57%) delete mode 100644 common/src/test/java/org/apache/uniffle/common/util/ExpireCloseableSupplierTest.java create mode 100644 common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index 43f2d0aaf3..feee2a3312 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import scala.Option; import scala.reflect.ClassTag; @@ -52,7 +53,6 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; import org.apache.uniffle.common.util.Constants; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; @@ -343,7 +343,7 @@ public static boolean isStageResubmitSupported() { } public static RssException reportRssFetchFailedException( - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssFetchFailedException rssFetchFailedException, SparkConf sparkConf, String appId, diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java index 9873b07df9..1bc61dc746 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.reader; -import java.io.IOException; import java.util.Objects; +import java.util.function.Supplier; import scala.Product2; import scala.collection.AbstractIterator; @@ -34,7 +34,6 @@ import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; public class RssFetchFailedIterator extends AbstractIterator> { private static final Logger LOG = LoggerFactory.getLogger(RssFetchFailedIterator.class); @@ -51,7 +50,7 @@ public static class Builder { private int shuffleId; private int partitionId; private int stageAttemptId; - private ExpireCloseableSupplier managerClientSupplier; + private Supplier managerClientSupplier; private Builder() {} @@ -75,8 +74,7 @@ Builder stageAttemptId(int stageAttemptId) { return this; } - Builder managerClientSupplier( - ExpireCloseableSupplier managerClientSupplier) { + Builder managerClientSupplier(Supplier managerClientSupplier) { this.managerClientSupplier = managerClientSupplier; return this; } @@ -92,25 +90,22 @@ static Builder newBuilder() { } private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) { - try (ShuffleManagerClient client = builder.managerClientSupplier.get()) { - RssReportShuffleFetchFailureRequest req = - new RssReportShuffleFetchFailureRequest( - builder.appId, - builder.shuffleId, - builder.stageAttemptId, - builder.partitionId, - e.getMessage()); - RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); - if (response.getReSubmitWholeStage()) { - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is - // provided. - FetchFailedException ffe = - RssSparkShuffleUtils.createFetchFailedException( - builder.shuffleId, -1, builder.partitionId, e); - return new RssException(ffe); - } - } catch (IOException ioe) { - LOG.info("Error closing shuffle manager client with error:", ioe); + ShuffleManagerClient client = builder.managerClientSupplier.get(); + RssReportShuffleFetchFailureRequest req = + new RssReportShuffleFetchFailureRequest( + builder.appId, + builder.shuffleId, + builder.stageAttemptId, + builder.partitionId, + e.getMessage()); + RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); + if (response.getReSubmitWholeStage()) { + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is + // provided. + FetchFailedException ffe = + RssSparkShuffleUtils.createFetchFailedException( + builder.shuffleId, -1, builder.partitionId, e); + return new RssException(ffe); } return e; } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java index 3100fa9ca0..93aa3f0fc0 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -35,14 +36,13 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.BlockIdLayout; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; /** * This class delegates the blockIds reporting/getting operations from shuffleServer side to Spark * driver side. */ public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl { - private ExpireCloseableSupplier managerClientSupplier; + private Supplier managerClientSupplier; public BlockIdSelfManagedShuffleWriteClient( RssShuffleClientFactory.ExtendWriteClientBuilder builder) { diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java index 9323c7aec4..bad10ab72a 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java @@ -17,11 +17,12 @@ package org.apache.uniffle.shuffle; +import java.util.function.Supplier; + import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.impl.ShuffleWriteClientImpl; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; public class RssShuffleClientFactory extends ShuffleClientFactory { @@ -42,18 +43,17 @@ public static ExtendWriteClientBuilder newWriteBuilder() { public static class ExtendWriteClientBuilder> extends WriteClientBuilder { private boolean blockIdSelfManagedEnabled; - private ExpireCloseableSupplier managerClientSupplier; + private Supplier managerClientSupplier; public boolean isBlockIdSelfManagedEnabled() { return blockIdSelfManagedEnabled; } - public ExpireCloseableSupplier getManagerClientSupplier() { + public Supplier getManagerClientSupplier() { return managerClientSupplier; } - public T managerClientSupplier( - ExpireCloseableSupplier managerClientSupplier) { + public T managerClientSupplier(Supplier managerClientSupplier) { this.managerClientSupplier = managerClientSupplier; return self(); } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 2c42a6ac8a..186ae5b233 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -31,6 +31,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; @@ -81,7 +82,7 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.rpc.StatusCode; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.shuffle.BlockIdManager; @@ -104,7 +105,7 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac protected String clientType; protected SparkConf sparkConf; - protected ExpireCloseableSupplier managerClientSupplier; + protected Supplier managerClientSupplier; protected boolean rssStageRetryEnabled; protected boolean rssStageRetryForWriteFailureEnabled; protected boolean rssStageRetryForFetchFailureEnabled; @@ -589,7 +590,7 @@ protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoW RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(shuffleId); RssReassignOnStageRetryResponse rpcPartitionToShufflerServer = - getOrCreateShuffleManagerClientWrapper() + getOrCreateShuffleManagerClientSupplier() .get() .getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest); StageAttemptShuffleHandleInfo shuffleHandleInfo = @@ -609,7 +610,7 @@ protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBl RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(shuffleId); RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer = - getOrCreateShuffleManagerClientWrapper() + getOrCreateShuffleManagerClientSupplier() .get() .getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest); MutableShuffleHandleInfo shuffleHandleInfo = @@ -617,15 +618,14 @@ protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBl return shuffleHandleInfo; } - protected synchronized ExpireCloseableSupplier - getOrCreateShuffleManagerClientWrapper() { + protected synchronized Supplier getOrCreateShuffleManagerClientSupplier() { if (managerClientSupplier == null) { RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); String driver = rssConf.getString("driver.host", ""); int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS); this.managerClientSupplier = - new ExpireCloseableSupplier<>( + new ExpiringCloseableSupplier<>( () -> ShuffleManagerClientFactory.getInstance() .createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout)); @@ -815,8 +815,9 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( @Override public void stop() { - if (managerClientSupplier != null) { - managerClientSupplier.forceClose(); + if (managerClientSupplier != null + && managerClientSupplier instanceof ExpiringCloseableSupplier) { + ((ExpiringCloseableSupplier) managerClientSupplier).close(); } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index d668181445..4990e23dc1 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -215,7 +215,7 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { } } if (shuffleManagerRpcServiceEnabled) { - getOrCreateShuffleManagerClientWrapper(); + getOrCreateShuffleManagerClientSupplier(); } this.shuffleWriteClient = RssShuffleClientFactory.getInstance() diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index d1074986ff..4b4ec32c59 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; +import java.util.function.Supplier; import scala.Function0; import scala.Function2; @@ -54,7 +55,6 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; @@ -79,7 +79,7 @@ public class RssShuffleReader implements ShuffleReader { private List shuffleServerInfoList; private Configuration hadoopConf; private RssConf rssConf; - private ExpireCloseableSupplier managerClientSupplier; + private Supplier managerClientSupplier; public RssShuffleReader( int startPartition, @@ -94,7 +94,7 @@ public RssShuffleReader( Roaring64NavigableMap taskIdBitmap, RssConf rssConf, Map> partitionToServers, - ExpireCloseableSupplier managerClientSupplier) { + Supplier managerClientSupplier) { this.appId = rssShuffleHandle.getAppId(); this.startPartition = startPartition; this.endPartition = endPartition; diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 7ad2bab2a0..19963415c1 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -30,6 +30,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import scala.Function1; @@ -73,7 +74,6 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssSendFailedException; import org.apache.uniffle.common.exception.RssWaitFailedException; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; import org.apache.uniffle.storage.util.StorageType; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED; @@ -109,7 +109,7 @@ public class RssShuffleWriter extends ShuffleWriter { private final Set blockIds = Sets.newConcurrentHashSet(); private TaskContext taskContext; private SparkConf sparkConf; - private ExpireCloseableSupplier managerClientSupplier; + private Supplier managerClientSupplier; public RssShuffleWriter( String appId, @@ -121,7 +121,7 @@ public RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssShuffleHandle rssHandle, SimpleShuffleHandleInfo shuffleHandleInfo, TaskContext context) { @@ -151,7 +151,7 @@ private RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssShuffleHandle rssHandle, Function taskFailureCallback, ShuffleHandleInfo shuffleHandleInfo, @@ -190,7 +190,7 @@ public RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssShuffleHandle rssHandle, Function taskFailureCallback, TaskContext context, diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java index b0878f1421..05883efce2 100644 --- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java +++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java @@ -39,7 +39,7 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler; import org.apache.uniffle.storage.util.StorageType; @@ -103,7 +103,7 @@ public void readTest() throws Exception { taskIdBitmap, rssConf, partitionToServers, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient))); + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient))); validateResult(rssShuffleReaderSpy.read(), expectedData, 10); } diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 4c72a5a1b4..f847d89323 100644 --- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -55,7 +55,7 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.rpc.StatusCode; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.storage.util.StorageType; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -127,7 +127,7 @@ public void checkBlockSendResultTest() { manager, conf, mockShuffleWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -304,7 +304,7 @@ public void writeTest() throws Exception { manager, conf, mockShuffleWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -418,7 +418,7 @@ public void postBlockEventTest() throws Exception { manager, conf, mockWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index eb13cc4649..f01c230c0a 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -239,7 +239,7 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { } } if (shuffleManagerRpcServiceEnabled) { - getOrCreateShuffleManagerClientWrapper(); + getOrCreateShuffleManagerClientSupplier(); } int unregisterThreadPoolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 4b6c8c6763..19682bd654 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; +import java.util.function.Supplier; import scala.Function0; import scala.Function1; @@ -57,7 +58,6 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; @@ -84,7 +84,7 @@ public class RssShuffleReader implements ShuffleReader { private ShuffleReadMetrics readMetrics; private RssConf rssConf; private ShuffleDataDistributionType dataDistributionType; - private ExpireCloseableSupplier managerClientSupplier; + private Supplier managerClientSupplier; public RssShuffleReader( int startPartition, @@ -99,7 +99,7 @@ public RssShuffleReader( Map partitionToExpectBlocks, Roaring64NavigableMap taskIdBitmap, ShuffleReadMetrics readMetrics, - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssConf rssConf, ShuffleDataDistributionType dataDistributionType, Map> allPartitionToServers) { diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index b1ee18aeca..46e485efbd 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -35,6 +35,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import scala.Function1; @@ -85,7 +86,6 @@ import org.apache.uniffle.common.exception.RssSendFailedException; import org.apache.uniffle.common.exception.RssWaitFailedException; import org.apache.uniffle.common.rpc.StatusCode; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; import org.apache.uniffle.storage.util.StorageType; import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES; @@ -135,7 +135,7 @@ public class RssShuffleWriter extends ShuffleWriter { private static final Set STATUS_CODE_WITHOUT_BLOCK_RESEND = Sets.newHashSet(StatusCode.NO_REGISTER); - private ExpireCloseableSupplier managerClientSupplier; + private final Supplier managerClientSupplier; // Only for tests @VisibleForTesting @@ -149,7 +149,7 @@ public RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssShuffleHandle rssHandle, ShuffleHandleInfo shuffleHandleInfo, TaskContext context) { @@ -180,7 +180,7 @@ private RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssShuffleHandle rssHandle, Function taskFailureCallback, ShuffleHandleInfo shuffleHandleInfo, @@ -232,7 +232,7 @@ public RssShuffleWriter( RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, - ExpireCloseableSupplier managerClientSupplier, + Supplier managerClientSupplier, RssShuffleHandle rssHandle, Function taskFailureCallback, TaskContext context, diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java index 2b724fbf34..08c3189ba6 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java @@ -41,7 +41,7 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler; import org.apache.uniffle.storage.util.StorageType; @@ -111,7 +111,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -135,7 +135,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -156,7 +156,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, Roaring64NavigableMap.bitmapOf(), new ShuffleReadMetrics(), - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index fb135b74c0..069ef2c9cb 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -62,7 +62,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.rpc.StatusCode; -import org.apache.uniffle.common.util.ExpireCloseableSupplier; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.storage.util.StorageType; @@ -182,7 +182,7 @@ private RssShuffleWriter createMockWriter(MutableShuffleHandleInfo shuffleHandle manager, conf, mockShuffleWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, shuffleHandle, contextMock); @@ -455,7 +455,7 @@ public void blockFailureResendTest() { manager, conf, mockShuffleWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, shuffleHandleInfo, contextMock); @@ -594,7 +594,7 @@ public void checkBlockSendResultTest() { manager, conf, mockShuffleWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -743,7 +743,7 @@ public void dataConsistencyWhenSpillTriggeredTest() throws Exception { manager, conf, mockShuffleWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -868,7 +868,7 @@ public void writeTest() throws Exception { manager, conf, mockShuffleWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -984,7 +984,7 @@ public void postBlockEventTest() throws Exception { mockShuffleManager, conf, mockWriteClient, - new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient), + new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); diff --git a/common/src/main/java/org/apache/uniffle/common/util/CloseStateful.java b/common/src/main/java/org/apache/uniffle/common/util/CloseStateful.java new file mode 100644 index 0000000000..a1dbbd10db --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/util/CloseStateful.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import java.io.Closeable; + +/** CloseStateful is an interface that utilizes the ExpiringCloseableSupplier delegate. */ +public interface CloseStateful extends Closeable { + boolean isClosed(); +} diff --git a/common/src/main/java/org/apache/uniffle/common/util/ExpireCloseableSupplier.java b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java similarity index 57% rename from common/src/main/java/org/apache/uniffle/common/util/ExpireCloseableSupplier.java rename to common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java index 93cc3f39f8..f9cfaed88b 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/ExpireCloseableSupplier.java +++ b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java @@ -17,59 +17,84 @@ package org.apache.uniffle.common.util; -import java.io.Closeable; import java.io.Serializable; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ExpireCloseableSupplier implements Supplier, Serializable { +/** + * A Supplier for T cacheable and autocloseable with delay By using ExpiringCloseableSupplier to + * obtain an object, manual closure may not be necessary. + */ +public class ExpiringCloseableSupplier + implements Supplier, Serializable { private static final long serialVersionUID = 0; - private static final Logger LOG = LoggerFactory.getLogger(ExpireCloseableSupplier.class); - private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); + private static final Logger LOG = LoggerFactory.getLogger(ExpiringCloseableSupplier.class); + private static final int DEFAULT_DELAY_CLOSE_INTERVAL = 60000; + private ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); + private ScheduledFuture future; private volatile T t; private final Supplier delegate; private transient volatile long freshTime; private final long delayCloseInterval; - public ExpireCloseableSupplier(Supplier delegate) { - this(delegate, 10000); + public ExpiringCloseableSupplier(Supplier delegate) { + this(delegate, DEFAULT_DELAY_CLOSE_INTERVAL); } - public ExpireCloseableSupplier(Supplier delegate, long delayCloseInterval) { + public ExpiringCloseableSupplier(Supplier delegate, long delayCloseInterval) { this.delegate = delegate; this.delayCloseInterval = delayCloseInterval; } public synchronized T get() { freshTime = System.currentTimeMillis(); - if (t == null) { + if (t == null || t.isClosed()) { t = delegate.get(); + startDelayCloseScheduler(); } - executor.schedule(this::close, delayCloseInterval, TimeUnit.MILLISECONDS); return t; } - public synchronized void forceClose() { + public synchronized void close() { try { - if (t != null) { + if (t != null && !t.isClosed()) { t.close(); } } catch (Exception e) { LOG.warn("Failed to close {} the resource", t.getClass().getName(), e); } finally { t = null; + freshTime = 0; + shutdownDelayCloseScheduler(); } } - public synchronized void close() { + public void tryClose() { if (System.currentTimeMillis() - freshTime > delayCloseInterval) { - this.forceClose(); - freshTime = 0; + this.close(); + } + } + + private void startDelayCloseScheduler() { + shutdownDelayCloseScheduler(); + executor = Executors.newSingleThreadScheduledExecutor(); + future = + executor.scheduleAtFixedRate( + this::tryClose, delayCloseInterval, delayCloseInterval, TimeUnit.MILLISECONDS); + } + + private void shutdownDelayCloseScheduler() { + if (future != null && !future.isDone()) { + future.cancel(false); + } + if (executor != null && !executor.isShutdown()) { + executor.shutdown(); } } } diff --git a/common/src/test/java/org/apache/uniffle/common/util/ExpireCloseableSupplierTest.java b/common/src/test/java/org/apache/uniffle/common/util/ExpireCloseableSupplierTest.java deleted file mode 100644 index ff695b35cd..0000000000 --- a/common/src/test/java/org/apache/uniffle/common/util/ExpireCloseableSupplierTest.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.uniffle.common.util; - -import java.io.Closeable; -import java.io.IOException; -import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; - -import com.google.common.util.concurrent.Uninterruptibles; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -class ExpireCloseableSupplierTest { - - @Test - void test1() { - Supplier cf = () -> new MockClient(false); - ExpireCloseableSupplier mockClientExpireCloseableSupplier = - new ExpireCloseableSupplier<>(cf); - - MockClient mockClient = mockClientExpireCloseableSupplier.get(); - MockClient mockClient2 = mockClientExpireCloseableSupplier.get(); - assertTrue(mockClient == mockClient2); - mockClientExpireCloseableSupplier.forceClose(); - mockClientExpireCloseableSupplier.forceClose(); - } - - @Test - void test2() { - Supplier cf = () -> new MockClient(true); - ExpireCloseableSupplier mockClientExpireCloseableSupplier = - new ExpireCloseableSupplier<>(cf, 10); - MockClient mockClient1 = mockClientExpireCloseableSupplier.get(); - assertNotNull(mockClient1); - Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); - mockClientExpireCloseableSupplier.forceClose(); - } - - @Test - void forceClose() { - Supplier cf = () -> new MockClient(true); - ExpireCloseableSupplier mockClientExpireCloseableSupplier = - new ExpireCloseableSupplier<>(cf); - MockClient mockClient = mockClientExpireCloseableSupplier.get(); - mockClientExpireCloseableSupplier.forceClose(); - MockClient mockClient2 = mockClientExpireCloseableSupplier.get(); - assertTrue(mockClient != mockClient2); - } - - static class MockClient implements Closeable { - boolean withException; - - MockClient(boolean withException) { - this.withException = withException; - } - - @Override - public void close() throws IOException { - if (withException) { - throw new IOException("test exception!"); - } - } - } -} diff --git a/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java new file mode 100644 index 0000000000..4e32f2ea64 --- /dev/null +++ b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import com.google.common.util.concurrent.Uninterruptibles; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ExpiringCloseableSupplierTest { + + @Test + void testCacheable() { + Supplier cf = () -> new MockClient(false); + ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + + MockClient mockClient = mockClientSupplier.get(); + MockClient mockClient2 = mockClientSupplier.get(); + assertSame(mockClient, mockClient2); + mockClientSupplier.close(); + mockClientSupplier.close(); + } + + @Test + void testAutoCloseable() { + Supplier cf = () -> new MockClient(true); + ExpiringCloseableSupplier mockClientSupplier = + new ExpiringCloseableSupplier<>(cf, 10); + MockClient mockClient1 = mockClientSupplier.get(); + assertNotNull(mockClient1); + Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); + assertTrue(mockClient1.isClosed()); + MockClient mockClient2 = mockClientSupplier.get(); + assertNotSame(mockClient1, mockClient2); + mockClientSupplier.close(); + } + + @Test + void testRenew() { + Supplier cf = () -> new MockClient(true); + ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + MockClient mockClient = mockClientSupplier.get(); + mockClientSupplier.close(); + MockClient mockClient2 = mockClientSupplier.get(); + assertNotSame(mockClient, mockClient2); + } + + @Test + void testReClose() { + Supplier cf = () -> new MockClient(true); + ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + mockClientSupplier.get(); + mockClientSupplier.close(); + mockClientSupplier.close(); + } + + @Test + void testDelegateExtendClose() throws IOException { + Supplier cf = () -> new MockClient(false); + ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + MockClient mockClient = mockClientSupplier.get(); + mockClient.close(); + assertTrue(mockClient.isClosed()); + + MockClient mockClient1 = mockClientSupplier.get(); + assertNotSame(mockClient, mockClient1); + MockClient mockClient2 = mockClientSupplier.get(); + assertSame(mockClient1, mockClient2); + mockClientSupplier.close(); + } + + static class MockClient implements CloseStateful { + boolean withException; + AtomicBoolean closed = new AtomicBoolean(false); + + MockClient(boolean withException) { + this.withException = withException; + } + + @Override + public void close() throws IOException { + closed.set(true); + if (withException) { + throw new IOException("test exception!"); + } + } + + @Override + public boolean isClosed() { + return closed.get(); + } + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java index c5b412a9e6..13fb61eaeb 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java @@ -17,8 +17,6 @@ package org.apache.uniffle.client.api; -import java.io.Closeable; - import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest; import org.apache.uniffle.client.request.RssGetShuffleResultRequest; import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest; @@ -34,8 +32,9 @@ import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.client.response.RssReportShuffleResultResponse; import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; +import org.apache.uniffle.common.util.CloseStateful; -public interface ShuffleManagerClient extends Closeable { +public interface ShuffleManagerClient extends CloseStateful { RssReportShuffleFetchFailureResponse reportShuffleFetchFailure( RssReportShuffleFetchFailureRequest request); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java index f042df8f10..78886d4157 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java @@ -168,4 +168,9 @@ public RssReportShuffleResultResponse reportShuffleResult(RssReportShuffleResult getBlockingStub().reportShuffleResult(request.toProto()); return RssReportShuffleResultResponse.fromProto(response); } + + @Override + public boolean isClosed() { + return channel.isShutdown() || channel.isTerminated(); + } }