diff --git a/src/main/java/net/snowflake/client/core/SFBaseSession.java b/src/main/java/net/snowflake/client/core/SFBaseSession.java index 1613b51a3..134a23ad2 100644 --- a/src/main/java/net/snowflake/client/core/SFBaseSession.java +++ b/src/main/java/net/snowflake/client/core/SFBaseSession.java @@ -840,6 +840,8 @@ public SFConnectionHandler getSfConnectionHandler() { public abstract int getAuthTimeout(); + public abstract int getMaxHttpRetries(); + public abstract SnowflakeConnectString getSnowflakeConnectionString(); public abstract boolean isAsyncSession(); diff --git a/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java b/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java index 804f516b5..33ab76879 100644 --- a/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java +++ b/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java @@ -54,6 +54,7 @@ public SFBaseSession getSession() { private final int networkTimeoutInMilli; private final int authTimeout; private final int socketTimeout; + private final int maxHttpRetries; private final SFBaseSession session; public ChunkDownloadContext( @@ -65,6 +66,7 @@ public ChunkDownloadContext( int networkTimeoutInMilli, int authTimeout, int socketTimeout, + int maxHttpRetries, SFBaseSession session) { this.chunkDownloader = chunkDownloader; this.resultChunk = resultChunk; @@ -74,6 +76,7 @@ public ChunkDownloadContext( this.networkTimeoutInMilli = networkTimeoutInMilli; this.authTimeout = authTimeout; this.socketTimeout = socketTimeout; + this.maxHttpRetries = maxHttpRetries; this.session = session; } } diff --git a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java index 7fbf6502c..7f9280dc5 100644 --- a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java @@ -132,6 +132,7 @@ else if (context.getQrmk() != null) { false, // no retry parameters in url false, // no request_guid true, // retry on HTTP403 for AWS S3 + true, // no retry on http request new ExecTimeTelemetryData()); SnowflakeResultSetSerializableV1.logger.debug( diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java index 030eac3c6..00c94b2ab 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java @@ -85,6 +85,7 @@ public class SnowflakeChunkDownloader implements ChunkDownloader { private final int socketTimeout; + private final int maxHttpRetries; private long memoryLimit; // the current memory usage across JVM @@ -120,7 +121,6 @@ static long getCurrentMemoryUsage() { private static final long downloadedConditionTimeoutInSeconds = HttpUtil.getDownloadedConditionTimeoutInSeconds(); - private static final int MAX_NUM_OF_RETRY = 10; private static final int MAX_RETRY_JITTER = 1000; // milliseconds // Only controls the max retry number when prefetch runs out of memory @@ -194,6 +194,7 @@ public SnowflakeChunkDownloader(SnowflakeResultSetSerializableV1 resultSetSerial this.networkTimeoutInMilli = resultSetSerializable.getNetworkTimeoutInMilli(); this.authTimeout = resultSetSerializable.getAuthTimeout(); this.socketTimeout = resultSetSerializable.getSocketTimeout(); + this.maxHttpRetries = resultSetSerializable.getMaxHttpRetries(); this.prefetchSlots = resultSetSerializable.getResultPrefetchThreads() * 2; this.queryResultFormat = resultSetSerializable.getQueryResultFormat(); logger.debug("qrmk = {}", this.qrmk); @@ -406,6 +407,7 @@ private void startNextDownloaders() throws SnowflakeSQLException { networkTimeoutInMilli, authTimeout, socketTimeout, + maxHttpRetries, this.session)); downloaderFutures.put(nextChunkToDownload, downloaderFuture); // increment next chunk to download @@ -634,7 +636,7 @@ public SnowflakeResultChunk getNextChunkToConsume() private void waitForChunkReady(SnowflakeResultChunk currentChunk) throws InterruptedException { int retry = 0; long startTime = System.currentTimeMillis(); - while (currentChunk.getDownloadState() != DownloadState.SUCCESS && retry < MAX_NUM_OF_RETRY) { + while (currentChunk.getDownloadState() != DownloadState.SUCCESS && retry < maxHttpRetries) { logger.debug( "Thread {} is waiting for #chunk{} to be ready, current" + "chunk state is: {}, retry={}", Thread.currentThread().getId(), @@ -704,6 +706,7 @@ private void waitForChunkReady(SnowflakeResultChunk currentChunk) throws Interru networkTimeoutInMilli, authTimeout, socketTimeout, + maxHttpRetries, session)); downloaderFutures.put(nextChunkToConsume, downloaderFuture); // Only when prefetch fails due to internal memory limitation, nextChunkToDownload @@ -715,7 +718,7 @@ private void waitForChunkReady(SnowflakeResultChunk currentChunk) throws Interru } if (currentChunk.getDownloadState() == DownloadState.SUCCESS) { logger.debug("ready to consume #chunk{}, succeed retry={}", nextChunkToConsume, retry); - } else if (retry >= MAX_NUM_OF_RETRY) { + } else if (retry >= maxHttpRetries) { // stop retrying and report failure currentChunk.setDownloadState(DownloadState.FAILURE); currentChunk.setDownloadError( @@ -859,6 +862,7 @@ private static Callable getDownloadChunkCallable( final int networkTimeoutInMilli, final int authTimeout, final int socketTimeout, + final int maxHttpRetries, final SFBaseSession session) { ChunkDownloadContext downloadContext = new ChunkDownloadContext( @@ -870,6 +874,7 @@ private static Callable getDownloadChunkCallable( networkTimeoutInMilli, authTimeout, socketTimeout, + maxHttpRetries, session); return new Callable() { diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java index 2fd529d5e..27d24b44d 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java @@ -119,6 +119,7 @@ public String toString() { int networkTimeoutInMilli; int authTimeout; int socketTimeout; + int maxHttpRetries; boolean isResultColumnCaseInsensitive; int resultSetType; int resultSetConcurrency; @@ -195,6 +196,7 @@ private SnowflakeResultSetSerializableV1(SnowflakeResultSetSerializableV1 toCopy this.networkTimeoutInMilli = toCopy.networkTimeoutInMilli; this.authTimeout = toCopy.authTimeout; this.socketTimeout = toCopy.socketTimeout; + this.maxHttpRetries = toCopy.maxHttpRetries; this.isResultColumnCaseInsensitive = toCopy.isResultColumnCaseInsensitive; this.resultSetType = toCopy.resultSetType; this.resultSetConcurrency = toCopy.resultSetConcurrency; @@ -317,6 +319,10 @@ public int getSocketTimeout() { return socketTimeout; } + public int getMaxHttpRetries() { + return maxHttpRetries; + } + public int getResultPrefetchThreads() { return resultPrefetchThreads; } @@ -662,6 +668,7 @@ public static SnowflakeResultSetSerializableV1 create( resultSetSerializable.snowflakeConnectionString = sfSession.getSnowflakeConnectionString(); resultSetSerializable.networkTimeoutInMilli = sfSession.getNetworkTimeoutInMilli(); resultSetSerializable.authTimeout = 0; + resultSetSerializable.maxHttpRetries = sfSession.getMaxHttpRetries(); resultSetSerializable.isResultColumnCaseInsensitive = sfSession.isResultColumnCaseInsensitive(); resultSetSerializable.treatNTZAsUTC = sfSession.getTreatNTZAsUTC(); resultSetSerializable.formatDateWithTimezone = sfSession.getFormatDateWithTimezone(); diff --git a/src/test/java/net/snowflake/client/jdbc/ChunkDownloaderS3RetryUrlLatestIT.java b/src/test/java/net/snowflake/client/jdbc/ChunkDownloaderS3RetryUrlLatestIT.java index 9cc5b67b7..1961ed14f 100644 --- a/src/test/java/net/snowflake/client/jdbc/ChunkDownloaderS3RetryUrlLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ChunkDownloaderS3RetryUrlLatestIT.java @@ -53,7 +53,7 @@ public void setup() throws SQLException, InterruptedException { ((SnowflakeResultSetSerializableV1) resultSetSerializable).getChunkHeadersMap(); sfContext = new ChunkDownloadContext( - downloader, chunk, qrmk, 0, chunkHeadersMap, 0, 0, 0, sfBaseSession); + downloader, chunk, qrmk, 0, chunkHeadersMap, 0, 0, 0, 7, sfBaseSession); } /** diff --git a/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java b/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java index 09a0dfb9d..f3f7d0405 100644 --- a/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java +++ b/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java @@ -711,6 +711,14 @@ public int getAuthTimeout() { return 0; } + /** + * @return + */ + @Override + public int getMaxHttpRetries() { + return 7; + } + public SnowflakeConnectString getSnowflakeConnectionString() { return null; } diff --git a/src/test/java/net/snowflake/client/jdbc/SnowflakeChunkDownloaderLatestIT.java b/src/test/java/net/snowflake/client/jdbc/SnowflakeChunkDownloaderLatestIT.java new file mode 100644 index 000000000..76c3b0466 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/SnowflakeChunkDownloaderLatestIT.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2022 Snowflake Computing Inc. All right reserved. + */ +package net.snowflake.client.jdbc; + +import static org.junit.Assert.assertTrue; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import java.util.Properties; +import org.junit.Test; +import org.mockito.Mockito; + +public class SnowflakeChunkDownloaderLatestIT extends BaseJDBCTest { + + /** + * Tests that the chunk downloader uses the maxHttpRetries and doesn't enter and infinite loop of + * retries. + * + * @throws SQLException + * @throws InterruptedException + */ + @Test + public void testChunkDownloaderRetry() throws SQLException, InterruptedException { + // set proxy to invalid host and bypass the snowflakecomputing.com domain + // this will cause connection issues to the internal stage on fetching + System.setProperty("https.proxyHost", "127.0.0.1"); + System.setProperty("https.proxyPort", "8080"); + System.setProperty("http.nonProxyHosts", "*snowflakecomputing.com"); + + // set max retries + Properties properties = new Properties(); + properties.put("maxHttpRetries", 2); + + SnowflakeChunkDownloader snowflakeChunkDownloaderSpy = null; + + try (Connection connection = getConnection(properties)) { + Statement statement = connection.createStatement(); + // execute a query that will require chunk downloading + ResultSet resultSet = + statement.executeQuery( + "select seq8(), randstr(1000, random()) from table(generator(rowcount => 10000))"); + List resultSetSerializables = + ((SnowflakeResultSet) resultSet).getResultSetSerializables(100 * 1024 * 1024); + SnowflakeResultSetSerializable resultSetSerializable = resultSetSerializables.get(0); + SnowflakeChunkDownloader downloader = + new SnowflakeChunkDownloader((SnowflakeResultSetSerializableV1) resultSetSerializable); + snowflakeChunkDownloaderSpy = Mockito.spy(downloader); + snowflakeChunkDownloaderSpy.getNextChunkToConsume(); + } catch (SnowflakeSQLException exception) { + // verify that request was retried twice before reaching max retries + Mockito.verify(snowflakeChunkDownloaderSpy, Mockito.times(2)).getResultStreamProvider(); + assertTrue(exception.getMessage().contains("Max retry reached for the download of #chunk0")); + assertTrue(exception.getMessage().contains("retry=2")); + } + } +}