Skip to content

Commit

Permalink
Add authorization code redirect request handler & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Jan 2, 2025
1 parent db9949c commit 2b3328f
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.client.core.auth.oauth;

import com.amazonaws.util.StringUtils;
import com.nimbusds.oauth2.sdk.id.State;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import net.snowflake.client.core.SFException;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URLEncodedUtils;

class AuthorizationCodeRedirectRequestHandler implements HttpHandler {

private static final SFLogger logger =
SFLoggerFactory.getLogger(AuthorizationCodeRedirectRequestHandler.class);

private final CompletableFuture<String> authorizationCodeFuture;
private final State expectedState;

AuthorizationCodeRedirectRequestHandler(
CompletableFuture<String> authorizationCodeFuture, State expectedState) {
this.authorizationCodeFuture = authorizationCodeFuture;
this.expectedState = expectedState;
}

@Override
public void handle(HttpExchange exchange) throws IOException {
Map<String, String> urlParams =
URLEncodedUtils.parse(exchange.getRequestURI(), StandardCharsets.UTF_8).stream()
.collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue));
String response = handleRedirectRequest(urlParams, authorizationCodeFuture, expectedState);
exchange.sendResponseHeaders(200, response.length());
exchange.getResponseBody().write(response.getBytes(StandardCharsets.UTF_8));
exchange.getResponseBody().close();
}

static String handleRedirectRequest(
Map<String, String> urlParams,
CompletableFuture<String> authorizationCodeFuture,
State expectedState) {
String response;
if (urlParams.containsKey("error")) {
response = "Authorization error: " + urlParams.get("error");
authorizationCodeFuture.completeExceptionally(
new SFException(
ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR,
String.format(
"Error during authorization: %s, %s",
urlParams.get("error"), urlParams.get("error_description"))));
} else if (!expectedState.getValue().equals(urlParams.get("state"))) {
authorizationCodeFuture.completeExceptionally(
new SFException(
ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR,
String.format(
"Invalid authorization request redirection state: %s, expected: %s",
urlParams.get("state"), expectedState.getValue())));
response = "Authorization error: invalid authorization request redirection state";
} else {
String authorizationCode = urlParams.get("code");
if (!StringUtils.isNullOrEmpty(authorizationCode)) {
logger.debug("Received authorization code on redirect URI");
response = "Authorization completed successfully.";
authorizationCodeFuture.complete(authorizationCode);
} else {
authorizationCodeFuture.completeExceptionally(
new SFException(
ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR,
String.format(
"Authorization code redirect URI server received request without authorization code; queryParams: %s",
urlParams)));
response = "Authorization error: authorization code has not been returned to the driver.";
}
}
return response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import net.snowflake.client.core.HttpUtil;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SFLoginInput;
Expand All @@ -40,8 +37,6 @@
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URLEncodedUtils;

@SnowflakeJdbcInternalApi
public class OAuthAuthorizationCodeAccessTokenProvider implements AccessTokenProvider {
Expand Down Expand Up @@ -151,55 +146,13 @@ private AuthorizationCode letUserAuthorize(

private static CompletableFuture<String> setupRedirectURIServerForAuthorizationCode(
HttpServer httpServer, State expectedState) {
CompletableFuture<String> accessTokenFuture = new CompletableFuture<>();
CompletableFuture<String> authorizationCodeFuture = new CompletableFuture<>();
httpServer.createContext(
REDIRECT_URI_ENDPOINT,
exchange -> {
Map<String, String> urlParams =
URLEncodedUtils.parse(exchange.getRequestURI(), StandardCharsets.UTF_8).stream()
.collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue));
String response = handleRedirectRequest(urlParams, accessTokenFuture, expectedState);
exchange.sendResponseHeaders(200, response.length());
exchange.getResponseBody().write(response.getBytes(StandardCharsets.UTF_8));
exchange.getResponseBody().close();
});
new AuthorizationCodeRedirectRequestHandler(authorizationCodeFuture, expectedState));
logger.debug("Starting OAuth redirect URI server @ {}", httpServer.getAddress());
httpServer.start();
return accessTokenFuture;
}

private static String handleRedirectRequest(
Map<String, String> urlParams,
CompletableFuture<String> accessTokenFuture,
State expectedState) {
String response;
if (urlParams.containsKey("error")) {
response = "Authorization error: " + urlParams.get("error");
accessTokenFuture.completeExceptionally(
new SFException(
ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR,
String.format(
"Error during authorization: %s, %s",
urlParams.get("error"), urlParams.get("error_description"))));
} else if (!expectedState.getValue().equals(urlParams.get("state"))) {
accessTokenFuture.completeExceptionally(
new SFException(
ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR,
String.format(
"Invalid authorization request redirection state: %s, expected: %s",
urlParams.get("state"), expectedState.getValue())));
response = "Authorization error: invalid authorization request redirection state";
} else {
String authorizationCode = urlParams.get("code");
if (!StringUtils.isNullOrEmpty(authorizationCode)) {
logger.debug("Received authorization code on redirect URI");
response = "Authorization completed successfully.";
accessTokenFuture.complete(authorizationCode);
} else {
response = "Authorization error: authorization code has not been returned to the driver.";
}
}
return response;
return authorizationCodeFuture;
}

private static HttpServer createHttpServer(SFOauthLoginInput loginInput) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public static void setUp() {
public static void tearDown() {
CredentialManager.resetSecureStorageManager();
}

@Test
public void shouldCreateHostBasedOnExternalIdpUrl() throws SFException {
SFLoginInput loginInput = createLoginInputWithExternalOAuth();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.client.core.auth.oauth;

import com.nimbusds.oauth2.sdk.id.State;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import net.snowflake.client.core.SFException;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.mockito.Mockito;

public class AuthorizationCodeRedirectRequestHandlerTest {

CompletableFuture<String> authorizationCodeFutureMock = Mockito.mock(CompletableFuture.class);

@Test
public void shouldReturnSuccessResponse() {
Map<String, String> params = new HashMap<>();
params.put("code", "some authorization code");
params.put("state", "abc");

String response =
AuthorizationCodeRedirectRequestHandler.handleRedirectRequest(
params, authorizationCodeFutureMock, new State("abc"));
Mockito.verify(authorizationCodeFutureMock).complete("some authorization code");
Assertions.assertEquals("Authorization completed successfully.", response);
}

@Test
public void shouldReturnRandomErrorResponse() {
Map<String, String> params = new HashMap<>();
params.put("error", "some random error");

String response =
AuthorizationCodeRedirectRequestHandler.handleRedirectRequest(
params, authorizationCodeFutureMock, new State("abc"));
Mockito.verify(authorizationCodeFutureMock)
.completeExceptionally(Mockito.any(SFException.class));
Assertions.assertEquals("Authorization error: some random error", response);
}

@Test
public void shouldReturnInvalidStateErrorResponse() {
Map<String, String> params = new HashMap<>();
params.put("authorization_code", "some authorization code");
params.put("state", "invalid state");

String response =
AuthorizationCodeRedirectRequestHandler.handleRedirectRequest(
params, authorizationCodeFutureMock, new State("abc"));
Mockito.verify(authorizationCodeFutureMock)
.completeExceptionally(Mockito.any(SFException.class));
Assertions.assertEquals(
"Authorization error: invalid authorization request redirection state", response);
}

@Test
public void shouldReturnAuthorizationCodeAbsentErrorResponse() {
Map<String, String> params = new HashMap<>();
params.put("state", "abc");
params.put("some-random-param", "some-value");

String response =
AuthorizationCodeRedirectRequestHandler.handleRedirectRequest(
params, authorizationCodeFutureMock, new State("abc"));
Mockito.verify(authorizationCodeFutureMock)
.completeExceptionally(Mockito.any(SFException.class));
Assertions.assertEquals(
"Authorization error: authorization code has not been returned to the driver.", response);
}
}

0 comments on commit 2b3328f

Please sign in to comment.