From 2b3328fc7a1ec961068016ad108aea3dd70af24a Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Thu, 2 Jan 2025 12:26:06 +0100 Subject: [PATCH] Add authorization code redirect request handler & tests --- ...thorizationCodeRedirectRequestHandler.java | 87 +++++++++++++++++++ ...hAuthorizationCodeAccessTokenProvider.java | 53 +---------- .../client/core/CredentialManagerTest.java | 1 + ...izationCodeRedirectRequestHandlerTest.java | 75 ++++++++++++++++ 4 files changed, 166 insertions(+), 50 deletions(-) create mode 100644 src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandler.java create mode 100644 src/test/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandlerTest.java diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandler.java b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandler.java new file mode 100644 index 000000000..e890ebce5 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandler.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core.auth.oauth; + +import com.amazonaws.util.StringUtils; +import com.nimbusds.oauth2.sdk.id.State; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import net.snowflake.client.core.SFException; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; +import org.apache.http.NameValuePair; +import org.apache.http.client.utils.URLEncodedUtils; + +class AuthorizationCodeRedirectRequestHandler implements HttpHandler { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(AuthorizationCodeRedirectRequestHandler.class); + + private final CompletableFuture authorizationCodeFuture; + private final State expectedState; + + AuthorizationCodeRedirectRequestHandler( + CompletableFuture authorizationCodeFuture, State expectedState) { + this.authorizationCodeFuture = authorizationCodeFuture; + this.expectedState = expectedState; + } + + @Override + public void handle(HttpExchange exchange) throws IOException { + Map urlParams = + URLEncodedUtils.parse(exchange.getRequestURI(), StandardCharsets.UTF_8).stream() + .collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue)); + String response = handleRedirectRequest(urlParams, authorizationCodeFuture, expectedState); + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes(StandardCharsets.UTF_8)); + exchange.getResponseBody().close(); + } + + static String handleRedirectRequest( + Map urlParams, + CompletableFuture authorizationCodeFuture, + State expectedState) { + String response; + if (urlParams.containsKey("error")) { + response = "Authorization error: " + urlParams.get("error"); + authorizationCodeFuture.completeExceptionally( + new SFException( + ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, + String.format( + "Error during authorization: %s, %s", + urlParams.get("error"), urlParams.get("error_description")))); + } else if (!expectedState.getValue().equals(urlParams.get("state"))) { + authorizationCodeFuture.completeExceptionally( + new SFException( + ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, + String.format( + "Invalid authorization request redirection state: %s, expected: %s", + urlParams.get("state"), expectedState.getValue()))); + response = "Authorization error: invalid authorization request redirection state"; + } else { + String authorizationCode = urlParams.get("code"); + if (!StringUtils.isNullOrEmpty(authorizationCode)) { + logger.debug("Received authorization code on redirect URI"); + response = "Authorization completed successfully."; + authorizationCodeFuture.complete(authorizationCode); + } else { + authorizationCodeFuture.completeExceptionally( + new SFException( + ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, + String.format( + "Authorization code redirect URI server received request without authorization code; queryParams: %s", + urlParams))); + response = "Authorization error: authorization code has not been returned to the driver."; + } + } + return response; + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java index 2b08303c6..9da067cab 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java @@ -26,12 +26,9 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; -import java.nio.charset.StandardCharsets; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.stream.Collectors; import net.snowflake.client.core.HttpUtil; import net.snowflake.client.core.SFException; import net.snowflake.client.core.SFLoginInput; @@ -40,8 +37,6 @@ import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; -import org.apache.http.NameValuePair; -import org.apache.http.client.utils.URLEncodedUtils; @SnowflakeJdbcInternalApi public class OAuthAuthorizationCodeAccessTokenProvider implements AccessTokenProvider { @@ -151,55 +146,13 @@ private AuthorizationCode letUserAuthorize( private static CompletableFuture setupRedirectURIServerForAuthorizationCode( HttpServer httpServer, State expectedState) { - CompletableFuture accessTokenFuture = new CompletableFuture<>(); + CompletableFuture authorizationCodeFuture = new CompletableFuture<>(); httpServer.createContext( REDIRECT_URI_ENDPOINT, - exchange -> { - Map urlParams = - URLEncodedUtils.parse(exchange.getRequestURI(), StandardCharsets.UTF_8).stream() - .collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue)); - String response = handleRedirectRequest(urlParams, accessTokenFuture, expectedState); - exchange.sendResponseHeaders(200, response.length()); - exchange.getResponseBody().write(response.getBytes(StandardCharsets.UTF_8)); - exchange.getResponseBody().close(); - }); + new AuthorizationCodeRedirectRequestHandler(authorizationCodeFuture, expectedState)); logger.debug("Starting OAuth redirect URI server @ {}", httpServer.getAddress()); httpServer.start(); - return accessTokenFuture; - } - - private static String handleRedirectRequest( - Map urlParams, - CompletableFuture accessTokenFuture, - State expectedState) { - String response; - if (urlParams.containsKey("error")) { - response = "Authorization error: " + urlParams.get("error"); - accessTokenFuture.completeExceptionally( - new SFException( - ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, - String.format( - "Error during authorization: %s, %s", - urlParams.get("error"), urlParams.get("error_description")))); - } else if (!expectedState.getValue().equals(urlParams.get("state"))) { - accessTokenFuture.completeExceptionally( - new SFException( - ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, - String.format( - "Invalid authorization request redirection state: %s, expected: %s", - urlParams.get("state"), expectedState.getValue()))); - response = "Authorization error: invalid authorization request redirection state"; - } else { - String authorizationCode = urlParams.get("code"); - if (!StringUtils.isNullOrEmpty(authorizationCode)) { - logger.debug("Received authorization code on redirect URI"); - response = "Authorization completed successfully."; - accessTokenFuture.complete(authorizationCode); - } else { - response = "Authorization error: authorization code has not been returned to the driver."; - } - } - return response; + return authorizationCodeFuture; } private static HttpServer createHttpServer(SFOauthLoginInput loginInput) throws IOException { diff --git a/src/test/java/net/snowflake/client/core/CredentialManagerTest.java b/src/test/java/net/snowflake/client/core/CredentialManagerTest.java index 796019ad9..7a41a7cf4 100644 --- a/src/test/java/net/snowflake/client/core/CredentialManagerTest.java +++ b/src/test/java/net/snowflake/client/core/CredentialManagerTest.java @@ -42,6 +42,7 @@ public static void setUp() { public static void tearDown() { CredentialManager.resetSecureStorageManager(); } + @Test public void shouldCreateHostBasedOnExternalIdpUrl() throws SFException { SFLoginInput loginInput = createLoginInputWithExternalOAuth(); diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandlerTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandlerTest.java new file mode 100644 index 000000000..5e80fb328 --- /dev/null +++ b/src/test/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeRedirectRequestHandlerTest.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core.auth.oauth; + +import com.nimbusds.oauth2.sdk.id.State; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import net.snowflake.client.core.SFException; +import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.mockito.Mockito; + +public class AuthorizationCodeRedirectRequestHandlerTest { + + CompletableFuture authorizationCodeFutureMock = Mockito.mock(CompletableFuture.class); + + @Test + public void shouldReturnSuccessResponse() { + Map params = new HashMap<>(); + params.put("code", "some authorization code"); + params.put("state", "abc"); + + String response = + AuthorizationCodeRedirectRequestHandler.handleRedirectRequest( + params, authorizationCodeFutureMock, new State("abc")); + Mockito.verify(authorizationCodeFutureMock).complete("some authorization code"); + Assertions.assertEquals("Authorization completed successfully.", response); + } + + @Test + public void shouldReturnRandomErrorResponse() { + Map params = new HashMap<>(); + params.put("error", "some random error"); + + String response = + AuthorizationCodeRedirectRequestHandler.handleRedirectRequest( + params, authorizationCodeFutureMock, new State("abc")); + Mockito.verify(authorizationCodeFutureMock) + .completeExceptionally(Mockito.any(SFException.class)); + Assertions.assertEquals("Authorization error: some random error", response); + } + + @Test + public void shouldReturnInvalidStateErrorResponse() { + Map params = new HashMap<>(); + params.put("authorization_code", "some authorization code"); + params.put("state", "invalid state"); + + String response = + AuthorizationCodeRedirectRequestHandler.handleRedirectRequest( + params, authorizationCodeFutureMock, new State("abc")); + Mockito.verify(authorizationCodeFutureMock) + .completeExceptionally(Mockito.any(SFException.class)); + Assertions.assertEquals( + "Authorization error: invalid authorization request redirection state", response); + } + + @Test + public void shouldReturnAuthorizationCodeAbsentErrorResponse() { + Map params = new HashMap<>(); + params.put("state", "abc"); + params.put("some-random-param", "some-value"); + + String response = + AuthorizationCodeRedirectRequestHandler.handleRedirectRequest( + params, authorizationCodeFutureMock, new State("abc")); + Mockito.verify(authorizationCodeFutureMock) + .completeExceptionally(Mockito.any(SFException.class)); + Assertions.assertEquals( + "Authorization error: authorization code has not been returned to the driver.", response); + } +}