forked from microsoft/mssql-jdbc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request microsoft#8 from srnagar/keyvault-upgrade
Remove tenant Id and fix adal issues
- Loading branch information
Showing
7 changed files
with
407 additions
and
29 deletions.
There are no files selected for viewing
124 changes: 124 additions & 0 deletions
124
src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package com.microsoft.sqlserver.jdbc; | ||
|
||
import com.azure.core.annotation.Immutable; | ||
import com.azure.core.credential.AccessToken; | ||
import com.azure.core.credential.TokenRequestContext; | ||
import com.azure.core.util.logging.ClientLogger; | ||
import com.microsoft.aad.msal4j.ClientCredentialFactory; | ||
import com.microsoft.aad.msal4j.ClientCredentialParameters; | ||
import com.microsoft.aad.msal4j.ConfidentialClientApplication; | ||
import com.microsoft.aad.msal4j.IAuthenticationResult; | ||
import com.microsoft.aad.msal4j.IClientCredential; | ||
import com.microsoft.aad.msal4j.SilentParameters; | ||
import java.net.MalformedURLException; | ||
import java.time.OffsetDateTime; | ||
import java.time.ZoneOffset; | ||
import java.util.HashSet; | ||
import java.util.Objects; | ||
import java.util.concurrent.CompletableFuture; | ||
import reactor.core.publisher.Mono; | ||
|
||
|
||
/** | ||
* An AAD credential that acquires a token with a client secret for an AAD application. | ||
* | ||
* <p><strong>Sample: Construct a simple KeyVaultCredential</strong></p> | ||
* {@codesnippet com.azure.identity.credential.clientsecretcredential.construct} | ||
* | ||
* <p><strong>Sample: Construct a KeyVaultCredential behind a proxy</strong></p> | ||
* {@codesnippet com.azure.identity.credential.clientsecretcredential.constructwithproxy} | ||
*/ | ||
@Immutable | ||
class KeyVaultCredential { | ||
private final ClientLogger logger = new ClientLogger(KeyVaultCredential.class); | ||
private final String clientId; | ||
private final String clientSecret; | ||
private String authorization; | ||
private ConfidentialClientApplication confidentialClientApplication; | ||
|
||
/** | ||
* Creates a KeyVaultCredential with the given identity client options. | ||
* | ||
* @param clientId the client ID of the application | ||
* @param clientSecret the secret value of the AAD application. | ||
*/ | ||
KeyVaultCredential(String clientId, String clientSecret) { | ||
Objects.requireNonNull(clientSecret, "'clientSecret' cannot be null."); | ||
Objects.requireNonNull(clientSecret, "'clientId' cannot be null."); | ||
this.clientId = clientId; | ||
this.clientSecret = clientSecret; | ||
} | ||
|
||
public Mono<AccessToken> getToken(TokenRequestContext request) { | ||
return authenticateWithConfidentialClientCache(request) | ||
.onErrorResume(t -> Mono.empty()) | ||
.switchIfEmpty(Mono.defer(() -> authenticateWithConfidentialClient(request))); | ||
} | ||
|
||
public KeyVaultCredential setAuthorization(String authorization) { | ||
if (this.authorization != null && this.authorization.equals(authorization)) { | ||
return this; | ||
} | ||
this.authorization = authorization; | ||
confidentialClientApplication = getConfidentialClientApplication(); | ||
return this; | ||
} | ||
|
||
private ConfidentialClientApplication getConfidentialClientApplication() { | ||
if (clientId == null) { | ||
throw logger.logExceptionAsError(new IllegalArgumentException( | ||
"A non-null value for client ID must be provided for user authentication.")); | ||
} | ||
|
||
if (authorization == null) { | ||
throw logger.logExceptionAsError(new IllegalArgumentException( | ||
"A non-null value for authorization must be provided for user authentication.")); | ||
} | ||
|
||
IClientCredential credential; | ||
if (clientSecret != null) { | ||
credential = ClientCredentialFactory.create(clientSecret); | ||
} else { | ||
throw logger.logExceptionAsError( | ||
new IllegalArgumentException("Must provide client secret.")); | ||
} | ||
ConfidentialClientApplication.Builder applicationBuilder = | ||
ConfidentialClientApplication.builder(clientId, credential); | ||
try { | ||
applicationBuilder = applicationBuilder.authority(authorization); | ||
} catch (MalformedURLException e) { | ||
throw logger.logExceptionAsWarning(new IllegalStateException(e)); | ||
} | ||
return applicationBuilder.build(); | ||
} | ||
|
||
private Mono<AccessToken> authenticateWithConfidentialClientCache(TokenRequestContext request) { | ||
return Mono.fromFuture(() -> { | ||
SilentParameters.SilentParametersBuilder parametersBuilder = SilentParameters | ||
.builder(new HashSet<>(request.getScopes())); | ||
try { | ||
return confidentialClientApplication.acquireTokenSilently(parametersBuilder.build()); | ||
} catch (MalformedURLException e) { | ||
return getFailedCompletableFuture(logger.logExceptionAsError(new RuntimeException(e))); | ||
} | ||
}).map(ar -> new AccessToken(ar.accessToken(), | ||
OffsetDateTime.ofInstant(ar.expiresOnDate().toInstant(), ZoneOffset.UTC))) | ||
.filter(t -> !t.isExpired()); | ||
} | ||
|
||
private CompletableFuture<IAuthenticationResult> getFailedCompletableFuture(Exception e) { | ||
CompletableFuture<IAuthenticationResult> completableFuture = new CompletableFuture<>(); | ||
completableFuture.completeExceptionally(e); | ||
return completableFuture; | ||
} | ||
|
||
private Mono<AccessToken> authenticateWithConfidentialClient(TokenRequestContext request) { | ||
return Mono.fromFuture(() -> confidentialClientApplication | ||
.acquireToken(ClientCredentialParameters.builder(new HashSet<>(request.getScopes())).build())) | ||
.map(ar -> new AccessToken(ar.accessToken(), | ||
OffsetDateTime.ofInstant(ar.expiresOnDate().toInstant(), ZoneOffset.UTC))); | ||
} | ||
} |
105 changes: 105 additions & 0 deletions
105
src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCustomCredentialPolicy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package com.microsoft.sqlserver.jdbc; | ||
|
||
import com.azure.core.credential.TokenCredential; | ||
import com.azure.core.credential.TokenRequestContext; | ||
import com.azure.core.http.HttpPipelineCallContext; | ||
import com.azure.core.http.HttpPipelineNextPolicy; | ||
import com.azure.core.http.HttpResponse; | ||
import com.azure.core.http.policy.HttpPipelinePolicy; | ||
import com.azure.core.util.CoreUtils; | ||
import java.util.HashMap; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import reactor.core.publisher.Mono; | ||
|
||
|
||
/** | ||
* A policy that authenticates requests with Azure Key Vault service. The content added by this policy | ||
* is leveraged in {@link TokenCredential} to get and set the correct "Authorization" header value. | ||
* | ||
* @see TokenCredential | ||
*/ | ||
class KeyVaultCustomCredentialPolicy implements HttpPipelinePolicy { | ||
private static final String WWW_AUTHENTICATE = "WWW-Authenticate"; | ||
private static final String BEARER_TOKEN_PREFIX = "Bearer "; | ||
private static final String AUTHORIZATION = "Authorization"; | ||
private final ScopeTokenCache cache; | ||
private final KeyVaultCredential keyVaultCredential; | ||
|
||
/** | ||
* Creates KeyVaultCustomCredentialPolicy. | ||
* | ||
* @param credential the token credential to authenticate the request | ||
*/ | ||
public KeyVaultCustomCredentialPolicy(KeyVaultCredential credential) { | ||
Objects.requireNonNull(credential, "'credential' cannot be null."); | ||
this.cache = new ScopeTokenCache(credential::getToken); | ||
this.keyVaultCredential = credential; | ||
} | ||
|
||
/** | ||
* Adds the required header to authenticate a request to Azure Key Vault service. | ||
* | ||
* @param context The request context | ||
* @param next The next HTTP pipeline policy to process the {@code context's} request after this policy completes. | ||
* @return A {@link Mono} representing the HTTP response that will arrive asynchronously. | ||
*/ | ||
@Override | ||
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { | ||
if ("http".equals(context.getHttpRequest().getUrl().getProtocol())) { | ||
return Mono.error(new RuntimeException("Token credentials require a URL using the HTTPS protocol scheme")); | ||
} | ||
return next.clone().process() | ||
// Ignore body | ||
.doOnNext(HttpResponse::close) | ||
.map(res -> res.getHeaderValue(WWW_AUTHENTICATE)) | ||
.map(header -> extractChallenge(header, BEARER_TOKEN_PREFIX)) | ||
.flatMap(map -> { | ||
keyVaultCredential.setAuthorization(map.get("authorization")); | ||
cache.setRequest(new TokenRequestContext().addScopes(map.get("resource") + "/.default")); | ||
return cache.getToken(); | ||
}) | ||
.flatMap(token -> { | ||
context.getHttpRequest().setHeader(AUTHORIZATION, BEARER_TOKEN_PREFIX + token.getToken()); | ||
return next.process(); | ||
}); | ||
} | ||
|
||
/** | ||
* Extracts the challenge off the authentication header. | ||
* | ||
* @param authenticateHeader The authentication header containing all the challenges. | ||
* @param authChallengePrefix The authentication challenge name. | ||
* @return a challenge map. | ||
*/ | ||
private static Map<String, String> extractChallenge(String authenticateHeader, String authChallengePrefix) { | ||
if (!isValidChallenge(authenticateHeader, authChallengePrefix)) { | ||
return null; | ||
} | ||
authenticateHeader = authenticateHeader.toLowerCase(Locale.ROOT).replace(authChallengePrefix.toLowerCase(Locale.ROOT), ""); | ||
|
||
String[] challenges = authenticateHeader.split(", "); | ||
Map<String, String> challengeMap = new HashMap<>(); | ||
for (String pair : challenges) { | ||
String[] keyValue = pair.split("="); | ||
challengeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", "")); | ||
} | ||
return challengeMap; | ||
} | ||
|
||
/** | ||
* Verifies whether a challenge is bearer or not. | ||
* | ||
* @param authenticateHeader The authentication header containing all the challenges. | ||
* @param authChallengePrefix The authentication challenge name. | ||
* @return A boolean indicating tha challenge is valid or not. | ||
*/ | ||
private static boolean isValidChallenge(String authenticateHeader, String authChallengePrefix) { | ||
return (!CoreUtils.isNullOrEmpty(authenticateHeader) | ||
&& authenticateHeader.toLowerCase(Locale.ROOT).startsWith(authChallengePrefix.toLowerCase(Locale.ROOT))); | ||
} | ||
} |
87 changes: 87 additions & 0 deletions
87
src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultHttpPipelineBuilder.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package com.microsoft.sqlserver.jdbc; | ||
|
||
import com.azure.core.http.HttpPipeline; | ||
import com.azure.core.http.HttpPipelineBuilder; | ||
import com.azure.core.http.policy.HttpLogOptions; | ||
import com.azure.core.http.policy.HttpLoggingPolicy; | ||
import com.azure.core.http.policy.HttpPipelinePolicy; | ||
import com.azure.core.http.policy.HttpPolicyProviders; | ||
import com.azure.core.http.policy.RetryPolicy; | ||
import com.azure.core.http.policy.UserAgentPolicy; | ||
import com.azure.core.util.Configuration; | ||
import com.azure.core.util.CoreUtils; | ||
import com.azure.core.util.logging.ClientLogger; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
|
||
final class KeyVaultHttpPipelineBuilder { | ||
private final ClientLogger logger = new ClientLogger(KeyVaultHttpPipelineBuilder.class); | ||
// This is properties file's name. | ||
private static final String AZURE_KEY_VAULT_SECRETS = "azure-key-vault-secrets.properties"; | ||
private static final String SDK_NAME = "name"; | ||
private static final String SDK_VERSION = "version"; | ||
|
||
private final List<HttpPipelinePolicy> policies; | ||
final Map<String, String> properties; | ||
private KeyVaultCredential credential; | ||
private HttpLogOptions httpLogOptions; | ||
private final RetryPolicy retryPolicy; | ||
|
||
/** | ||
* The constructor with defaults. | ||
*/ | ||
public KeyVaultHttpPipelineBuilder() { | ||
retryPolicy = new RetryPolicy(); | ||
httpLogOptions = new HttpLogOptions(); | ||
policies = new ArrayList<>(); | ||
properties = CoreUtils.getProperties(AZURE_KEY_VAULT_SECRETS); | ||
} | ||
|
||
public HttpPipeline buildPipeline() { | ||
Configuration buildConfiguration = Configuration.getGlobalConfiguration().clone(); | ||
|
||
if (credential == null) { | ||
throw logger.logExceptionAsError( | ||
new IllegalStateException( | ||
"Token Credential should be specified.")); | ||
} | ||
|
||
// Closest to API goes first, closest to wire goes last. | ||
final List<HttpPipelinePolicy> policies = new ArrayList<>(); | ||
|
||
String clientName = properties.getOrDefault(SDK_NAME, "UnknownName"); | ||
String clientVersion = properties.getOrDefault(SDK_VERSION, "UnknownVersion"); | ||
policies.add(new UserAgentPolicy(httpLogOptions.getApplicationId(), clientName, clientVersion, | ||
buildConfiguration)); | ||
HttpPolicyProviders.addBeforeRetryPolicies(policies); | ||
policies.add(retryPolicy); | ||
policies.add(new KeyVaultCustomCredentialPolicy(credential)); | ||
policies.addAll(this.policies); | ||
HttpPolicyProviders.addAfterRetryPolicies(policies); | ||
policies.add(new HttpLoggingPolicy(httpLogOptions)); | ||
|
||
HttpPipeline pipeline = new HttpPipelineBuilder() | ||
.policies(policies.toArray(new HttpPipelinePolicy[0])) | ||
.build(); | ||
|
||
return pipeline; | ||
} | ||
|
||
|
||
/** | ||
* Sets the credential to use when authenticating HTTP requests. | ||
* | ||
* @param credential The credential to use for authenticating HTTP requests. | ||
* @return the updated KVHttpPipelineBuilder object. | ||
* @throws NullPointerException if {@code credential} is {@code null}. | ||
*/ | ||
public KeyVaultHttpPipelineBuilder credential(KeyVaultCredential credential) { | ||
Objects.requireNonNull(credential); | ||
this.credential = credential; | ||
return this; | ||
} | ||
} | ||
|
Oops, something went wrong.