Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor multipart download to a more async model #10349

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -241,37 +241,23 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
return;
}

final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
final List<ReadContext.StreamPartCreator> blobPartInputStreamFutures = new ArrayList<>();
final long blobSize = blobMetadata.objectSize();
final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

if (numberOfParts == null) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
blobPartInputStreamFutures.add(() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
} else {
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber));
final int innerPartNumber = partNumber;
blobPartInputStreamFutures.add(
() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, innerPartNumber)
);
}
andrross marked this conversation as resolved.
Show resolved Hide resolved
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new))
.whenComplete((unused, partThrowable) -> {
if (partThrowable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
Exception ex = partThrowable.getCause() instanceof Exception
? (Exception) partThrowable.getCause()
: new Exception(partThrowable.getCause());
listener.onFailure(ex);
}
});
listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum));
});
} catch (Exception ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception {
assertEquals(objectSize, readContext.getBlobSize());

for (int partNumber = 1; partNumber < objectPartCount; partNumber++) {
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber);
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get().join();
final int offset = partNumber * partSize;
assertEquals(partSize, inputStreamContainer.getContentLength());
assertEquals(offset, inputStreamContainer.getOffset());
Expand Down Expand Up @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception {
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get();
InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get().join();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ public static final IndexShard newIndexShard(
null,
null,
() -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL,
nodeId
nodeId,
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -124,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
long contentLength = listBlobs().get(blobName).length();
long partSize = contentLength / 10;
int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
List<InputStreamContainer> blobPartStreams = new ArrayList<>();
List<ReadContext.StreamPartCreator> blobPartStreams = new ArrayList<>();
for (int partNumber = 0; partNumber < numberOfParts; partNumber++) {
long offset = partNumber * partSize;
InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset);
blobPartStreams.add(blobPartStream);
blobPartStreams.add(() -> CompletableFuture.completedFuture(blobPartStream));
}
ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null);
listener.onResponse(blobReadContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@

import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;

/**
* An extension of {@link BlobContainer} that adds {@link AsyncMultiStreamBlobContainer#asyncBlobUpload} to allow
Expand Down Expand Up @@ -45,19 +42,6 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer {
@ExperimentalApi
void readBlobAsync(String blobName, ActionListener<ReadContext> listener);

/**
* Asynchronously downloads the blob to the specified location using an executor from the thread pool.
* @param blobName The name of the blob for which needs to be downloaded.
* @param fileLocation The path on local disk where the blob needs to be downloaded.
* @param threadPool The threadpool instance which will provide the executor for performing a multipart download.
* @param completionListener Listener which will be notified when the download is complete.
*/
@ExperimentalApi
default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener);
readBlobAsync(blobName, readContextListener);
}

/*
* Wether underlying blobContainer can verify integrity of data after transfer. If true and if expected
* checksum is provided in WriteContext, then the checksum of transferred data is compared with expected checksum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ public long getBlobSize() {
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList());
public List<StreamPartCreator> getPartStreams() {
return super.getPartStreams().stream()
.map(supplier -> (StreamPartCreator) () -> supplier.get().thenApply(this::decryptInputStreamContainer))
.collect(Collectors.toUnmodifiableList());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@
import org.opensearch.common.io.InputStreamContainer;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;

/**
* ReadContext is used to encapsulate all data needed by <code>BlobContainer#readBlobAsync</code>
*/
@ExperimentalApi
public class ReadContext {
private final long blobSize;
private final List<InputStreamContainer> partStreams;
private final List<StreamPartCreator> asyncPartStreams;
private final String blobChecksum;

public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String blobChecksum) {
public ReadContext(long blobSize, List<StreamPartCreator> asyncPartStreams, String blobChecksum) {
this.blobSize = blobSize;
this.partStreams = partStreams;
this.asyncPartStreams = asyncPartStreams;
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.asyncPartStreams = readContext.asyncPartStreams;
this.blobChecksum = readContext.blobChecksum;
}

Expand All @@ -39,14 +41,30 @@ public String getBlobChecksum() {
}

public int getNumberOfParts() {
return partStreams.size();
return asyncPartStreams.size();
}

public long getBlobSize() {
return blobSize;
}

public List<InputStreamContainer> getPartStreams() {
return partStreams;
public List<StreamPartCreator> getPartStreams() {
return asyncPartStreams;
}

/**
* Functional interface defining an instance that can create an async action
* to create a part of an object represented as an InputStreamContainer.
*/
@FunctionalInterface
public interface StreamPartCreator extends Supplier<CompletableFuture<InputStreamContainer>> {
/**
* Kicks off a async process to start streaming.
*
* @return When the returned future is completed, streaming has
* just begun. Clients must fully consume the resulting stream.
*/
@Override
CompletableFuture<InputStreamContainer> get();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,37 @@

package org.opensearch.common.blobstore.stream.read.listener;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.annotation.InternalApi;
import org.opensearch.common.io.Channels;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;

import java.io.IOException;
import java.io.InputStream;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.UnaryOperator;

/**
* FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel}
* instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion.
* instance.
*/
@InternalApi
class FilePartWriter implements Runnable {

private final int partNumber;
private final InputStreamContainer blobPartStreamContainer;
private final Path fileLocation;
private final AtomicBoolean anyPartStreamFailed;
private final ActionListener<Integer> fileCompletionListener;
private static final Logger logger = LogManager.getLogger(FilePartWriter.class);

class FilePartWriter {
// 8 MB buffer for transfer
private static final int BUFFER_SIZE = 8 * 1024 * 2024;

public FilePartWriter(
int partNumber,
InputStreamContainer blobPartStreamContainer,
Path fileLocation,
AtomicBoolean anyPartStreamFailed,
ActionListener<Integer> fileCompletionListener
) {
this.partNumber = partNumber;
this.blobPartStreamContainer = blobPartStreamContainer;
this.fileLocation = fileLocation;
this.anyPartStreamFailed = anyPartStreamFailed;
this.fileCompletionListener = fileCompletionListener;
}

@Override
public void run() {
// Ensures no writes to the file if any stream fails.
if (anyPartStreamFailed.get() == false) {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
try (InputStream inputStream = blobPartStreamContainer.getInputStream()) {
long streamOffset = blobPartStreamContainer.getOffset();
final byte[] buffer = new byte[BUFFER_SIZE];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset);
streamOffset += bytesRead;
}
public static void write(Path fileLocation, InputStreamContainer stream, UnaryOperator<InputStream> rateLimiter) throws IOException {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
try (InputStream inputStream = rateLimiter.apply(stream.getInputStream())) {
long streamOffset = stream.getOffset();
final byte[] buffer = new byte[BUFFER_SIZE];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset);
kotwanikunal marked this conversation as resolved.
Show resolved Hide resolved
streamOffset += bytesRead;
}
} catch (IOException e) {
processFailure(e);
return;
}
fileCompletionListener.onResponse(partNumber);
}
}

void processFailure(Exception e) {
try {
Files.deleteIfExists(fileLocation);
} catch (IOException ex) {
// Die silently
logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex);
}
if (anyPartStreamFailed.getAndSet(true) == false) {
fileCompletionListener.onFailure(e);
}
}
}
Loading