Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine ExpiringCloseableSupplier #1

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,7 @@ protected synchronized Supplier<ShuffleManagerClient> getOrCreateShuffleManagerC
String driver = rssConf.getString("driver.host", "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
this.managerClientSupplier =
new ExpiringCloseableSupplier<>(
this.managerClientSupplier = ExpiringCloseableSupplier.of(
() ->
ShuffleManagerClientFactory.getInstance()
.createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public void readTest() throws Exception {
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
Expand All @@ -135,7 +135,7 @@ public void readTest() throws Exception {
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
Expand All @@ -156,7 +156,7 @@ public void readTest() throws Exception {
partitionToExpectBlocks,
Roaring64NavigableMap.bitmapOf(),
new ShuffleReadMetrics(),
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ private RssShuffleWriter createMockWriter(MutableShuffleHandleInfo shuffleHandle
manager,
conf,
mockShuffleWriteClient,
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
shuffleHandle,
contextMock);
Expand Down Expand Up @@ -455,7 +455,7 @@ public void blockFailureResendTest() {
manager,
conf,
mockShuffleWriteClient,
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
shuffleHandleInfo,
contextMock);
Expand Down Expand Up @@ -594,7 +594,7 @@ public void checkBlockSendResultTest() {
manager,
conf,
mockShuffleWriteClient,
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
Expand Down Expand Up @@ -743,7 +743,7 @@ public void dataConsistencyWhenSpillTriggeredTest() throws Exception {
manager,
conf,
mockShuffleWriteClient,
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
Expand Down Expand Up @@ -868,7 +868,7 @@ public void writeTest() throws Exception {
manager,
conf,
mockShuffleWriteClient,
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
Expand Down Expand Up @@ -984,7 +984,7 @@ public void postBlockEventTest() throws Exception {
mockShuffleManager,
conf,
mockWriteClient,
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.uniffle.common.util;

import java.io.IOException;
import java.io.Serializable;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
Expand All @@ -28,35 +28,34 @@
import org.slf4j.LoggerFactory;

/**
* A Supplier for T cacheable and autocloseable with delay By using ExpiringCloseableSupplier to
* A Supplier for T cacheable and autocloseable with delay by using ExpiringCloseableSupplier to
* obtain an object, manual closure may not be necessary.
*/
public class ExpiringCloseableSupplier<T extends CloseStateful>
implements Supplier<T>, Serializable {
private static final long serialVersionUID = 0;
private static final Logger LOG = LoggerFactory.getLogger(ExpiringCloseableSupplier.class);
private static final int DEFAULT_DELAY_CLOSE_INTERVAL = 60000;
private ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
private ScheduledFuture<?> future;
private volatile T t;
private static final ScheduledExecutorService executor =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("ExpiringCloseableSupplier");

private final Supplier<T> delegate;
private transient volatile long freshTime;
private final long delayCloseInterval;

public ExpiringCloseableSupplier(Supplier<T> delegate) {
this(delegate, DEFAULT_DELAY_CLOSE_INTERVAL);
}
private transient volatile ScheduledFuture<?> future;
private transient volatile long accessTime;
private transient volatile T t;

public ExpiringCloseableSupplier(Supplier<T> delegate, long delayCloseInterval) {
private ExpiringCloseableSupplier(Supplier<T> delegate, long delayCloseInterval) {
this.delegate = delegate;
this.delayCloseInterval = delayCloseInterval;
}

public synchronized T get() {
freshTime = System.currentTimeMillis();
accessTime = System.currentTimeMillis();
if (t == null || t.isClosed()) {
t = delegate.get();
startDelayCloseScheduler();
this.t = delegate.get();
ensureCloseFutureScheduled();
}
return t;
}
Expand All @@ -66,35 +65,41 @@ public synchronized void close() {
if (t != null && !t.isClosed()) {
t.close();
}
} catch (Exception e) {
LOG.warn("Failed to close {} the resource", t.getClass().getName(), e);
} catch (IOException ioe) {
LOG.warn("Failed to close {} the resource", t.getClass().getName(), ioe);
} finally {
t = null;
freshTime = 0;
shutdownDelayCloseScheduler();
this.t = null;
this.accessTime = System.currentTimeMillis();
cancelCloseFuture();
}
}

public void tryClose() {
if (System.currentTimeMillis() - freshTime > delayCloseInterval) {
this.close();
private void tryClose() {
if (System.currentTimeMillis() - accessTime > delayCloseInterval) {
close();
}
}

private void startDelayCloseScheduler() {
shutdownDelayCloseScheduler();
executor = Executors.newSingleThreadScheduledExecutor();
future =
private void ensureCloseFutureScheduled() {
cancelCloseFuture();
this.future =
executor.scheduleAtFixedRate(
this::tryClose, delayCloseInterval, delayCloseInterval, TimeUnit.MILLISECONDS);
}

private void shutdownDelayCloseScheduler() {
private void cancelCloseFuture() {
if (future != null && !future.isDone()) {
future.cancel(false);
this.future = null;
}
if (executor != null && !executor.isShutdown()) {
executor.shutdown();
}
}

public static <T extends CloseStateful> ExpiringCloseableSupplier<T> of(Supplier<T> delegate) {
return new ExpiringCloseableSupplier<>(delegate, DEFAULT_DELAY_CLOSE_INTERVAL);
}

public static <T extends CloseStateful> ExpiringCloseableSupplier<T> of(
Supplier<T> delegate, long delayCloseInterval) {
return new ExpiringCloseableSupplier<>(delegate, delayCloseInterval);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@
package org.apache.uniffle.common.util;

import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.commons.lang3.SerializationUtils;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
Expand All @@ -35,7 +42,7 @@ class ExpiringCloseableSupplierTest {
@Test
void testCacheable() {
Supplier<MockClient> cf = () -> new MockClient(false);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = new ExpiringCloseableSupplier<>(cf);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf);

MockClient mockClient = mockClientSupplier.get();
MockClient mockClient2 = mockClientSupplier.get();
Expand All @@ -47,8 +54,7 @@ void testCacheable() {
@Test
void testAutoCloseable() {
Supplier<MockClient> cf = () -> new MockClient(true);
ExpiringCloseableSupplier<MockClient> mockClientSupplier =
new ExpiringCloseableSupplier<>(cf, 10);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10);
MockClient mockClient1 = mockClientSupplier.get();
assertNotNull(mockClient1);
Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
Expand All @@ -61,7 +67,7 @@ void testAutoCloseable() {
@Test
void testRenew() {
Supplier<MockClient> cf = () -> new MockClient(true);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = new ExpiringCloseableSupplier<>(cf);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf);
MockClient mockClient = mockClientSupplier.get();
mockClientSupplier.close();
MockClient mockClient2 = mockClientSupplier.get();
Expand All @@ -71,7 +77,7 @@ void testRenew() {
@Test
void testReClose() {
Supplier<MockClient> cf = () -> new MockClient(true);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = new ExpiringCloseableSupplier<>(cf);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf);
mockClientSupplier.get();
mockClientSupplier.close();
mockClientSupplier.close();
Expand All @@ -80,7 +86,7 @@ void testReClose() {
@Test
void testDelegateExtendClose() throws IOException {
Supplier<MockClient> cf = () -> new MockClient(false);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = new ExpiringCloseableSupplier<>(cf);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf);
MockClient mockClient = mockClientSupplier.get();
mockClient.close();
assertTrue(mockClient.isClosed());
Expand All @@ -92,7 +98,57 @@ void testDelegateExtendClose() throws IOException {
mockClientSupplier.close();
}

static class MockClient implements CloseStateful {
@Test
public void testSerialization() {
Supplier<MockClient> cf = (Supplier<MockClient> & Serializable) () -> new MockClient(true);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10);
MockClient mockClient = mockClientSupplier.get();

ExpiringCloseableSupplier<MockClient> mockClientSupplier2 =
SerializationUtils.roundtrip(mockClientSupplier);
MockClient mockClient2 = mockClientSupplier2.get();
assertFalse(mockClient2.isClosed());
assertNotSame(mockClient, mockClient2);
Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
assertTrue(mockClient.isClosed());
assertTrue(mockClient2.isClosed());
}

@Test
public void testMultipleSupplierShouldNotInterfere() {
Supplier<MockClient> cf = () -> new MockClient(true);
ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10);
ExpiringCloseableSupplier<MockClient> mockClientSupplier2 =
ExpiringCloseableSupplier.of(cf, 10);
MockClient mockClient = mockClientSupplier.get();
MockClient mockClient2 = mockClientSupplier2.get();
Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
assertTrue(mockClient.isClosed());
assertTrue(mockClient2.isClosed());
mockClientSupplier.close();
mockClientSupplier.close();
mockClientSupplier2.close();
mockClientSupplier2.close();
}

@Test
public void stressingTestManySuppliers() {
int num = 100000; // this should be sufficient for most production use cases
Supplier<MockClient> cf = () -> new MockClient(true);
List<MockClient> clients = Lists.newArrayList();
Random random = new Random(42);
for (int i = 0; i < num; i++) {
int delayCloseInterval = random.nextInt(1000) + 1;
ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf, delayCloseInterval);
MockClient mockClient = mockClientSupplier.get();
clients.add(mockClient);
}
Awaitility.waitAtMost(5, TimeUnit.SECONDS)
.until(() -> clients.stream().allMatch(MockClient::isClosed));
}

private static class MockClient implements CloseStateful, Serializable {
boolean withException;
AtomicBoolean closed = new AtomicBoolean(false);

Expand Down
Loading