Skip to content

Commit

Permalink
Add azure arc managed identity (#730)
Browse files Browse the repository at this point in the history
* Add azure arc managed identity

* Removed lenient from mock and merge conflicts

* Clear cache in unit tests

* Fix after manual testing

* Update the log message

* Service Fabric MSI (#729)

* Support for service fabric, most tests working

* TODOs and sonarlint recommendations

* Address PR comments

---------

Co-authored-by: Avery-Dunn <avdunn@microsoft.com>
Co-authored-by: Avery-Dunn <62066438+Avery-Dunn@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 9, 2023
1 parent 1f7fa9d commit cc7f1c6
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ abstract class AbstractManagedIdentitySource {

protected final ManagedIdentityRequest managedIdentityRequest;
protected final ServiceBundle serviceBundle;
private ManagedIdentitySourceType managedIdentitySourceType;
ManagedIdentitySourceType managedIdentitySourceType;

@Getter
@Setter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ AuthenticationResult execute() throws Exception {
scopes.add(this.managedIdentityParameters.resource);
SilentParameters parameters = SilentParameters
.builder(scopes)
.tenant(managedIdentityParameters.tenant())
.build();

RequestContext context = new RequestContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ class AppServiceManagedIdentitySource extends AbstractManagedIdentitySource{
private static final String APP_SERVICE_MSI_API_VERSION = "2019-08-01";
private static final String SECRET_HEADER_NAME = "X-IDENTITY-HEADER";

private final URI MSI_ENDPOINT;
private final String SECRET;
private final URI msiEndpoint;
private final String identityHeader;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, SECRET);
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, identityHeader);

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(APP_SERVICE_MSI_API_VERSION));
Expand All @@ -50,8 +50,8 @@ public void createManagedIdentityRequest(String resource) {
private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint, String secret)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.APP_SERVICE);
this.MSI_ENDPOINT = msiEndpoint;
this.SECRET = secret;
this.msiEndpoint = msiEndpoint;
this.identityHeader = secret;
}

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.FileReader;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.HashMap;

class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{

private final static Logger LOG = LoggerFactory.getLogger(AzureArcManagedIdentitySource.class);
private static final String ARC_API_VERSION = "2019-11-01";
private static final String AZURE_ARC = "Azure Arc";

private final URI MSI_ENDPOINT;

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
{
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);

URI validatedUri = validateAndGetUri(identityEndpoint, imdsEndpoint);
return validatedUri == null ? null : new AzureArcManagedIdentitySource(validatedUri, msalRequest, serviceBundle );
}

private static URI validateAndGetUri(String identityEndpoint, String imdsEndpoint) {

// if BOTH the env vars IDENTITY_ENDPOINT and IMDS_ENDPOINT are set the MsiType is Azure Arc
if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(imdsEndpoint))
{
LOG.info("[Managed Identity] Azure Arc managed identity is unavailable.");
return null;
}

URI endpointUri;
try {
endpointUri = new URI(identityEndpoint);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "IDENTITY_ENDPOINT", identityEndpoint, AZURE_ARC),
ManagedIdentitySourceType.AZURE_ARC);
}

LOG.info("[Managed Identity] Creating Azure Arc managed identity. Endpoint URI: " + endpointUri);
return endpointUri;
}

private AzureArcManagedIdentitySource(URI endpoint, MsalRequest msalRequest, ServiceBundle serviceBundle){
super(msalRequest, serviceBundle, ManagedIdentitySourceType.AZURE_ARC);
this.MSI_ENDPOINT = endpoint;

ManagedIdentityIdType idType =
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
if (idType != ManagedIdentityIdType.SYSTEM_ASSIGNED) {
throw new MsalManagedIdentityException(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED,
String.format(MsalErrorMessage.MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED, AZURE_ARC),
ManagedIdentitySourceType.AZURE_ARC);
}
}

@Override
public void createManagedIdentityRequest(String resource)
{
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put("Metadata", "true");

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(ARC_API_VERSION));
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
}

@Override
public ManagedIdentityResponse handleResponse(
ManagedIdentityParameters parameters,
IHttpResponse response) {

LOG.info("[Managed Identity] Response received. Status code: {response.StatusCode}");

if (response.statusCode() == HttpURLConnection.HTTP_UNAUTHORIZED) {
if(!response.headers().containsKey("Www-Authenticate")) {
LOG.error("[Managed Identity] WWW-Authenticate header is expected but not found.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR,
ManagedIdentitySourceType.AZURE_ARC);
}

String challenge = response.headers().get("Www-Authenticate").get(0);
String[] splitChallenge = challenge.split("=");

if (splitChallenge.length != 2) {
LOG.error("[Managed Identity] The WWW-Authenticate header for Azure arc managed identity is not an expected format.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE,
ManagedIdentitySourceType.AZURE_ARC);
}

Path path = Paths.get(splitChallenge[1]);

String authHeaderValue = null;
try {
authHeaderValue = "Basic " + new String(Files.readAllBytes(path), StandardCharsets.UTF_8);
} catch (IOException e) {
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_FILE_READ_ERROR, e.getMessage(), ManagedIdentitySourceType.AZURE_ARC);
}

createManagedIdentityRequest(parameters.resource);

LOG.info("[Managed Identity] Adding authorization header to the request.");

managedIdentityRequest.headers.put("Authorization", authHeaderValue);

try {
response = HttpHelper.executeHttpRequest(
new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString(),
managedIdentityRequest.headers),
managedIdentityRequest.requestContext(),
serviceBundle);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT,
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR,
managedIdentitySourceType);
}

return super.handleResponse(parameters, response);
}

return super.handleResponse(parameters, response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class CloudShellManagedIdentitySource extends AbstractManagedIdentitySource{

private static final Logger LOG = LoggerFactory.getLogger(CloudShellManagedIdentitySource.class);

private final URI MSI_ENDPOINT;
private final URI msiEndpoint;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.POST;

managedIdentityRequest.headers = new HashMap<>();
Expand All @@ -33,7 +33,7 @@ public void createManagedIdentityRequest(String resource) {
private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.CLOUD_SHELL);
this.MSI_ENDPOINT = msiEndpoint;
this.msiEndpoint = msiEndpoint;

ManagedIdentityIdType idType =
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
Expand All @@ -57,28 +57,23 @@ static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBund
return null;
}

URI validatedUri = validateAndGetUri(msiEndpoint);
return validatedUri == null ? null
: new CloudShellManagedIdentitySource(msalRequest, serviceBundle, validatedUri);
return new CloudShellManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(msiEndpoint));
}

private static URI validateAndGetUri(String msiEndpoint)
{
URI endpointUri = null;

try
{
endpointUri = new URI(msiEndpoint);
URI endpointUri = new URI(msiEndpoint);
LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
return endpointUri;
}
catch (URISyntaxException ex)
{
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Cloud Shell"),
ManagedIdentitySourceType.CLOUD_SHELL);
}

LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
return endpointUri;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public ManagedIdentityResponse handleResponse(

message = message + " " + errorContentMessage;

LOG.error(String.format("Error message: %s Http status code: %s"), message, response.statusCode());
LOG.error(String.format("Error message: %s Http status code: %s", message, response.statusCode()));
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, message,
ManagedIdentitySourceType.IMDS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
AbstractManagedIdentitySource managedIdentitySource;
if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
if ((managedIdentitySource = ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else {
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class ManagedIdentityParameters implements IAcquireTokenParameters {

@Getter
String resource;

boolean forceRefresh;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ public class MsalError {
* Managed Identity endpoint is not reachable.
*/
public static final String MANAGED_IDENTITY_UNREACHABLE_NETWORK = "managed_identity_unreachable_network";

public static final String MANAGED_IDENTITY_FILE_READ_ERROR = "managed_identity_file_read_error";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.HashMap;

class ServiceFabricManagedIdentitySource extends AbstractManagedIdentitySource {

private static final Logger LOG = LoggerFactory.getLogger(ServiceFabricManagedIdentitySource.class);

private static final String SERVICE_FABRIC_MSI_API_VERSION = "2019-07-01-preview";

private final URI msiEndpoint;
private final String identityHeader;
private final ManagedIdentityIdType idType;
private final String userAssignedId;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put("secret", identityHeader);

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(SERVICE_FABRIC_MSI_API_VERSION));

if (idType == ManagedIdentityIdType.CLIENT_ID) {
LOG.info("[Managed Identity] Adding user assigned client id to the request for Service Fabric Managed Identity.");
managedIdentityRequest.queryParameters.put(Constants.MANAGED_IDENTITY_CLIENT_ID, Collections.singletonList(userAssignedId));
} else if (idType == ManagedIdentityIdType.RESOURCE_ID) {
LOG.info("[Managed Identity] Adding user assigned resource id to the request for Service Fabric Managed Identity.");
managedIdentityRequest.queryParameters.put(Constants.MANAGED_IDENTITY_RESOURCE_ID, Collections.singletonList(userAssignedId));
}
}

private ServiceFabricManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint, String identityHeader)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.SERVICE_FABRIC);
this.msiEndpoint = msiEndpoint;
this.identityHeader = identityHeader;

this.idType = ((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
this.userAssignedId = ((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getUserAssignedId();
}

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT);


if (StringHelper.isNullOrBlank(msiEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
{
LOG.info("[Managed Identity] Service fabric managed identity is unavailable.");
return null;
}

return new ServiceFabricManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(msiEndpoint), identityHeader);
}

private static URI validateAndGetUri(String msiEndpoint)
{
try
{
URI endpointUri = new URI(msiEndpoint);
LOG.info("[Managed Identity] Environment variables validation passed for Service Fabric Managed Identity. Endpoint URI: " + endpointUri);
return endpointUri;
}
catch (URISyntaxException ex)
{
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Service Fabric"),
ManagedIdentitySourceType.SERVICE_FABRIC);
}
}

}
Loading

0 comments on commit cc7f1c6

Please sign in to comment.