From 30c97520d0739cbe01451e207172a5525e8fe02d Mon Sep 17 00:00:00 2001 From: Marc Handalian Date: Tue, 3 Oct 2023 15:24:25 -0700 Subject: [PATCH] Update multipart download path to write to temp files. This change updates ReadContextListener to first write parts to a temp location until all parts have been received. Signed-off-by: Marc Handalian --- .../read/listener/ReadContextListener.java | 46 ++++++- .../listener/ReadContextListenerTests.java | 117 ++++++++++++++++-- 2 files changed, 150 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java index 2914fd0c440fa..c1756d5dddf02 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -10,7 +10,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.IOUtils; import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.core.action.ActionListener; @@ -20,6 +23,8 @@ import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Collection; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; @@ -33,9 +38,11 @@ @InternalApi public class ReadContextListener implements ActionListener { private static final Logger logger = LogManager.getLogger(ReadContextListener.class); - + private static final String DOWNLOAD_PREFIX = "download."; private final String blobName; private final Path fileLocation; + private final String tmpFileName; + private final Path tmpFileLocation; private final ActionListener completionListener; private final ThreadPool threadPool; private final UnaryOperator rateLimiter; @@ -55,6 +62,8 @@ public ReadContextListener( this.threadPool = threadPool; this.rateLimiter = rateLimiter; this.maxConcurrentStreams = maxConcurrentStreams; + this.tmpFileName = DOWNLOAD_PREFIX + UUIDs.randomBase64UUID() + "." + blobName; + this.tmpFileLocation = fileLocation.getParent().resolve(tmpFileName); } @Override @@ -62,15 +71,12 @@ public void onResponse(ReadContext readContext) { logger.debug("Received {} parts for blob {}", readContext.getNumberOfParts(), blobName); final int numParts = readContext.getNumberOfParts(); final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(false); - final GroupedActionListener groupedListener = new GroupedActionListener<>( - ActionListener.wrap(r -> completionListener.onResponse(blobName), completionListener::onFailure), - numParts - ); + final GroupedActionListener groupedListener = new GroupedActionListener<>(getFileCompletionListener(), numParts); final Queue queue = new ConcurrentLinkedQueue<>(readContext.getPartStreams()); final StreamPartProcessor processor = new StreamPartProcessor( queue, anyPartStreamFailed, - fileLocation, + tmpFileLocation, groupedListener, threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY), rateLimiter @@ -80,6 +86,34 @@ public void onResponse(ReadContext readContext) { } } + private ActionListener> getFileCompletionListener() { + return ActionListener.wrap(response -> { + logger.trace(() -> new ParameterizedMessage("renaming temp file [{}] to [{}]", tmpFileLocation, fileLocation)); + try { + IOUtils.fsync(tmpFileLocation, false); + Files.move(tmpFileLocation, fileLocation, StandardCopyOption.ATOMIC_MOVE); + completionListener.onResponse(blobName); + } catch (IOException e) { + logger.error("Unable to rename temp file + " + tmpFileLocation, e); + completionListener.onFailure(e); + } + }, e -> { + try { + Files.deleteIfExists(tmpFileLocation); + } catch (IOException ex) { + logger.warn("Unable to clean temp file {}", tmpFileLocation); + } + completionListener.onFailure(e); + }); + } + + /* + * For Tests + */ + Path getTmpFileLocation() { + return tmpFileLocation; + } + @Override public void onFailure(Exception e) { completionListener.onFailure(e); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java index 7e4c96cbadcda..3d4fab992d93e 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -17,9 +17,8 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.junit.AfterClass; +import org.junit.After; import org.junit.Before; -import org.junit.BeforeClass; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -31,6 +30,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.UnaryOperator; import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; @@ -44,19 +44,19 @@ public class ReadContextListenerTests extends OpenSearchTestCase { private Path path; - private static ThreadPool threadPool; + private ThreadPool threadPool; private static final int NUMBER_OF_PARTS = 5; private static final int PART_SIZE = 10; private static final String TEST_SEGMENT_FILE = "test_segment_file"; private static final int MAX_CONCURRENT_STREAMS = 10; - @BeforeClass - public static void setup() { + @Before + public void setup() { threadPool = new TestThreadPool(ReadContextListenerTests.class.getName()); } - @AfterClass - public static void cleanup() { + @After + public void cleanup() { threadPool.shutdown(); } @@ -130,6 +130,7 @@ public int available() { countDownLatch.await(); assertFalse(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); } public void testReadContextListenerException() { @@ -149,6 +150,108 @@ public void testReadContextListenerException() { assertEquals(exception, listener.getException()); } + public void testWriteToTempFile() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); + ByteArrayInputStream assertingStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)) { + @Override + public int read(byte[] b) throws IOException { + assertTrue("parts written to temp file location", Files.exists(readContextListener.getTmpFileLocation())); + return super.read(b); + } + }; + blobPartStreams.add( + NUMBER_OF_PARTS, + () -> CompletableFuture.supplyAsync( + () -> new InputStreamContainer(assertingStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS), + threadPool.generic() + ) + ); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS + 1, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + assertTrue(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + + public void testWriteToTempFile_alreadyExists_replacesFile() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + // create an empty file at location. + Files.createFile(fileLocation); + assertEquals(0, Files.readAllBytes(fileLocation).length); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + assertTrue(Files.exists(fileLocation)); + assertEquals(50, Files.readAllBytes(fileLocation).length); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + + /** + * Simulate a node drop by invoking shutDownNow on the thread pool while writing a part. + */ + public void testTerminateThreadsWhileWritingParts() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); + ByteArrayInputStream assertingStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)) { + @Override + public int read(byte[] b) throws IOException { + assertTrue("parts written to temp file location", Files.exists(readContextListener.getTmpFileLocation())); + threadPool.shutdownNow(); + return super.read(b); + } + }; + blobPartStreams.add( + NUMBER_OF_PARTS, + () -> CompletableFuture.supplyAsync( + () -> new InputStreamContainer(assertingStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS + 1), + threadPool.generic() + ) + ); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS + 1, blobPartStreams, null); + readContextListener.onResponse(readContext); + countDownLatch.await(5, TimeUnit.SECONDS); + assertTrue(terminate(threadPool)); + assertFalse(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + private List initializeBlobPartStreams() { List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) {