diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 9633308a..05b370ac 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -34,8 +34,6 @@ steps: arguments: 'github.com/AlekSi/gocov-xml@latest' workingDirectory: '$(System.DefaultWorkingDirectory)' -#Your build pipeline references an undefined variables named SQLPASSWORD and AZURESERVER_DSN. -#Create or edit the build pipeline for this YAML file, define the variable on the Variables tab. See https://go.microsoft.com/fwlink/?linkid=865972 - task: Docker@2 displayName: 'Run SQL 2017 docker image' @@ -54,6 +52,11 @@ steps: SQLPASSWORD: $(SQLPASSWORD) AZURESERVER_DSN: $(AZURESERVER_DSN) SQLSERVER_DSN: $(SQLSERVER_DSN) + AZURE_CLIENT_SECRET: $(AZURE_CLIENT_SECRET) + KEY_VAULT_NAME: $(KEY_VAULT_NAME) + AZURE_TENANT_ID: $(AZURE_TENANT_ID) + AZURE_CLIENT_ID: $(AZURE_CLIENT_ID) + COLUMNENCRYPTION: 1 continueOnError: true - task: PublishTestResults@2 displayName: "Publish junit-style results" diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e814692..ceabdbb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## 1.6.0 + +### Changed + +* Go.mod updated to Go 1.17 +* Azure SDK for Go dependencies updated + +### Features + +* Always Encrypted encryption and decryption with 2 hour key cache (#116) +* 'pfx', 'MSSQL_CERTIFICATE_STORE', and 'AZURE_KEY_VAULT' encryption key providers + ## 1.5.0 ### Features diff --git a/README.md b/README.md index 191512db..5b5f363d 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ## Install -Requires Go 1.16 or above. +Requires Go 1.17 or above. Install with `go install github.com/microsoft/go-mssqldb@latest`. @@ -425,6 +425,13 @@ The `MSSQL_CERTIFICATE_STORE` provider exposes its instance as the variable `Win Both providers can be constrained to an allowed list of encryption key paths by appending paths to `provider.AllowedLocations`. + +### Azure Key Vault (AZURE_KEY_VAULT) key provider + +Import this provider using `github.com/microsoft/go-mssqldb/aecmk/akv` + +Constrain the provider to an allowed list of key vaults by appending vault host strings like "mykeyvault.vault.azure.net" to `akv.KeyProvider.AllowedLocations`. + ## Important Notes diff --git a/aecmk/akv/keyprovider.go b/aecmk/akv/keyprovider.go new file mode 100644 index 00000000..cecd40c0 --- /dev/null +++ b/aecmk/akv/keyprovider.go @@ -0,0 +1,264 @@ +//go:build go1.18 +// +build go1.18 + +package akv + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "fmt" + "math/big" + "net/url" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" + "github.com/microsoft/go-mssqldb/aecmk" + ae "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg" + "golang.org/x/text/encoding/unicode" +) + +const ( + wildcard = "*" +) + +// Provider implements a column encryption key provider backed by Azure Key Vault +type Provider struct { + // AllowedLocations constrains which locations the provider will use to find certificates. If empty, all locations are allowed. + // When presented with a key store path whose endpoint not in the allowed list, the data will be returned still encrypted. + AllowedLocations []string + credentials map[string]azcore.TokenCredential +} + +type keyData struct { + publicKey *rsa.PublicKey + endpoint string + name string + version string +} + +// SetCertificateCredential stores the AzureCredential associated with the given AKV endpoint. +// If endpoint is empty the given credential applies to all endpoints that have not been explicitly assigned a value. +// If SetCertificateCredential is never called, the provider uses azidentity.DefaultAzureCredential. +func (p Provider) SetCertificateCredential(endpoint string, credential azcore.TokenCredential) { + if endpoint == "" { + endpoint = wildcard + } + p.credentials[endpoint] = credential +} + +var KeyProvider = Provider{credentials: make(map[string]azcore.TokenCredential), AllowedLocations: make([]string, 0)} + +func init() { + err := aecmk.RegisterCekProvider(aecmk.AzureKeyVaultKeyProvider, &KeyProvider) + if err != nil { + panic(err) + } +} + +// DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. +// The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. +func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { + decryptedKey = nil + keyData := p.getKeyData(masterKeyPath) + if keyData == nil { + return + } + keySize := keyData.publicKey.Size() + cekv := ae.LoadCEKV(encryptedCek) + if cekv.Version != 1 { + panic(fmt.Errorf("Invalid version byte in encrypted key")) + } + if keySize != len(cekv.Ciphertext) { + panic(fmt.Errorf("Encrypted key has wrong ciphertext length")) + } + if keySize != len(cekv.SignedHash) { + panic(fmt.Errorf("Encrypted key signature length mismatch")) + } + if !cekv.VerifySignature(keyData.publicKey) { + panic(fmt.Errorf("Invalid signature hash")) + } + + client := p.getAKVClient(keyData.endpoint) + algorithm := getAlgorithm(encryptionAlgorithm) + parameters := azkeys.KeyOperationParameters{ + Algorithm: &algorithm, + Value: cekv.Ciphertext, + } + r, err := client.UnwrapKey(context.Background(), keyData.name, keyData.version, parameters, nil) + if err != nil { + panic(fmt.Errorf("Unable to decrypt key %s: %w", masterKeyPath, err)) + } + decryptedKey = r.Result + return +} + +// EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. +func (p *Provider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { + keyData := p.getKeyData(masterKeyPath) + // just validate the algorith + _ = getAlgorithm(encryptionAlgorithm) + keySize := keyData.publicKey.Size() + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder() + // Start with version byte == 1 + buf := []byte{byte(1)} + // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature + // version + keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath))) + if err != nil { + panic(fmt.Errorf("Unable to serialize key path %w", err)) + } + k := uint16(len(keyPathBytes)) + // keyPathLength + buf = append(buf, byte(k), byte(k>>8)) + + cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, keyData.publicKey, cek, []byte{}) + if err != nil { + panic(fmt.Errorf("Unable to encrypt data %w", err)) + } + l := uint16(len(cipherText)) + // ciphertextLength + buf = append(buf, byte(l), byte(l>>8)) + // keypath + buf = append(buf, keyPathBytes...) + // ciphertext + buf = append(buf, cipherText...) + hash := sha256.Sum256(buf) + client := p.getAKVClient(keyData.endpoint) + signAlgorithm := azkeys.SignatureAlgorithmRS256 + parameters := azkeys.SignParameters{ + Algorithm: &signAlgorithm, + Value: hash[:], + } + r, err := client.Sign(context.Background(), keyData.name, keyData.version, parameters, nil) + if err != nil { + panic(err) + } + if len(r.Result) != keySize { + panic("Signature length doesn't match certificate key size") + } + // signature + buf = append(buf, r.Result...) + return buf +} + +// SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key +// referenced by the masterKeyPath parameter. The input values used to generate the signature should be the +// specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. +func (p *Provider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte { + return nil +} + +// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key +// with the specified key path and the specified enclave behavior. Return nil if not supported. +func (p *Provider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool { + return nil +} + +// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. +// If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. +// If it returns zero, the keys will not be cached. +func (p *Provider) KeyLifetime() *time.Duration { + return nil +} + +func getAlgorithm(encryptionAlgorithm string) (algorithm azkeys.EncryptionAlgorithm) { + // support both RSA_OAEP and RSA-OAEP + if strings.EqualFold(encryptionAlgorithm, aecmk.KeyEncryptionAlgorithm) { + encryptionAlgorithm = string(azkeys.EncryptionAlgorithmRSAOAEP) + } + if !strings.EqualFold(encryptionAlgorithm, string(azkeys.EncryptionAlgorithmRSAOAEP)) { + panic(fmt.Errorf("Unsupported encryption algorithm %s", encryptionAlgorithm)) + } + return azkeys.EncryptionAlgorithmRSAOAEP +} + +// masterKeyPath is a full URL. The AKV client requires it broken down into endpoint, name, and version +// The URL has format '{endpoint}/{host}/keys/{name}/[{version}/]' +func (p *Provider) getKeyData(masterKeyPath string) *keyData { + endpoint, keypath, allowed := p.allowedPathAndEndpoint(masterKeyPath) + if !(allowed) { + return nil + } + k := &keyData{ + endpoint: endpoint, + name: keypath[0], + } + if len(keypath) > 1 { + k.version = keypath[1] + } + client := p.getAKVClient(endpoint) + r, err := client.GetKey(context.Background(), k.name, k.version, nil) + if err != nil { + panic(fmt.Errorf("Unable to get key from AKV %w", err)) + } + if r.Key.Kty == nil || (*r.Key.Kty != azkeys.KeyTypeRSA && *r.Key.Kty != azkeys.KeyTypeRSAHSM) { + panic(fmt.Errorf("Key type not supported for Always Encrypted")) + } + k.publicKey = &rsa.PublicKey{ + N: new(big.Int).SetBytes(r.Key.N), + E: int(new(big.Int).SetBytes(r.Key.E).Int64()), + } + return k +} + +func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string, keypath []string, allowed bool) { + allowed = len(p.AllowedLocations) == 0 + url, err := url.Parse(masterKeyPath) + if err != nil { + panic(fmt.Errorf("Invalid URL for master key path %s: %w", masterKeyPath, err)) + } + if !allowed { + + loop: + for _, l := range p.AllowedLocations { + if strings.HasSuffix(strings.ToLower(url.Host), strings.ToLower(l)) { + allowed = true + break loop + } + } + } + if allowed { + pathParts := strings.Split(strings.TrimLeft(url.Path, "/"), "/") + if len(pathParts) < 2 || len(pathParts) > 3 || pathParts[0] != "keys" { + panic(fmt.Errorf("Invalid URL for master key path %s", masterKeyPath)) + } + keypath = pathParts[1:] + url.Path = "" + url.RawQuery = "" + url.Fragment = "" + endpoint = url.String() + } + return +} + +func (p *Provider) getAKVClient(endpoint string) (client *azkeys.Client) { + client, err := azkeys.NewClient(endpoint, p.getCredential(endpoint), nil) + if err != nil { + panic(fmt.Errorf("Unable to create AKV client %w", err)) + } + return +} + +func (p *Provider) getCredential(endpoint string) azcore.TokenCredential { + if len(p.credentials) == 0 { + credential, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + panic(fmt.Errorf("Unable to create a default credential: %w", err)) + } + p.credentials[wildcard] = credential + return credential + } + if credential, ok := p.credentials[endpoint]; ok { + return credential + } + if credential, ok := p.credentials[wildcard]; ok { + return credential + } + panic(fmt.Errorf("No credential available for AKV path %s", endpoint)) +} diff --git a/aecmk/akv/keyprovider_test.go b/aecmk/akv/keyprovider_test.go new file mode 100644 index 00000000..f16f826a --- /dev/null +++ b/aecmk/akv/keyprovider_test.go @@ -0,0 +1,34 @@ +//go:build go1.18 +// +build go1.18 + +package akv + +import ( + "crypto/rand" + "net/url" + "testing" + + "github.com/microsoft/go-mssqldb/aecmk" + "github.com/microsoft/go-mssqldb/internal/akvkeys" + "github.com/stretchr/testify/assert" +) + +func TestEncryptDecryptRoundTrip(t *testing.T) { + client, vaultURL, err := akvkeys.GetTestAKV() + if err != nil { + t.Skip("No access to AKV") + } + name, err := akvkeys.CreateRSAKey(client) + assert.NoError(t, err, "CreateRSAKey") + defer akvkeys.DeleteRSAKey(client, name) + keyPath, _ := url.JoinPath(vaultURL, name) + p := &KeyProvider + plainKey := make([]byte, 32) + _, _ = rand.Read(plainKey) + t.Log("Plainkey:", plainKey) + encryptedKey := p.EncryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, plainKey) + t.Log("Encryptedkey:", encryptedKey) + assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey") + decryptedKey := p.DecryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey) + assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey) +} diff --git a/alwaysencrypted_akv_test.go b/alwaysencrypted_akv_test.go new file mode 100644 index 00000000..72252069 --- /dev/null +++ b/alwaysencrypted_akv_test.go @@ -0,0 +1,54 @@ +//go:build go1.18 +// +build go1.18 + +package mssql + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" + + "github.com/microsoft/go-mssqldb/aecmk" + "github.com/microsoft/go-mssqldb/aecmk/akv" + "github.com/microsoft/go-mssqldb/internal/akvkeys" + "github.com/stretchr/testify/assert" +) + +type akvProviderTest struct { + client *azkeys.Client + keyName string +} + +func (p *akvProviderTest) ProvisionMasterKey(t *testing.T) string { + t.Helper() + client, vaultURL, err := akvkeys.GetTestAKV() + if err != nil { + t.Skip("Unable to access AKV") + } + name, err := akvkeys.CreateRSAKey(client) + assert.NoError(t, err, "CreateRSAKey") + keyPath := vaultURL + "/" + name + p.client = client + p.keyName = name + return keyPath +} + +func (p *akvProviderTest) DeleteMasterKey(t *testing.T) { + t.Helper() + if !akvkeys.DeleteRSAKey(p.client, p.keyName) { + assert.Fail(t, "DeleteRSAKey failed") + } +} + +func (p *akvProviderTest) GetProvider(t *testing.T) aecmk.ColumnEncryptionKeyProvider { + t.Helper() + return &akv.KeyProvider +} + +func (p *akvProviderTest) Name() string { + return aecmk.AzureKeyVaultKeyProvider +} + +func init() { + addProviderTest(&akvProviderTest{}) +} diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go new file mode 100644 index 00000000..05260c43 --- /dev/null +++ b/alwaysencrypted_test.go @@ -0,0 +1,223 @@ +package mssql + +import ( + "crypto/rand" + "database/sql" + "fmt" + "math/big" + "strings" + "testing" + "time" + + "github.com/golang-sql/civil" + "github.com/microsoft/go-mssqldb/aecmk" + "github.com/stretchr/testify/assert" +) + +type providerTest interface { + // ProvisionMasterKey creates a master key in the key storage and returns the path of the key + ProvisionMasterKey(t *testing.T) string + // DeleteMasterKey deletes the master key + DeleteMasterKey(t *testing.T) + // GetProvider returns the appropriate ColumnEncryptionKeyProvider instance + GetProvider(t *testing.T) aecmk.ColumnEncryptionKeyProvider + // Name is the name of the key provider + Name() string +} + +var providerTests []providerTest = make([]providerTest, 0, 2) + +func addProviderTest(p providerTest) { + providerTests = append(providerTests, p) +} + +// Define phrases for create table for each enryptable data type along with sample data for insertion and validation +type aeColumnInfo struct { + queryPhrase string + sqlDataType string + encType ColumnEncryptionType + sampleValue interface{} +} + +func TestAlwaysEncryptedE2E(t *testing.T) { + params := testConnParams(t) + if !params.ColumnEncryption { + t.Skip("Test is not running with column encryption enabled") + } + // civil.DateTime has 9 digit precision while SQL only has 7, so we can't use time.Now + dt, err := time.Parse("2006-01-02T15:04:05.9999999", "2023-08-21T18:33:36.5315137") + assert.NoError(t, err, "time.Parse") + encryptableColumns := []aeColumnInfo{ + {"int", "INT", ColumnEncryptionDeterministic, int32(1)}, + {"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")}, + {"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)}, + {"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)}, + {"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)}, + // We can't use fractional float/real values due to rounding errors in the round trip + {"real", "REAL", ColumnEncryptionDeterministic, float32(5)}, + {"float", "FLOAT", ColumnEncryptionRandomized, float64(6)}, + {"varbinary(10)", "VARBINARY", ColumnEncryptionDeterministic, []byte{1, 2, 3, 4}}, + // TODO: Varchar support requires proper selection of a collation and conversion + // {"varchar(10) COLLATE Latin1_General_BIN2", "VARCHAR", ColumnEncryptionRandomized, VarChar("varcharval")}, + {"nvarchar(30)", "NVARCHAR", ColumnEncryptionRandomized, "nvarcharval"}, + {"bit", "BIT", ColumnEncryptionDeterministic, true}, + {"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, dt}, + {"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)}, + {"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")}, + // TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that. + // {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, + } + for _, test := range providerTests { + t.Run(test.Name(), func(t *testing.T) { + conn, _ := open(t) + defer conn.Close() + certPath := test.ProvisionMasterKey(t) + defer test.DeleteMasterKey(t) + s := fmt.Sprintf(createColumnMasterKey, certPath, test.Name(), certPath) + if _, err := conn.Exec(s); err != nil { + t.Fatalf("Unable to create CMK: %s", err.Error()) + } + defer func() { + _, err := conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) + assert.NoError(t, err, "dropColumnMasterKey") + }() + r, _ := rand.Int(rand.Reader, big.NewInt(1000)) + cekName := fmt.Sprintf("mssqlCek%d", r.Int64()) + tableName := fmt.Sprintf("mssqlAe%d", r.Int64()) + keyBytes := make([]byte, 32) + _, _ = rand.Read(keyBytes) + encryptedCek := test.GetProvider(t).EncryptColumnEncryptionKey(certPath, KeyEncryptionAlgorithm, keyBytes) + createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek) + _, err := conn.Exec(createCek) + assert.NoError(t, err, "Unable to create CEK") + defer func() { + _, err := conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName)) + assert.NoError(t, err, "dropColumnEncryptionKey") + }() + _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) + query := new(strings.Builder) + insert := new(strings.Builder) + sel := new(strings.Builder) + _, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName)) + _, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName)) + _, _ = sel.WriteString("select top(1) ") + insertArgs := make([]interface{}, len(encryptableColumns)+1) + for i, ec := range encryptableColumns { + encType := "RANDOMIZED" + null := "" + _, ok := ec.sampleValue.(sql.NullInt32) + if ok { + null = "NULL" + } + if ec.encType == ColumnEncryptionDeterministic { + encType = "DETERMINISTIC" + } + _, _ = query.WriteString(fmt.Sprintf(`col%d %s ENCRYPTED WITH (ENCRYPTION_TYPE = %s, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]) %s, + `, i, ec.queryPhrase, encType, cekName, null)) + + insertArgs[i] = ec.sampleValue + insert.WriteString(fmt.Sprintf("@p%d,", i+1)) + sel.WriteString(fmt.Sprintf("col%d,", i)) + } + _, _ = query.WriteString("unencryptedcolumn nvarchar(100)") + _, _ = query.WriteString(")") + insertArgs[len(encryptableColumns)] = "unencryptedvalue" + insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1)) + sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName)) + _, err = conn.Exec(query.String()) + assert.NoError(t, err, "Failed to create encrypted table") + defer func() { _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) }() + _, err = conn.Exec(insert.String(), insertArgs...) + assert.NoError(t, err, "Failed to insert row in encrypted table") + rows, err := conn.Query(sel.String()) + assert.NoErrorf(t, err, "Unable to query encrypted columns") + if !rows.Next() { + rows.Close() + assert.FailNow(t, "rows.Next returned false") + } + cols, err := rows.ColumnTypes() + assert.NoError(t, err, "rows.ColumnTypes failed") + for i := range encryptableColumns { + assert.Equalf(t, encryptableColumns[i].sqlDataType, cols[i].DatabaseTypeName(), + "Got wrong type name for col%d.", i) + } + + var unencryptedColumnValue string + scanValues := make([]interface{}, len(encryptableColumns)+1) + for v := range scanValues { + if v < len(encryptableColumns) { + scanValues[v] = new(interface{}) + } + } + scanValues[len(encryptableColumns)] = &unencryptedColumnValue + err = rows.Scan(scanValues...) + defer rows.Close() + if err != nil { + assert.FailNow(t, "Scan failed ", err) + } + for i := range encryptableColumns { + var strVal string + var expectedStrVal string + if encryptableColumns[i].sampleValue == nil { + expectedStrVal = "NULL" + } else { + expectedStrVal = comparisonValueFromObject(encryptableColumns[i].sampleValue) + } + rawVal := scanValues[i].(*interface{}) + + if rawVal == nil { + strVal = "NULL" + } else { + strVal = comparisonValueFromObject(*rawVal) + } + assert.Equalf(t, expectedStrVal, strVal, "Incorrect value for col%d. ", i) + } + assert.Equalf(t, "unencryptedvalue", unencryptedColumnValue, "Got wrong value for unencrypted column") + _ = rows.Next() + err = rows.Err() + assert.NoError(t, err, "rows.Err() has non-nil values") + }) + } +} + +func comparisonValueFromObject(object interface{}) string { + switch v := object.(type) { + case []byte: + { + return string(v) + } + case string: + return v + case time.Time: + return civil.DateTimeOf(v).String() + //return v.Format(time.RFC3339) + case fmt.Stringer: + return v.String() + case bool: + if v == true { + return "1" + } + return "0" + default: + return fmt.Sprintf("%v", v) + } +} + +const ( + createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= '%s', KEY_PATH='%s')` + dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` + createColumnEncryptionKey = `CREATE COLUMN ENCRYPTION KEY [%s] WITH VALUES (COLUMN_MASTER_KEY = [%s], ALGORITHM = 'RSA_OAEP', ENCRYPTED_VALUE = 0x%x )` + dropColumnEncryptionKey = `DROP COLUMN ENCRYPTION KEY [%s]` + createEncryptedTable = `CREATE TABLE %s + (col1 int + ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]), + col2 nchar(10) COLLATE Latin1_General_BIN2 + ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]) + )` +) diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index 20f95c73..c3a6e926 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -4,217 +4,42 @@ package mssql import ( - "crypto/rand" - "database/sql" "fmt" - "math/big" - "strings" "testing" - "time" - "github.com/golang-sql/civil" + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/aecmk/localcert" "github.com/microsoft/go-mssqldb/internal/certs" + "github.com/stretchr/testify/assert" ) -// Define phrases for create table for each enryptable data type along with sample data for insertion and validation -type aeColumnInfo struct { - queryPhrase string - sqlDataType string - encType ColumnEncryptionType - sampleValue interface{} +type certStoreProviderTest struct { + thumbprint string } -var encryptableColumns = []aeColumnInfo{ - {"int", "INT", ColumnEncryptionDeterministic, int32(1)}, - {"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")}, - {"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)}, - {"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)}, - {"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)}, - // We can't use fractional float/real values due to rounding errors in the round trip - {"real", "REAL", ColumnEncryptionDeterministic, float32(5)}, - {"float", "FLOAT", ColumnEncryptionRandomized, float64(6)}, - {"varbinary(10)", "VARBINARY", ColumnEncryptionDeterministic, []byte{1, 2, 3, 4}}, - // TODO: Varchar support requires proper selection of a collation and conversion - // {"varchar(10) COLLATE Latin1_General_BIN2", "VARCHAR", ColumnEncryptionRandomized, VarChar("varcharval")}, - {"nvarchar(30)", "NVARCHAR", ColumnEncryptionRandomized, "nvarcharval"}, - {"bit", "BIT", ColumnEncryptionDeterministic, true}, - {"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, time.Now()}, - {"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(time.Now())}, - {"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")}, - // TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that. - // {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, -} - -func TestAlwaysEncryptedE2E(t *testing.T) { - params := testConnParams(t) - if !params.ColumnEncryption { - t.Skip("Test is not running with column encryption enabled") - } - conn, _ := open(t) - defer conn.Close() +func (p *certStoreProviderTest) ProvisionMasterKey(t *testing.T) string { + t.Helper() thumbprint, err := certs.ProvisionMasterKeyInCertStore() - if err != nil { - t.Fatal(err) - } - defer certs.DeleteMasterKeyCert(thumbprint) + assert.NoError(t, err, "Create cert in cert store") certPath := fmt.Sprintf(`CurrentUser/My/%s`, thumbprint) - s := fmt.Sprintf(createColumnMasterKey, certPath, certPath) - if _, err := conn.Exec(s); err != nil { - t.Fatalf("Unable to create CMK: %s", err.Error()) - } - defer conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) - r, _ := rand.Int(rand.Reader, big.NewInt(1000)) - cekName := fmt.Sprintf("mssqlCek%d", r.Int64()) - tableName := fmt.Sprintf("mssqlAe%d", r.Int64()) - keyBytes := make([]byte, 32) - _, _ = rand.Read(keyBytes) - encryptedCek := localcert.WindowsCertificateStoreKeyProvider.EncryptColumnEncryptionKey(certPath, KeyEncryptionAlgorithm, keyBytes) - createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek) - _, err = conn.Exec(createCek) - if err != nil { - t.Fatalf("Unable to create CEK: %s", err.Error()) - } - defer conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName)) - _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) - query := new(strings.Builder) - insert := new(strings.Builder) - sel := new(strings.Builder) - _, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName)) - _, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName)) - _, _ = sel.WriteString("select top(1) ") - insertArgs := make([]interface{}, len(encryptableColumns)+1) - for i, ec := range encryptableColumns { - encType := "RANDOMIZED" - null := "" - _, ok := ec.sampleValue.(sql.NullInt32) - if ok { - null = "NULL" - } - if ec.encType == ColumnEncryptionDeterministic { - encType = "DETERMINISTIC" - } - _, _ = query.WriteString(fmt.Sprintf(`col%d %s ENCRYPTED WITH (ENCRYPTION_TYPE = %s, - ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', - COLUMN_ENCRYPTION_KEY = [%s]) %s, - `, i, ec.queryPhrase, encType, cekName, null)) - - insertArgs[i] = ec.sampleValue - insert.WriteString(fmt.Sprintf("@p%d,", i+1)) - sel.WriteString(fmt.Sprintf("col%d,", i)) - } - _, _ = query.WriteString("unencryptedcolumn nvarchar(100)") - _, _ = query.WriteString(")") - insertArgs[len(encryptableColumns)] = "unencryptedvalue" - insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1)) - sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName)) - _, err = conn.Exec(query.String()) - if err != nil { - t.Fatalf("Failed to create encrypted table %s", err.Error()) - } - defer conn.Exec("DROP TABLE IF EXISTS " + tableName) - _, err = conn.Exec(insert.String(), insertArgs...) - if err != nil { - t.Fatalf("Failed to insert row in encrypted table %s", err.Error()) - } - rows, err := conn.Query(sel.String()) - if err != nil { - t.Fatalf("Unable to query encrypted columns: %v", err.(Error).All) - } - if !rows.Next() { - rows.Close() - t.Fatalf("rows.Next returned false") - } - cols, err := rows.ColumnTypes() - if err != nil { - t.Fatalf("rows.ColumnTypes failed %s", err.Error()) - } - for i := range encryptableColumns { - - if cols[i].DatabaseTypeName() != encryptableColumns[i].sqlDataType { - t.Fatalf("Got wrong type name for col%d. Expected: %s, Got:%s", i, encryptableColumns[i].sqlDataType, cols[i].DatabaseTypeName()) - } - } - - var unencryptedColumnValue string - scanValues := make([]interface{}, len(encryptableColumns)+1) - for v := range scanValues { - if v < len(encryptableColumns) { - scanValues[v] = new(interface{}) - } - } - scanValues[len(encryptableColumns)] = &unencryptedColumnValue - err = rows.Scan(scanValues...) - if err != nil { - rows.Close() - t.Fatalf("rows.Scan failed: %s", err.Error()) - } + p.thumbprint = thumbprint + return certPath +} - for i := range encryptableColumns { - var strVal string - var expectedStrVal string - if encryptableColumns[i].sampleValue == nil { - expectedStrVal = "NULL" - } else { - expectedStrVal = comparisonValueFromObject(encryptableColumns[i].sampleValue) - } - rawVal := scanValues[i].(*interface{}) +func (p *certStoreProviderTest) DeleteMasterKey(t *testing.T) { + t.Helper() + certs.DeleteMasterKeyCert(p.thumbprint) +} - if rawVal == nil { - strVal = "NULL" - } else { - strVal = comparisonValueFromObject(*rawVal) - } - if expectedStrVal != strVal { - t.Fatalf("Incorrect value for col%d. Expected:%s, Got:%s", i, expectedStrVal, strVal) - } - } - if unencryptedColumnValue != "unencryptedvalue" { - t.Fatalf("Got wrong value for unencrypted column: %s", unencryptedColumnValue) - } - rows.Close() - err = rows.Err() - if err != nil { - t.Fatalf("rows.Err() has non-nil value: %s", err.Error()) - } +func (p *certStoreProviderTest) GetProvider(t *testing.T) aecmk.ColumnEncryptionKeyProvider { + t.Helper() + return &localcert.WindowsCertificateStoreKeyProvider } -func comparisonValueFromObject(object interface{}) string { - switch v := object.(type) { - case []byte: - { - return string(v) - } - case string: - return v - case time.Time: - return civil.DateTimeOf(v).String() - //return v.Format(time.RFC3339) - case fmt.Stringer: - return v.String() - case bool: - if v == true { - return "1" - } - return "0" - default: - return fmt.Sprintf("%v", v) - } +func (p *certStoreProviderTest) Name() string { + return aecmk.CertificateStoreKeyProvider } -const ( - createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= 'MSSQL_CERTIFICATE_STORE', KEY_PATH='%s')` - dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` - createColumnEncryptionKey = `CREATE COLUMN ENCRYPTION KEY [%s] WITH VALUES (COLUMN_MASTER_KEY = [%s], ALGORITHM = 'RSA_OAEP', ENCRYPTED_VALUE = 0x%x )` - dropColumnEncryptionKey = `DROP COLUMN ENCRYPTION KEY [%s]` - createEncryptedTable = `CREATE TABLE %s - (col1 int - ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, - ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', - COLUMN_ENCRYPTION_KEY = [%s]), - col2 nchar(10) COLLATE Latin1_General_BIN2 - ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, - ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', - COLUMN_ENCRYPTION_KEY = [%s]) - )` -) +func init() { + addProviderTest(&certStoreProviderTest{}) +} diff --git a/appveyor.yml b/appveyor.yml index b3bcc5c2..ba39e314 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -11,15 +11,13 @@ environment: SQLUSER: sa SQLPASSWORD: Password12! DATABASE: test - GOVERSION: 116 + GOVERSION: 117 COLUMNENCRYPTION: APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 RACE: -race -cpu 4 TAGS: matrix: - SQLINSTANCE: SQL2017 - - GOVERSION: 117 - SQLINSTANCE: SQL2017 - GOVERSION: 118 SQLINSTANCE: SQL2017 - GOVERSION: 120 @@ -46,11 +44,6 @@ install: - set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH% - go version - go env - - go get -u github.com/golang-sql/civil - - go get -u github.com/golang-sql/sqlexp - - go get -u golang.org/x/crypto/md4 - - go get github.com/stretchr/testify/assert@v1.8.1 - - go get -u golang.org/x/text/encoding/unicode build_script: - go build diff --git a/go.mod b/go.mod index 4c3ea17a..84044794 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,36 @@ module github.com/microsoft/go-mssqldb -go 1.16 +go 1.17 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0 github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 github.com/golang-sql/sqlexp v0.1.0 github.com/jcmturner/gokrb5/v8 v8.4.4 - github.com/stretchr/testify v1.8.1 - golang.org/x/crypto v0.9.0 - golang.org/x/sys v0.8.0 - golang.org/x/text v0.9.0 + github.com/stretchr/testify v1.8.4 + golang.org/x/crypto v0.12.0 + golang.org/x/sys v0.11.0 + golang.org/x/text v0.12.0 +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v5 v5.0.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/hashicorp/go-uuid v1.0.3 // indirect + github.com/jcmturner/aescts/v2 v2.0.0 // indirect + github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect + github.com/jcmturner/gofork v1.7.6 // indirect + github.com/jcmturner/goidentity/v6 v6.0.1 // indirect + github.com/jcmturner/rpc/v2 v2.0.3 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 69dce41d..05c59b71 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,21 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 h1:/iHxaJhsFr0+xVFfbMr5vxz848jyiWuIEDhYq3y5odY= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= -github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0 h1:yfJe15aSwEQ6Oo6J+gdfdulPNoZ3TEhmbhLIoxZcA+U= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0/go.mod h1:Q28U+75mpCaSCDowNEmhIo/rmgdkqmkmzI7N6TGR4UY= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 h1:T028gtTPiYt/RMUfs8nVsAL7FDQrfLlrm/NnRG/zcC4= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0/go.mod h1:cw4zVQgBby0Z5f2v0itn6se2dDP17nTjbZFXW5uPyHA= github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= +github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpCKs6LGGLAhwBARt9632unrVcI6i8s/8os= +github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -15,6 +25,8 @@ github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5O github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= @@ -55,30 +67,39 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -86,20 +107,23 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/internal/akvkeys/utils.go b/internal/akvkeys/utils.go new file mode 100644 index 00000000..bd8e2c9a --- /dev/null +++ b/internal/akvkeys/utils.go @@ -0,0 +1,53 @@ +//go:build go1.18 +// +build go1.18 + +package akvkeys + +import ( + "context" + "crypto/rand" + "fmt" + "math/big" + "net/url" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" +) + +func GetTestAKV() (client *azkeys.Client, u string, err error) { + vaultName := os.Getenv("KEY_VAULT_NAME") + if len(vaultName) == 0 { + err = fmt.Errorf("KEY_VAULT_NAME is not set in the environment") + return + } + vaultURL := fmt.Sprintf("https://%s.vault.azure.net/", url.PathEscape(vaultName)) + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return + } + client, err = azkeys.NewClient(vaultURL, cred, nil) + if err != nil { + return + } + u = vaultURL + "keys" + return +} + +func CreateRSAKey(client *azkeys.Client) (name string, err error) { + kt := azkeys.KeyTypeRSA + ks := int32(2048) + rsaKeyParams := azkeys.CreateKeyParameters{ + Kty: &kt, + KeySize: &ks, + } + i, _ := rand.Int(rand.Reader, big.NewInt(1000)) + name = fmt.Sprintf("go-mssqlkey%d", i) + _, err = client.CreateKey(context.TODO(), name, rsaKeyParams, nil) + return +} + +func DeleteRSAKey(client *azkeys.Client, name string) bool { + _, err := client.DeleteKey(context.TODO(), name, nil) + return err == nil +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go index 64ca57f6..db16b16a 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go @@ -21,13 +21,17 @@ type CEKV struct { Key []byte } -func (c *CEKV) Verify(cert *x509.Certificate) bool { +func (c *CEKV) VerifySignature(key *rsa.PublicKey) bool { sha256Sum := sha256.Sum256(c.DataToSign) - err := rsa.VerifyPKCS1v15(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, sha256Sum[:], c.SignedHash) + err := rsa.VerifyPKCS1v15(key, crypto.SHA256, sha256Sum[:], c.SignedHash) return err == nil } +func (c *CEKV) Verify(cert *x509.Certificate) bool { + return c.VerifySignature(cert.PublicKey.(*rsa.PublicKey)) +} + func (c *CEKV) Decrypt(private *rsa.PrivateKey) ([]byte, error) { decryptedData, decryptErr := rsa.DecryptOAEP(sha1.New(), rand.Reader, private, c.Ciphertext, nil) if decryptErr != nil { diff --git a/tds_test.go b/tds_test.go index daabd714..dca94788 100644 --- a/tds_test.go +++ b/tds_test.go @@ -305,6 +305,9 @@ func GetConnParams() (*msdsn.Config, error) { return nil, err } params.LogFlags = logFlags + if os.Getenv("COLUMNENCRYPTION") != "" { + params.ColumnEncryption = true + } return ¶ms, nil } if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 { diff --git a/version.go b/version.go index 58fd4619..256e9b4e 100644 --- a/version.go +++ b/version.go @@ -4,7 +4,7 @@ import "fmt" // Update this variable with the release tag before pushing the tag // This value is written to the prelogin and login7 packets during a new connection -const driverVersion = "v1.5.0" +const driverVersion = "v1.6.0" func getDriverVersion(ver string) uint32 { var majorVersion uint32