Skip to content

Commit

Permalink
Merge pull request microsoft#6 from lilgreenbird/msal
Browse files Browse the repository at this point in the history
Msal
  • Loading branch information
srnagar authored Oct 6, 2020
2 parents ca36f34 + 2bf012b commit 9e38cc3
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class KeyVaultCustomCredentialPolicy implements HttpPipelinePolicy {
*/
@Override
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
if ("http".equals(context.getHttpRequest().getUrl().getProtocol())) {
if ("https".equals(context.getHttpRequest().getUrl().getProtocol())) {
return Mono.error(new RuntimeException(SQLServerException.getErrString("R_TokenRequireUrl")));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class KeyVaultTokenCredential implements TokenCredential {
/**
* Creates a KeyVaultCredential with the given identity client options.
*
* @param authenticationCallback The authentication callback that gets invoked when an access token is requested.
* @param authenticationCallback
* The authentication callback that gets invoked when an access token is requested.
*/
KeyVaultTokenCredential(SQLServerKeyVaultAuthenticationCallback authenticationCallback) {
this.authenticationCallback = authenticationCallback;
Expand All @@ -78,9 +79,11 @@ class KeyVaultTokenCredential implements TokenCredential {
@Override
public Mono<AccessToken> getToken(TokenRequestContext request) {
if (null != authenticationCallback) {
// If the callback is not null, invoke the callback to get the token. This gets invoked each time
// this method is called and will not cache the token. It's the callback's responsibility to return a valid
// token each time it's invoked.
/*
* If the callback is not null, invoke the callback to get the token. This gets invoked each time this
* method is called and will not cache the token. It's the callback's responsibility to return a valid token
* each time it's invoked.
*/
String accessToken = authenticationCallback.getAccessToken(this.authorization, this.resource, this.scope);
return Mono.just(new AccessToken(accessToken, OffsetDateTime.MIN));
}
Expand All @@ -93,7 +96,8 @@ public Mono<AccessToken> getToken(TokenRequestContext request) {
/**
* Sets the authority that will be used for authentication.
*
* @param authorization The name of the authorization.
* @param authorization
* The name of the authorization.
* @return The updated {@link KeyVaultTokenCredential} instance.
*/
KeyVaultTokenCredential setAuthorization(String authorization) {
Expand Down Expand Up @@ -143,9 +147,11 @@ private ConfidentialClientApplication getConfidentialClientApplication() {
}

/**
* Attempts to get the access token from the client cache if it's not expired. If it's expired this returns an
* empty response.
* @param request The context for requesting the token including the scope.
* Attempts to get the access token from the client cache if it's not expired. If it's expired this returns an empty
* response.
*
* @param request
* The context for requesting the token including the scope.
* @return The cached access token if it's not expired.
*/
private Mono<AccessToken> authenticateWithConfidentialClientCache(TokenRequestContext request) {
Expand All @@ -163,7 +169,9 @@ private Mono<AccessToken> authenticateWithConfidentialClientCache(TokenRequestCo

/**
* If fetching the token resulted in an error, this method returns the error wrapped in a completable future.
* @param e The original exception.
*
* @param e
* The original exception.
* @return A {@link CompletableFuture} that completes with an error.
*/
private CompletableFuture<IAuthenticationResult> getFailedCompletableFuture(Exception e) {
Expand All @@ -174,7 +182,9 @@ private CompletableFuture<IAuthenticationResult> getFailedCompletableFuture(Exce

/**
* Attempts to get the access token from the {@link ConfidentialClientApplication} for the requested scope.
* @param request The context for requesting the token that includes the scope.
*
* @param request
* The context for requesting the token that includes the scope.
* @return The access token.
*/
private Mono<AccessToken> authenticateWithConfidentialClient(TokenRequestContext request) {
Expand All @@ -187,15 +197,19 @@ private Mono<AccessToken> authenticateWithConfidentialClient(TokenRequestContext

/**
* Sets the resource name.
* @param resource The resource name.
*
* @param resource
* The resource name.
*/
void setResource(String resource) {
this.resource = resource;
}

/**
* Sets the scope for the access token.
* @param scope The scope for the access token.
*
* @param scope
* The scope for the access token.
*/
void setScope(String scope) {
this.scope = scope;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;

import com.azure.core.credential.TokenCredential;
Expand Down Expand Up @@ -199,8 +198,11 @@ public SQLServerColumnEncryptionAzureKeyVaultProvider(

/**
* Sets the credential that will be used for authenticating requests to Key Vault service.
* @param credential A credential of type {@link TokenCredential}.
* @throws SQLServerException If the credential is null.
*
* @param credential
* A credential of type {@link TokenCredential}.
* @throws SQLServerException
* If the credential is null.
*/
private void setCredential(TokenCredential credential) throws SQLServerException {
if (null == credential) {
Expand Down Expand Up @@ -670,11 +672,14 @@ private int getAKVKeySize(String masterKeyPath) throws SQLServerException {
}

/**
* Fetches the key from Azure Key Vault for given key path. If the key path includes a version, then that
* specific version of the key is retrieved, otherwise the latest key will be retrieved.
* @param masterKeyPath The key path associated with the key
* Fetches the key from Azure Key Vault for given key path. If the key path includes a version, then that specific
* version of the key is retrieved, otherwise the latest key will be retrieved.
*
* @param masterKeyPath
* The key path associated with the key
* @return The Key Vault key.
* @throws SQLServerException If there was an error retrieving the key from Key Vault.
* @throws SQLServerException
* If there was an error retrieving the key from Key Vault.
*/
private KeyVaultKey getKeyVaultKey(String masterKeyPath) throws SQLServerException {
String[] keyTokens = masterKeyPath.split(KEY_URL_DELIMITER);
Expand Down Expand Up @@ -707,10 +712,11 @@ private KeyVaultKey getKeyVaultKey(String masterKeyPath) throws SQLServerExcepti
}

/**
* Creates a new {@link KeyClient} if one does not exist for the given key path. If the client already exists,
* the client is returned from the cache. As the client is stateless, it's safe to cache the client for each key
* path.
* @param masterKeyPath The key path for which the {@link KeyClient} will be created, if it does not exist.
* Creates a new {@link KeyClient} if one does not exist for the given key path. If the client already exists, the
* client is returned from the cache. As the client is stateless, it's safe to cache the client for each key path.
*
* @param masterKeyPath
* The key path for which the {@link KeyClient} will be created, if it does not exist.
* @return The {@link KeyClient} associated with the key path.
*/
private KeyClient getKeyClient(String masterKeyPath) {
Expand All @@ -731,7 +737,9 @@ private KeyClient getKeyClient(String masterKeyPath) {

/**
* Returns the vault url extracted from the master key path.
* @param masterKeyPath The master key path.
*
* @param masterKeyPath
* The master key path.
* @return The vault url.
*/
private static String getVaultUrl(String masterKeyPath) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use
} catch (MalformedURLException | InterruptedException e) {
throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_ADALExecution"));
if (logger.isLoggable(Level.SEVERE)) {
logger.fine(logger.toString() + " MSAL exception:" + e.getMessage());
}

MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
Object[] msgArgs = {user, authenticationString};

/*
Expand Down Expand Up @@ -98,7 +102,11 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
} catch (InterruptedException | IOException e) {
throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_ADALExecution"));
if (logger.isLoggable(Level.SEVERE)) {
logger.fine(logger.toString() + " MSAL exception:" + e.getMessage());
}

MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
Object[] msgArgs = {"", authenticationString};

if (null == e.getCause() || null == e.getCause().getMessage()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ protected Object[][] getContents() {
"FEDAUTHINFO token stream is not long enough ({0}) to contain the data it claims to."},
{"R_FedAuthInfoDoesNotContainStsurlAndSpn",
"FEDAUTHINFO token stream does not contain both STSURL and SPN."},
{"R_ADALExecution", "Failed to authenticate the user {0} in Active Directory (Authentication={1})."},
{"R_MSALExecution", "Failed to authenticate the user {0} in Active Directory (Authentication={1})."},
{"R_UnrequestedFeatureAckReceived", "Unrequested feature acknowledge is received. Feature ID: {0}."},
{"R_FedAuthFeatureAckContainsExtraData",
"Federated authentication feature extension ack for ADAL and Security Token includes extra data."},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void testConcurrentLogin() throws Exception {

t1.start();
t2.start();
if (isWindows && enableADIntegrated) {
if (enableADIntegrated) {
Thread t3 = new Thread(r3);
t3.setUncaughtExceptionHandler(handler);
t3.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void testAccessTokenExpiredThenCreateNewStatementADPassword() throws SQLE

@Test
public void testAccessTokenExpiredThenCreateNewStatementADIntegrated() throws SQLException {
org.junit.Assume.assumeTrue(isWindows && enableADIntegrated);
org.junit.Assume.assumeTrue(enableADIntegrated);

testAccessTokenExpiredThenCreateNewStatement(SqlAuthentication.ActiveDirectoryIntegrated);
}
Expand Down Expand Up @@ -107,7 +107,7 @@ public void testAccessTokenExpiredThenExecuteUsingSameStatementADPassword() thro

@Test
public void testAccessTokenExpiredThenExecuteUsingSameStatementADIntegrated() throws SQLException {
org.junit.Assume.assumeTrue(isWindows && enableADIntegrated);
org.junit.Assume.assumeTrue(enableADIntegrated);

testAccessTokenExpiredThenExecuteUsingSameStatement(SqlAuthentication.ActiveDirectoryIntegrated);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@Tag(Constants.fedAuth)
public class ErrorMessageTest extends FedauthCommon {

String userName = "abc" + azureUserName;
String badUserName = "abc" + azureUserName;
String connectionUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase;

@Test
Expand Down Expand Up @@ -214,13 +214,14 @@ public void testSQLPasswordWithUntrustedSqlDB() throws SQLException {

@Test
public void testADPasswordUnregisteredUserWithConnectionStringUserName() throws SQLException {
try (Connection connection = DriverManager.getConnection(connectionUrl + ";userName=" + userName + ";password="
+ azurePassword + ";Authentication=" + SqlAuthentication.ActiveDirectoryPassword.toString())) {
try (Connection connection = DriverManager
.getConnection(connectionUrl + ";userName=" + badUserName + ";password=" + azurePassword
+ ";Authentication=" + SqlAuthentication.ActiveDirectoryPassword.toString())) {
fail(EXPECTED_EXCEPTION_NOT_THROWN);
} catch (SQLServerException e) {
assertTrue(INVALID_EXCEPION_MSG + ": " + e.getMessage(),
e.getMessage()
.contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + userName
.contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName
+ " in Active Directory (Authentication=ActiveDirectoryPassword).")
&& e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_ADD));
}
Expand All @@ -232,7 +233,7 @@ public void testADPasswordUnregisteredUserWithDatasource() throws SQLException {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName(azureServer);
ds.setDatabaseName(azureDatabase);
ds.setUser(userName);
ds.setUser(badUserName);
ds.setPassword(azurePassword);
ds.setAuthentication(SqlAuthentication.ActiveDirectoryPassword.toString());

Expand All @@ -241,21 +242,21 @@ public void testADPasswordUnregisteredUserWithDatasource() throws SQLException {
} catch (SQLServerException e) {
assertTrue(INVALID_EXCEPION_MSG + ": " + e.getMessage(),
e.getMessage()
.contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + userName
.contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName
+ " in Active Directory (Authentication=ActiveDirectoryPassword).")
&& e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_ADD));
}
}

@Test
public void testADPasswordUnregisteredUserWithConnectionStringUser() throws SQLException {
try (Connection connection = DriverManager.getConnection(connectionUrl + ";user=" + userName + ";password="
try (Connection connection = DriverManager.getConnection(connectionUrl + ";user=" + badUserName + ";password="
+ azurePassword + ";Authentication=" + SqlAuthentication.ActiveDirectoryPassword.toString())) {
fail(EXPECTED_EXCEPTION_NOT_THROWN);
} catch (SQLServerException e) {
assertTrue(INVALID_EXCEPION_MSG + ": " + e.getMessage(),
e.getMessage()
.contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + userName
.contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName
+ " in Active Directory (Authentication=ActiveDirectoryPassword).")
&& e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_ADD));
}
Expand All @@ -268,20 +269,20 @@ public void testAuthenticationAgainstSQLServerWithActivedirectorypassword() thro
info.put("Authentication", SqlAuthentication.ActiveDirectoryPassword.toString());

try (Connection connection = DriverManager
.getConnection(connectionUrl + ";user=" + userName + ";password=" + azurePassword, info)) {
.getConnection(connectionUrl + ";user=" + badUserName + ";password=" + azurePassword, info)) {
fail(EXPECTED_EXCEPTION_NOT_THROWN);
} catch (Exception e) {
if (!(e instanceof SQLServerException)) {
fail(EXPECTED_EXCEPTION_NOT_THROWN);
}
assertTrue(INVALID_EXCEPION_MSG + ": " + e.getMessage(), e.getMessage().contains(ERR_MSG_FAILED_AUTHENTICATE
+ " the user " + userName + " in Active Directory (Authentication=ActiveDirectoryPassword)."));
+ " the user " + badUserName + " in Active Directory (Authentication=ActiveDirectoryPassword)."));
}
}

@Test
public void testAuthenticationAgainstSQLServerWithActivedirectoryIntegrated() throws SQLException {
org.junit.Assume.assumeTrue(isWindows && enableADIntegrated);
org.junit.Assume.assumeTrue(enableADIntegrated);

java.util.Properties info = new Properties();
info.put("TrustServerCertificate", "true");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public static void getConfigs() throws Exception {
azureGroupUserName = getConfiguredProperty("azureGroupUserName");

String prop = getConfiguredProperty("enableADIntegrated");
enableADIntegrated = (isWindows && null != prop && prop.equalsIgnoreCase("true")) ? true : false;
enableADIntegrated = (null != prop && prop.equalsIgnoreCase("true")) ? true : false;

adPasswordConnectionStr = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";user="
+ azureUserName + ";password=" + azurePassword + ";Authentication="
Expand Down Expand Up @@ -169,7 +169,12 @@ void testUserName(Connection conn, String user, SqlAuthentication authentication
if (SqlAuthentication.ActiveDirectoryIntegrated != authentication) {
assertTrue(user.equals(rs.getString(1)));
} else {
assertTrue(rs.getString(1).contains(System.getProperty("user.name")));
if (isWindows) {
assertTrue(rs.getString(1).contains(System.getProperty("user.name")));
} else {
// cannot verify user in kerberos tickets so just check it's not empty
assertTrue(!rs.getString(1).isEmpty());
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void testActiveDirectoryPasswordDS() throws Exception {

@Test
public void testActiveDirectoryIntegratedDS() throws Exception {
org.junit.Assume.assumeTrue(isWindows && enableADIntegrated);
org.junit.Assume.assumeTrue(enableADIntegrated);

SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName(azureServer);
Expand Down Expand Up @@ -173,7 +173,7 @@ public void testNotValidSqlPassword() throws SQLException {

@Test
public void testNotValidActiveDirectoryIntegrated() throws SQLException {
org.junit.Assume.assumeTrue(isWindows && enableADIntegrated);
org.junit.Assume.assumeTrue(enableADIntegrated);

testNotValid(SqlAuthentication.ActiveDirectoryIntegrated.toString(), false, true);
testNotValid(SqlAuthentication.ActiveDirectoryIntegrated.toString(), true, true);
Expand All @@ -200,7 +200,7 @@ public void testValidSqlPassword() throws SQLException {

@Test
public void testValidActiveDirectoryIntegrated() throws SQLException {
org.junit.Assume.assumeTrue(isWindows && enableADIntegrated);
org.junit.Assume.assumeTrue(enableADIntegrated);

testValid(SqlAuthentication.ActiveDirectoryIntegrated.toString(), false, true);
testValid(SqlAuthentication.ActiveDirectoryIntegrated.toString(), true, true);
Expand Down
Loading

0 comments on commit 9e38cc3

Please sign in to comment.