Skip to content

Commit

Permalink
Added support to Microsoft.Extensions.Azure for configuring `managedI…
Browse files Browse the repository at this point in the history
…dentityObjectId` (#46909)
  • Loading branch information
christothes authored Oct 30, 2024
1 parent 2c437c9 commit 98b0c07
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 22 deletions.
4 changes: 2 additions & 2 deletions eng/Packages.Data.props
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
<PackageReference Update="Azure.MixedReality.Authentication" version= "1.2.0" />
<PackageReference Update="Azure.Monitor.OpenTelemetry.Exporter" Version="1.4.0-beta.2" />
<PackageReference Update="Azure.Monitor.Query" Version="1.1.0" />
<PackageReference Update="Azure.Identity" Version="1.12.0" />
<PackageReference Update="Azure.Identity" Version="1.13.1" />
<PackageReference Update="Azure.Security.KeyVault.Secrets" Version="4.6.0" />
<PackageReference Update="Azure.Security.KeyVault.Keys" Version="4.6.0" />
<PackageReference Update="Azure.Security.KeyVault.Certificates" Version="4.6.0" />
Expand Down Expand Up @@ -261,7 +261,7 @@
<ItemGroup Condition="('$(IsTestProject)' == 'true') OR ('$(IsTestSupportProject)' == 'true') OR ('$(IsPerfProject)' == 'true') OR ('$(IsStressProject)' == 'true') OR ('$(IsSamplesProject)' == 'true')">
<PackageReference Update="ApprovalTests" Version="3.0.22" />
<PackageReference Update="ApprovalUtilities" Version="3.0.22" />
<PackageReference Update="Azure.Identity" Version="1.12.0" />
<PackageReference Update="Azure.Identity" Version="1.13.1" />
<PackageReference Update="Azure.Messaging.EventGrid" Version="4.17.0" />
<PackageReference Update="Azure.Messaging.EventHubs.Processor" Version="5.11.3" />
<PackageReference Update="Azure.Messaging.ServiceBus" Version="7.18.2" />
Expand Down
4 changes: 4 additions & 0 deletions sdk/extensions/Microsoft.Extensions.Azure/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@

### Features Added

- Added support for constructing a `ManagedIdentityCredential` from config by setting the `managedIdentityObjectId` key.

### Breaking Changes

### Bugs Fixed

### Other Changes

- Updated dependency `Azure.Identity` to version `1.13.1`.

## 1.7.6 (2024-10-04)

### Other Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ internal static TokenCredential CreateCredential(IConfiguration configuration)
var clientId = configuration["clientId"];
var tenantId = configuration["tenantId"];
var resourceId = configuration["managedIdentityResourceId"];
var objectId = configuration["managedIdentityObjectId"];
var clientSecret = configuration["clientSecret"];
var certificate = configuration["clientCertificate"];
var certificateStoreName = configuration["clientCertificateStoreName"];
Expand All @@ -114,16 +115,26 @@ internal static TokenCredential CreateCredential(IConfiguration configuration)

if (string.Equals(credentialType, "managedidentity", StringComparison.OrdinalIgnoreCase))
{
if (!string.IsNullOrWhiteSpace(clientId) && !string.IsNullOrWhiteSpace(resourceId))
int idCount = 0;
idCount += string.IsNullOrWhiteSpace(clientId) ? 0 : 1;
idCount += string.IsNullOrWhiteSpace(resourceId) ? 0 : 1;
idCount += string.IsNullOrWhiteSpace(objectId) ? 0 : 1;

if (idCount > 1)
{
throw new ArgumentException("Cannot specify both 'clientId' and 'managedIdentityResourceId'");
throw new ArgumentException("Only one of either 'clientId', 'managedIdentityResourceId', or 'managedIdentityObjectId' can be specified for managed identity.");
}

if (!string.IsNullOrWhiteSpace(resourceId))
{
return new ManagedIdentityCredential(new ResourceIdentifier(resourceId));
}

if (!string.IsNullOrWhiteSpace(objectId))
{
return new ManagedIdentityCredential(ManagedIdentityId.FromUserAssignedObjectId(objectId));
}

return new ManagedIdentityCredential(clientId);
}

Expand Down Expand Up @@ -215,6 +226,11 @@ internal static TokenCredential CreateCredential(IConfiguration configuration)

// TODO: More logging

if (!string.IsNullOrWhiteSpace(objectId))
{
throw new ArgumentException("Managed identity 'objectId' is only supported when the credential type is 'managedidentity'.");
}

if (additionallyAllowedTenantsList != null
|| !string.IsNullOrWhiteSpace(tenantId)
|| !string.IsNullOrWhiteSpace(clientId)
Expand Down
114 changes: 96 additions & 18 deletions sdk/extensions/Microsoft.Extensions.Azure/tests/ClientFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ public void CreatesDefaultAzureCredential(
[Values(true, false)] bool additionalTenants,
[Values(true, false)] bool clientId,
[Values(true, false)] bool tenantId,
[Values(true, false)] bool objectId,
[Values(true, false)] bool resourceId)
{
List<KeyValuePair<string, string>> configEntries = new();
Expand All @@ -299,10 +300,16 @@ public void CreatesDefaultAzureCredential(
{
configEntries.Add(new KeyValuePair<string, string>("managedIdentityResourceId", resourceIdValue));
}
if (objectId)
{
configEntries.Add(new KeyValuePair<string, string>("managedIdentityObjectId", "objectId"));
}

IConfiguration configuration = new ConfigurationBuilder().AddInMemoryCollection(configEntries).Build();

// if both clientId and resourceId set, we expect an ArgumentException
if (clientId && resourceId)
// We also expect an exception if objectId is set for DefaultAzureCredential, as it is only supported for ManagedIdentityCredential
if ((clientId && resourceId) || objectId)
{
Assert.Throws<ArgumentException>(() => ClientFactory.CreateCredential(configuration));
return;
Expand Down Expand Up @@ -336,20 +343,27 @@ public void CreatesDefaultAzureCredential(
Assert.AreEqual("tenantId", pwshCredential.TenantId);
}

// TODO: Since these can't build with project reference, we have to comment them out for now.
// When we resolve https://github.com/Azure/azure-sdk-for-net/issues/45806, we can add them back.
//if (clientId)
//{
// Assert.AreEqual("clientId", miCredential.Client.ClientId);
//}
//if (resourceId)
//{
// Assert.AreEqual(resourceIdValue, miCredential.Client.ResourceIdentifier.ToString());
//}
string managedIdentityId;
int idType;
ReflectIdAndType(miCredential, out managedIdentityId, out idType);
if (clientId)
{
Assert.AreEqual("clientId", managedIdentityId);
Assert.AreEqual(1, idType); // 1 is the value for ClientId
}
if (resourceId)
{
Assert.AreEqual(resourceIdValue.ToString(), managedIdentityId);
Assert.AreEqual(2, idType); // 2 is the value for ResourceId
}
if (objectId)
{
Assert.AreEqual("objectId", managedIdentityId);
Assert.AreEqual(3, idType); // 3 is the value for ObjectId
}
}

[Test]
[Ignore("This test is failing, ignore it to pass CI. Tracking this in https://github.com/Azure/azure-sdk-for-net/issues/45806")]
public void CreatesManagedServiceIdentityCredentialsWithClientId()
{
IConfiguration configuration = GetConfiguration(
Expand All @@ -362,14 +376,15 @@ public void CreatesManagedServiceIdentityCredentialsWithClientId()
Assert.IsInstanceOf<ManagedIdentityCredential>(credential);
var managedIdentityCredential = (ManagedIdentityCredential)credential;

var client = (ManagedIdentityClient)typeof(ManagedIdentityCredential).GetProperty("Client", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(managedIdentityCredential);
var clientId = typeof(ManagedIdentityClient).GetProperty("ClientId", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(client);
string clientId;
int idType;
ReflectIdAndType(managedIdentityCredential, out clientId, out idType);

Assert.AreEqual("ConfigurationClientId", clientId);
Assert.AreEqual(1, idType); // 1 is the value for ClientId
}

[Test]
[Ignore("This test is failing, ignore it to pass CI. Tracking this in https://github.com/Azure/azure-sdk-for-net/issues/45806")]
public void CreatesManagedServiceIdentityCredentials()
{
IConfiguration configuration = GetConfiguration(
Expand All @@ -381,10 +396,12 @@ public void CreatesManagedServiceIdentityCredentials()
Assert.IsInstanceOf<ManagedIdentityCredential>(credential);
var managedIdentityCredential = (ManagedIdentityCredential)credential;

var client = (ManagedIdentityClient)typeof(ManagedIdentityCredential).GetProperty("Client", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(managedIdentityCredential);
var clientId = typeof(ManagedIdentityClient).GetProperty("ClientId", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(client);
string clientId;
int idType;
ReflectIdAndType(managedIdentityCredential, out clientId, out idType);

Assert.Null(clientId);
Assert.AreEqual(0, idType); // 0 is the value for SystemAssigned
}

[Test]
Expand All @@ -400,9 +417,33 @@ public void CreatesManagedServiceIdentityCredentialsWithResourceId()
Assert.IsInstanceOf<ManagedIdentityCredential>(credential);
var managedIdentityCredential = (ManagedIdentityCredential)credential;

var resourceId = (string)typeof(ManagedIdentityCredential).GetField("_clientId", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(managedIdentityCredential);
string resourceId;
int idType;
ReflectIdAndType(managedIdentityCredential, out resourceId, out idType);

Assert.AreEqual("ConfigurationResourceId", resourceId);
Assert.AreEqual(2, idType); // 2 is the value for ResourceId
}

[Test]
public void CreatesManagedServiceIdentityCredentialsWithObjectId()
{
IConfiguration configuration = GetConfiguration(
new KeyValuePair<string, string>("managedIdentityObjectId", "ConfigurationObjectId"),
new KeyValuePair<string, string>("credential", "managedidentity")
);

var credential = ClientFactory.CreateCredential(configuration);

Assert.IsInstanceOf<ManagedIdentityCredential>(credential);
var managedIdentityCredential = (ManagedIdentityCredential)credential;

string objectId;
int idType;
ReflectIdAndType(managedIdentityCredential, out objectId, out idType);

Assert.AreEqual("ConfigurationObjectId", objectId);
Assert.AreEqual(3, idType); // 3 is the value for ObjectId
}

[Test]
Expand All @@ -419,6 +460,34 @@ public void CreatesManagedServiceIdentityCredentialsThrowsWhenResourceIdAndClien
Throws.InstanceOf<ArgumentException>().With.Message.Contains("managedIdentityResourceId"));
}

[Test]
public void CreatesManagedServiceIdentityCredentialsThrowsWhenClientIdAndObjectIdSpecified()
{
IConfiguration configuration = GetConfiguration(
new KeyValuePair<string, string>("managedIdentityObjectId", "ConfigurationObjectId"),
new KeyValuePair<string, string>("clientId", "ConfigurationClientId"),
new KeyValuePair<string, string>("credential", "managedidentity")
);

Assert.That(
() => ClientFactory.CreateCredential(configuration),
Throws.InstanceOf<ArgumentException>().With.Message.Contains("managedIdentityResourceId"));
}

[Test]
public void CreatesManagedServiceIdentityCredentialsThrowsWhenResourceIdAndObjectIdSpecified()
{
IConfiguration configuration = GetConfiguration(
new KeyValuePair<string, string>("managedIdentityObjectId", "ConfigurationObjectId"),
new KeyValuePair<string, string>("managedIdentityResourceId", "ConfigurationResourceId"),
new KeyValuePair<string, string>("credential", "managedidentity")
);

Assert.That(
() => ClientFactory.CreateCredential(configuration),
Throws.InstanceOf<ArgumentException>().With.Message.Contains("managedIdentityResourceId"));
}

[Test]
public void CreatesWorkloadIdentityCredentialsWithOptions()
{
Expand Down Expand Up @@ -561,5 +630,14 @@ private IConfiguration GetConfiguration(params KeyValuePair<string, string>[] it
{
return new ConfigurationBuilder().AddInMemoryCollection(items).Build();
}

private static void ReflectIdAndType(ManagedIdentityCredential managedIdentityCredential, out string clientId, out int idType)
{
var managedIdentityClient = typeof(ManagedIdentityCredential).GetProperty("Client", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(managedIdentityCredential);
var managedIdentityClientOptions = managedIdentityClient.GetType().GetField("_options", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(managedIdentityClient);
var managedIdentityId = managedIdentityClientOptions.GetType().GetProperty("ManagedIdentityId").GetValue(managedIdentityClientOptions);
clientId = (string)typeof(ManagedIdentityId).GetField("_userAssignedId", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(managedIdentityId);
idType = (int)typeof(ManagedIdentityId).GetField("_idType", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(managedIdentityId);
}
}
}

0 comments on commit 98b0c07

Please sign in to comment.