Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logging for failure paths as well as some various improvements to the code base. #71

Merged
merged 2 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,21 @@

if (!@this.ProtectedHeaders.TryGetValue(CoseHeaderLabel.ContentType, out CoseHeaderValue contentTypeValue))
{
Trace.TraceError($"{nameof(TryGetDetachedSignatureAlgorithm)} was called on a CoseSign1Message object({@this.GetHashCode()}) without the ContentType protected header present.");
Trace.TraceWarning($"{nameof(TryGetDetachedSignatureAlgorithm)} was called on a CoseSign1Message object({@this.GetHashCode()}) without the ContentType protected header present.");
return false;
}

string contentType = contentTypeValue.GetValueAsString();
if (string.IsNullOrEmpty(contentType))
{
Trace.TraceError($"{nameof(TryGetDetachedSignatureAlgorithm)} was called on a CoseSign1Message object({@this.GetHashCode()}) without the ContentType protected header being a string value.");
Trace.TraceWarning($"{nameof(TryGetDetachedSignatureAlgorithm)} was called on a CoseSign1Message object({@this.GetHashCode()}) without the ContentType protected header being a string value.");
return false;
}

Match mimeMatch = HashMimeTypeExtension.Match(contentType);
if (!mimeMatch.Success)
{
Trace.TraceError($"{nameof(TryGetDetachedSignatureAlgorithm)} was called on a CoseSign1Message object({@this.GetHashCode()}) with the ContentType protected header being \"{contentType}\" however it did not match the regex pattern \"{HashMimeTypeExtension}\".");
Trace.TraceWarning($"{nameof(TryGetDetachedSignatureAlgorithm)} was called on a CoseSign1Message object({@this.GetHashCode()}) with the ContentType protected header being \"{contentType}\" however it did not match the regex pattern \"{HashMimeTypeExtension}\".");
return false;
}

Expand Down Expand Up @@ -172,23 +172,23 @@
{
hasher = null;

if (!@this.TryGetDetachedSignatureAlgorithm(out HashAlgorithmName algorithmName))

Check warning

Code scanning / CodeQL

Dereferenced variable may be null Warning

Variable
this
may be null at this access as suggested by
this
null check.
{
Trace.TraceError($"{nameof(TryGetHashAlgorithm)} was called on a CoseSign1Message[{@this?.GetHashCode()}] object which did not have a valid hashing algorithm defined");
Trace.TraceWarning($"{nameof(TryGetHashAlgorithm)} was called on a CoseSign1Message[{@this?.GetHashCode()}] object which did not have a valid hashing algorithm defined");
return false;
}

if (!@this!.Content.HasValue)
{
Trace.TraceError($"{nameof(TryGetHashAlgorithm)} was called on a CoseSign1Message object which did not have a content value, unable to compute signature match.");
Trace.TraceWarning($"{nameof(TryGetHashAlgorithm)} was called on a CoseSign1Message object which did not have a content value, unable to compute signature match.");
return false;
}

// note that HashAlgorithm.Create() does not throw for names which do not properly map, null is returned.
hasher = CreateHashAlgorithmFromName(algorithmName);
if (hasher == null)
{
Trace.TraceError($"{nameof(TryGetHashAlgorithm)} was called on a CoseSign1Message object which did not have a hashing algorithm ({algorithmName.Name}) which could be instantiated.");
Trace.TraceWarning($"{nameof(TryGetHashAlgorithm)} was called on a CoseSign1Message object which did not have a hashing algorithm ({algorithmName.Name}) which could be instantiated.");
return false;
}
Debug.WriteLine($"{nameof(TryGetHashAlgorithm)} created a HashAlgorithm from Hash Algorithm Name: {algorithmName.Name}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ public void TestGetECDsaSigningKeyMethod()
[Test]
public void TestGetProtectedHeadersSuccess()
{
X509Certificate2 testCert = TestCertificateUtils.CreateCertificate(nameof(TestGetProtectedHeadersSuccess));
X509Certificate2Collection testChain = TestCertificateUtils.CreateTestChain(nameof(TestGetProtectedHeadersSuccess));
X509Certificate2Collection testChain = TestCertificateUtils.CreateTestChain(nameof(TestGetProtectedHeadersSuccess), leafFirst: true);
X509Certificate2 testCert = testChain[0];

Mock<CertificateCoseSigningKeyProvider> testObj = new(MockBehavior.Strict)
Mock <CertificateCoseSigningKeyProvider> testObj = new(MockBehavior.Strict)
{
CallBase = true
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,69 +106,4 @@ public void X509Certificate2MessageValidatorValidates()
invokedCertChain.Select(c => c.Thumbprint).SequenceEqual(testChain.Reverse().Select(c => c.Thumbprint)).Should().BeTrue();
invokedExtraCerts.Should().BeNull();
}

/// <summary>
/// Run through some basic validator tests.
/// </summary>
[Test]
public void X509Certificate2MessageValidatorFailsWhenCertFailsToBeFound()
{
// setup
X509Certificate2Collection testChain = TestCertificateUtils.CreateTestChain(nameof(X509Certificate2MessageValidatorValidates));
Mock<X509Certificate2MessageValidator> mockValidator = new(MockBehavior.Strict)
{
CallBase = true
};
Mock<ICertificateChainBuilder> mockBuilder = new(MockBehavior.Strict);
ICoseSign1MessageFactory factory = new CoseSign1MessageFactory();
X509Certificate2CoseSigningKeyProvider keyProvider = new(mockBuilder.Object, testChain.Last());
byte[] testArray = new byte[] { 1, 2, 3, 4 };
mockBuilder.Setup(x => x.Build(It.IsAny<X509Certificate2>())).Returns(true);
mockBuilder.Setup(x => x.ChainElements).Returns<List<X509Certificate2>>(null);
X509Certificate2? invokedCert = null;
List<X509Certificate2>? invokedCertChain = null;
List<X509Certificate2>? invokedExtraCerts = null;
List<CoseSign1ValidationResult>? validationResults = null;

mockValidator.Setup(m => m.TryValidate(It.IsAny<CoseSign1Message>(), out validationResults)).CallBase();
mockValidator.Setup(m => m.Validate(It.IsAny<CoseSign1Message>())).CallBase();
mockValidator.Setup(m => m.NextElement).CallBase();
mockValidator.Protected()
.Setup<CoseSign1ValidationResult>(
"ValidateMessage",
ItExpr.IsAny<CoseSign1Message>())
.CallBase();
mockValidator.Protected()
.Setup<CoseSign1ValidationResult>(
"ValidateCertificate",
ItExpr.IsAny<X509Certificate2>(),
ItExpr.IsAny<List<X509Certificate2>>(),
ItExpr.IsAny<List<X509Certificate2>>())
.Callback<X509Certificate2, List<X509Certificate2>, List<X509Certificate2>>(
(cert, certChain, extraCerts) =>
{
invokedCert = cert;
invokedCertChain = certChain;
invokedExtraCerts = extraCerts;
})
.Returns(new CoseSign1ValidationResult(typeof(X509Certificate2MessageValidator)) { PassedValidation = true });

CoseSign1Message message = factory.CreateCoseSign1Message(testArray, keyProvider, embedPayload: true, ContentTypeConstants.Cose);

mockValidator.Object.TryValidate(message, out var results).Should().BeFalse();
results.Should().NotBeNull();
results.Count.Should().Be(1);
results[0].PassedValidation.Should().BeFalse();
mockValidator.Protected()
.Verify<CoseSign1ValidationResult>(
"ValidateCertificate",
Times.Never(),
ItExpr.IsAny<X509Certificate2>(),
ItExpr.IsAny<List<X509Certificate2>>(),
ItExpr.IsAny<List<X509Certificate2>>());

invokedCert.Should().BeNull();
invokedCertChain.Should().BeNull();
invokedExtraCerts.Should().BeNull();
}
}
7 changes: 7 additions & 0 deletions CoseSign1.Certificates/CertificateCoseSigningKeyProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public CertificateCoseSigningKeyProvider(ICertificateChainBuilder? certificateCh
}

/// <inheritdoc/>
/// <exception cref="CoseSign1CertificateException">Thrown if the signing certificate thumbprint does not match the first element in the certificate chain returned by <see cref="GetCertificateChain(X509ChainSortOrder)"/>.</exception>
public CoseHeaderMap GetProtectedHeaders()
{
CoseHeaderMap protectedHeaders = new();
Expand All @@ -102,6 +103,12 @@ signingCertificate is null
//X509ChainSortOrder is based on x5Chain elements order suggested here <see cref="https://datatracker.ietf.org/doc/rfc9360/"/>.
IEnumerable<X509Certificate2> chain = GetCertificateChain(X509ChainSortOrder.LeafFirst);

// ensure the first chain element thumbprint matches the signing certificate otherwise this message will not be processable.
if (!signingCertificate.Thumbprint.Equals(chain.FirstOrDefault()?.Thumbprint ?? string.Empty))
{
throw new CoseSign1CertificateException($"The signing certificate thumprint: \"{signingCertificate.Thumbprint}\" must match the first item in the signing certificate chain list, which is found to be: \"{chain.FirstOrDefault()?.Thumbprint}\".");
}

// Encode signing cert chain
cborWriter.EncodeCertList(chain);
value = CoseHeaderValue.FromEncodedValue(cborWriter.Encode());
Expand Down
6 changes: 6 additions & 0 deletions CoseSign1.Certificates/CoseX509Thumbprint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

namespace CoseSign1.Certificates;

using System.Buffers.Text;

/// <summary>
/// Represents a COSE X509 thumbprint, which corresponds to the x5t header in a COSE signature structure.
/// This is different from an X509 certificate thumbprint, which is the SHA1 hash of the certificate.
Expand Down Expand Up @@ -63,6 +65,10 @@ public CoseX509Thumprint(X509Certificate2 cert, HashAlgorithmName hashAlgorithm)
}

#region Public Methods
private string? ToStringCache = null;
/// <inheritdoc />
public override string ToString() => ToStringCache ??= Convert.ToBase64String(Thumbprint.ToArray());

/// <summary>
/// Checks if a certificate matches this thumbprint
/// </summary>
Expand Down
89 changes: 69 additions & 20 deletions CoseSign1.Certificates/Extensions/CborReaderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

namespace CoseSign1.Certificates.Extensions;

using System.Diagnostics;

/// <summary>
/// Extensions for the <see cref="CborReader"/> class.
/// </summary>
/// <remarks>
/// Logging is done through Trace.TraceError and Debug.WriteLine.
/// </remarks>
public static class CborReaderExtensions
{
/// <summary>
Expand All @@ -20,6 +25,17 @@ public static bool TryReadCertificateSet(
ref List<X509Certificate2> certificates,
out CoseX509FormatException? ex)
{
if (reader == null)
{
throw new ArgumentNullException(nameof(reader));
}

if(certificates == null)
{
throw new ArgumentNullException(nameof(certificates));
}


ex = null;
try
{
Expand All @@ -28,6 +44,7 @@ public static bool TryReadCertificateSet(
catch (CoseX509FormatException e)
{
ex = e;
Trace.TraceWarning($"Encountered exception: {e} in {nameof(ReadCertificateSet)}, returning false.");
return false;
}
return true;
Expand All @@ -37,40 +54,72 @@ public static bool TryReadCertificateSet(
/// Loads a collection of certificates into the current CborReader.
/// </summary>
/// <param name="reader">The current CborReader.</param>
/// <param name="certificates">The certificates to read.</param>
/// <param name="certificates">The list of certificates to load with certificates from this certificate set.</param>
/// <exception cref="CoseX509FormatException">The certificate collection was not in a valid CBOR-supported format.</exception>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="reader"/> or <paramref name="certificates"/> is null.</exception>
public static void ReadCertificateSet(this CborReader reader, ref List<X509Certificate2> certificates)
{
CborReaderState peekState = reader.PeekState();
if (peekState == CborReaderState.ByteString)
if (reader == null)
{
byte[] certBytes = reader.ReadByteString();
if (certBytes.Length > 0)
{
certificates.Add(new X509Certificate2(certBytes));
}
throw new ArgumentNullException(nameof(reader));
}
else if (peekState == CborReaderState.StartArray)

if (certificates == null)
{
throw new ArgumentNullException(nameof(certificates));
}

try
{
int? certCount = reader.ReadStartArray();
for (int i = 0; i < certCount; i++)
CborReaderState peekState = reader.PeekState();
if (peekState == CborReaderState.ByteString)
{
if (reader.PeekState() != CborReaderState.ByteString)
try
{
certificates.Add(reader.ReadByteStringAsCertificate());
}
catch(CoseX509FormatException ex)
{
throw new CoseX509FormatException("Certificate array must only contain ByteString");
Trace.TraceWarning($"Failed to read certificates from CborReader: {reader.GetHashCode()} with exception: {ex}, unable to read certificate set.");
}
byte[] certBytes = reader.ReadByteString();
if (certBytes.Length > 0)
}
else if (peekState == CborReaderState.StartArray)
{
int? certCount = reader.ReadStartArray();
for (int i = 0; i < certCount; i++)
{
certificates.Add(new X509Certificate2(certBytes));
certificates.Add(reader.ReadByteStringAsCertificate());
}
reader.ReadEndArray();
}
else
{
throw new CoseX509FormatException(
"Certificate collections must be ByteString for single certificate or Array for multiple certificates");
}
reader.ReadEndArray();
}
else
catch(Exception ex) when (ex is not CoseX509FormatException)
{
throw new CoseX509FormatException(
"Certificate collections must be ByteString for single certificate or Array for multiple certificates");
throw new CoseX509FormatException(ex.Message, ex);
}
}

/// <summary>
/// Extracts a certificate from the ByteString on this <see cref="CborReader"/>.
/// </summary>
/// <param name="reader">The <see cref="CborReader"/> to extract a certificate from presuming it's on a ByteString.</param>
/// <returns>A <see cref="X509Certificate2"/> extracted from the ByteString.</returns>
/// <exception cref="CoseX509FormatException">Thrown if the <paramref name="reader"/> is not on a ByteString, or if the extract ByteString cannot be converted into a <see cref="X509Certificate2"/>.</exception>
private static X509Certificate2 ReadByteStringAsCertificate(this CborReader reader)
{
if (reader.PeekState() != CborReaderState.ByteString)
{
throw new CoseX509FormatException($"Certificate array must only contain ByteString on reader: {reader.GetHashCode()}");
}
byte[] certBytes = reader.ReadByteString();

return certBytes.Length > 0
? new X509Certificate2(certBytes)
: throw new CoseX509FormatException($"Failed to read certificate bytes from ByteString on CborReader: {reader.GetHashCode()} and convert to a certificate.");
}
}
Loading
Loading