diff --git a/server/src/main/java/org/opensearch/index/store/remote/utils/TransferManager.java b/server/src/main/java/org/opensearch/index/store/remote/utils/TransferManager.java index 2bb004028d015..cea397ccd0333 100644 --- a/server/src/main/java/org/opensearch/index/store/remote/utils/TransferManager.java +++ b/server/src/main/java/org/opensearch/index/store/remote/utils/TransferManager.java @@ -20,10 +20,14 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.io.UncheckedIOException; import java.nio.file.Files; import java.nio.file.Path; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; /** * This acts as entry point to fetch {@link BlobFetchRequest} and return actual {@link IndexInput}. Utilizes the BlobContainer interface to @@ -34,6 +38,7 @@ public class TransferManager { private static final Logger logger = LogManager.getLogger(TransferManager.class); + private final ConcurrentMap latchMap = new ConcurrentHashMap<>(); private final BlobContainer blobContainer; private final FileCache fileCache; @@ -48,41 +53,67 @@ public TransferManager(final BlobContainer blobContainer, final FileCache fileCa * @return future of IndexInput augmented with internal caching maintenance tasks */ public IndexInput fetchBlob(BlobFetchRequest blobFetchRequest) throws IOException { - final Path key = blobFetchRequest.getFilePath(); - // We need to do a privileged action here in order to fetch from remote // and write to the local file cache in case this is invoked as a side // effect of a plugin (such as a scripted search) that doesn't have the // necessary permissions. - final IndexInput origin = AccessController.doPrivileged( - (PrivilegedAction) () -> fileCache.compute(key, (path, cachedIndexInput) -> { - if (cachedIndexInput == null) { - try { - return new FileCachedIndexInput(fileCache, blobFetchRequest.getFilePath(), downloadBlockLocally(blobFetchRequest)); - } catch (IOException e) { - logger.warn("Failed to download " + blobFetchRequest.getFilePath(), e); - return null; - } - } else { - if (cachedIndexInput.isClosed()) { - // if it's already in the file cache, but closed, open it and replace the original one - try { - final IndexInput luceneIndexInput = blobFetchRequest.getDirectory() - .openInput(blobFetchRequest.getFileName(), IOContext.READ); - return new FileCachedIndexInput(fileCache, blobFetchRequest.getFilePath(), luceneIndexInput); - } catch (IOException e) { - logger.warn("Failed to open existing file for " + blobFetchRequest.getFilePath(), e); - return null; - } - } - // already in the cache and ready to be used (open) - return cachedIndexInput; + try { + return AccessController.doPrivileged((PrivilegedAction) () -> fetchBlobInternal(blobFetchRequest)); + } catch (UncheckedIOException e) { + throw e.getCause(); + } + } + + private IndexInput fetchBlobInternal(BlobFetchRequest blobFetchRequest) { + final Path key = blobFetchRequest.getFilePath(); + // check if the origin is already in block cache + IndexInput origin = fileCache.compute(key, (path, cachedIndexInput) -> { + if (cachedIndexInput != null && cachedIndexInput.isClosed()) { + // if it's already in the file cache, but closed, open it and replace the original one + try { + IndexInput luceneIndexInput = blobFetchRequest.getDirectory().openInput(blobFetchRequest.getFileName(), IOContext.READ); + return new FileCachedIndexInput(fileCache, blobFetchRequest.getFilePath(), luceneIndexInput); + } catch (IOException ioe) { + logger.warn("Open index input " + blobFetchRequest.getFilePath() + " got error ", ioe); + // open failed so return null to download the file again + return null; } - }) - ); + + } + // already in the cache and ready to be used (open) + return cachedIndexInput; + }); if (origin == null) { - throw new IOException("Failed to create IndexInput for " + blobFetchRequest.getFileName()); + final CountDownLatch existingLatch = latchMap.putIfAbsent(key, new CountDownLatch(1)); + if (existingLatch != null) { + // Another thread is downloading the same resource. Wait for it + // to complete then make a recursive call to fetch it from the + // cache. + try { + existingLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException("Interrupted while waiting on block download at " + key, e); + } + return fetchBlobInternal(blobFetchRequest); + } else { + // Origin is not in file cache, download origin and put in cache + // We've effectively taken a lock for this key by inserting a + // latch into the concurrent map, so we must be sure to remove it + // and count it down before leaving. + try { + IndexInput downloaded = downloadBlockLocally(blobFetchRequest); + FileCachedIndexInput newOrigin = + new FileCachedIndexInput(fileCache, blobFetchRequest.getFilePath(), downloaded); + fileCache.put(key, newOrigin); + origin = newOrigin; + } catch (IOException e) { + throw new UncheckedIOException(e); + } finally { + latchMap.remove(key).countDown(); + } + } } // Origin was either retrieved from the cache or newly added, either diff --git a/server/src/test/java/org/opensearch/index/store/remote/utils/TransferManagerTests.java b/server/src/test/java/org/opensearch/index/store/remote/utils/TransferManagerTests.java index 804101038fbed..e0c0086412e7c 100644 --- a/server/src/test/java/org/opensearch/index/store/remote/utils/TransferManagerTests.java +++ b/server/src/test/java/org/opensearch/index/store/remote/utils/TransferManagerTests.java @@ -12,6 +12,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -178,6 +179,45 @@ public void testDownloadFails() throws Exception { MatcherAssert.assertThat(fileCache.usage().usage(), equalTo(0L)); } + public void testFetchesToDifferentBlobsDoNotBlockOnEachOther() throws Exception { + // Mock a call for a blob that will block until the latch is released, + // then start the fetch for that blob on a separate thread + final CountDownLatch latch = new CountDownLatch(1); + doAnswer(i -> { + latch.await(); + return new ByteArrayInputStream(createData()); + }).when(blobContainer).readBlob(eq("blocking-blob"), anyLong(), anyLong()); + final Thread blockingThread = new Thread( + () -> { + try { + transferManager.fetchBlob( + BlobFetchRequest.builder() + .blobName("blocking-blob") + .position(0) + .fileName("blocking-file") + .directory(directory) + .length(EIGHT_MB) + .build() + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + ); + blockingThread.start(); + + // Assert that a different blob can be fetched and will not block on the other blob + try (IndexInput i = fetchBlobWithName("file")) { + assertIndexInputIsFunctional(i); + MatcherAssert.assertThat(fileCache.usage().activeUsage(), equalTo((long) EIGHT_MB)); + } + + assertTrue(blockingThread.isAlive()); + latch.countDown(); + blockingThread.join(5_000); + assertFalse(blockingThread.isAlive()); + } + private IndexInput fetchBlobWithName(String blobname) throws IOException { return transferManager.fetchBlob( BlobFetchRequest.builder().blobName("blob").position(0).fileName(blobname).directory(directory).length(EIGHT_MB).build()