From f1cb90048d7c1fec8b10cccee86b7d541901df33 Mon Sep 17 00:00:00 2001 From: Xianjin Date: Wed, 24 Jul 2024 20:26:30 +0800 Subject: [PATCH] Refine ExpiringCloseableSupplier --- .../manager/RssShuffleManagerBase.java | 3 +- .../shuffle/reader/RssShuffleReaderTest.java | 6 +- .../shuffle/writer/RssShuffleWriterTest.java | 12 ++-- .../util/ExpiringCloseableSupplier.java | 63 +++++++++-------- .../util/ExpiringCloseableSupplierTest.java | 70 +++++++++++++++++-- 5 files changed, 107 insertions(+), 47 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 186ae5b233..e9b1e57931 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -624,8 +624,7 @@ protected synchronized Supplier getOrCreateShuffleManagerC String driver = rssConf.getString("driver.host", ""); int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS); - this.managerClientSupplier = - new ExpiringCloseableSupplier<>( + this.managerClientSupplier = ExpiringCloseableSupplier.of( () -> ShuffleManagerClientFactory.getInstance() .createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout)); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java index 08c3189ba6..bc77f71920 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java @@ -111,7 +111,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -135,7 +135,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -156,7 +156,7 @@ public void readTest() throws Exception { partitionToExpectBlocks, Roaring64NavigableMap.bitmapOf(), new ShuffleReadMetrics(), - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 069ef2c9cb..a4317aae85 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -182,7 +182,7 @@ private RssShuffleWriter createMockWriter(MutableShuffleHandleInfo shuffleHandle manager, conf, mockShuffleWriteClient, - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, shuffleHandle, contextMock); @@ -455,7 +455,7 @@ public void blockFailureResendTest() { manager, conf, mockShuffleWriteClient, - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, shuffleHandleInfo, contextMock); @@ -594,7 +594,7 @@ public void checkBlockSendResultTest() { manager, conf, mockShuffleWriteClient, - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -743,7 +743,7 @@ public void dataConsistencyWhenSpillTriggeredTest() throws Exception { manager, conf, mockShuffleWriteClient, - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -868,7 +868,7 @@ public void writeTest() throws Exception { manager, conf, mockShuffleWriteClient, - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -984,7 +984,7 @@ public void postBlockEventTest() throws Exception { mockShuffleManager, conf, mockWriteClient, - new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); diff --git a/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java index f9cfaed88b..336509bea6 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java +++ b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java @@ -17,8 +17,8 @@ package org.apache.uniffle.common.util; +import java.io.IOException; import java.io.Serializable; -import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -28,7 +28,7 @@ import org.slf4j.LoggerFactory; /** - * A Supplier for T cacheable and autocloseable with delay By using ExpiringCloseableSupplier to + * A Supplier for T cacheable and autocloseable with delay by using ExpiringCloseableSupplier to * obtain an object, manual closure may not be necessary. */ public class ExpiringCloseableSupplier @@ -36,27 +36,26 @@ public class ExpiringCloseableSupplier private static final long serialVersionUID = 0; private static final Logger LOG = LoggerFactory.getLogger(ExpiringCloseableSupplier.class); private static final int DEFAULT_DELAY_CLOSE_INTERVAL = 60000; - private ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); - private ScheduledFuture future; - private volatile T t; + private static final ScheduledExecutorService executor = + ThreadUtils.getDaemonSingleThreadScheduledExecutor("ExpiringCloseableSupplier"); + private final Supplier delegate; - private transient volatile long freshTime; private final long delayCloseInterval; - public ExpiringCloseableSupplier(Supplier delegate) { - this(delegate, DEFAULT_DELAY_CLOSE_INTERVAL); - } + private transient volatile ScheduledFuture future; + private transient volatile long accessTime; + private transient volatile T t; - public ExpiringCloseableSupplier(Supplier delegate, long delayCloseInterval) { + private ExpiringCloseableSupplier(Supplier delegate, long delayCloseInterval) { this.delegate = delegate; this.delayCloseInterval = delayCloseInterval; } public synchronized T get() { - freshTime = System.currentTimeMillis(); + accessTime = System.currentTimeMillis(); if (t == null || t.isClosed()) { - t = delegate.get(); - startDelayCloseScheduler(); + this.t = delegate.get(); + ensureCloseFutureScheduled(); } return t; } @@ -66,35 +65,41 @@ public synchronized void close() { if (t != null && !t.isClosed()) { t.close(); } - } catch (Exception e) { - LOG.warn("Failed to close {} the resource", t.getClass().getName(), e); + } catch (IOException ioe) { + LOG.warn("Failed to close {} the resource", t.getClass().getName(), ioe); } finally { - t = null; - freshTime = 0; - shutdownDelayCloseScheduler(); + this.t = null; + this.accessTime = System.currentTimeMillis(); + cancelCloseFuture(); } } - public void tryClose() { - if (System.currentTimeMillis() - freshTime > delayCloseInterval) { - this.close(); + private void tryClose() { + if (System.currentTimeMillis() - accessTime > delayCloseInterval) { + close(); } } - private void startDelayCloseScheduler() { - shutdownDelayCloseScheduler(); - executor = Executors.newSingleThreadScheduledExecutor(); - future = + private void ensureCloseFutureScheduled() { + cancelCloseFuture(); + this.future = executor.scheduleAtFixedRate( this::tryClose, delayCloseInterval, delayCloseInterval, TimeUnit.MILLISECONDS); } - private void shutdownDelayCloseScheduler() { + private void cancelCloseFuture() { if (future != null && !future.isDone()) { future.cancel(false); + this.future = null; } - if (executor != null && !executor.isShutdown()) { - executor.shutdown(); - } + } + + public static ExpiringCloseableSupplier of(Supplier delegate) { + return new ExpiringCloseableSupplier<>(delegate, DEFAULT_DELAY_CLOSE_INTERVAL); + } + + public static ExpiringCloseableSupplier of( + Supplier delegate, long delayCloseInterval) { + return new ExpiringCloseableSupplier<>(delegate, delayCloseInterval); } } diff --git a/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java index 4e32f2ea64..eab01bd296 100644 --- a/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java +++ b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java @@ -18,13 +18,20 @@ package org.apache.uniffle.common.util; import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.commons.lang3.SerializationUtils; +import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; @@ -35,7 +42,7 @@ class ExpiringCloseableSupplierTest { @Test void testCacheable() { Supplier cf = () -> new MockClient(false); - ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + ExpiringCloseableSupplier mockClientSupplier = ExpiringCloseableSupplier.of(cf); MockClient mockClient = mockClientSupplier.get(); MockClient mockClient2 = mockClientSupplier.get(); @@ -47,8 +54,7 @@ void testCacheable() { @Test void testAutoCloseable() { Supplier cf = () -> new MockClient(true); - ExpiringCloseableSupplier mockClientSupplier = - new ExpiringCloseableSupplier<>(cf, 10); + ExpiringCloseableSupplier mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10); MockClient mockClient1 = mockClientSupplier.get(); assertNotNull(mockClient1); Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); @@ -61,7 +67,7 @@ void testAutoCloseable() { @Test void testRenew() { Supplier cf = () -> new MockClient(true); - ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + ExpiringCloseableSupplier mockClientSupplier = ExpiringCloseableSupplier.of(cf); MockClient mockClient = mockClientSupplier.get(); mockClientSupplier.close(); MockClient mockClient2 = mockClientSupplier.get(); @@ -71,7 +77,7 @@ void testRenew() { @Test void testReClose() { Supplier cf = () -> new MockClient(true); - ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + ExpiringCloseableSupplier mockClientSupplier = ExpiringCloseableSupplier.of(cf); mockClientSupplier.get(); mockClientSupplier.close(); mockClientSupplier.close(); @@ -80,7 +86,7 @@ void testReClose() { @Test void testDelegateExtendClose() throws IOException { Supplier cf = () -> new MockClient(false); - ExpiringCloseableSupplier mockClientSupplier = new ExpiringCloseableSupplier<>(cf); + ExpiringCloseableSupplier mockClientSupplier = ExpiringCloseableSupplier.of(cf); MockClient mockClient = mockClientSupplier.get(); mockClient.close(); assertTrue(mockClient.isClosed()); @@ -92,7 +98,57 @@ void testDelegateExtendClose() throws IOException { mockClientSupplier.close(); } - static class MockClient implements CloseStateful { + @Test + public void testSerialization() { + Supplier cf = (Supplier & Serializable) () -> new MockClient(true); + ExpiringCloseableSupplier mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10); + MockClient mockClient = mockClientSupplier.get(); + + ExpiringCloseableSupplier mockClientSupplier2 = + SerializationUtils.roundtrip(mockClientSupplier); + MockClient mockClient2 = mockClientSupplier2.get(); + assertFalse(mockClient2.isClosed()); + assertNotSame(mockClient, mockClient2); + Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); + assertTrue(mockClient.isClosed()); + assertTrue(mockClient2.isClosed()); + } + + @Test + public void testMultipleSupplierShouldNotInterfere() { + Supplier cf = () -> new MockClient(true); + ExpiringCloseableSupplier mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10); + ExpiringCloseableSupplier mockClientSupplier2 = + ExpiringCloseableSupplier.of(cf, 10); + MockClient mockClient = mockClientSupplier.get(); + MockClient mockClient2 = mockClientSupplier2.get(); + Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); + assertTrue(mockClient.isClosed()); + assertTrue(mockClient2.isClosed()); + mockClientSupplier.close(); + mockClientSupplier.close(); + mockClientSupplier2.close(); + mockClientSupplier2.close(); + } + + @Test + public void stressingTestManySuppliers() { + int num = 100000; // this should be sufficient for most production use cases + Supplier cf = () -> new MockClient(true); + List clients = Lists.newArrayList(); + Random random = new Random(42); + for (int i = 0; i < num; i++) { + int delayCloseInterval = random.nextInt(1000) + 1; + ExpiringCloseableSupplier mockClientSupplier = + ExpiringCloseableSupplier.of(cf, delayCloseInterval); + MockClient mockClient = mockClientSupplier.get(); + clients.add(mockClient); + } + Awaitility.waitAtMost(5, TimeUnit.SECONDS) + .until(() -> clients.stream().allMatch(MockClient::isClosed)); + } + + private static class MockClient implements CloseStateful, Serializable { boolean withException; AtomicBoolean closed = new AtomicBoolean(false);