Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code to handle claims in authentication challenges #41814

Merged
merged 19 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.core.http.HttpHeaderName;
import com.azure.core.http.HttpPipelineCallContext;
import com.azure.core.http.HttpPipelineNextPolicy;
import com.azure.core.http.HttpPipelineNextSyncPolicy;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.HttpResponse;
import com.azure.core.http.policy.BearerTokenAuthenticationPolicy;
import com.azure.core.util.Base64Util;
import com.azure.core.util.BinaryData;
import com.azure.core.util.CoreUtils;
import com.azure.core.util.logging.ClientLogger;
Expand All @@ -27,6 +29,9 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import static com.azure.core.http.HttpHeaderName.CONTENT_LENGTH;
import static com.azure.core.http.HttpHeaderName.WWW_AUTHENTICATE;

/**
* A policy that authenticates requests with the Azure Key Vault service. The content added by this policy is
* leveraged in {@link TokenCredential} to get and set the correct "Authorization" header value.
Expand Down Expand Up @@ -67,16 +72,20 @@ private static Map<String, String> extractChallengeAttributes(String authenticat
return Collections.emptyMap();
}

authenticateHeader =
authenticateHeader.toLowerCase(Locale.ROOT).replace(authChallengePrefix.toLowerCase(Locale.ROOT), "");

String[] attributes = authenticateHeader.split(", ");
String[] attributes = authenticateHeader.substring(authChallengePrefix.length()).split(", ");
Map<String, String> attributeMap = new HashMap<>();

for (String pair : attributes) {
String[] keyValue = pair.split("=");
// This is ugly, but we need to trim here because currently the 'claims' attribute comes after two spaces.
pair = pair.trim();

attributeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", ""));
if (pair.startsWith("claims=")) {
attributeMap.put("claims", pair.substring("claims=".length()).replaceAll("\"", ""));
} else {
String[] keyValue = pair.split("=");

attributeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", ""));
}
}

return attributeMap;
Expand All @@ -102,8 +111,7 @@ public Mono<Void> authorizeRequest(HttpPipelineCallContext context) {

// If this policy doesn't have challenge parameters cached try to get it from the static challenge cache.
if (this.challenge == null) {
String authority = getRequestAuthority(request);
this.challenge = CHALLENGE_CACHE.get(authority);
this.challenge = CHALLENGE_CACHE.get(getRequestAuthority(request));
}

if (this.challenge != null) {
Expand All @@ -115,18 +123,18 @@ public Mono<Void> authorizeRequest(HttpPipelineCallContext context) {
return setAuthorizationHeader(context, tokenRequestContext);
}

// The body is removed from the initial request because Key Vault supports other authentication schemes which
// also protect the body of the request. As a result, before we know the auth scheme we need to avoid sending
// an unprotected body to Key Vault. We don't currently support this enhanced auth scheme in the SDK but we
// still don't want to send any unprotected data to vaults which require it.
// The body is removed from the initial request because Key Vault supports other authentication schemes
// which also protect the body of the request. As a result, before we know the auth scheme we need to
// avoid sending an unprotected body to Key Vault. We don't currently support this enhanced auth scheme
// in the SDK, but we still don't want to send any unprotected data to vaults which require it.

// Do not overwrite previous contents if retrying after initial request failed (e.g. timeout).
if (!context.getData(KEY_VAULT_STASHED_CONTENT_KEY).isPresent()) {
if (request.getBody() != null) {
context.setData(KEY_VAULT_STASHED_CONTENT_KEY, request.getBody());
context.setData(KEY_VAULT_STASHED_CONTENT_LENGTH_KEY,
request.getHeaders().getValue(HttpHeaderName.CONTENT_LENGTH));
request.setHeader(HttpHeaderName.CONTENT_LENGTH, "0");
request.getHeaders().getValue(CONTENT_LENGTH));
request.setHeader(CONTENT_LENGTH, "0");
request.setBody((Flux<ByteBuffer>) null);
}
}
Expand All @@ -145,13 +153,12 @@ public Mono<Boolean> authorizeRequestOnChallenge(HttpPipelineCallContext context

if (request.getBody() == null && contentOptional.isPresent() && contentLengthOptional.isPresent()) {
request.setBody((Flux<ByteBuffer>) contentOptional.get());
request.setHeader(HttpHeaderName.CONTENT_LENGTH, (String) contentLengthOptional.get());
request.setHeader(CONTENT_LENGTH, (String) contentLengthOptional.get());
}

String authority = getRequestAuthority(request);
Map<String, String> challengeAttributes =
extractChallengeAttributes(response.getHeaderValue(HttpHeaderName.WWW_AUTHENTICATE),
BEARER_TOKEN_PREFIX);
extractChallengeAttributes(response.getHeaderValue(WWW_AUTHENTICATE), BEARER_TOKEN_PREFIX);
String scope = challengeAttributes.get("resource");

if (scope != null) {
Expand Down Expand Up @@ -203,6 +210,22 @@ public Mono<Boolean> authorizeRequestOnChallenge(HttpPipelineCallContext context
.addScopes(this.challenge.getScopes())
.setTenantId(this.challenge.getTenantId());

String error = challengeAttributes.get("error");

if (error != null) {
LOGGER.verbose(String.format("The challenge response contained an error: %s", error));

if ("insufficient_claims".equalsIgnoreCase(error)) {
String claims = challengeAttributes.get("claims");

if (claims != null) {
tokenRequestContext
.setCaeEnabled(true)
.setClaims(new String(Base64Util.decodeString(claims)));
}
}
}

return setAuthorizationHeader(context, tokenRequestContext)
.then(Mono.just(true));
});
Expand All @@ -214,8 +237,7 @@ public void authorizeRequestSync(HttpPipelineCallContext context) {

// If this policy doesn't have challenge parameters cached try to get it from the static challenge cache.
if (this.challenge == null) {
String authority = getRequestAuthority(request);
this.challenge = CHALLENGE_CACHE.get(authority);
this.challenge = CHALLENGE_CACHE.get(getRequestAuthority(request));
}

if (this.challenge != null) {
Expand All @@ -225,27 +247,27 @@ public void authorizeRequestSync(HttpPipelineCallContext context) {
.setTenantId(this.challenge.getTenantId());

setAuthorizationHeaderSync(context, tokenRequestContext);

return;
}

// The body is removed from the initial request because Key Vault supports other authentication schemes which
// also protect the body of the request. As a result, before we know the auth scheme we need to avoid sending
// an unprotected body to Key Vault. We don't currently support this enhanced auth scheme in the SDK but we
// still don't want to send any unprotected data to vaults which require it.
// also protect the body of the request. As a result, before we know the auth scheme we need to avoid sending an
// unprotected body to Key Vault. We don't currently support this enhanced auth scheme in the SDK, but we still
// don't want to send any unprotected data to vaults which require it.

// Do not overwrite previous contents if retrying after initial request failed (e.g. timeout).
if (!context.getData(KEY_VAULT_STASHED_CONTENT_KEY).isPresent()) {
if (request.getBodyAsBinaryData() != null) {
context.setData(KEY_VAULT_STASHED_CONTENT_KEY, request.getBodyAsBinaryData());
context.setData(KEY_VAULT_STASHED_CONTENT_LENGTH_KEY,
request.getHeaders().getValue(HttpHeaderName.CONTENT_LENGTH));
request.setHeader(HttpHeaderName.CONTENT_LENGTH, "0");
request.getHeaders().getValue(CONTENT_LENGTH));
request.setHeader(CONTENT_LENGTH, "0");
request.setBody((BinaryData) null);
}
}
}

@SuppressWarnings("unchecked")
@Override
public boolean authorizeRequestOnChallengeSync(HttpPipelineCallContext context, HttpResponse response) {
HttpRequest request = context.getHttpRequest();
Expand All @@ -254,12 +276,12 @@ public boolean authorizeRequestOnChallengeSync(HttpPipelineCallContext context,

if (request.getBody() == null && contentOptional.isPresent() && contentLengthOptional.isPresent()) {
request.setBody((BinaryData) (contentOptional.get()));
request.setHeader(HttpHeaderName.CONTENT_LENGTH, (String) contentLengthOptional.get());
request.setHeader(CONTENT_LENGTH, (String) contentLengthOptional.get());
}

String authority = getRequestAuthority(request);
Map<String, String> challengeAttributes =
extractChallengeAttributes(response.getHeaderValue(HttpHeaderName.WWW_AUTHENTICATE), BEARER_TOKEN_PREFIX);
extractChallengeAttributes(response.getHeaderValue(WWW_AUTHENTICATE), BEARER_TOKEN_PREFIX);
String scope = challengeAttributes.get("resource");

if (scope != null) {
Expand Down Expand Up @@ -311,10 +333,130 @@ public boolean authorizeRequestOnChallengeSync(HttpPipelineCallContext context,
.addScopes(this.challenge.getScopes())
.setTenantId(this.challenge.getTenantId());

String error = challengeAttributes.get("error");

if (error != null) {
LOGGER.verbose(String.format("The challenge response contained an error: %s", error));

if ("insufficient_claims".equalsIgnoreCase(error)) {
String claims = challengeAttributes.get("claims");

if (claims != null) {
tokenRequestContext
.setCaeEnabled(true)
.setClaims(new String(Base64Util.decodeString(claims)));
}
}
}

setAuthorizationHeaderSync(context, tokenRequestContext);

return true;
}

@Override
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
if (!"https".equals(context.getHttpRequest().getUrl().getProtocol())) {
return Mono.error(new RuntimeException("Token credentials require a URL using the HTTPS protocol scheme."));
}

HttpPipelineNextPolicy nextPolicy = next.clone();

return authorizeRequest(context).then(Mono.defer(next::process)).flatMap(httpResponse -> {
String authHeader = httpResponse.getHeaderValue(WWW_AUTHENTICATE);

if (httpResponse.getStatusCode() == 401 && authHeader != null) {
return handleChallenge(context, httpResponse, nextPolicy);
}

return Mono.just(httpResponse);
});
}

@Override
public HttpResponse processSync(HttpPipelineCallContext context, HttpPipelineNextSyncPolicy next) {
if (!"https".equals(context.getHttpRequest().getUrl().getProtocol())) {
throw LOGGER.logExceptionAsError(
new RuntimeException("Token credentials require a URL using the HTTPS protocol scheme."));
}

HttpPipelineNextSyncPolicy nextPolicy = next.clone();

authorizeRequestSync(context);

HttpResponse httpResponse = next.processSync();
String authHeader = httpResponse.getHeaderValue(WWW_AUTHENTICATE);

if (httpResponse.getStatusCode() == 401 && authHeader != null) {
return handleChallengeSync(context, httpResponse, nextPolicy);
}

return httpResponse;
}

private Mono<HttpResponse> handleChallenge(HttpPipelineCallContext context, HttpResponse httpResponse,
HttpPipelineNextPolicy next) {
return authorizeRequestOnChallenge(context, httpResponse).flatMap(authorized -> {
if (authorized) {
// The body needs to be closed or read to the end to release the connection.
httpResponse.close();

HttpPipelineNextPolicy nextPolicy = next.clone();

return next.process().flatMap(newResponse -> {
String authHeader = newResponse.getHeaderValue(WWW_AUTHENTICATE);

if (newResponse.getStatusCode() == 401 && authHeader != null
&& !(isClaimsPresent(httpResponse) && isClaimsPresent(newResponse))) {

return handleChallenge(context, newResponse, nextPolicy);
} else {
return Mono.just(newResponse);
}
});
}

return Mono.just(httpResponse);
});
}

private HttpResponse handleChallengeSync(HttpPipelineCallContext context, HttpResponse httpResponse,
HttpPipelineNextSyncPolicy next) {
if (authorizeRequestOnChallengeSync(context, httpResponse)) {
// The body needs to be closed or read to the end to release the connection.
httpResponse.close();

HttpPipelineNextSyncPolicy nextPolicy = next.clone();
HttpResponse newResponse = next.processSync();
String authHeader = newResponse.getHeaderValue(WWW_AUTHENTICATE);

if (newResponse.getStatusCode() == 401 && authHeader != null
&& !(isClaimsPresent(httpResponse) && isClaimsPresent(newResponse))) {

return handleChallengeSync(context, newResponse, nextPolicy);
}

return newResponse;
}

return httpResponse;
}

private boolean isClaimsPresent(HttpResponse httpResponse) {
Map<String, String> challengeAttributes =
extractChallengeAttributes(httpResponse.getHeaderValue(WWW_AUTHENTICATE), BEARER_TOKEN_PREFIX);

String error = challengeAttributes.get("error");

if (error != null) {
String base64Claims = challengeAttributes.get("claims");

return "insufficient_claims".equalsIgnoreCase(error) && base64Claims != null;
}

return false;
}

private static class ChallengeParameters {
private final URI authorizationUri;
private final String tenantId;
Expand Down
Loading