Skip to content

Commit

Permalink
add CloseStateful interface
Browse files Browse the repository at this point in the history
  • Loading branch information
xumanbu committed Jul 24, 2024
1 parent 1499396 commit 0bb986c
Show file tree
Hide file tree
Showing 21 changed files with 260 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

import scala.Option;
import scala.reflect.ClassTag;
Expand Down Expand Up @@ -52,7 +53,6 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;

import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;

Expand Down Expand Up @@ -343,7 +343,7 @@ public static boolean isStageResubmitSupported() {
}

public static RssException reportRssFetchFailedException(
ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier,
Supplier<ShuffleManagerClient> managerClientSupplier,
RssFetchFailedException rssFetchFailedException,
SparkConf sparkConf,
String appId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.shuffle.reader;

import java.io.IOException;
import java.util.Objects;
import java.util.function.Supplier;

import scala.Product2;
import scala.collection.AbstractIterator;
Expand All @@ -34,7 +34,6 @@
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;

public class RssFetchFailedIterator<K, C> extends AbstractIterator<Product2<K, C>> {
private static final Logger LOG = LoggerFactory.getLogger(RssFetchFailedIterator.class);
Expand All @@ -51,7 +50,7 @@ public static class Builder {
private int shuffleId;
private int partitionId;
private int stageAttemptId;
private ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier;
private Supplier<ShuffleManagerClient> managerClientSupplier;

private Builder() {}

Expand All @@ -75,8 +74,7 @@ Builder stageAttemptId(int stageAttemptId) {
return this;
}

Builder managerClientSupplier(
ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier) {
Builder managerClientSupplier(Supplier<ShuffleManagerClient> managerClientSupplier) {
this.managerClientSupplier = managerClientSupplier;
return this;
}
Expand All @@ -92,25 +90,22 @@ static Builder newBuilder() {
}

private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) {
try (ShuffleManagerClient client = builder.managerClientSupplier.get()) {
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
builder.appId,
builder.shuffleId,
builder.stageAttemptId,
builder.partitionId,
e.getMessage());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
builder.shuffleId, -1, builder.partitionId, e);
return new RssException(ffe);
}
} catch (IOException ioe) {
LOG.info("Error closing shuffle manager client with error:", ioe);
ShuffleManagerClient client = builder.managerClientSupplier.get();
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
builder.appId,
builder.shuffleId,
builder.stageAttemptId,
builder.partitionId,
e.getMessage());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
builder.shuffleId, -1, builder.partitionId, e);
return new RssException(ffe);
}
return e;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.roaringbitmap.longlong.Roaring64NavigableMap;
Expand All @@ -35,14 +36,13 @@
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;

/**
* This class delegates the blockIds reporting/getting operations from shuffleServer side to Spark
* driver side.
*/
public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl {
private ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier;
private Supplier<ShuffleManagerClient> managerClientSupplier;

public BlockIdSelfManagedShuffleWriteClient(
RssShuffleClientFactory.ExtendWriteClientBuilder builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.uniffle.shuffle;

import java.util.function.Supplier;

import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;

public class RssShuffleClientFactory extends ShuffleClientFactory {

Expand All @@ -42,18 +43,17 @@ public static ExtendWriteClientBuilder<?> newWriteBuilder() {
public static class ExtendWriteClientBuilder<T extends ExtendWriteClientBuilder<T>>
extends WriteClientBuilder<T> {
private boolean blockIdSelfManagedEnabled;
private ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier;
private Supplier<ShuffleManagerClient> managerClientSupplier;

public boolean isBlockIdSelfManagedEnabled() {
return blockIdSelfManagedEnabled;
}

public ExpireCloseableSupplier<ShuffleManagerClient> getManagerClientSupplier() {
public Supplier<ShuffleManagerClient> getManagerClientSupplier() {
return managerClientSupplier;
}

public T managerClientSupplier(
ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier) {
public T managerClientSupplier(Supplier<ShuffleManagerClient> managerClientSupplier) {
this.managerClientSupplier = managerClientSupplier;
return self();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -81,7 +82,7 @@
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;
import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.shuffle.BlockIdManager;

Expand All @@ -104,7 +105,7 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
protected String clientType;

protected SparkConf sparkConf;
protected ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier;
protected Supplier<ShuffleManagerClient> managerClientSupplier;
protected boolean rssStageRetryEnabled;
protected boolean rssStageRetryForWriteFailureEnabled;
protected boolean rssStageRetryForFetchFailureEnabled;
Expand Down Expand Up @@ -589,7 +590,7 @@ protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoW
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssReassignOnStageRetryResponse rpcPartitionToShufflerServer =
getOrCreateShuffleManagerClientWrapper()
getOrCreateShuffleManagerClientSupplier()
.get()
.getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
StageAttemptShuffleHandleInfo shuffleHandleInfo =
Expand All @@ -609,23 +610,22 @@ protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBl
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer =
getOrCreateShuffleManagerClientWrapper()
getOrCreateShuffleManagerClientSupplier()
.get()
.getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
MutableShuffleHandleInfo shuffleHandleInfo =
MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
return shuffleHandleInfo;
}

protected synchronized ExpireCloseableSupplier<ShuffleManagerClient>
getOrCreateShuffleManagerClientWrapper() {
protected synchronized Supplier<ShuffleManagerClient> getOrCreateShuffleManagerClientSupplier() {
if (managerClientSupplier == null) {
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
String driver = rssConf.getString("driver.host", "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
this.managerClientSupplier =
new ExpireCloseableSupplier<>(
new ExpiringCloseableSupplier<>(
() ->
ShuffleManagerClientFactory.getInstance()
.createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout));
Expand Down Expand Up @@ -815,8 +815,9 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(

@Override
public void stop() {
if (managerClientSupplier != null) {
managerClientSupplier.forceClose();
if (managerClientSupplier != null
&& managerClientSupplier instanceof ExpiringCloseableSupplier) {
((ExpiringCloseableSupplier<ShuffleManagerClient>) managerClientSupplier).close();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
}
}
if (shuffleManagerRpcServiceEnabled) {
getOrCreateShuffleManagerClientWrapper();
getOrCreateShuffleManagerClientSupplier();
}
this.shuffleWriteClient =
RssShuffleClientFactory.getInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import scala.Function0;
import scala.Function2;
Expand Down Expand Up @@ -54,7 +55,6 @@
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;

import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;

Expand All @@ -79,7 +79,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private List<ShuffleServerInfo> shuffleServerInfoList;
private Configuration hadoopConf;
private RssConf rssConf;
private ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier;
private Supplier<ShuffleManagerClient> managerClientSupplier;

public RssShuffleReader(
int startPartition,
Expand All @@ -94,7 +94,7 @@ public RssShuffleReader(
Roaring64NavigableMap taskIdBitmap,
RssConf rssConf,
Map<Integer, List<ShuffleServerInfo>> partitionToServers,
ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier) {
Supplier<ShuffleManagerClient> managerClientSupplier) {
this.appId = rssShuffleHandle.getAppId();
this.startPartition = startPartition;
this.endPartition = endPartition;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import scala.Function1;
Expand Down Expand Up @@ -73,7 +74,6 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;
import org.apache.uniffle.storage.util.StorageType;

import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
Expand Down Expand Up @@ -109,7 +109,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
private SparkConf sparkConf;
private ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier;
private Supplier<ShuffleManagerClient> managerClientSupplier;

public RssShuffleWriter(
String appId,
Expand All @@ -121,7 +121,7 @@ public RssShuffleWriter(
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier,
Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
SimpleShuffleHandleInfo shuffleHandleInfo,
TaskContext context) {
Expand Down Expand Up @@ -151,7 +151,7 @@ private RssShuffleWriter(
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier,
Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
ShuffleHandleInfo shuffleHandleInfo,
Expand Down Expand Up @@ -190,7 +190,7 @@ public RssShuffleWriter(
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
ExpireCloseableSupplier<ShuffleManagerClient> managerClientSupplier,
Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
TaskContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;
import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;

Expand Down Expand Up @@ -103,7 +103,7 @@ public void readTest() throws Exception {
taskIdBitmap,
rssConf,
partitionToServers,
new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient)));
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient)));

validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ExpireCloseableSupplier;
import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.storage.util.StorageType;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -127,7 +127,7 @@ public void checkBlockSendResultTest() {
manager,
conf,
mockShuffleWriteClient,
new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient),
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
Expand Down Expand Up @@ -304,7 +304,7 @@ public void writeTest() throws Exception {
manager,
conf,
mockShuffleWriteClient,
new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient),
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
Expand Down Expand Up @@ -418,7 +418,7 @@ public void postBlockEventTest() throws Exception {
manager,
conf,
mockWriteClient,
new ExpireCloseableSupplier<>(() -> mockShuffleManagerClient),
new ExpiringCloseableSupplier<>(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) {
}
}
if (shuffleManagerRpcServiceEnabled) {
getOrCreateShuffleManagerClientWrapper();
getOrCreateShuffleManagerClientSupplier();
}
int unregisterThreadPoolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
Expand Down
Loading

0 comments on commit 0bb986c

Please sign in to comment.