Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue where x5c header is not sent for OnBehalfOfCredential when SendCertificateChain option is set #27721

Merged
merged 4 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ public virtual async ValueTask<AuthenticationResult> AcquireTokenOnBehalfOf(
{
IConfidentialClientApplication client = await GetClientAsync(async, cancellationToken).ConfigureAwait(false);

var builder = client.AcquireTokenOnBehalfOf(scopes, userAssertionValue);
var builder = client
.AcquireTokenOnBehalfOf(scopes, userAssertionValue)
.WithSendX5C(_includeX5CClaimHeader);

if (!string.IsNullOrEmpty(tenantId))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Azure.Core.TestFramework;
using Azure.Identity.Tests.Mock;
using NUnit.Framework;
Expand Down Expand Up @@ -159,5 +160,32 @@ public async Task UsesTenantIdHint(

Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
}

[Test]
public async Task SendCertificateChain([Values(true, false)] bool usePemFile, [Values(true)] bool sendCertChain)
{
TestSetup();
var _transport = Createx5cValidatingTransport(sendCertChain);
var _pipeline = new HttpPipeline(_transport, new[] {new BearerTokenAuthenticationPolicy(new MockCredential(), "scope")});
var context = new TokenRequestContext(new[] { Scope }, tenantId: TenantId);
expectedTenantId = TenantIdResolver.Resolve(TenantId, context);
var certificatePath = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pfx");
var certificatePathPem = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pem");
var mockCert = new X509Certificate2(certificatePath);
options = new ClientCertificateCredentialOptions();
((ClientCertificateCredentialOptions)options).SendCertificateChain = sendCertChain;

ClientCertificateCredential credential = InstrumentClient(
usePemFile
? new ClientCertificateCredential(TenantId, ClientId, certificatePathPem, options,
new CredentialPipeline(new Uri("https://localhost"), _pipeline, new ClientDiagnostics(options)), null)
: new ClientCertificateCredential(TenantId, ClientId, mockCert, options,
new CredentialPipeline(new Uri("https://localhost"), _pipeline, new ClientDiagnostics(options)), null)
);

var token = await credential.GetTokenAsync(context);

Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
}
}
}
94 changes: 83 additions & 11 deletions sdk/identity/Azure.Identity/tests/CredentialTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
using System.Globalization;
using System.IO;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.TestFramework;
using Azure.Identity.Tests.Mock;
using Microsoft.Identity.Client;
Expand Down Expand Up @@ -36,8 +38,12 @@ public class CredentialTestBase : ClientTestBase
protected string expectedCode;
protected DeviceCodeResult deviceCodeResult;

protected const string DiscoveryResponseBody =
"{\"tenant_discovery_endpoint\": \"https://login.microsoftonline.com/c54fac88-3dd3-461f-a7c4-8a368e0340b3/v2.0/.well-known/openid-configuration\",\"api-version\": \"1.1\",\"metadata\":[{\"preferred_network\": \"login.microsoftonline.com\",\"preferred_cache\": \"login.windows.net\",\"aliases\":[\"login.microsoftonline.com\",\"login.windows.net\",\"login.microsoft.com\",\"sts.windows.net\"]},{\"preferred_network\": \"login.partner.microsoftonline.cn\",\"preferred_cache\": \"login.partner.microsoftonline.cn\",\"aliases\":[\"login.partner.microsoftonline.cn\",\"login.chinacloudapi.cn\"]},{\"preferred_network\": \"login.microsoftonline.de\",\"preferred_cache\": \"login.microsoftonline.de\",\"aliases\":[\"login.microsoftonline.de\"]},{\"preferred_network\": \"login.microsoftonline.us\",\"preferred_cache\": \"login.microsoftonline.us\",\"aliases\":[\"login.microsoftonline.us\",\"login.usgovcloudapi.net\"]},{\"preferred_network\": \"login-us.microsoftonline.com\",\"preferred_cache\": \"login-us.microsoftonline.com\",\"aliases\":[\"login-us.microsoftonline.com\"]}]}";

public CredentialTestBase(bool isAsync) : base(isAsync)
{ }
{
}

public void TestSetup()
{
Expand All @@ -57,7 +63,7 @@ public void TestSetup()
TenantId,
new MockAccount("username"),
null,
new[] { Scope },
new[] {Scope},
Guid.NewGuid(),
null,
"Bearer");
Expand Down Expand Up @@ -103,7 +109,7 @@ public void TestSetup()
TenantId,
new MockAccount("username"),
null,
new[] { Scope },
new[] {Scope},
Guid.NewGuid(),
null,
"Bearer");
Expand Down Expand Up @@ -150,6 +156,7 @@ protected async Task<string> ReadMockRequestContent(MockRequest request)
{
return null;
}

using var memoryStream = new MemoryStream();
request.Content.WriteTo(memoryStream, CancellationToken.None);
memoryStream.Position = 0;
Expand All @@ -159,7 +166,8 @@ protected async Task<string> ReadMockRequestContent(MockRequest request)
}
}

protected MockResponse CreateMockMsalTokenResponse(int responseCode, string token, string tenantId, string userName)
protected MockResponse CreateMockMsalTokenResponse(int responseCode, string token, string tenantId,
string userName)
{
var response = new MockResponse(responseCode);
var idToken = CreateMsalIdToken(Guid.NewGuid().ToString(), userName, tenantId);
Expand Down Expand Up @@ -190,7 +198,7 @@ public static string CreateMsalIdToken(string uniqueId, string displayableId, st
return string.Format(CultureInfo.InvariantCulture, "someheader.{0}.somesignature", MsalEncode(id));
}

private const char base64PadCharacter = '=';
private const char base64PadCharacter = '=';
#if NET45
private const string doubleBase64PadCharacter = "==";
#endif
Expand All @@ -204,11 +212,9 @@ public static string CreateMsalIdToken(string uniqueId, string displayableId, st
/// </summary>
internal static readonly char[] s_base64Table =
{
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z',
'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
'0','1','2','3','4','5','6','7','8','9',
base64UrlCharacter62,
base64UrlCharacter63
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y',
'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', base64UrlCharacter62, base64UrlCharacter63
};

/// <summary>
Expand Down Expand Up @@ -302,7 +308,7 @@ private static string MsalEncode(byte[] inArray, int offset, int length)
}
break;

//default or case 0: no further operations are needed.
//default or case 0: no further operations are needed.
}

return new string(output, 0, j);
Expand All @@ -323,5 +329,71 @@ public static string MsalEncode(byte[] inArray)

return MsalEncode(inArray, 0, inArray.Length);
}

protected bool RequestBodyHasUserAssertionWithHeader(Request req, string headerName)
{
req.Content.TryComputeLength(out var len);
byte[] content = new byte[len];
var stream = new MemoryStream((int)len);
req.Content.WriteTo(stream, default);
var body = Encoding.UTF8.GetString(stream.GetBuffer(), 0, (int)stream.Length);
var parts = body.Split('&');
foreach (var part in parts)
{
if (part.StartsWith("client_assertion="))
{
var assertion = part.AsSpan();
int start = assertion.IndexOf('=') + 1;
assertion = assertion.Slice(start);
int end = assertion.IndexOf('.');
var jwt = assertion.Slice(0, end);
string convertedToken = jwt.ToString().Replace('_', '/').Replace('-', '+');
switch (jwt.Length % 4)
{
case 2:
convertedToken += "==";
break;
case 3:
convertedToken += "=";
break;
}

Utf8JsonReader reader = new Utf8JsonReader(Convert.FromBase64String(convertedToken));
while (reader.Read())
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
var header = reader.GetString();
if (header == headerName)
{
return true;
}

reader.Read();
}
}
}
}

return false;
}

protected MockTransport Createx5cValidatingTransport(bool sendCertChain) => new MockTransport((req) =>
{
// respond to tenant discovery
if (req.Uri.Path.StartsWith("/common/discovery"))
{
return new MockResponse(200).SetContent(DiscoveryResponseBody);
}

// respond to token request
if (req.Uri.Path.EndsWith("/token"))
{
Assert.That(sendCertChain, Is.EqualTo(RequestBodyHasUserAssertionWithHeader(req, "x5c")));
return new MockResponse(200).WithContent(
$"{{\"token_type\": \"Bearer\",\"expires_in\": 9999,\"ext_expires_in\": 9999,\"access_token\": \"{expectedToken}\" }}");
}
return new MockResponse(200);
});
}
}
73 changes: 59 additions & 14 deletions sdk/identity/Azure.Identity/tests/OnBehalfOfCredentialTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
// Licensed under the MIT License.

using System;
using System.IO;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Azure.Core.TestFramework;
using Azure.Identity.Tests.Mock;
using Microsoft.Diagnostics.Tracing.Parsers.AspNet;
using NUnit.Framework;

namespace Azure.Identity.Tests
Expand All @@ -25,27 +31,41 @@ public void CtorValidation()
string userAssertion = Guid.NewGuid().ToString();
string clientSecret = Guid.NewGuid().ToString();

Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, clientSecret, userAssertion, null));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, clientSecret, userAssertion, null));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, clientSecret, null, null));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(null, ClientId, clientSecret, userAssertion, null));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(TenantId, null, clientSecret, userAssertion, null));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(TenantId, ClientId, clientSecret, null, null));
cred = new OnBehalfOfCredential(TenantId, ClientId, clientSecret, userAssertion, null);
// Assert
Assert.AreEqual(clientSecret, cred._client._clientSecret);

Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, _mockCertificate, userAssertion));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, _mockCertificate, userAssertion));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(null, ClientId, _mockCertificate, userAssertion));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(TenantId, null, _mockCertificate, userAssertion));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null));
cred = new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, userAssertion);
// Assert
Assert.NotNull(cred._client._certificateProvider);

Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, _mockCertificate, userAssertion, new OnBehalfOfCredentialOptions()));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, _mockCertificate, userAssertion, new OnBehalfOfCredentialOptions()));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, default(X509Certificate2), userAssertion, new OnBehalfOfCredentialOptions()));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null, new OnBehalfOfCredentialOptions()));
cred = new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, userAssertion, new OnBehalfOfCredentialOptions());
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, _mockCertificate,
userAssertion, new OnBehalfOfCredentialOptions()));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, _mockCertificate,
userAssertion, new OnBehalfOfCredentialOptions()));
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId,
default(X509Certificate2), userAssertion, new OnBehalfOfCredentialOptions()));
Assert.Throws<ArgumentNullException>(() =>
new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null,
new OnBehalfOfCredentialOptions()));
cred = new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, userAssertion,
new OnBehalfOfCredentialOptions());
// Assert
Assert.NotNull(cred._client._certificateProvider);
}
Expand All @@ -58,7 +78,7 @@ public async Task UsesTenantIdHint(
{
TestSetup();
options = new OnBehalfOfCredentialOptions();
var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId);
var context = new TokenRequestContext(new[] {Scope}, tenantId: tenantId);
expectedTenantId = TenantIdResolver.Resolve(explicitTenantId, context);
OnBehalfOfCredential client = InstrumentClient(
new OnBehalfOfCredential(
Expand All @@ -73,5 +93,30 @@ public async Task UsesTenantIdHint(
var token = await client.GetTokenAsync(new TokenRequestContext(MockScopes.Default), default);
Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
}

[Test]
public async Task SendCertificateChain([Values(true, false)] bool sendCertChain)
{
TestSetup();
var _transport = Createx5cValidatingTransport(sendCertChain);
var _pipeline = new HttpPipeline(_transport, new[] {new BearerTokenAuthenticationPolicy(new MockCredential(), "scope")});
var certificatePath = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pfx");
var mockCert = new X509Certificate2(certificatePath);

options = new OnBehalfOfCredentialOptions();
((OnBehalfOfCredentialOptions)options).SendCertificateChain = sendCertChain;
OnBehalfOfCredential client = InstrumentClient(
new OnBehalfOfCredential(
TenantId,
ClientId,
mockCert,
expectedUserAssertion,
options as OnBehalfOfCredentialOptions,
new CredentialPipeline(new Uri("https://localhost"), _pipeline, new ClientDiagnostics(options)),
null));

var token = await client.GetTokenAsync(new TokenRequestContext(MockScopes.Default), default);
Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
}
}
}