Skip to content

Commit

Permalink
[apache#1608] refactor: Reuse ShuffleManageClient in ShuffleReader an…
Browse files Browse the repository at this point in the history
…d ShuffleWriter
  • Loading branch information
xumanbu committed Jun 26, 2024
1 parent f8e4329 commit 9dd463f
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.shuffle.reader;

import java.io.IOException;
import java.util.Objects;
import java.util.function.Function;

import scala.Product2;
import scala.collection.AbstractIterator;
Expand All @@ -29,11 +29,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;

Expand All @@ -52,8 +49,8 @@ public static class Builder {
private int shuffleId;
private int partitionId;
private int stageAttemptId;
private String reportServerHost;
private int reportServerPort;
private Function<RssReportShuffleFetchFailureRequest, RssReportShuffleFetchFailureResponse>
sendReportFunc;

private Builder() {}

Expand All @@ -77,19 +74,15 @@ Builder stageAttemptId(int stageAttemptId) {
return this;
}

Builder reportServerHost(String host) {
this.reportServerHost = host;
return this;
}

Builder port(int port) {
this.reportServerPort = port;
Builder doReportFun(
Function<RssReportShuffleFetchFailureRequest, RssReportShuffleFetchFailureResponse>
doReportFun) {
this.sendReportFunc = doReportFun;
return this;
}

<K, C> RssFetchFailedIterator<K, C> build(Iterator<Product2<K, C>> iter) {
Objects.requireNonNull(this.appId);
Objects.requireNonNull(this.reportServerHost);
return new RssFetchFailedIterator<>(this, iter);
}
}
Expand All @@ -98,37 +91,25 @@ static Builder newBuilder() {
return new Builder();
}

private static ShuffleManagerClient createShuffleManagerClient(String host, int port)
throws IOException {
ClientType grpc = ClientType.GRPC;
// host is passed from spark.driver.bindAddress, which would be set when SparkContext is
// constructed.
return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port);
}

private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) {
String driver = builder.reportServerHost;
int port = builder.reportServerPort;
// todo: reuse this manager client if this is a bottleneck.
try (ShuffleManagerClient client = createShuffleManagerClient(driver, port)) {
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
builder.appId,
builder.shuffleId,
builder.stageAttemptId,
builder.partitionId,
e.getMessage());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
builder.shuffleId, -1, builder.partitionId, e);
return new RssException(ffe);
}
} catch (IOException ioe) {
LOG.info("Error closing shuffle manager client with error:", ioe);
private RssException generateFetchFailedIfNecessary(
RssFetchFailedException e,
Function<RssReportShuffleFetchFailureRequest, RssReportShuffleFetchFailureResponse>
doReportFun) {
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
builder.appId,
builder.shuffleId,
builder.stageAttemptId,
builder.partitionId,
e.getMessage());
RssReportShuffleFetchFailureResponse response = doReportFun.apply(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
builder.shuffleId, -1, builder.partitionId, e);
return new RssException(ffe);
}
return e;
}
Expand All @@ -138,7 +119,7 @@ public boolean hasNext() {
try {
return this.iter.hasNext();
} catch (RssFetchFailedException e) {
throw generateFetchFailedIfNecessary(e);
throw generateFetchFailedIfNecessary(e, builder.sendReportFunc);
}
}

Expand All @@ -147,7 +128,7 @@ public Product2<K, C> next() {
try {
return this.iter.next();
} catch (RssFetchFailedException e) {
throw generateFetchFailedIfNecessary(e);
throw generateFetchFailedIfNecessary(e, builder.sendReportFunc);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ public <K, V> ShuffleWriter<K, V> getWriter(
this,
sparkConf,
shuffleWriteClient,
shuffleManagerClient,
rssHandle,
this::markFailedTask,
context,
Expand Down Expand Up @@ -721,6 +722,7 @@ public <K, C> ShuffleReader<K, C> getReaderImpl(
blockIdBitmap, startPartition, endPartition, blockIdLayout),
taskIdBitmap,
readMetrics,
shuffleManagerClient,
RssSparkConfig.toRssConf(sparkConf),
dataDistributionType,
shuffleHandleInfo.getAllPartitionServersForReader());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.util.RssClientConfig;
Expand All @@ -58,7 +59,6 @@
import org.apache.uniffle.common.config.RssConf;

import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;

public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private static final Logger LOG = LoggerFactory.getLogger(RssShuffleReader.class);
Expand All @@ -83,6 +83,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private ShuffleReadMetrics readMetrics;
private RssConf rssConf;
private ShuffleDataDistributionType dataDistributionType;
private ShuffleManagerClient shuffleManagerClient;

public RssShuffleReader(
int startPartition,
Expand All @@ -97,6 +98,7 @@ public RssShuffleReader(
Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks,
Roaring64NavigableMap taskIdBitmap,
ShuffleReadMetrics readMetrics,
ShuffleManagerClient shuffleManagerClient,
RssConf rssConf,
ShuffleDataDistributionType dataDistributionType,
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers) {
Expand All @@ -120,6 +122,7 @@ public RssShuffleReader(
this.partitionToShuffleServers = allPartitionToServers;
this.rssConf = rssConf;
this.dataDistributionType = dataDistributionType;
this.shuffleManagerClient = shuffleManagerClient;
}

@Override
Expand Down Expand Up @@ -193,16 +196,13 @@ public Void apply(TaskContext context) {
// resubmit stage and shuffle manager server port are both set
if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) {
String driver = rssConf.getString(DRIVER_HOST, "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
resultIter =
RssFetchFailedIterator.newBuilder()
.appId(appId)
.shuffleId(shuffleId)
.partitionId(startPartition)
.stageAttemptId(context.stageAttemptNumber())
.reportServerHost(driver)
.port(port)
.doReportFun(shuffleManagerClient::reportShuffleFetchFailure)
.build(resultIter);
}
return resultIter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND =
Sets.newHashSet(StatusCode.NO_REGISTER);

private ShuffleManagerClient shuffleManagerClient;

// Only for tests
@VisibleForTesting
public RssShuffleWriter(
Expand All @@ -150,6 +152,7 @@ public RssShuffleWriter(
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
ShuffleManagerClient shuffleManagerClient,
RssShuffleHandle<K, V, C> rssHandle,
ShuffleHandleInfo shuffleHandleInfo,
TaskContext context) {
Expand All @@ -162,6 +165,7 @@ public RssShuffleWriter(
shuffleManager,
sparkConf,
shuffleWriteClient,
shuffleManagerClient,
rssHandle,
(tid) -> true,
shuffleHandleInfo,
Expand All @@ -179,6 +183,7 @@ private RssShuffleWriter(
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
ShuffleManagerClient shuffleManagerClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
ShuffleHandleInfo shuffleHandleInfo,
Expand Down Expand Up @@ -211,6 +216,7 @@ private RssShuffleWriter(
this.taskFailureCallback = taskFailureCallback;
this.taskContext = context;
this.sparkConf = sparkConf;
this.shuffleManagerClient = shuffleManagerClient;
this.blockFailSentRetryEnabled =
sparkConf.getBoolean(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
Expand All @@ -229,6 +235,7 @@ public RssShuffleWriter(
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
ShuffleManagerClient shuffleManagerClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
TaskContext context,
Expand All @@ -242,6 +249,7 @@ public RssShuffleWriter(
shuffleManager,
sparkConf,
shuffleWriteClient,
shuffleManagerClient,
rssHandle,
taskFailureCallback,
shuffleHandleInfo,
Expand Down Expand Up @@ -835,33 +843,26 @@ private void throwFetchFailedIfNecessary(Exception e) {
taskContext.stageAttemptNumber(),
shuffleServerInfos,
e.getMessage());
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
String driver = rssConf.getString("driver.host", "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) {
RssReportShuffleWriteFailureResponse response =
shuffleManagerClient.reportShuffleWriteFailure(req);
if (response.getReSubmitWholeStage()) {
RssReassignServersRequest rssReassignServersRequest =
new RssReassignServersRequest(
taskContext.stageId(),
taskContext.stageAttemptNumber(),
shuffleId,
partitioner.numPartitions());
RssReassignServersResponse rssReassignServersResponse =
shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest);
LOG.info(
"Whether the reassignment is successful: {}",
rssReassignServersResponse.isNeedReassign());
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
shuffleId, -1, taskContext.stageAttemptNumber(), e);
throw new RssException(ffe);
}
} catch (IOException ioe) {
LOG.info("Error closing shuffle manager client with error:", ioe);
RssReportShuffleWriteFailureResponse response =
shuffleManagerClient.reportShuffleWriteFailure(req);
if (response.getReSubmitWholeStage()) {
RssReassignServersRequest rssReassignServersRequest =
new RssReassignServersRequest(
taskContext.stageId(),
taskContext.stageAttemptNumber(),
shuffleId,
partitioner.numPartitions());
RssReassignServersResponse rssReassignServersResponse =
shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest);
LOG.info(
"Whether the reassignment is successful: {}",
rssReassignServersResponse.isNeedReassign());
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
shuffleId, -1, taskContext.stageAttemptNumber(), e);
throw new RssException(ffe);
}
}
throw new RssException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;

import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
Expand Down Expand Up @@ -93,6 +94,7 @@ public void readTest() throws Exception {
rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class);
RssShuffleReader<String, String> rssShuffleReaderSpy =
spy(
new RssShuffleReader<>(
Expand All @@ -108,6 +110,7 @@ public void readTest() throws Exception {
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
mockShuffleManagerClient,
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
Expand All @@ -131,6 +134,7 @@ public void readTest() throws Exception {
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
mockShuffleManagerClient,
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
Expand All @@ -151,6 +155,7 @@ public void readTest() throws Exception {
partitionToExpectBlocks,
Roaring64NavigableMap.bitmapOf(),
new ShuffleReadMetrics(),
mockShuffleManagerClient,
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
Expand Down
Loading

0 comments on commit 9dd463f

Please sign in to comment.