diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java index e67aaf8adeade4..10b545ad0d1ea4 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java @@ -340,7 +340,7 @@ private Path downloadInExecutor( throw new IOException(getRewriterBlockedAllUrlsMessage(originalUrls)); } - for (int attempt = 0; attempt <= retries; ++attempt) { + for (int attempt = 0; ; ++attempt) { try { downloader.download( rewrittenUrls, @@ -353,12 +353,12 @@ private Path downloadInExecutor( clientEnv, type); break; - } catch (ContentLengthMismatchException e) { - if (attempt == retries) { - throw e; - } } catch (InterruptedIOException e) { throw new InterruptedException(e.getMessage()); + } catch (IOException e) { + if (!shouldRetryDownload(e, attempt)) { + throw e; + } } } @@ -372,6 +372,24 @@ private Path downloadInExecutor( return destination; } + private boolean shouldRetryDownload(IOException e, int attempt) { + if (attempt >= retries) { + return false; + } + + if (e instanceof ContentLengthMismatchException) { + return true; + } + + for (var suppressed : e.getSuppressed()) { + if (suppressed instanceof ContentLengthMismatchException) { + return true; + } + } + + return false; + } + /** * Downloads the contents of one URL and reads it into a byte array. * @@ -447,8 +465,8 @@ public byte[] downloadAndReadOneUrlForBzlmod( } HttpDownloader httpDownloader = new HttpDownloader(); - byte[] content = null; - for (int attempt = 0; attempt <= retries; ++attempt) { + byte[] content; + for (int attempt = 0; ; ++attempt) { try { content = httpDownloader.downloadAndReadOneUrl( @@ -458,12 +476,12 @@ public byte[] downloadAndReadOneUrlForBzlmod( eventHandler, clientEnv); break; - } catch (ContentLengthMismatchException e) { - if (attempt == retries) { - throw e; - } } catch (InterruptedIOException e) { throw new InterruptedException(e.getMessage()); + } catch (IOException e) { + if (!shouldRetryDownload(e, attempt)) { + throw e; + } } } if (content == null) { diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java index 2a47760b36b7de..9d8199a5274e13 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java @@ -713,4 +713,49 @@ public void download_contentLengthMismatch_retries() throws Exception { String content = new String(ByteStreams.toByteArray(result.getInputStream()), UTF_8); assertThat(content).isEqualTo("content"); } + + @Test + public void download_contentLengthMismatchWithOtherErrors_retries() throws Exception { + Downloader downloader = mock(Downloader.class); + int retires = 5; + DownloadManager downloadManager = new DownloadManager(repositoryCache, downloader); + downloadManager.setRetries(retires); + AtomicInteger times = new AtomicInteger(0); + byte[] data = "content".getBytes(UTF_8); + doAnswer( + (Answer) + invocationOnMock -> { + if (times.getAndIncrement() < 3) { + IOException e = new IOException(); + e.addSuppressed(new ContentLengthMismatchException(0, data.length)); + e.addSuppressed(new IOException()); + throw e; + } + Path output = invocationOnMock.getArgument(5, Path.class); + try (OutputStream outputStream = output.getOutputStream()) { + ByteStreams.copy(new ByteArrayInputStream(data), outputStream); + } + + return null; + }) + .when(downloader) + .download(any(), any(), any(), any(), any(), any(), any(), any(), any()); + + Path result = + downloadManager.download( + ImmutableList.of(new URL("http://localhost")), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + "testCanonicalId", + Optional.empty(), + fs.getPath(workingDir.newFile().getAbsolutePath()), + eventHandler, + ImmutableMap.of(), + "testRepo"); + + assertThat(times.get()).isEqualTo(4); + String content = new String(result.getInputStream().readAllBytes(), UTF_8); + assertThat(content).isEqualTo("content"); + } }