Skip to content

Commit

Permalink
Update multipart download path to first write to temp files (opensear…
Browse files Browse the repository at this point in the history
…ch-project#10347)

* Update multipart download path to write to temp files.

This change updates ReadContextListener to first write parts to a temp location
until all parts have been received.

Signed-off-by: Marc Handalian <handalm@amazon.com>

* Suppress forbidden IOUtils.fsync

Signed-off-by: Marc Handalian <handalm@amazon.com>

* Remove unnecessary logging format

Signed-off-by: Marc Handalian <handalm@amazon.com>

* sync directory after file rename

Signed-off-by: Marc Handalian <handalm@amazon.com>

* Remove flaky threadpool terminate test

Signed-off-by: Marc Handalian <handalm@amazon.com>

---------

Signed-off-by: Marc Handalian <handalm@amazon.com>
  • Loading branch information
mch2 authored and vikasvb90 committed Oct 10, 2023
1 parent 35d499d commit 47f396c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.IOUtils;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.UUIDs;
import org.opensearch.common.annotation.InternalApi;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.core.action.ActionListener;
Expand All @@ -20,6 +23,8 @@
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.Collection;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
Expand All @@ -33,9 +38,11 @@
@InternalApi
public class ReadContextListener implements ActionListener<ReadContext> {
private static final Logger logger = LogManager.getLogger(ReadContextListener.class);

private static final String DOWNLOAD_PREFIX = "download.";
private final String blobName;
private final Path fileLocation;
private final String tmpFileName;
private final Path tmpFileLocation;
private final ActionListener<String> completionListener;
private final ThreadPool threadPool;
private final UnaryOperator<InputStream> rateLimiter;
Expand All @@ -55,22 +62,21 @@ public ReadContextListener(
this.threadPool = threadPool;
this.rateLimiter = rateLimiter;
this.maxConcurrentStreams = maxConcurrentStreams;
this.tmpFileName = DOWNLOAD_PREFIX + UUIDs.randomBase64UUID() + "." + blobName;
this.tmpFileLocation = fileLocation.getParent().resolve(tmpFileName);
}

@Override
public void onResponse(ReadContext readContext) {
logger.debug("Received {} parts for blob {}", readContext.getNumberOfParts(), blobName);
final int numParts = readContext.getNumberOfParts();
final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(false);
final GroupedActionListener<String> groupedListener = new GroupedActionListener<>(
ActionListener.wrap(r -> completionListener.onResponse(blobName), completionListener::onFailure),
numParts
);
final GroupedActionListener<String> groupedListener = new GroupedActionListener<>(getFileCompletionListener(), numParts);
final Queue<ReadContext.StreamPartCreator> queue = new ConcurrentLinkedQueue<>(readContext.getPartStreams());
final StreamPartProcessor processor = new StreamPartProcessor(
queue,
anyPartStreamFailed,
fileLocation,
tmpFileLocation,
groupedListener,
threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY),
rateLimiter
Expand All @@ -80,6 +86,37 @@ public void onResponse(ReadContext readContext) {
}
}

@SuppressForbidden(reason = "need to fsync once all parts received")
private ActionListener<Collection<String>> getFileCompletionListener() {
return ActionListener.wrap(response -> {
logger.trace("renaming temp file [{}] to [{}]", tmpFileLocation, fileLocation);
try {
IOUtils.fsync(tmpFileLocation, false);
Files.move(tmpFileLocation, fileLocation, StandardCopyOption.ATOMIC_MOVE);
// sync parent dir metadata
IOUtils.fsync(fileLocation.getParent(), true);
completionListener.onResponse(blobName);
} catch (IOException e) {
logger.error("Unable to rename temp file + " + tmpFileLocation, e);
completionListener.onFailure(e);
}
}, e -> {
try {
Files.deleteIfExists(tmpFileLocation);
} catch (IOException ex) {
logger.warn("Unable to clean temp file {}", tmpFileLocation);
}
completionListener.onFailure(e);
});
}

/*
* For Tests
*/
Path getTmpFileLocation() {
return tmpFileLocation;
}

@Override
public void onFailure(Exception e) {
completionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ public int available() {

countDownLatch.await();
assertFalse(Files.exists(fileLocation));
assertFalse(Files.exists(readContextListener.getTmpFileLocation()));
}

public void testReadContextListenerException() {
Expand All @@ -149,6 +150,68 @@ public void testReadContextListenerException() {
assertEquals(exception, listener.getException());
}

public void testWriteToTempFile() throws Exception {
final String fileName = UUID.randomUUID().toString();
Path fileLocation = path.resolve(fileName);
List<ReadContext.StreamPartCreator> blobPartStreams = initializeBlobPartStreams();
CountDownLatch countDownLatch = new CountDownLatch(1);
ActionListener<String> completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch);
ReadContextListener readContextListener = new ReadContextListener(
TEST_SEGMENT_FILE,
fileLocation,
completionListener,
threadPool,
UnaryOperator.identity(),
MAX_CONCURRENT_STREAMS
);
ByteArrayInputStream assertingStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)) {
@Override
public int read(byte[] b) throws IOException {
assertTrue("parts written to temp file location", Files.exists(readContextListener.getTmpFileLocation()));
return super.read(b);
}
};
blobPartStreams.add(
NUMBER_OF_PARTS,
() -> CompletableFuture.supplyAsync(
() -> new InputStreamContainer(assertingStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS),
threadPool.generic()
)
);
ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS + 1, blobPartStreams, null);
readContextListener.onResponse(readContext);

countDownLatch.await();
assertTrue(Files.exists(fileLocation));
assertFalse(Files.exists(readContextListener.getTmpFileLocation()));
}

public void testWriteToTempFile_alreadyExists_replacesFile() throws Exception {
final String fileName = UUID.randomUUID().toString();
Path fileLocation = path.resolve(fileName);
// create an empty file at location.
Files.createFile(fileLocation);
assertEquals(0, Files.readAllBytes(fileLocation).length);
List<ReadContext.StreamPartCreator> blobPartStreams = initializeBlobPartStreams();
CountDownLatch countDownLatch = new CountDownLatch(1);
ActionListener<String> completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch);
ReadContextListener readContextListener = new ReadContextListener(
TEST_SEGMENT_FILE,
fileLocation,
completionListener,
threadPool,
UnaryOperator.identity(),
MAX_CONCURRENT_STREAMS
);
ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null);
readContextListener.onResponse(readContext);

countDownLatch.await();
assertTrue(Files.exists(fileLocation));
assertEquals(50, Files.readAllBytes(fileLocation).length);
assertFalse(Files.exists(readContextListener.getTmpFileLocation()));
}

private List<ReadContext.StreamPartCreator> initializeBlobPartStreams() {
List<ReadContext.StreamPartCreator> blobPartStreams = new ArrayList<>();
for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) {
Expand Down

0 comments on commit 47f396c

Please sign in to comment.