diff --git a/coordinator/main.go b/coordinator/main.go index 100b4dc52e..eff66e51c5 100644 --- a/coordinator/main.go +++ b/coordinator/main.go @@ -14,7 +14,7 @@ import ( "github.com/edgelesssys/contrast/coordinator/history" "github.com/edgelesssys/contrast/coordinator/internal/authority" - "github.com/edgelesssys/contrast/internal/attestation" + "github.com/edgelesssys/contrast/internal/atls" "github.com/edgelesssys/contrast/internal/grpc/atlscredentials" "github.com/edgelesssys/contrast/internal/logger" "github.com/edgelesssys/contrast/internal/meshapi" @@ -135,7 +135,7 @@ func newServerMetrics(reg *prometheus.Registry) *grpcprometheus.ServerMetrics { } func newGRPCServer(serverMetrics *grpcprometheus.ServerMetrics, log *slog.Logger) (*grpc.Server, error) { - issuer, err := attestation.PlatformIssuer(log) + issuer, err := atls.PlatformIssuer(log) if err != nil { return nil, fmt.Errorf("creating issuer: %w", err) } diff --git a/initializer/main.go b/initializer/main.go index 075cd4d1af..5d423f7e65 100644 --- a/initializer/main.go +++ b/initializer/main.go @@ -17,7 +17,6 @@ import ( "time" "github.com/edgelesssys/contrast/internal/atls" - "github.com/edgelesssys/contrast/internal/attestation" "github.com/edgelesssys/contrast/internal/grpc/dialer" "github.com/edgelesssys/contrast/internal/logger" "github.com/edgelesssys/contrast/internal/meshapi" @@ -55,7 +54,7 @@ func run() (retErr error) { return fmt.Errorf("generating key: %w", err) } - issuer, err := attestation.PlatformIssuer(log) + issuer, err := atls.PlatformIssuer(log) if err != nil { return fmt.Errorf("creating issuer: %w", err) } diff --git a/internal/atls/atls.go b/internal/atls/atls.go index 250991dde2..4b9490343c 100644 --- a/internal/atls/atls.go +++ b/internal/atls/atls.go @@ -21,6 +21,7 @@ import ( "math/big" "time" + "github.com/edgelesssys/contrast/internal/attestation" "github.com/edgelesssys/contrast/internal/crypto" ) @@ -31,6 +32,11 @@ var ( NoValidator Validator // NoIssuer skips embedding the client's attestation document. NoIssuer Issuer + + // ErrNoValidAttestationExtensions is returned when no valid attestation document certificate extensions are found. + ErrNoValidAttestationExtensions = errors.New("no valid attestation document certificate extensions found") + // ErrNoMatchingValidators is returned when no validator matches the attestation document. + ErrNoMatchingValidators = errors.New("no matching validators found") ) // CreateAttestationServerTLSConfig creates a tls.Config object with a self-signed certificate and an embedded attestation document. @@ -205,19 +211,56 @@ func processCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) (*x509.Certi } // verifyEmbeddedReport verifies an aTLS certificate by validating the attestation document embedded in the TLS certificate. -func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, peerPublicKey, nonce []byte) error { +// +// It will check against all applicable validator for the type of attestation document, and return success on the first match. +func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, peerPublicKey, nonce []byte) (retErr error) { + // For better error reporting, let's keep track of whether we've found a valid extension at all.. + var foundExtension bool + // .. and whether we've found a matching validator. + var foundMatchingValidator bool + + // We'll need to have a look at all extensions in the certificate to find the attestation document. for _, ex := range cert.Extensions { + // Optimization: Skip the extension early before heading into the m*n complexity of the validator check + // if the extension is not an attestation document. + if !attestation.IsAttestationDocumentExtension(ex.Id) { + continue + } + + // We have a valid attestation document. Let's check it against all applicable validators. + foundExtension = true for _, validator := range validators { - if ex.Id.Equal(validator.OID()) { - ctx, cancel := context.WithTimeout(context.Background(), attestationTimeout) - defer cancel() + // Optimization: Skip the validator if it doesn't match the attestation type of the document. + if !ex.Id.Equal(validator.OID()) { + continue + } + + // We've found a matching validator. Let's validate the document. + foundMatchingValidator = true - return validator.Validate(ctx, ex.Value, nonce, peerPublicKey) + ctx, cancel := context.WithTimeout(context.Background(), attestationTimeout) + defer cancel() + + validationErr := validator.Validate(ctx, ex.Value, nonce, peerPublicKey) + if validationErr == nil { + // The validator has successfully verified the document. We can exit. + return nil } + // Otherwise, we'll keep track of the error and continue with the next validator. + retErr = errors.Join(retErr, fmt.Errorf("validator %s failed: %w", validator.OID(), validationErr)) } } - return errors.New("certificate does not contain attestation document") + if !foundExtension { + return ErrNoValidAttestationExtensions + } + + if !foundMatchingValidator { + return ErrNoMatchingValidators + } + + // If we're here, an error must've happened during validation. + return retErr } // encodeNonceToCertPool returns a cert pool that contains a certificate whose CN is the base64-encoded nonce. diff --git a/internal/atls/atls_test.go b/internal/atls/atls_test.go new file mode 100644 index 0000000000..02c5ea8353 --- /dev/null +++ b/internal/atls/atls_test.go @@ -0,0 +1,131 @@ +// Copyright 2024 Edgeless Systems GmbH +// SPDX-License-Identifier: AGPL-3.0-only + +package atls + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/json" + "testing" + + "github.com/edgelesssys/contrast/internal/oid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestVerifyEmbeddedReport(t *testing.T) { + fakeAttDoc := FakeAttestationDoc{} + attDocBytes, err := json.Marshal(fakeAttDoc) + assert.NoError(t, err) + + testCases := map[string]struct { + cert *x509.Certificate + validators []Validator + wantErr bool + targetErr error + }{ + "success": { + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: oid.RawTDXReport, + }, + { + Id: oid.RawSNPReport, + Value: attDocBytes, + }, + }, + }, + validators: NewFakeValidators(stubSNPValidator{}), + }, + "multiple matches": { + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: oid.RawSNPReport, + Value: []byte("foo"), + }, + { + Id: oid.RawSNPReport, + Value: attDocBytes, + }, + }, + }, + validators: NewFakeValidators(stubSNPValidator{}), + }, + "skip non-matching validator": { + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: []int{4, 5, 6}, + }, + { + Id: oid.RawSNPReport, + Value: attDocBytes, + }, + }, + }, + validators: append(NewFakeValidators(stubSNPValidator{}), NewFakeValidator(stubFooValidator{})), + }, + "match, error": { + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: oid.RawSNPReport, + Value: []byte("foo"), + }, + }, + }, + validators: NewFakeValidators(stubSNPValidator{}), + wantErr: true, + }, + "no extensions": { + cert: &x509.Certificate{}, + validators: nil, + targetErr: ErrNoValidAttestationExtensions, + wantErr: true, + }, + "no matching validator": { + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: oid.RawSNPReport, + }, + }, + }, + validators: nil, + targetErr: ErrNoMatchingValidators, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + err := verifyEmbeddedReport(tc.validators, tc.cert, nil, nil) + if tc.wantErr { + require.Error(err) + if tc.targetErr != nil { + assert.ErrorIs(err, tc.targetErr) + } + } else { + require.NoError(err) + } + }) + } +} + +type stubSNPValidator struct{} + +func (v stubSNPValidator) OID() asn1.ObjectIdentifier { + return oid.RawSNPReport +} + +type stubFooValidator struct{} + +func (v stubFooValidator) OID() asn1.ObjectIdentifier { + return []int{1, 2, 3} +} diff --git a/internal/attestation/issuer.go b/internal/atls/issuer.go similarity index 86% rename from internal/attestation/issuer.go rename to internal/atls/issuer.go index 76a1953bff..09ac4a192d 100644 --- a/internal/attestation/issuer.go +++ b/internal/atls/issuer.go @@ -1,13 +1,12 @@ // Copyright 2024 Edgeless Systems GmbH // SPDX-License-Identifier: AGPL-3.0-only -package attestation +package atls import ( "fmt" "log/slog" - "github.com/edgelesssys/contrast/internal/atls" "github.com/edgelesssys/contrast/internal/attestation/snp" "github.com/edgelesssys/contrast/internal/attestation/tdx" "github.com/edgelesssys/contrast/internal/logger" @@ -15,7 +14,7 @@ import ( ) // PlatformIssuer creates an attestation issuer for the current platform. -func PlatformIssuer(log *slog.Logger) (atls.Issuer, error) { +func PlatformIssuer(log *slog.Logger) (Issuer, error) { cpuid.Detect() switch { case cpuid.CPU.Supports(cpuid.SEV_SNP): diff --git a/internal/attestation/oid.go b/internal/attestation/oid.go new file mode 100644 index 0000000000..97b436f32a --- /dev/null +++ b/internal/attestation/oid.go @@ -0,0 +1,16 @@ +// Copyright 2024 Edgeless Systems GmbH +// SPDX-License-Identifier: AGPL-3.0-only + +package attestation + +import ( + "encoding/asn1" + + oids "github.com/edgelesssys/contrast/internal/oid" +) + +// IsAttestationDocumentExtension checks whether the given OID corresponds to an attestation document extension +// supported by Contrast (i.e. TDX or SNP). +func IsAttestationDocumentExtension(oid asn1.ObjectIdentifier) bool { + return oid.Equal(oids.RawTDXReport) || oid.Equal(oids.RawSNPReport) +}