diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java index 4338bddb3fbe7..82a5687a49d4c 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -10,12 +10,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; +import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.StandardCopyOption; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -26,8 +31,11 @@ @InternalApi public class ReadContextListener implements ActionListener { + private static final String DOWNLOAD_PREFIX = "download."; private final String fileName; private final Path fileLocation; + private final String tmpFileName; + private final Path tmpFileLocation; private final ThreadPool threadPool; private final ActionListener completionListener; private static final Logger logger = LogManager.getLogger(ReadContextListener.class); @@ -37,6 +45,8 @@ public ReadContextListener(String fileName, Path fileLocation, ThreadPool thread this.fileLocation = fileLocation; this.threadPool = threadPool; this.completionListener = completionListener; + this.tmpFileName = DOWNLOAD_PREFIX + UUIDs.randomBase64UUID() + "." + fileName; + this.tmpFileLocation = fileLocation.getParent().resolve(tmpFileName); } @Override @@ -44,13 +54,13 @@ public void onResponse(ReadContext readContext) { logger.trace("Streams received for blob {}", fileName); final int numParts = readContext.getNumberOfParts(); final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); + final FileCompletionListener fileCompletionListener = getFileCompletionListener(numParts); for (int partNumber = 0; partNumber < numParts; partNumber++) { FilePartWriter filePartWriter = new FilePartWriter( partNumber, readContext.getPartStreams().get(partNumber), - fileLocation, + tmpFileLocation, anyPartStreamFailed, fileCompletionListener ); @@ -58,6 +68,34 @@ public void onResponse(ReadContext readContext) { } } + private FileCompletionListener getFileCompletionListener(int numParts) { + ActionListener wrappedListener = ActionListener.wrap(response -> { + logger.trace(() -> new ParameterizedMessage("renaming temp file [{}] to [{}]", tmpFileLocation, fileLocation)); + try { + Files.move(tmpFileLocation, fileLocation, StandardCopyOption.ATOMIC_MOVE); + completionListener.onResponse(fileName); + } catch (IOException e) { + logger.error("Unable to rename temp file + " + tmpFileLocation, e); + completionListener.onFailure(e); + } + }, e -> { + try { + Files.deleteIfExists(tmpFileLocation); + } catch (IOException ex) { + logger.warn("Unable to clean temp file {}", tmpFileLocation); + } + completionListener.onFailure(e); + }); + return new FileCompletionListener(numParts, tmpFileName, wrappedListener); + } + + /* + * For Tests + */ + Path getTmpFileLocation() { + return tmpFileLocation; + } + @Override public void onFailure(Exception e) { completionListener.onFailure(e); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java index 21b7b47390a9b..f4955470c8ec9 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -17,9 +17,8 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.junit.AfterClass; +import org.junit.After; import org.junit.Before; -import org.junit.BeforeClass; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -42,18 +41,18 @@ public class ReadContextListenerTests extends OpenSearchTestCase { private Path path; - private static ThreadPool threadPool; + private ThreadPool threadPool; private static final int NUMBER_OF_PARTS = 5; private static final int PART_SIZE = 10; private static final String TEST_SEGMENT_FILE = "test_segment_file"; - @BeforeClass - public static void setup() { + @Before + public void setup() { threadPool = new TestThreadPool(ReadContextListenerTests.class.getName()); } - @AfterClass - public static void cleanup() { + @After + public void cleanup() { threadPool.shutdown(); } @@ -107,6 +106,7 @@ public int available() { countDownLatch.await(); assertFalse(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); } public void testReadContextListenerException() { @@ -119,6 +119,75 @@ public void testReadContextListenerException() { assertEquals(exception, listener.getException()); } + public void testWriteToTempFile() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ByteArrayInputStream assertingStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)) { + @Override + public int read(byte[] b) throws IOException { + assertTrue("parts written to temp file location", Files.exists(readContextListener.getTmpFileLocation())); + return super.read(b); + } + }; + blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(assertingStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS + 1)); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS + 1, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + assertTrue(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + + public void testWriteToTempFile_alreadyExists_replacesFile() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + // create an empty file at location. + Files.createFile(fileLocation); + assertEquals(0, Files.readAllBytes(fileLocation).length); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + assertTrue(Files.exists(fileLocation)); + assertEquals(50, Files.readAllBytes(fileLocation).length); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + + /** + * Simulate a node drop by invoking shutDownNow on the thread pool while writing a chunk. + */ + public void testTerminateThreadsWhileWritingParts() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ByteArrayInputStream assertingStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)) { + @Override + public int read(byte[] b) throws IOException { + assertTrue("parts written to temp file location", Files.exists(readContextListener.getTmpFileLocation())); + threadPool.shutdownNow(); + return super.read(b); + } + }; + blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(assertingStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS + 1)); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS + 1, blobPartStreams, null); + readContextListener.onResponse(readContext); + countDownLatch.await(); + assertTrue(terminate(threadPool)); + assertFalse(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + private List initializeBlobPartStreams() { List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) {