From 6ce6c4f81d7d7952b99dc21547e55231f2043991 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 7 Jun 2023 09:18:26 -0500 Subject: [PATCH 01/47] add core CEK parameters and types --- columnencryptionkey.go | 48 ++++++++++++++++++++++++++++++++++++++++++ msdsn/conn_str.go | 36 ++++++++++++++++++++++++------- msdsn/conn_str_test.go | 12 +++++------ 3 files changed, 82 insertions(+), 14 deletions(-) create mode 100644 columnencryptionkey.go diff --git a/columnencryptionkey.go b/columnencryptionkey.go new file mode 100644 index 00000000..532b3ab4 --- /dev/null +++ b/columnencryptionkey.go @@ -0,0 +1,48 @@ +package mssql + +// cek ==> Column Encryption Key +// Every row of an encrypted table has an associated list of keys used to decrypt its columns +type cekTable struct { + entries []cekTableEntry +} + +type encryptionKeyInfo struct { + encryptedKey []byte + databaseID int + cekID int + cekVersion int + cekMdVersion []byte + keyPath string + keyStoreName string + algorithmName string +} + +type cekTableEntry struct { + databaseID int + keyId int + keyVersion int + mdVersion []byte + valueCount int + cekValues []encryptionKeyInfo +} + +func newCekTable(size uint16) cekTable { + return cekTable{entries: make([]cekTableEntry, size)} +} + +// ColumnEncryptionKeyProvider is the interface for decrypting and encrypting column encryption keys. +// It is similar to .Net https://learn.microsoft.com/dotnet/api/microsoft.data.sqlclient.sqlcolumnencryptionkeystoreprovider. +type ColumnEncryptionKeyProvider interface { + // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. + // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. + DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) []byte + // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. + EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte + // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key + // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the + // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. + SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte + // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key + // with the specified key path and the specified enclave behavior. Return nil if not supported. + VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool +} diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 4f71453d..ee186e0d 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -88,6 +88,8 @@ type Config struct { ProtocolParameters map[string]interface{} // BrowserMsg is the message identifier to fetch instance data from SQL browser BrowserMessage BrowserMsg + //ColumnEncryption is true if the application needs to decrypt or encrypt Always Encrypted values + ColumnEncryption bool } // Build a tls.Config object from the supplied certificate. @@ -371,6 +373,19 @@ func Parse(dsn string) (Config, error) { return p, err } + if c, ok := params["columnencryption"]; ok { + columnEncryption, err := strconv.ParseBool(c) + if err != nil { + if strings.EqualFold(c, "Enabled") { + columnEncryption = true + } else if strings.EqualFold(c, "Disabled") { + columnEncryption = false + } else { + return p, fmt.Errorf("invalid columnencryption '%v' : %v", columnEncryption, err.Error()) + } + } + p.ColumnEncryption = columnEncryption + } return p, nil } @@ -421,6 +436,9 @@ func (p Config) URL() *url.URL { res.Path = p.Instance } q.Add("dial timeout", strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64)) + if p.ColumnEncryption { + q.Add("columnencryption", "true") + } if len(q) > 0 { res.RawQuery = q.Encode() } @@ -428,15 +446,17 @@ func (p Config) URL() *url.URL { return &res } +// ADO connection string keywords at https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs var adoSynonyms = map[string]string{ - "application name": "app name", - "data source": "server", - "address": "server", - "network address": "server", - "addr": "server", - "user": "user id", - "uid": "user id", - "initial catalog": "database", + "application name": "app name", + "data source": "server", + "address": "server", + "network address": "server", + "addr": "server", + "user": "user id", + "uid": "user id", + "initial catalog": "database", + "column encryption setting": "columnencryption", } func splitConnectionString(dsn string) (res map[string]string) { diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 5fa1a0ed..1e001385 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -54,7 +54,7 @@ func TestValidConnectionString(t *testing.T) { {"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p Config) bool { return p.Host == "server" && p.Instance == "instance" && p.User == "tester" && p.Password == "pwd" }}, - {"server=.", func(p Config) bool { return p.Host == "localhost" }}, + {"server=.", func(p Config) bool { return p.Host == "localhost" && !p.ColumnEncryption }}, {"server=(local)", func(p Config) bool { return p.Host == "localhost" }}, {"ServerSPN=serverspn;Workstation ID=workstid", func(p Config) bool { return p.ServerSPN == "serverspn" && p.Workstation == "workstid" }}, {"failoverpartner=fopartner;failoverport=2000", func(p Config) bool { return p.FailOverPartner == "fopartner" && p.FailOverPort == 2000 }}, @@ -68,8 +68,8 @@ func TestValidConnectionString(t *testing.T) { {"encrypt=false;tlsmin=1.0", func(p Config) bool { return p.Encryption == EncryptionOff && p.TLSConfig.MinVersion == tls.VersionTLS10 }}, - {"encrypt=true;tlsmin=1.1", func(p Config) bool { - return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 + {"encrypt=true;tlsmin=1.1;column encryption setting=enabled", func(p Config) bool { + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption }}, {"encrypt=true;tlsmin=1.2", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS12 @@ -174,10 +174,10 @@ func TestValidConnectionString(t *testing.T) { return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry }}, {"sqlserver://someuser@somehost?connection+timeout=30&disableretry=1", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption }}, - {"sqlserver://somehost?encrypt=true&tlsmin=1.1", func(p Config) bool { - return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 + {"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool { + return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption }}, } for _, ts := range connStrings { From 1ba7c9abd9f94b995828b2a6f063c38717b3320b Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 7 Jun 2023 10:17:03 -0500 Subject: [PATCH 02/47] add column encryption featureext --- tds.go | 46 ++++++++++++++++++++++++++++++++++------------ tds_test.go | 6 ++++-- token.go | 3 ++- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/tds.go b/tds.go index d10f9c6b..d82a3049 100644 --- a/tds.go +++ b/tds.go @@ -157,16 +157,17 @@ const ( ) type tdsSession struct { - buf *tdsBuffer - loginAck loginAckStruct - database string - partner string - columns []columnStruct - tranid uint64 - logFlags uint64 - logger ContextLogger - routedServer string - routedPort uint16 + buf *tdsBuffer + loginAck loginAckStruct + database string + partner string + columns []columnStruct + tranid uint64 + logFlags uint64 + logger ContextLogger + routedServer string + routedPort uint16 + alwaysEncrypted bool } const ( @@ -1047,6 +1048,9 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont CtlIntName: "go-mssqldb", ClientProgVer: getDriverVersion(driverVersion), } + if p.ColumnEncryption { + _ = l.FeatureExt.Add(&featureExtColumnEncryption{}) + } switch { case fe.FedAuthLibrary == FedAuthLibrarySecurityToken: if uint64(p.LogFlags)&logDebug != 0 { @@ -1061,14 +1065,14 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont return nil, err } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case fe.FedAuthLibrary == FedAuthLibraryADAL: if uint64(p.LogFlags)&logDebug != 0 { logger.Log(ctx, msdsn.LogDebug, "Starting federated authentication using ADAL") } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case auth != nil: if uint64(p.LogFlags)&logDebug != 0 { @@ -1317,3 +1321,21 @@ initiate_connection: } return &sess, nil } + +type featureExtColumnEncryption struct { +} + +func (f *featureExtColumnEncryption) featureID() byte { + return featExtCOLUMNENCRYPTION +} + +func (f *featureExtColumnEncryption) toBytes() []byte { + /* + 1 = The client supports column encryption without enclave computations. + 2 = The client SHOULD<25> support column encryption when encrypted data require enclave computations. + 3 = The client SHOULD<26> support column encryption when encrypted data require enclave computations + with the additional ability to cache column encryption keys that are to be sent to the enclave + and the ability to retry queries when the keys sent by the client do not match what is needed for the query to run. + */ + return []byte{0x02} +} diff --git a/tds_test.go b/tds_test.go index 94d1be65..837d51c2 100644 --- a/tds_test.go +++ b/tds_test.go @@ -126,12 +126,13 @@ func TestSendLoginWithFeatureExt(t *testing.T) { FedAuthLibrary: FedAuthLibrarySecurityToken, FedAuthToken: "fedauthtoken", }) + login.FeatureExt.Add(&featureExtColumnEncryption{}) err := sendLogin(buf, &login) if err != nil { t.Error("sendLogin should succeed") } ref := []byte{ - 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, + 16, 1, 0, 0xe5, 0, 0, 1, 0, 0xdd, 0, 0, 0, 4, 0, 0, 116, 0, 16, 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, @@ -144,7 +145,8 @@ func TestSendLoginWithFeatureExt(t *testing.T) { 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, 0, 0, 2, 29, 0, 0, 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, - 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 255} + 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 4, 1, + 0, 0, 0, 2, 255} out := memBuf.Bytes() if !bytes.Equal(ref, out) { t.Log("Expected:") diff --git a/token.go b/token.go index 76d4e025..ccfb515a 100644 --- a/token.go +++ b/token.go @@ -95,7 +95,8 @@ const ( // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( - colFlagNullable = 1 + colFlagNullable = 1 + colFlagEncrypted = 0x0800 // TODO implement more flags ) From 67d93fa6e52555fd2b3d286c33a11aacf7ec7e24 Mon Sep 17 00:00:00 2001 From: davidshi Date: Fri, 9 Jun 2023 16:59:31 -0500 Subject: [PATCH 03/47] Add parsing of always encrypted tokens --- columnencryptionkey.go | 39 ++++++ go.mod | 1 + tds.go | 35 ++++- tds_test.go | 74 ++++++++-- token.go | 299 +++++++++++++++++++++++++++++++++++++---- types.go | 40 ++++-- 6 files changed, 438 insertions(+), 50 deletions(-) diff --git a/columnencryptionkey.go b/columnencryptionkey.go index 532b3ab4..5220a8d3 100644 --- a/columnencryptionkey.go +++ b/columnencryptionkey.go @@ -1,5 +1,10 @@ package mssql +import ( + "fmt" + "time" +) + // cek ==> Column Encryption Key // Every row of an encrypted table has an associated list of keys used to decrypt its columns type cekTable struct { @@ -30,6 +35,27 @@ func newCekTable(size uint16) cekTable { return cekTable{entries: make([]cekTableEntry, size)} } +// ColumnEncryptionKeyLifetime is the default lifetime of decrypted Column Encryption Keys in the global cache. +// The default is 2 hours +var ColumnEncryptionKeyLifetime time.Duration = 2 * time.Hour + +type cekCacheEntry struct { + expiry time.Time + key []byte +} + +type cekCache map[string]cekCacheEntry + +type cekProvider struct { + provider ColumnEncryptionKeyProvider + decryptedKeys cekCache +} + +// no synchronization on this map. Providers register during init. +type columnEncryptionKeyProviderMap map[string]cekProvider + +var globalCekProviderFactoryMap = columnEncryptionKeyProviderMap{} + // ColumnEncryptionKeyProvider is the interface for decrypting and encrypting column encryption keys. // It is similar to .Net https://learn.microsoft.com/dotnet/api/microsoft.data.sqlclient.sqlcolumnencryptionkeystoreprovider. type ColumnEncryptionKeyProvider interface { @@ -45,4 +71,17 @@ type ColumnEncryptionKeyProvider interface { // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key // with the specified key path and the specified enclave behavior. Return nil if not supported. VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool + // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. + // If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. + // If it returns zero, the keys will not be cached. + KeyLifetime() *time.Duration +} + +func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) error { + _, ok := globalCekProviderFactoryMap[name] + if ok { + return fmt.Errorf("CEK provider %s is already registered", name) + } + globalCekProviderFactoryMap[name] = cekProvider{provider: provider} + return nil } diff --git a/go.mod b/go.mod index 89969302..49183736 100644 --- a/go.mod +++ b/go.mod @@ -13,5 +13,6 @@ require ( github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d golang.org/x/sys v0.0.0-20220224120231-95c6836cb0e7 // indirect + golang.org/x/text v0.3.7 // indirect gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce ) diff --git a/tds.go b/tds.go index d82a3049..1a22b241 100644 --- a/tds.go +++ b/tds.go @@ -168,6 +168,11 @@ type tdsSession struct { routedServer string routedPort uint16 alwaysEncrypted bool + aeSettings *alwaysEncryptedSettings +} + +type alwaysEncryptedSettings struct { + enclaveType string } const ( @@ -179,10 +184,19 @@ const ( ) type columnStruct struct { - UserType uint32 - Flags uint16 - ColName string - ti typeInfo + UserType uint32 + Flags uint16 + ColName string + ti typeInfo + cryptoMeta *cryptoMetadata +} + +func (c columnStruct) isEncrypted() bool { + return isEncryptedFlag(c.Flags) +} + +func isEncryptedFlag(flags uint16) bool { + return colFlagEncrypted == (flags & colFlagEncrypted) } type keySlice []uint8 @@ -1292,6 +1306,19 @@ initiate_connection: case loginAckStruct: sess.loginAck = token loginAck = true + case featureExtAck: + for _, v := range token { + switch v := v.(type) { + case colAckStruct: + if v.Version <= 2 && v.Version > 0 { + sess.alwaysEncrypted = true + sess.aeSettings = &alwaysEncryptedSettings{} + if len(v.EnclaveType) > 0 { + sess.aeSettings.enclaveType = string(v.EnclaveType) + } + } + } + } case doneStruct: if token.isError() { tokenErr := token.getError() diff --git a/tds_test.go b/tds_test.go index 837d51c2..6b1c6481 100644 --- a/tds_test.go +++ b/tds_test.go @@ -204,6 +204,59 @@ func TestSendSqlBatch(t *testing.T) { } } +func TestLoginWithColumnEncryption(t *testing.T) { + checkConnStr(t) + p, err := msdsn.Parse(makeConnStr(t).String()) + if err != nil { + t.Error("parseConnectParams failed:", err.Error()) + return + } + p.ColumnEncryption = true + tl := testLogger{t: t} + defer tl.StopLogging() + conn, err := connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) + if err != nil { + t.Error("Open connection failed:", err.Error()) + return + } + defer conn.buf.transport.Close() + + headers := []headerStruct{ + {hdrtype: dataStmHdrTransDescr, + data: transDescrHdr{0, 1}.pack()}, + } + err = sendSqlBatch72(conn.buf, "select (@@microsoftversion / 0x1000000) & 0xff AS [VersionMajor]", headers, true) + if err != nil { + t.Error("Sending sql batch failed", err.Error()) + return + } + + reader := startReading(conn, context.Background(), outputs{}) + + err = reader.iterateResponse() + if err != nil { + t.Fatal(err) + } + + if len(reader.lastRow) == 0 { + t.Fatal("expected row but no row set") + } + + switch value := reader.lastRow[0].(type) { + case int64: + if value > 12 { + if !conn.alwaysEncrypted { + t.Fatalf("SQL Version %d should have alwaysEncrypted == true", value) + } + } else if conn.alwaysEncrypted { + t.Fatalf("SQL Version %d should have alwaysEncrypted == false", value) + } + + default: + t.Fatalf("Expected int64 return but got %v", value) + } +} + // returns parsed connection parameters derived from // environment variables func testConnParams(t testing.TB) msdsn.Config { @@ -914,19 +967,22 @@ func BenchmarkPacketSize(b *testing.B) { b.Run(bm.name, func(b *testing.B) { for i := 0; i < b.N; i++ { p.PacketSize = bm.packetSize - runBatch(b, p) + runBatch(b, "", p) } }) } } -func runBatch(t testing.TB, p msdsn.Config) { +func runBatch(t testing.TB, batch string, p msdsn.Config) int32 { + if len(batch) == 0 { + batch = "select 1" + } tl := testLogger{t: t} defer tl.StopLogging() conn, err := connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) if err != nil { t.Error("Open connection failed:", err.Error()) - return + return 0 } defer conn.buf.transport.Close() @@ -934,10 +990,10 @@ func runBatch(t testing.TB, p msdsn.Config) { {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{0, 1}.pack()}, } - err = sendSqlBatch72(conn.buf, "select 1", headers, true) + err = sendSqlBatch72(conn.buf, batch, headers, true) if err != nil { t.Error("Sending sql batch failed", err.Error()) - return + return 0 } reader := startReading(conn, context.Background(), outputs{}) @@ -953,11 +1009,11 @@ func runBatch(t testing.TB, p msdsn.Config) { switch value := reader.lastRow[0].(type) { case int32: - if value != 1 { - t.Error("Invalid value returned, should be 1", value) - return - } + return value + default: + t.Fatalf("expected an int32 return but got %v", value) } + return 0 } func TestGetDriverVersion(t *testing.T) { diff --git a/token.go b/token.go index ccfb515a..cc7f68ea 100644 --- a/token.go +++ b/token.go @@ -1,6 +1,7 @@ package mssql import ( + "bytes" "context" "encoding/binary" "fmt" @@ -11,6 +12,7 @@ import ( "github.com/golang-sql/sqlexp" "github.com/microsoft/go-mssqldb/msdsn" + "golang.org/x/text/encoding/unicode" ) //go:generate go run golang.org/x/tools/cmd/stringer -type token @@ -92,6 +94,10 @@ const ( fedAuthInfoSPN = 0x02 ) +const ( + cipherAlgCustom = 0x00 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -534,7 +540,14 @@ type fedAuthAckStruct struct { Signature []byte } -func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { +type colAckStruct struct { + Version int + EnclaveType string +} + +type featureExtAck map[byte]interface{} + +func parseFeatureExtAck(r *tdsBuffer) featureExtAck { ack := map[byte]interface{}{} for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { @@ -556,7 +569,21 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { length -= 32 } ack[feature] = fedAuthAck - + case featExtCOLUMNENCRYPTION: + colAck := colAckStruct{Version: int(r.byte())} + length-- + if length > 0 { + // enclave type is sent as utf16 le + enclaveLength := r.byte() * 2 + length-- + enclaveBytes := make([]byte, enclaveLength) + r.ReadFull(enclaveBytes) + // if the enclave type is malformed we'll just ignore it + colAck.EnclaveType, _ = ucs22str(enclaveBytes) + length -= uint32(enclaveLength) + + } + ack[feature] = colAck } // Skip unprocessed bytes @@ -569,34 +596,244 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { } // http://msdn.microsoft.com/en-us/library/dd357363.aspx -func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { +func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { count := r.uint16() if count == 0xffff { // no metadata is sent return nil } columns = make([]columnStruct, count) + var cekTable *cekTable + if s.alwaysEncrypted { + // column encryption key list + cekTable = readCekTable(r) + } + for i := range columns { column := &columns[i] - column.UserType = r.uint32() - column.Flags = r.uint16() + baseTi := getBaseTypeInfo(r, true) + typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta) + typeInfo.UserType = baseTi.UserType + typeInfo.Flags = baseTi.Flags + typeInfo.TypeId = baseTi.TypeId + + column.Flags = baseTi.Flags + column.UserType = baseTi.UserType + column.ti = typeInfo + + if column.isEncrypted() && s.alwaysEncrypted { + // Read Crypto Metadata + cryptoMeta := parseCryptoMetadata(r, cekTable) + cryptoMeta.typeInfo.Flags = baseTi.Flags + column.cryptoMeta = &cryptoMeta + } else { + column.cryptoMeta = nil + } - // parsing TYPE_INFO structure - column.ti = readTypeInfo(r) column.ColName = r.BVarChar() } return columns } +func getBaseTypeInfo(r *tdsBuffer, parseFlags bool) typeInfo { + userType := r.uint32() + flags := uint16(0) + if parseFlags { + flags = r.uint16() + } + tId := r.byte() + + return typeInfo{ + UserType: userType, + Flags: flags, + TypeId: tId} +} + +type cryptoMetadata struct { + entry *cekTableEntry + ordinal uint16 + algorithmId byte + algorithmName *string + encType byte + normRuleVer byte + typeInfo typeInfo +} + +func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata { + ordinal := uint16(0) + if cekTable != nil { + ordinal = r.uint16() + } + + typeInfo := getBaseTypeInfo(r, false) + ti := readTypeInfo(r, typeInfo.TypeId, nil) + ti.UserType = typeInfo.UserType + ti.Flags = typeInfo.Flags + ti.TypeId = typeInfo.TypeId + + algorithmId := r.byte() + var algName *string = nil + + if algorithmId == cipherAlgCustom { + // Read the name when a custom algorithm is used + nameLen := int(r.byte()) + var algNameUtf16 = make([]byte, nameLen*2) + r.ReadFull(algNameUtf16) + algNameBytes, _ := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder().Bytes(algNameUtf16) + mAlgName := string(algNameBytes) + algName = &mAlgName + } + + encType := r.byte() + normRuleVer := r.byte() + + var entry *cekTableEntry = nil + + if cekTable != nil { + if int(ordinal) > len(cekTable.entries)-1 { + panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries))) + } + entry = &cekTable.entries[ordinal] + } + + return cryptoMetadata{ + entry: entry, + ordinal: ordinal, + algorithmId: algorithmId, + algorithmName: algName, + encType: encType, + normRuleVer: normRuleVer, + typeInfo: ti, + } +} + +func readCekTable(r *tdsBuffer) *cekTable { + tableSize := r.uint16() + var cekTable *cekTable = nil + + if tableSize != 0 { + mCekTable := newCekTable(tableSize) + for i := uint16(0); i < tableSize; i++ { + mCekTable.entries[i] = readCekTableEntry(r) + } + cekTable = &mCekTable + } + + return cekTable +} + +func readCekTableEntry(r *tdsBuffer) cekTableEntry { + databaseId := r.int32() + cekID := r.int32() + cekVersion := r.int32() + var cekMdVersion = make([]byte, 8) + _, err := r.Read(cekMdVersion) + if err != nil { + panic("unable to read cekMdVersion") + } + + cekValueCount := uint(r.byte()) + // not using ucs22str because we already know the data is utf16 + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) + utf16dec := enc.NewDecoder() + cekValues := make([]encryptionKeyInfo, cekValueCount) + + for i := uint(0); i < cekValueCount; i++ { + encryptedCekLength := r.uint16() + encryptedCek := make([]byte, encryptedCekLength) + r.ReadFull(encryptedCek) + + keyStoreLength := r.byte() + keyStoreNameUtf16 := make([]byte, keyStoreLength*2) + r.ReadFull(keyStoreNameUtf16) + keyStoreName, _ := utf16dec.Bytes(keyStoreNameUtf16) + + keyPathLength := r.uint16() + keyPathUtf16 := make([]byte, keyPathLength*2) + r.ReadFull(keyPathUtf16) + keyPath, _ := utf16dec.Bytes(keyPathUtf16) + + algLength := r.byte() + algNameUtf16 := make([]byte, algLength*2) + r.ReadFull(algNameUtf16) + algName, _ := utf16dec.Bytes(algNameUtf16) + + cekValues[i] = encryptionKeyInfo{ + encryptedKey: encryptedCek, + databaseID: int(databaseId), + cekID: int(cekID), + cekVersion: int(cekVersion), + cekMdVersion: cekMdVersion, + keyPath: string(keyPath), + keyStoreName: string(keyStoreName), + algorithmName: string(algName), + } + } + + return cekTableEntry{ + databaseID: int(databaseId), + keyId: int(cekID), + keyVersion: int(cekVersion), + mdVersion: cekMdVersion, + valueCount: int(cekValueCount), + cekValues: cekValues, + } +} + // http://msdn.microsoft.com/en-us/library/dd357254.aspx -func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { for i, column := range columns { - row[i] = column.ti.Reader(&column.ti, r) + columnContent := column.ti.Reader(&column.ti, r, nil) + if columnContent == nil { + row[i] = columnContent + continue + } + + if column.isEncrypted() { + buffer := decryptColumn(column, s, columnContent) + // Decrypt + row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer, column.cryptoMeta) + } else { + row[i] = columnContent + } } } +type RWCBuffer struct { + buffer *bytes.Reader +} + +func (R RWCBuffer) Read(p []byte) (n int, err error) { + return R.buffer.Read(p) +} + +func (R RWCBuffer) Write(p []byte) (n int, err error) { + return 0, nil +} + +func (R RWCBuffer) Close() error { + return nil +} + +func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer { + // Decrypt + cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] + s.logger.Log(context.Background(), msdsn.LogMessages, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName)) + + // returning empty data for now + newBuff := make([]byte, 0) + + rwc := RWCBuffer{ + buffer: bytes.NewReader(newBuff), + } + + column.cryptoMeta.typeInfo.Buffer = make([]byte, 0) + buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc} + return buffer +} + // http://msdn.microsoft.com/en-us/library/dd304783.aspx -func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseNbcRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { bitlen := (len(columns) + 7) / 8 pres := make([]byte, bitlen) r.ReadFull(pres) @@ -605,7 +842,15 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { row[i] = nil continue } - row[i] = col.ti.Reader(&col.ti, r) + columnContent := col.ti.Reader(&col.ti, r, nil) + if col.isEncrypted() { + buffer := decryptColumn(col, s, columnContent) + // Decrypt + row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, &buffer, col.cryptoMeta) + } else { + row[i] = columnContent + } + } } @@ -638,7 +883,7 @@ func parseInfo(r *tdsBuffer) (res Error) { } // https://msdn.microsoft.com/en-us/library/dd303881.aspx -func parseReturnValue(r *tdsBuffer) (nv namedValue) { +func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) { /* ParamOrdinal ParamName @@ -649,13 +894,21 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { CryptoMetadata Value */ - r.uint16() - nv.Name = r.BVarChar() - r.byte() - r.uint32() // UserType (uint16 prior to 7.2) - r.uint16() - ti := readTypeInfo(r) - nv.Value = ti.Reader(&ti, r) + _ = r.uint16() // ParamOrdinal + nv.Name = r.BVarChar() // ParamName + _ = r.byte() // Status + + ti := getBaseTypeInfo(r, true) // UserType + Flags + TypeInfo + + var cryptoMetadata *cryptoMetadata = nil + if s.alwaysEncrypted { + cm := parseCryptoMetadata(r, nil) // CryptoMetadata + cryptoMetadata = &cm + } + + ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata) + nv.Value = ti2.Reader(&ti2, r, cryptoMetadata) + return } @@ -782,7 +1035,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS return } case tokenColMetadata: - columns = parseColMetadata72(sess.buf) + columns = parseColMetadata72(sess.buf, sess) ch <- columns colsReceived = true if outs.msgq != nil { @@ -791,11 +1044,11 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS case tokenRow: row := make([]interface{}, len(columns)) - parseRow(sess.buf, columns, row) + parseRow(sess.buf, sess, columns, row) ch <- row case tokenNbcRow: row := make([]interface{}, len(columns)) - parseNbcRow(sess.buf, columns, row) + parseNbcRow(sess.buf, sess, columns, row) ch <- row case tokenEnvChange: processEnvChg(ctx, sess) @@ -823,7 +1076,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info}) } case tokenReturnValue: - nv := parseReturnValue(sess.buf) + nv := parseReturnValue(sess.buf, sess) if len(nv.Name) > 0 { name := nv.Name[1:] // Remove the leading "@". if ov, has := outs.params[name]; has { diff --git a/types.go b/types.go index 3b4760e3..1fd25a0d 100644 --- a/types.go +++ b/types.go @@ -89,6 +89,8 @@ const ( // http://msdn.microsoft.com/en-us/library/dd358284.aspx type typeInfo struct { TypeId uint8 + UserType uint32 + Flags uint16 Size int Scale uint8 Prec uint8 @@ -96,7 +98,7 @@ type typeInfo struct { Collation cp.Collation UdtInfo udtInfo XmlInfo xmlInfo - Reader func(ti *typeInfo, r *tdsBuffer) (res interface{}) + Reader func(ti *typeInfo, r *tdsBuffer, cryptoMeta *cryptoMetadata) (res interface{}) Writer func(w io.Writer, ti typeInfo, buf []byte) (err error) } @@ -119,9 +121,9 @@ type xmlInfo struct { XmlSchemaCollection string } -func readTypeInfo(r *tdsBuffer) (res typeInfo) { - res.TypeId = r.byte() - switch res.TypeId { +func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) { + res.TypeId = typeId + switch typeId { case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4, typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8: // those are fixed length types @@ -140,7 +142,7 @@ func readTypeInfo(r *tdsBuffer) (res typeInfo) { res.Reader = readFixedType res.Buffer = make([]byte, res.Size) default: // all others are VARLENTYPE - readVarLen(&res, r) + readVarLen(&res, r, c) } return } @@ -315,7 +317,7 @@ func decodeDateTime(buf []byte) time.Time { 0, 0, secs, ns, time.UTC) } -func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { +func readFixedType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { r.ReadFull(ti.Buffer) buf := ti.Buffer switch ti.TypeId { @@ -349,8 +351,13 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { panic("shoulnd't get here") } -func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { - size := r.byte() +func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + var size byte + if c != nil { + size = byte(r.rsize) + } else { + size = r.byte() + } if size == 0 { return nil } @@ -448,8 +455,13 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} { - size := r.uint16() +func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + var size uint16 + if c != nil { + size = uint16(r.rsize) + } else { + size = r.uint16() + } if size == 0xffff { return nil } @@ -491,7 +503,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} { +func readLongLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { // information about this format can be found here: // http://msdn.microsoft.com/en-us/library/dd304783.aspx // and here: @@ -566,7 +578,7 @@ func writeCollation(w io.Writer, col cp.Collation) (err error) { // reads variant value // http://msdn.microsoft.com/en-us/library/dd303302.aspx -func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { +func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { size := r.int32() if size == 0 { return nil @@ -658,7 +670,7 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { // partially length prefixed stream // http://msdn.microsoft.com/en-us/library/dd340469.aspx -func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { +func readPLPType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { size := r.uint64() var buf *bytes.Buffer switch size { @@ -719,7 +731,7 @@ func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { } } -func readVarLen(ti *typeInfo, r *tdsBuffer) { +func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) { switch ti.TypeId { case typeDateN: ti.Size = 3 From 0cc1a7e124ab2149b13925fd23a52f7178af797b Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 12 Jun 2023 12:37:17 -0500 Subject: [PATCH 04/47] Add skeleton for AE test --- alwaysencrypted_windows_test.go | 96 +++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 4 ++ 3 files changed, 101 insertions(+) create mode 100644 alwaysencrypted_windows_test.go diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go new file mode 100644 index 00000000..02fbf478 --- /dev/null +++ b/alwaysencrypted_windows_test.go @@ -0,0 +1,96 @@ +package mssql + +import ( + "bytes" + "fmt" + "os/exec" + "strings" + "testing" + + "github.com/Microsoft/go-winio/pkg/guid" +) + +func TestAlwaysEncryptedE2E(t *testing.T) { + params := testConnParams(t) + if !params.ColumnEncryption { + t.Skip("Test is not running with column encryption enabled") + } + conn, _ := open(t) + defer conn.Close() + certPath := provisionMasterKeyInCertStore(t) + s := fmt.Sprintf(createColumnMasterKey, certPath, certPath) + if _, err := conn.Exec(s); err != nil { + t.Fatalf("Unable to create CMK: %s", err.Error()) + } + defer conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) + // TODO: Implement encryption and insert encrypted values into a table using custom CEK + rows, err := conn.Query("select top (1) col1, col2 from Table_1") + if err != nil { + t.Fatalf("Unable to query encrypted columns: %s", err.Error()) + } + if !rows.Next() { + rows.Close() + t.Fatalf("rows.Next returned false") + } + var col1 string + var col2 int32 + err = rows.Scan(&col1, &col2) + if err != nil { + rows.Close() + t.Fatalf("rows.Scan failed: %s", err.Error()) + } + rows.Close() + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err() has non-nil value: %s", err.Error()) + } +} + +const ( + createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint}` + deleteUserCertScript = `Get-ChildItem Cert:\CurrentUser\My\%s | Remove-Item -DeleteKey` + createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= 'MSSQL_CERTIFICATE_STORE', KEY_PATH='%s')` + dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` +) + +func provisionMasterKeyInCertStore(t *testing.T) string { + t.Helper() + var g guid.GUID + var err error + if g, err = guid.NewV4(); err != nil { + t.Fatalf("Unable to allocate a guid %v", err.Error()) + } + subject := fmt.Sprintf(`gomssqltest-%s`, g.String()) + + cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(createUserCertScript, subject)) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + cmd.Stdout = buf + if err = cmd.Run(); err != nil { + t.Fatalf("Unable to create cert for encryption: %v", err.Error()) + } + out := buf.buf.String() + thumbPrint := strings.Trim(out[strings.LastIndex(out, "-"):], "\r\n") + return fmt.Sprintf(`CurrentUser/My/%s`, thumbPrint) +} + +func deleteMasterKeyCert(t *testing.T, thumbprint string) { + t.Helper() + cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(deleteUserCertScript, thumbprint)) + if err := cmd.Run; err != nil { + t.Fatalf("Unable to delete user cert %s", thumbprint) + } +} + +type memoryBuffer struct { + buf *bytes.Buffer +} + +func (b *memoryBuffer) Write(p []byte) (n int, err error) { + return b.buf.Write(p) +} + +func (b *memoryBuffer) Close() error { + return nil +} + +// C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe /ExecutionPolicy Unrestricted New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint} diff --git a/go.mod b/go.mod index f8d240b7..393ca36a 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 + github.com/Microsoft/go-winio v0.6.1 github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 github.com/golang-sql/sqlexp v0.1.0 github.com/jcmturner/gokrb5/v8 v8.4.4 diff --git a/go.sum b/go.sum index f37d8343..235f38d6 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInm github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -48,6 +50,7 @@ github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -83,6 +86,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 282072d708fc3e4702aae848f44fd5c2bb05dc00 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 14 Jun 2023 15:16:45 -0500 Subject: [PATCH 05/47] implement local cert key provider --- aecmk/localcert/keyprovider.go | 169 ++++++++++++++++ aecmk/localcert/keyprovider_darwin.go | 5 + aecmk/localcert/keyprovider_linux.go | 5 + aecmk/localcert/keyprovider_test.go | 15 ++ aecmk/localcert/keyprovider_windows.go | 44 ++++ aecmk/localcert/keyprovider_windows_test.go | 30 +++ alwaysencrypted_windows_test.go | 56 +---- columnencryptionkey.go | 8 + go.mod | 3 + go.sum | 9 + internal/certs/certs.go | 56 +++++ internal/certs/certs_windows.go | 214 ++++++++++++++++++++ 12 files changed, 565 insertions(+), 49 deletions(-) create mode 100644 aecmk/localcert/keyprovider.go create mode 100644 aecmk/localcert/keyprovider_darwin.go create mode 100644 aecmk/localcert/keyprovider_linux.go create mode 100644 aecmk/localcert/keyprovider_test.go create mode 100644 aecmk/localcert/keyprovider_windows.go create mode 100644 aecmk/localcert/keyprovider_windows_test.go create mode 100644 internal/certs/certs.go create mode 100644 internal/certs/certs_windows.go diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go new file mode 100644 index 00000000..10d366ea --- /dev/null +++ b/aecmk/localcert/keyprovider.go @@ -0,0 +1,169 @@ +package localcert + +import ( + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "fmt" + "io/ioutil" + "os" + "strconv" + "time" + + mssql "github.com/microsoft/go-mssqldb" + ae "github.com/swisscom/mssql-always-encrypted/pkg" + pkcs "software.sslmate.com/src/go-pkcs12" +) + +const ( + PfxKeyProviderName = "pfx" + wildcard = "*" +) + +// LocalCertProvider uses local certificates to decrypt CEKs +// It supports both 'MSSQL_CERTIFICATE_STORE' and 'pfx' key stores. +// MSSQL_CERTIFICATE_STORE key paths are of the form `storename/storepath/thumbprint` and only supported on Windows clients. +// pfx key paths are absolute file system paths that are operating system dependent. +type LocalCertProvider struct { + // Name identifies which key store the provider supports. + Name string + // AllowedLocations constrains which locations the provider will use to find certificates. If empty, all locations are allowed. + // When presented with a key store path not in the allowed list, the data will be returned still encrypted. + AllowedLocations []string + passwords map[string]string +} + +// SetCertificatePassword stores the password associated with the certificate at the given location. +// If location is empty the given password applies to all certificates that have not been explicitly assigned a value. +func (p LocalCertProvider) SetCertificatePassword(location string, password string) { + if location == "" { + location = wildcard + } + p.passwords[location] = password +} + +var PfxKeyProvider = LocalCertProvider{Name: PfxKeyProviderName, passwords: make(map[string]string), AllowedLocations: make([]string, 0)} + +func init() { + mssql.RegisterCekProvider(mssql.CertificateStoreKeyProvider, &PfxKeyProvider) +} + +// DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. +// The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. +func (p *LocalCertProvider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { + decryptedKey = nil + allowed := len(p.AllowedLocations) == 0 + if !allowed { + loop: + for _, l := range p.AllowedLocations { + if l == masterKeyPath { + allowed = true + break loop + } + } + } + if !allowed { + return + } + var cert *x509.Certificate + var pk interface{} + switch p.Name { + case PfxKeyProviderName: + pk, cert = p.loadLocalCertificate(masterKeyPath) + case mssql.CertificateStoreKeyProvider: + pk, cert = p.loadWindowsCertStoreCertificate(masterKeyPath) + default: + return + } + cekv := ae.LoadCEKV(encryptedCek) + if !cekv.Verify(cert) { + panic(fmt.Errorf("Invalid certificate provided for decryption. Key Store Path: %s. <%s>-<%v>", masterKeyPath, cekv.KeyPath, fmt.Sprintf("%02x", sha1.Sum(cert.Raw)))) + } + + decryptedKey, err := cekv.Decrypt(pk.(*rsa.PrivateKey)) + if err != nil { + panic(err) + } + return +} + +func (p *LocalCertProvider) loadLocalCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + if f, err := os.Open(path); err == nil { + pfxBytes, err := ioutil.ReadAll(f) + if err != nil { + panic(invalidCertificatePath(path, err)) + } + pwd, ok := p.passwords[path] + if !ok { + pwd, ok = p.passwords[wildcard] + if !ok { + pwd = "" + } + } + privateKey, cert, err = pkcs.Decode(pfxBytes, pwd) + if err != nil { + panic(err) + } + } else { + panic(invalidCertificatePath(path, err)) + } + return +} + +// EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. +func (p *LocalCertProvider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { + return nil +} + +// SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key +// referenced by the masterKeyPath parameter. The input values used to generate the signature should be the +// specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. +func (p *LocalCertProvider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte { + return nil +} + +// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key +// with the specified key path and the specified enclave behavior. Return nil if not supported. +func (p *LocalCertProvider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool { + return nil +} + +// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. +// If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. +// If it returns zero, the keys will not be cached. +func (p *LocalCertProvider) KeyLifetime() *time.Duration { + return nil +} + +// InvalidCertificatePathError indicates the provided path could not be used to load a certificate +type InvalidCertificatePathError struct { + path string + innerErr error +} + +func (i *InvalidCertificatePathError) Error() string { + return fmt.Sprintf("Invalid certificate path: %s", i.path) +} + +func (i *InvalidCertificatePathError) Unwrap() error { + return i.innerErr +} + +func invalidCertificatePath(path string, err error) error { + return &InvalidCertificatePathError{path: path, innerErr: err} +} + +func thumbprintToByteArray(thumbprint string) []byte { + if len(thumbprint)%2 != 0 { + panic(fmt.Errorf("Thumbprint must have even length %s", thumbprint)) + } + bytes := make([]byte, len(thumbprint)/2) + for i := range bytes { + b, err := strconv.ParseInt(thumbprint[i*2:(i*2)+2], 16, 32) + if err != nil { + panic(err) + } + bytes[i] = byte(b) + } + return bytes +} diff --git a/aecmk/localcert/keyprovider_darwin.go b/aecmk/localcert/keyprovider_darwin.go new file mode 100644 index 00000000..5842c08d --- /dev/null +++ b/aecmk/localcert/keyprovider_darwin.go @@ -0,0 +1,5 @@ +package localcert + +func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + panic(fmt.Errorf("Windows cert store not supported on this OS")) +} diff --git a/aecmk/localcert/keyprovider_linux.go b/aecmk/localcert/keyprovider_linux.go new file mode 100644 index 00000000..5842c08d --- /dev/null +++ b/aecmk/localcert/keyprovider_linux.go @@ -0,0 +1,5 @@ +package localcert + +func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + panic(fmt.Errorf("Windows cert store not supported on this OS")) +} diff --git a/aecmk/localcert/keyprovider_test.go b/aecmk/localcert/keyprovider_test.go new file mode 100644 index 00000000..c02354af --- /dev/null +++ b/aecmk/localcert/keyprovider_test.go @@ -0,0 +1,15 @@ +package localcert + +import ( + "bytes" + "encoding/hex" + "testing" +) + +func TestThumbPrintToSignature(t *testing.T) { + thumbprint := "5e89a107f0ade0aed5f753ecc60378b1bbae3598" + signature := thumbprintToByteArray(thumbprint) + if !bytes.Equal(signature, []byte{0x5e, 0x89, 0xa1, 0x07, 0xf0, 0xad, 0xe0, 0xae, 0xd5, 0xf7, 0x53, 0xec, 0xc6, 0x03, 0x78, 0xb1, 0xbb, 0xae, 0x35, 0x98}) { + t.Fatalf("Incorrect signature bytes for %s. Got: %s", thumbprint, hex.Dump(signature)) + } +} diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go new file mode 100644 index 00000000..6a95d08a --- /dev/null +++ b/aecmk/localcert/keyprovider_windows.go @@ -0,0 +1,44 @@ +package localcert + +import ( + "crypto/x509" + "fmt" + "strings" + "unsafe" + + "github.com/microsoft/go-mssqldb/internal/certs" + "golang.org/x/sys/windows" +) + +func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + privateKey = nil + cert = nil + pathParts := strings.Split(path, `/`) + if len(pathParts) != 3 { + panic(invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments"))) + } + + var storeId uint32 + switch strings.ToLower(pathParts[0]) { + case "localmachine": + storeId = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE + case "currentuser": + storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER + default: + panic(invalidCertificatePath(path, fmt.Errorf("Unknown certificate store"))) + } + system, err := windows.UTF16PtrFromString(pathParts[1]) + if err != nil { + panic(err) + } + h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, + windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, + 0, + storeId, uintptr(unsafe.Pointer(system))) + if err != nil { + panic(err) + } + defer windows.CertCloseStore(h, 0) + signature := thumbprintToByteArray(pathParts[2]) + return certs.FindCertBySignatureHash(h, signature) +} diff --git a/aecmk/localcert/keyprovider_windows_test.go b/aecmk/localcert/keyprovider_windows_test.go new file mode 100644 index 00000000..e67443b2 --- /dev/null +++ b/aecmk/localcert/keyprovider_windows_test.go @@ -0,0 +1,30 @@ +package localcert + +import ( + "crypto/rsa" + "strings" + "testing" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/internal/certs" +) + +func TestLoadWindowsCertStoreCertificate(t *testing.T) { + thumbprint, err := certs.ProvisionMasterKeyInCertStore() + if err != nil { + t.Fatal(err) + } + defer certs.DeleteMasterKeyCert(thumbprint) + provider := &LocalCertProvider{Name: mssql.AzureKeyVaultKeyProvider} + pk, cert := provider.loadWindowsCertStoreCertificate("CurrentUser/My/" + thumbprint) + switch z := pk.(type) { + case *rsa.PrivateKey: + + t.Logf("Got an rsa.PrivateKey with size %d", z.Size()) + default: + t.Fatalf("Unexpected private key type: %v", z) + } + if !strings.HasPrefix(cert.Subject.String(), `CN=gomssqltest-`) { + t.Fatalf("Wrong cert loaded: %s", cert.Subject.String()) + } +} diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index 02fbf478..07994e09 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -1,13 +1,10 @@ package mssql import ( - "bytes" "fmt" - "os/exec" - "strings" "testing" - "github.com/Microsoft/go-winio/pkg/guid" + "github.com/microsoft/go-mssqldb/internal/certs" ) func TestAlwaysEncryptedE2E(t *testing.T) { @@ -17,7 +14,12 @@ func TestAlwaysEncryptedE2E(t *testing.T) { } conn, _ := open(t) defer conn.Close() - certPath := provisionMasterKeyInCertStore(t) + thumbprint, err := certs.ProvisionMasterKeyInCertStore() + if err != nil { + t.Fatal(err) + } + defer certs.DeleteMasterKeyCert(thumbprint) + certPath := fmt.Sprintf(`CurrentUser/My/%s`, thumbprint) s := fmt.Sprintf(createColumnMasterKey, certPath, certPath) if _, err := conn.Exec(s); err != nil { t.Fatalf("Unable to create CMK: %s", err.Error()) @@ -47,50 +49,6 @@ func TestAlwaysEncryptedE2E(t *testing.T) { } const ( - createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint}` - deleteUserCertScript = `Get-ChildItem Cert:\CurrentUser\My\%s | Remove-Item -DeleteKey` createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= 'MSSQL_CERTIFICATE_STORE', KEY_PATH='%s')` dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` ) - -func provisionMasterKeyInCertStore(t *testing.T) string { - t.Helper() - var g guid.GUID - var err error - if g, err = guid.NewV4(); err != nil { - t.Fatalf("Unable to allocate a guid %v", err.Error()) - } - subject := fmt.Sprintf(`gomssqltest-%s`, g.String()) - - cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(createUserCertScript, subject)) - buf := &memoryBuffer{buf: new(bytes.Buffer)} - cmd.Stdout = buf - if err = cmd.Run(); err != nil { - t.Fatalf("Unable to create cert for encryption: %v", err.Error()) - } - out := buf.buf.String() - thumbPrint := strings.Trim(out[strings.LastIndex(out, "-"):], "\r\n") - return fmt.Sprintf(`CurrentUser/My/%s`, thumbPrint) -} - -func deleteMasterKeyCert(t *testing.T, thumbprint string) { - t.Helper() - cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(deleteUserCertScript, thumbprint)) - if err := cmd.Run; err != nil { - t.Fatalf("Unable to delete user cert %s", thumbprint) - } -} - -type memoryBuffer struct { - buf *bytes.Buffer -} - -func (b *memoryBuffer) Write(p []byte) (n int, err error) { - return b.buf.Write(p) -} - -func (b *memoryBuffer) Close() error { - return nil -} - -// C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe /ExecutionPolicy Unrestricted New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint} diff --git a/columnencryptionkey.go b/columnencryptionkey.go index 5220a8d3..cfbad2ca 100644 --- a/columnencryptionkey.go +++ b/columnencryptionkey.go @@ -5,6 +5,14 @@ import ( "time" ) +const ( + CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE" + CspKeyProvider = "MSSQL_CSP_PROVIDER" + CngKeyProvider = "MSSQL_CNG_STORE" + AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT" + JavaKeyProvider = "MSSQL_JAVA_KEYSTORE" +) + // cek ==> Column Encryption Key // Every row of an encrypted table has an associated list of keys used to decrypt its columns type cekTable struct { diff --git a/go.mod b/go.mod index 393ca36a..370a3e7c 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,10 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 github.com/golang-sql/sqlexp v0.1.0 github.com/jcmturner/gokrb5/v8 v8.4.4 + github.com/swisscom/mssql-always-encrypted v0.1.3 golang.org/x/crypto v0.9.0 + golang.org/x/sys v0.8.0 golang.org/x/text v0.9.0 gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce + software.sslmate.com/src/go-pkcs12 v0.2.0 ) diff --git a/go.sum b/go.sum index 235f38d6..92daeded 100644 --- a/go.sum +++ b/go.sum @@ -60,9 +60,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/swisscom/mssql-always-encrypted v0.1.3 h1:+Q7sa71G2taM4SmwyNfPIB1iB8750iKNJEJQvqtlB38= +github.com/swisscom/mssql-always-encrypted v0.1.3/go.mod h1:FlEWLI3+svdMFq2w7GVMvk7iVhwBEBi7E7llAHb4B20= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= @@ -72,6 +75,7 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -83,6 +87,7 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -99,6 +104,8 @@ golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= @@ -121,3 +128,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +software.sslmate.com/src/go-pkcs12 v0.2.0 h1:nlFkj7bTysH6VkC4fGphtjXRbezREPgrHuJG20hBGPE= +software.sslmate.com/src/go-pkcs12 v0.2.0/go.mod h1:23rNcYsMabIc1otwLpTkCCPwUq6kQsTyowttG/as0kQ= diff --git a/internal/certs/certs.go b/internal/certs/certs.go new file mode 100644 index 00000000..9ddbc519 --- /dev/null +++ b/internal/certs/certs.go @@ -0,0 +1,56 @@ +package certs + +import ( + "bytes" + "fmt" + "os/exec" + "strings" + + "github.com/Microsoft/go-winio/pkg/guid" +) + +const ( + createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint}` + deleteUserCertScript = `Get-ChildItem Cert:\CurrentUser\My\%s | Remove-Item -DeleteKey` +) + +func ProvisionMasterKeyInCertStore() (thumbprint string, err error) { + var g guid.GUID + if g, err = guid.NewV4(); err != nil { + return + } + subject := fmt.Sprintf(`gomssqltest-%s`, g.String()) + + cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(createUserCertScript, subject)) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + cmd.Stdout = buf + if err = cmd.Run(); err != nil { + err = fmt.Errorf("Unable to create cert for encryption: %v", err.Error()) + return + } + out := buf.buf.String() + thumbprint = strings.Trim(out[strings.LastIndex(out, "-")+1:], " \r\n") + return +} + +func DeleteMasterKeyCert(thumbprint string) error { + cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(deleteUserCertScript, thumbprint)) + if err := cmd.Run(); err != nil { + return fmt.Errorf("Unable to delete user cert %s. %s", thumbprint, err.Error()) + } + return nil +} + +type memoryBuffer struct { + buf *bytes.Buffer +} + +func (b *memoryBuffer) Write(p []byte) (n int, err error) { + return b.buf.Write(p) +} + +func (b *memoryBuffer) Close() error { + return nil +} + +// C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe /ExecutionPolicy Unrestricted New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint} diff --git a/internal/certs/certs_windows.go b/internal/certs/certs_windows.go new file mode 100644 index 00000000..d3ec2415 --- /dev/null +++ b/internal/certs/certs_windows.go @@ -0,0 +1,214 @@ +package certs + +import ( + "bytes" + "crypto/rsa" + "crypto/x509" + "encoding/binary" + "errors" + "fmt" + "math/big" + "reflect" + + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +func FindCertBySignatureHash(storeHandle windows.Handle, hash []byte) (interface{}, *x509.Certificate) { + var certContext *windows.CertContext + var prevCertContext *windows.CertContext + var err error + cryptoAPIBlob := windows.CryptHashBlob{ + Size: uint32(len(hash)), + Data: &hash[0], + } + + for { + certContext, err = windows.CertFindCertificateInStore( + storeHandle, + windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, + 0, + windows.CERT_FIND_HASH, + unsafe.Pointer(&cryptoAPIBlob), + prevCertContext) + if certContext == nil || err != nil { + break + } + prevCertContext = certContext + } + + if prevCertContext == nil { + if err == nil { + err = syscall.GetLastError() + } + panic(fmt.Errorf("Unable to find certificate by signature hash. %s", err.Error())) + } + + pk, cert, err := certContextToX509(prevCertContext) + if err != nil { + panic(err) + } + + return pk, cert +} + +func certContextToX509(ctx *windows.CertContext) (pk interface{}, cert *x509.Certificate, err error) { + var der []byte + slice := (*reflect.SliceHeader)(unsafe.Pointer(&der)) + slice.Data = uintptr(unsafe.Pointer(ctx.EncodedCert)) + slice.Len = int(ctx.Length) + slice.Cap = int(ctx.Length) + cert, err = x509.ParseCertificate(der) + if err != nil { + return + } + var kh windows.Handle + var keySpec uint32 + var freeProvOrKey bool + err = windows.CryptAcquireCertificatePrivateKey(ctx, windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG, nil, &kh, &keySpec, &freeProvOrKey) + if err != nil { + return + } + + pkBytes, err := nCryptExportKey(kh, "RSAFULLPRIVATEBLOB") + if freeProvOrKey { + _, _, _ = procNCryptFreeObject.Call(uintptr(kh)) + } + if err != nil { + return + } + + pk, err = unmarshalRSA(pkBytes) + return +} + +var ( + nCrypt = windows.MustLoadDLL("ncrypt.dll") + procNCryptExportKey = nCrypt.MustFindProc("NCryptExportKey") + procNCryptFreeObject = nCrypt.MustFindProc("NCryptFreeObject") +) + +// wide returns a pointer to a uint16 representing the equivalent +// to a Windows LPCWSTR. +func wide(s string) *uint16 { + w, _ := windows.UTF16PtrFromString(s) + return w +} + +func nCryptExportKey(kh windows.Handle, blobType string) ([]byte, error) { + var size uint32 + // When obtaining the size of a public key, most parameters are not required + r, _, err := procNCryptExportKey.Call( + uintptr(kh), + 0, + uintptr(unsafe.Pointer(wide(blobType))), + 0, + 0, + 0, + uintptr(unsafe.Pointer(&size)), + 0) + if !errors.Is(err, windows.Errno(0)) { + return nil, fmt.Errorf("nCryptExportKey returned %w", err) + } + if r != 0 { + return nil, fmt.Errorf("NCryptExportKey returned 0x%X during size check", uint32(r)) + } + + // Place the exported key in buf now that we know the size required + buf := make([]byte, size) + r, _, err = procNCryptExportKey.Call( + uintptr(kh), + 0, + uintptr(unsafe.Pointer(wide(blobType))), + 0, + uintptr(unsafe.Pointer(&buf[0])), + uintptr(size), + uintptr(unsafe.Pointer(&size)), + 0) + if !errors.Is(err, windows.Errno(0)) { + return nil, fmt.Errorf("nCryptExportKey returned %w", err) + } + if r != 0 { + return nil, fmt.Errorf("NCryptExportKey returned 0x%X during export", uint32(r)) + } + return buf, nil +} + +// TODO: See if we can rewrite this to avoid copying the data from buf twice per field +func unmarshalRSA(buf []byte) (*rsa.PrivateKey, error) { + // BCRYPT_RSA_BLOB -- https://learn.microsoft.com/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob + header := struct { + Magic uint32 + BitLength uint32 + PublicExpSize uint32 + ModulusSize uint32 + Prime1Size uint32 + Prime2Size uint32 + }{} + + r := bytes.NewReader(buf) + if err := binary.Read(r, binary.LittleEndian, &header); err != nil { + return nil, err + } + + if header.Magic != 0x33415352 { // "RSA3" BCRYPT_RSAFULLPRIVATE_MAGIC + return nil, fmt.Errorf("invalid header magic %x", header.Magic) + } + + if header.PublicExpSize > 8 { + return nil, fmt.Errorf("unsupported public exponent size (%d bits)", header.PublicExpSize*8) + } + + // the exponent is in BigEndian format, so read the data into the right place in the buffer + exp := make([]byte, 8) + n, err := r.Read(exp[8-header.PublicExpSize:]) + + if err != nil { + return nil, fmt.Errorf("failed to read public exponent %w", err) + } + + if n != int(header.PublicExpSize) { + return nil, fmt.Errorf("failed to read correct public exponent size, read %d expected %d", n, int(header.PublicExpSize)) + } + + mod := make([]byte, header.ModulusSize) + n, err = r.Read(mod) + + if err != nil { + return nil, fmt.Errorf("failed to read modulus %w", err) + } + + if n != int(header.ModulusSize) { + return nil, fmt.Errorf("failed to read correct modulus size, read %d expected %d", n, int(header.ModulusSize)) + } + + pk := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: new(big.Int).SetBytes(mod), + E: int(binary.BigEndian.Uint64(exp)), + }, + D: new(big.Int), + Primes: make([]*big.Int, 2), + } + prime := make([]byte, header.Prime1Size) + n, err = r.Read(prime) + if err != nil { + return nil, fmt.Errorf("failed to read prime1 %w", err) + } + pk.Primes[0] = new(big.Int).SetBytes(prime) + prime = make([]byte, header.Prime2Size) + n, err = r.Read(prime) + if err != nil { + return nil, fmt.Errorf("failed to read prime2 %w", err) + } + pk.Primes[1] = new(big.Int).SetBytes(prime) + expBytes := make([]byte, 2*header.Prime1Size+header.Prime2Size+header.ModulusSize) + n, err = r.Read(expBytes) + if err != nil { + return nil, fmt.Errorf("Unable to read PrivateExponent %w", err) + } + pk.D = new(big.Int).SetBytes(expBytes[2*header.Prime1Size+header.Prime2Size:]) + return pk, nil +} From bcd9e1ce25110eb7b4681b51433d016e56c5ddbf Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 14 Jun 2023 15:22:13 -0500 Subject: [PATCH 06/47] build fixes --- aecmk/localcert/keyprovider_darwin.go | 6 ++++++ aecmk/localcert/keyprovider_linux.go | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/aecmk/localcert/keyprovider_darwin.go b/aecmk/localcert/keyprovider_darwin.go index 5842c08d..a3a0e7d6 100644 --- a/aecmk/localcert/keyprovider_darwin.go +++ b/aecmk/localcert/keyprovider_darwin.go @@ -1,5 +1,11 @@ package localcert +import ( + "crypto/x509" + "fmt" +) + func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { panic(fmt.Errorf("Windows cert store not supported on this OS")) + return } diff --git a/aecmk/localcert/keyprovider_linux.go b/aecmk/localcert/keyprovider_linux.go index 5842c08d..a3a0e7d6 100644 --- a/aecmk/localcert/keyprovider_linux.go +++ b/aecmk/localcert/keyprovider_linux.go @@ -1,5 +1,11 @@ package localcert +import ( + "crypto/x509" + "fmt" +) + func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { panic(fmt.Errorf("Windows cert store not supported on this OS")) + return } From 1c7a2e319385a42f2e2ed552dbd72bca677967b5 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 14 Jun 2023 17:40:26 -0500 Subject: [PATCH 07/47] use key providers for decrypt --- aecmk/localcert/keyprovider.go | 8 +++---- aecmk/localcert/keyprovider_windows.go | 8 +++++++ columnencryptionkey.go | 4 ++-- mssql.go | 29 ++++++++++++++++---------- tds.go | 17 ++++++++++----- token.go | 21 ++++++++++++++++--- 6 files changed, 62 insertions(+), 25 deletions(-) diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index 10d366ea..e21e94cc 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -26,7 +26,7 @@ const ( // pfx key paths are absolute file system paths that are operating system dependent. type LocalCertProvider struct { // Name identifies which key store the provider supports. - Name string + name string // AllowedLocations constrains which locations the provider will use to find certificates. If empty, all locations are allowed. // When presented with a key store path not in the allowed list, the data will be returned still encrypted. AllowedLocations []string @@ -42,10 +42,10 @@ func (p LocalCertProvider) SetCertificatePassword(location string, password stri p.passwords[location] = password } -var PfxKeyProvider = LocalCertProvider{Name: PfxKeyProviderName, passwords: make(map[string]string), AllowedLocations: make([]string, 0)} +var PfxKeyProvider = LocalCertProvider{name: PfxKeyProviderName, passwords: make(map[string]string), AllowedLocations: make([]string, 0)} func init() { - mssql.RegisterCekProvider(mssql.CertificateStoreKeyProvider, &PfxKeyProvider) + mssql.RegisterCekProvider("pfx", &PfxKeyProvider) } // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. @@ -67,7 +67,7 @@ func (p *LocalCertProvider) DecryptColumnEncryptionKey(masterKeyPath string, enc } var cert *x509.Certificate var pk interface{} - switch p.Name { + switch p.name { case PfxKeyProviderName: pk, cert = p.loadLocalCertificate(masterKeyPath) case mssql.CertificateStoreKeyProvider: diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index 6a95d08a..9b599f5a 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -6,10 +6,18 @@ import ( "strings" "unsafe" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/internal/certs" "golang.org/x/sys/windows" ) +var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: mssql.CertificateStoreKeyProvider, passwords: make(map[string]string)} + +func init() { + mssql.RegisterCekProvider(mssql.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) +} + func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { privateKey = nil cert = nil diff --git a/columnencryptionkey.go b/columnencryptionkey.go index cfbad2ca..5d694c09 100644 --- a/columnencryptionkey.go +++ b/columnencryptionkey.go @@ -60,7 +60,7 @@ type cekProvider struct { } // no synchronization on this map. Providers register during init. -type columnEncryptionKeyProviderMap map[string]cekProvider +type columnEncryptionKeyProviderMap map[string]*cekProvider var globalCekProviderFactoryMap = columnEncryptionKeyProviderMap{} @@ -90,6 +90,6 @@ func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) erro if ok { return fmt.Errorf("CEK provider %s is already registered", name) } - globalCekProviderFactoryMap[name] = cekProvider{provider: provider} + globalCekProviderFactoryMap[name] = &cekProvider{provider: provider, decryptedKeys: make(cekCache)} return nil } diff --git a/mssql.go b/mssql.go index a875e8a4..0771c139 100644 --- a/mssql.go +++ b/mssql.go @@ -69,10 +69,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) { return nil, err } - return &Connector{ - params: params, - driver: d, - }, nil + return newConnector(params, d), nil } func (d *Driver) Open(dsn string) (driver.Conn, error) { @@ -122,10 +119,8 @@ func NewConnector(dsn string) (*Connector, error) { if err != nil { return nil, err } - c := &Connector{ - params: params, - driver: driverInstanceNoProcess, - } + c := newConnector(params, driverInstanceNoProcess) + return c, nil } @@ -146,9 +141,14 @@ func NewConnectorWithAccessTokenProvider(dsn string, tokenProvider func(ctx cont // NewConnectorConfig creates a new Connector for a DSN Config struct. // The returned connector may be used with sql.OpenDB. func NewConnectorConfig(config msdsn.Config) *Connector { + return newConnector(config, driverInstanceNoProcess) +} + +func newConnector(config msdsn.Config, driver *Driver) *Connector { return &Connector{ - params: config, - driver: driverInstanceNoProcess, + params: config, + driver: driver, + keyProviders: make(columnEncryptionKeyProviderMap), } } @@ -197,6 +197,8 @@ type Connector struct { // Dialer sets a custom dialer for all network operations. // If Dialer is not set, normal net dialers are used. Dialer Dialer + + keyProviders columnEncryptionKeyProviderMap } type Dialer interface { @@ -210,6 +212,11 @@ func (c *Connector) getDialer(p *msdsn.Config) Dialer { return createDialer(p) } +// RegisterCekProvider associated the given provider with the named key store. If an entry of the given name already exists, that entry is overwritten +func (c *Connector) RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) { + c.keyProviders[name] = &cekProvider{provider: provider, decryptedKeys: make(cekCache)} +} + type Conn struct { connector *Connector sess *tdsSession @@ -394,7 +401,7 @@ func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) { if err != nil { return nil, err } - c := &Connector{params: params} + c := newConnector(params, nil) return d.connect(ctx, c, params) } diff --git a/tds.go b/tds.go index 1a22b241..83336df2 100644 --- a/tds.go +++ b/tds.go @@ -172,7 +172,8 @@ type tdsSession struct { } type alwaysEncryptedSettings struct { - enclaveType string + enclaveType string + keyProviders columnEncryptionKeyProviderMap } const ( @@ -1154,11 +1155,17 @@ initiate_connection: outbuf := newTdsBuffer(packetSize, toconn) sess := tdsSession{ - buf: outbuf, - logger: logger, - logFlags: uint64(p.LogFlags), + buf: outbuf, + logger: logger, + logFlags: uint64(p.LogFlags), + aeSettings: &alwaysEncryptedSettings{keyProviders: make(columnEncryptionKeyProviderMap)}, + } + for i, p := range globalCekProviderFactoryMap { + sess.aeSettings.keyProviders[i] = p + } + for i, p := range c.keyProviders { + sess.aeSettings.keyProviders[i] = p } - fedAuth := &featureExtFedAuth{ FedAuthLibrary: FedAuthLibraryReserved, } diff --git a/token.go b/token.go index cc7f68ea..3c15e60d 100644 --- a/token.go +++ b/token.go @@ -12,6 +12,9 @@ import ( "github.com/golang-sql/sqlexp" "github.com/microsoft/go-mssqldb/msdsn" + "github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/swisscom/mssql-always-encrypted/pkg/keys" "golang.org/x/text/encoding/unicode" ) @@ -816,12 +819,24 @@ func (R RWCBuffer) Close() error { } func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer { - // Decrypt + encType := encryption.From(column.cryptoMeta.encType) cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] s.logger.Log(context.Background(), msdsn.LogMessages, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName)) - // returning empty data for now - newBuff := make([]byte, 0) + cekProvider, ok := s.aeSettings.keyProviders[cekValue.keyStoreName] + if !ok { + panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName)) + } + cek := cekProvider.provider.DecryptColumnEncryptionKey(cekValue.keyPath, cekValue.algorithmName, column.cryptoMeta.entry.cekValues[0].encryptedKey) + k := keys.NewAeadAes256CbcHmac256(cek) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion)) + d, err := alg.Decrypt(columnContent.([]byte)) + if err != nil { + panic(err) + } + + var newBuff []byte + newBuff = append(newBuff, d...) rwc := RWCBuffer{ buffer: bytes.NewReader(newBuff), From fb7a0815495a5491a1c2bfce65fb02182b19667d Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 14 Jun 2023 21:46:18 -0500 Subject: [PATCH 08/47] refactor packages to avoid cycle --- aecmk/keyprovider.go | 73 +++++++++++++++++++++ aecmk/localcert/keyprovider.go | 6 +- aecmk/localcert/keyprovider_windows.go | 7 +- aecmk/localcert/keyprovider_windows_test.go | 4 +- alwaysencrypted_windows_test.go | 3 + columnencryptionkey.go | 56 ---------------- internal/certs/certs.go | 2 +- internal/certs/certs_windows.go | 44 +++++-------- mssql.go | 9 +-- tds.go | 11 ++-- token.go | 4 +- 11 files changed, 113 insertions(+), 106 deletions(-) create mode 100644 aecmk/keyprovider.go diff --git a/aecmk/keyprovider.go b/aecmk/keyprovider.go new file mode 100644 index 00000000..69f784df --- /dev/null +++ b/aecmk/keyprovider.go @@ -0,0 +1,73 @@ +package aecmk + +import ( + "fmt" + "time" +) + +const ( + CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE" + CspKeyProvider = "MSSQL_CSP_PROVIDER" + CngKeyProvider = "MSSQL_CNG_STORE" + AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT" + JavaKeyProvider = "MSSQL_JAVA_KEYSTORE" +) + +// ColumnEncryptionKeyLifetime is the default lifetime of decrypted Column Encryption Keys in the global cache. +// The default is 2 hours +var ColumnEncryptionKeyLifetime time.Duration = 2 * time.Hour + +type CekCacheEntry struct { + Expiry time.Time + Key []byte +} + +type CekCache map[string]CekCacheEntry + +type CekProvider struct { + Provider ColumnEncryptionKeyProvider + DecryptedKeys CekCache +} + +// no synchronization on this map. Providers register during init. +type ColumnEncryptionKeyProviderMap map[string]*CekProvider + +var globalCekProviderFactoryMap = ColumnEncryptionKeyProviderMap{} + +// ColumnEncryptionKeyProvider is the interface for decrypting and encrypting column encryption keys. +// It is similar to .Net https://learn.microsoft.com/dotnet/api/microsoft.data.sqlclient.sqlcolumnencryptionkeystoreprovider. +type ColumnEncryptionKeyProvider interface { + // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. + // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. + DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) []byte + // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. + EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte + // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key + // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the + // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. + SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte + // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key + // with the specified key path and the specified enclave behavior. Return nil if not supported. + VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool + // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. + // If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. + // If it returns zero, the keys will not be cached. + KeyLifetime() *time.Duration +} + +func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) error { + _, ok := globalCekProviderFactoryMap[name] + if ok { + return fmt.Errorf("CEK provider %s is already registered", name) + } + globalCekProviderFactoryMap[name] = &CekProvider{Provider: provider, DecryptedKeys: CekCache{}} + return nil +} + +func GetGlobalCekProviders() (providers ColumnEncryptionKeyProviderMap) { + providers = make(ColumnEncryptionKeyProviderMap) + for i, p := range globalCekProviderFactoryMap { + providers[i] = p + } + return +} diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index e21e94cc..23453bdd 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -10,7 +10,7 @@ import ( "strconv" "time" - mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/aecmk" ae "github.com/swisscom/mssql-always-encrypted/pkg" pkcs "software.sslmate.com/src/go-pkcs12" ) @@ -45,7 +45,7 @@ func (p LocalCertProvider) SetCertificatePassword(location string, password stri var PfxKeyProvider = LocalCertProvider{name: PfxKeyProviderName, passwords: make(map[string]string), AllowedLocations: make([]string, 0)} func init() { - mssql.RegisterCekProvider("pfx", &PfxKeyProvider) + aecmk.RegisterCekProvider("pfx", &PfxKeyProvider) } // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. @@ -70,7 +70,7 @@ func (p *LocalCertProvider) DecryptColumnEncryptionKey(masterKeyPath string, enc switch p.name { case PfxKeyProviderName: pk, cert = p.loadLocalCertificate(masterKeyPath) - case mssql.CertificateStoreKeyProvider: + case aecmk.CertificateStoreKeyProvider: pk, cert = p.loadWindowsCertStoreCertificate(masterKeyPath) default: return diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index 9b599f5a..cf03a397 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -6,16 +6,15 @@ import ( "strings" "unsafe" - mssql "github.com/microsoft/go-mssqldb" - + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/internal/certs" "golang.org/x/sys/windows" ) -var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: mssql.CertificateStoreKeyProvider, passwords: make(map[string]string)} +var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} func init() { - mssql.RegisterCekProvider(mssql.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) + aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) } func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { diff --git a/aecmk/localcert/keyprovider_windows_test.go b/aecmk/localcert/keyprovider_windows_test.go index e67443b2..6e169564 100644 --- a/aecmk/localcert/keyprovider_windows_test.go +++ b/aecmk/localcert/keyprovider_windows_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/internal/certs" ) @@ -15,7 +15,7 @@ func TestLoadWindowsCertStoreCertificate(t *testing.T) { t.Fatal(err) } defer certs.DeleteMasterKeyCert(thumbprint) - provider := &LocalCertProvider{Name: mssql.AzureKeyVaultKeyProvider} + provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*LocalCertProvider) pk, cert := provider.loadWindowsCertStoreCertificate("CurrentUser/My/" + thumbprint) switch z := pk.(type) { case *rsa.PrivateKey: diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index 07994e09..daf8fc56 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + _ "github.com/microsoft/go-mssqldb/aecmk/localcert" "github.com/microsoft/go-mssqldb/internal/certs" ) @@ -26,6 +27,8 @@ func TestAlwaysEncryptedE2E(t *testing.T) { } defer conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) // TODO: Implement encryption and insert encrypted values into a table using custom CEK + // Currently this test only passes when run against a particular database whose + // columns are encrypted using a cert on a developer's machine. rows, err := conn.Query("select top (1) col1, col2 from Table_1") if err != nil { t.Fatalf("Unable to query encrypted columns: %s", err.Error()) diff --git a/columnencryptionkey.go b/columnencryptionkey.go index 5d694c09..257c1df0 100644 --- a/columnencryptionkey.go +++ b/columnencryptionkey.go @@ -1,10 +1,5 @@ package mssql -import ( - "fmt" - "time" -) - const ( CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE" CspKeyProvider = "MSSQL_CSP_PROVIDER" @@ -42,54 +37,3 @@ type cekTableEntry struct { func newCekTable(size uint16) cekTable { return cekTable{entries: make([]cekTableEntry, size)} } - -// ColumnEncryptionKeyLifetime is the default lifetime of decrypted Column Encryption Keys in the global cache. -// The default is 2 hours -var ColumnEncryptionKeyLifetime time.Duration = 2 * time.Hour - -type cekCacheEntry struct { - expiry time.Time - key []byte -} - -type cekCache map[string]cekCacheEntry - -type cekProvider struct { - provider ColumnEncryptionKeyProvider - decryptedKeys cekCache -} - -// no synchronization on this map. Providers register during init. -type columnEncryptionKeyProviderMap map[string]*cekProvider - -var globalCekProviderFactoryMap = columnEncryptionKeyProviderMap{} - -// ColumnEncryptionKeyProvider is the interface for decrypting and encrypting column encryption keys. -// It is similar to .Net https://learn.microsoft.com/dotnet/api/microsoft.data.sqlclient.sqlcolumnencryptionkeystoreprovider. -type ColumnEncryptionKeyProvider interface { - // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. - // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. - DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) []byte - // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. - EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte - // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key - // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the - // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. - SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte - // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key - // with the specified key path and the specified enclave behavior. Return nil if not supported. - VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool - // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. - // If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. - // If it returns zero, the keys will not be cached. - KeyLifetime() *time.Duration -} - -func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) error { - _, ok := globalCekProviderFactoryMap[name] - if ok { - return fmt.Errorf("CEK provider %s is already registered", name) - } - globalCekProviderFactoryMap[name] = &cekProvider{provider: provider, decryptedKeys: make(cekCache)} - return nil -} diff --git a/internal/certs/certs.go b/internal/certs/certs.go index 9ddbc519..24e372a7 100644 --- a/internal/certs/certs.go +++ b/internal/certs/certs.go @@ -10,7 +10,7 @@ import ( ) const ( - createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint}` + createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 -HashAlgorithm 'SHA256' | select {$_.Thumbprint}` deleteUserCertScript = `Get-ChildItem Cert:\CurrentUser\My\%s | Remove-Item -DeleteKey` ) diff --git a/internal/certs/certs_windows.go b/internal/certs/certs_windows.go index d3ec2415..757db289 100644 --- a/internal/certs/certs_windows.go +++ b/internal/certs/certs_windows.go @@ -8,9 +8,7 @@ import ( "errors" "fmt" "math/big" - "reflect" - "syscall" "unsafe" "golang.org/x/sys/windows" @@ -18,35 +16,25 @@ import ( func FindCertBySignatureHash(storeHandle windows.Handle, hash []byte) (interface{}, *x509.Certificate) { var certContext *windows.CertContext - var prevCertContext *windows.CertContext var err error cryptoAPIBlob := windows.CryptHashBlob{ Size: uint32(len(hash)), Data: &hash[0], } - for { - certContext, err = windows.CertFindCertificateInStore( - storeHandle, - windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, - 0, - windows.CERT_FIND_HASH, - unsafe.Pointer(&cryptoAPIBlob), - prevCertContext) - if certContext == nil || err != nil { - break - } - prevCertContext = certContext - } - - if prevCertContext == nil { - if err == nil { - err = syscall.GetLastError() - } + certContext, err = windows.CertFindCertificateInStore( + storeHandle, + windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, + 0, + windows.CERT_FIND_HASH, + unsafe.Pointer(&cryptoAPIBlob), + nil) + + if err != nil { + panic(fmt.Errorf("Unable to find certificate by signature hash. %s", err.Error())) } - - pk, cert, err := certContextToX509(prevCertContext) + pk, cert, err := certContextToX509(certContext) if err != nil { panic(err) } @@ -55,11 +43,11 @@ func FindCertBySignatureHash(storeHandle windows.Handle, hash []byte) (interface } func certContextToX509(ctx *windows.CertContext) (pk interface{}, cert *x509.Certificate, err error) { - var der []byte - slice := (*reflect.SliceHeader)(unsafe.Pointer(&der)) - slice.Data = uintptr(unsafe.Pointer(ctx.EncodedCert)) - slice.Len = int(ctx.Length) - slice.Cap = int(ctx.Length) + // To ensure we don't mess with the cert context's memory, use a copy of it. + src := (*[1 << 20]byte)(unsafe.Pointer(ctx.EncodedCert))[:ctx.Length:ctx.Length] + der := make([]byte, int(ctx.Length)) + copy(der, src) + cert, err = x509.ParseCertificate(der) if err != nil { return diff --git a/mssql.go b/mssql.go index 0771c139..a9316e8d 100644 --- a/mssql.go +++ b/mssql.go @@ -17,6 +17,7 @@ import ( "unicode" "github.com/golang-sql/sqlexp" + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/internal/querytext" "github.com/microsoft/go-mssqldb/msdsn" ) @@ -148,7 +149,7 @@ func newConnector(config msdsn.Config, driver *Driver) *Connector { return &Connector{ params: config, driver: driver, - keyProviders: make(columnEncryptionKeyProviderMap), + keyProviders: make(aecmk.ColumnEncryptionKeyProviderMap), } } @@ -198,7 +199,7 @@ type Connector struct { // If Dialer is not set, normal net dialers are used. Dialer Dialer - keyProviders columnEncryptionKeyProviderMap + keyProviders aecmk.ColumnEncryptionKeyProviderMap } type Dialer interface { @@ -213,8 +214,8 @@ func (c *Connector) getDialer(p *msdsn.Config) Dialer { } // RegisterCekProvider associated the given provider with the named key store. If an entry of the given name already exists, that entry is overwritten -func (c *Connector) RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) { - c.keyProviders[name] = &cekProvider{provider: provider, decryptedKeys: make(cekCache)} +func (c *Connector) RegisterCekProvider(name string, provider aecmk.ColumnEncryptionKeyProvider) { + c.keyProviders[name] = &aecmk.CekProvider{Provider: provider, DecryptedKeys: make(aecmk.CekCache)} } type Conn struct { diff --git a/tds.go b/tds.go index 83336df2..772c98de 100644 --- a/tds.go +++ b/tds.go @@ -15,6 +15,7 @@ import ( "unicode/utf16" "unicode/utf8" + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/integratedauth" "github.com/microsoft/go-mssqldb/msdsn" ) @@ -173,7 +174,7 @@ type tdsSession struct { type alwaysEncryptedSettings struct { enclaveType string - keyProviders columnEncryptionKeyProviderMap + keyProviders aecmk.ColumnEncryptionKeyProviderMap } const ( @@ -1158,11 +1159,10 @@ initiate_connection: buf: outbuf, logger: logger, logFlags: uint64(p.LogFlags), - aeSettings: &alwaysEncryptedSettings{keyProviders: make(columnEncryptionKeyProviderMap)}, - } - for i, p := range globalCekProviderFactoryMap { - sess.aeSettings.keyProviders[i] = p + aeSettings: &alwaysEncryptedSettings{keyProviders: make(aecmk.ColumnEncryptionKeyProviderMap)}, } + sess.aeSettings.keyProviders = aecmk.GetGlobalCekProviders() + for i, p := range c.keyProviders { sess.aeSettings.keyProviders[i] = p } @@ -1319,7 +1319,6 @@ initiate_connection: case colAckStruct: if v.Version <= 2 && v.Version > 0 { sess.alwaysEncrypted = true - sess.aeSettings = &alwaysEncryptedSettings{} if len(v.EnclaveType) > 0 { sess.aeSettings.enclaveType = string(v.EnclaveType) } diff --git a/token.go b/token.go index 3c15e60d..15f2fda1 100644 --- a/token.go +++ b/token.go @@ -827,7 +827,7 @@ func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{} if !ok { panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName)) } - cek := cekProvider.provider.DecryptColumnEncryptionKey(cekValue.keyPath, cekValue.algorithmName, column.cryptoMeta.entry.cekValues[0].encryptedKey) + cek := cekProvider.Provider.DecryptColumnEncryptionKey(cekValue.keyPath, cekValue.algorithmName, column.cryptoMeta.entry.cekValues[0].encryptedKey) k := keys.NewAeadAes256CbcHmac256(cek) alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion)) d, err := alg.Decrypt(columnContent.([]byte)) @@ -842,7 +842,7 @@ func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{} buffer: bytes.NewReader(newBuff), } - column.cryptoMeta.typeInfo.Buffer = make([]byte, 0) + column.cryptoMeta.typeInfo.Buffer = d buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc} return buffer } From 3083f568390859ba78d84d75fb47d128363a7016 Mon Sep 17 00:00:00 2001 From: davidshi Date: Thu, 22 Jun 2023 09:52:22 -0500 Subject: [PATCH 09/47] initial code for AE result set query --- encrypt.go | 55 +++++++++++++++++++++++++++++++++++++++ encrypt_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++ quoter.go | 40 ++++++++++++++++++++++++++++ 3 files changed, 164 insertions(+) create mode 100644 encrypt.go create mode 100644 encrypt_test.go create mode 100644 quoter.go diff --git a/encrypt.go b/encrypt.go new file mode 100644 index 00000000..2717c731 --- /dev/null +++ b/encrypt.go @@ -0,0 +1,55 @@ +package mssql + +import ( + "context" + "strings" +) + +// when Always Encrypted is turned on, we have to ask the server for metadata about how to encrypt input parameters. +func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArgs []namedValue, err error) { + // q := Stmt{c:s.c, + // paramCount:s.paramCount, + // query: "sp_describe_parameter_encryption", + // } + return args, nil +} + +func prepareEncryptionQuery(isProc bool, q string, args []namedValue) (query string, err error) { + return "", nil +} + +// Based on the .Net implementation at https://github.com/dotnet/SqlClient/blob/2b31810ce69b88d707450e2059ee8fbde63f774f/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs#L6040 +func buildStoredProcedureStatementForColumnEncryption(sproc string, args []namedValue) string { + b := new(strings.Builder) + _, _ = b.WriteString("EXEC ") + q := TSQLQuoter{} + sproc = q.ID(sproc) + + b.WriteString(sproc) + + // Unlike ADO.Net, go-mssqldb doesn't support ReturnValue named parameters + first := true + for _, a := range args { + if !first { + b.WriteRune(',') + } + first = false + b.WriteRune(' ') + appendPrefixedParameterName(b, a.Name) + b.WriteRune('=') + appendPrefixedParameterName(b, a.Name) + if isOutputValue(a.Value) { + b.WriteString(" OUTPUT") + } + } + return b.String() +} + +func appendPrefixedParameterName(b *strings.Builder, p string) { + if len(p) > 0 { + if p[0] != '@' { + b.WriteRune('@') + } + b.WriteString(p) + } +} diff --git a/encrypt_test.go b/encrypt_test.go new file mode 100644 index 00000000..06d1cb9a --- /dev/null +++ b/encrypt_test.go @@ -0,0 +1,69 @@ +package mssql + +import ( + "database/sql" + "testing" +) + +func TestSprocQueryForCE(t *testing.T) { + type test struct { + name string + proc string + args []namedValue + expected string + } + var out int + tests := []test{ + { + "Empty args", + "m]yproc", + make([]namedValue, 0), + "EXEC [m]]yproc]", + }, + { + "No OUT args", + "myproc", + []namedValue{ + { + "p1", + 0, + 5, + }, + { + "@p2", + 0, + "val", + }, + }, + "EXEC [myproc] @p1=@p1, @p2=@p2", + }, + { + "OUT args", + "myproc", + []namedValue{ + { + "pout", + 0, + sql.Out{ + Dest: &out, + In: false, + }, + }, + { + "pin", + 1, + "in", + }, + }, + "EXEC [myproc] @pout=@pout OUTPUT, @pin=@pin", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + q := buildStoredProcedureStatementForColumnEncryption(tc.proc, tc.args) + if q != tc.expected { + t.Fatalf("Incorrect query for %s: %s", tc.name, q) + } + }) + } +} diff --git a/quoter.go b/quoter.go new file mode 100644 index 00000000..1f8f4f38 --- /dev/null +++ b/quoter.go @@ -0,0 +1,40 @@ +package mssql + +import ( + "strings" +) + +// TSQLQuoter implements sqlexp.Quoter +type TSQLQuoter struct { +} + +// ID quotes identifiers such as schema, table, or column names. +// This implementation handles multi-part names. +func (TSQLQuoter) ID(name string) string { + return "[" + strings.Replace(name, "]", "]]", -1) + "]" +} + +// Value quotes database values such as string or []byte types as strings +// that are suitable and safe to embed in SQL text. The returned value +// of a string will include all surrounding quotes. +// +// If a value type is not supported it must panic. +func (TSQLQuoter) Value(v interface{}) string { + switch v := v.(type) { + default: + panic("unsupported value") + + case string: + return sqlString(v) + case VarChar: + return sqlString(string(v)) + case VarCharMax: + return sqlString(string(v)) + case NVarCharMax: + return sqlString(string(v)) + } +} + +func sqlString(v string) string { + return "'" + strings.Replace(string(v), "'", "''", -1) + "'" +} From fc53c14a6db3d836e53adbb42757172d6e9be8af Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 28 Jun 2023 15:16:37 -0500 Subject: [PATCH 10/47] skeleton of parameter encryption --- encrypt.go | 137 +++++++++++++++++++++++++++++++++++++++++++++--- encrypt_test.go | 46 ++++++++++++++++ mssql.go | 3 ++ 3 files changed, 180 insertions(+), 6 deletions(-) diff --git a/encrypt.go b/encrypt.go index 2717c731..b596062b 100644 --- a/encrypt.go +++ b/encrypt.go @@ -2,20 +2,93 @@ package mssql import ( "context" + "database/sql/driver" + "fmt" + "io" "strings" ) +type ColumnEncryptionType int + +var ( + ColumnEncryptionPlainText ColumnEncryptionType = 0 + ColumnEncryptionDeterministic ColumnEncryptionType = 1 + ColumnEncryptionRandomized ColumnEncryptionType = 1 +) + +type cekData struct { + ordinal int + database_id int + id int + version int + metadataVersion []byte + encryptedValue []byte + cmkStoreName string + cmkPath string + algorithm string + byEnclave bool + cmkSignature string +} + +type parameterEncData struct { + ordinal int + name string + algorithm int + encType ColumnEncryptionType + cekOrdinal int + ruleVersion int +} + // when Always Encrypted is turned on, we have to ask the server for metadata about how to encrypt input parameters. func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArgs []namedValue, err error) { - // q := Stmt{c:s.c, - // paramCount:s.paramCount, - // query: "sp_describe_parameter_encryption", - // } + q := Stmt{c: s.c, + paramCount: s.paramCount, + query: "sp_describe_parameter_encryption", + } + newArgs, err := s.prepareEncryptionQuery(isProc(s.query), s.query, args) + if err != nil { + return + } + rows, err := q.queryContext(ctx, newArgs) + if err != nil { + return + } + cekInfo, paramsInfo, err := processDescribeParameterEncryption(rows) + if err != nil { + return + } + fmt.Printf("cekInfo: %v\nparamsInfo:%v\n", cekInfo, paramsInfo) return args, nil } -func prepareEncryptionQuery(isProc bool, q string, args []namedValue) (query string, err error) { - return "", nil +// returns the arguments to sp_describe_parameter_encryption +// sp_describe_parameter_encryption +// [ @tsql = ] N'Transact-SQL_batch' , +// [ @params = ] N'parameters' +// [ ;] +func (s *Stmt) prepareEncryptionQuery(isProc bool, q string, args []namedValue) (newArgs []namedValue, err error) { + if isProc { + newArgs = make([]namedValue, 1) + newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: buildStoredProcedureStatementForColumnEncryption(q, args)} + return + } + newArgs = make([]namedValue, 2) + newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: q} + params, err := s.buildParametersForColumnEncryption(args) + if err != nil { + return + } + newArgs[1] = namedValue{Name: "params", Ordinal: 1, Value: params} + return +} + +func (s *Stmt) buildParametersForColumnEncryption(args []namedValue) (parameters string, err error) { + _, decls, err := s.makeRPCParams(args, false) + if err != nil { + return + } + parameters = strings.Join(decls, ", ") + return } // Based on the .Net implementation at https://github.com/dotnet/SqlClient/blob/2b31810ce69b88d707450e2059ee8fbde63f774f/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs#L6040 @@ -53,3 +126,55 @@ func appendPrefixedParameterName(b *strings.Builder, p string) { b.WriteString(p) } } + +func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []cekData, paramInfo []parameterEncData, err error) { + cekInfo = make([]cekData, 0) + values := make([]driver.Value, 9) + qerr := rows.Next(values) + for qerr == nil { + cekInfo = append(cekInfo, cekData{ordinal: int(values[0].(int64)), + database_id: int(values[1].(int64)), + id: int(values[2].(int64)), + version: int(values[3].(int64)), + metadataVersion: values[4].([]byte), + encryptedValue: values[5].([]byte), + cmkStoreName: values[6].(string), + cmkPath: values[7].(string), + algorithm: values[8].(string), + }) + qerr = rows.Next(values) + } + if len(cekInfo) == 0 || qerr != io.EOF { + if qerr != io.EOF { + err = qerr + } else { + err = fmt.Errorf("No column encryption key rows were returned from sp_describe_parameter_encryption") + } + return + } + r := rows.(driver.RowsNextResultSet) + err = r.NextResultSet() + if err != nil { + return + } + paramInfo = make([]parameterEncData, 0) + qerr = rows.Next(values[:6]) + for qerr == nil { + paramInfo = append(paramInfo, parameterEncData{ordinal: int(values[0].(int64)), + name: values[1].(string), + algorithm: int(values[2].(int64)), + encType: ColumnEncryptionType(values[3].(int64)), + cekOrdinal: int(values[4].(int64)), + ruleVersion: int(values[5].(int64)), + }) + qerr = rows.Next(values[:6]) + } + if len(paramInfo) == 0 || qerr != io.EOF { + if qerr != io.EOF { + err = qerr + } else { + err = fmt.Errorf("No parameter encryption rows were returned from sp_describe_parameter_encryption") + } + } + return +} diff --git a/encrypt_test.go b/encrypt_test.go index 06d1cb9a..7bd86450 100644 --- a/encrypt_test.go +++ b/encrypt_test.go @@ -2,9 +2,55 @@ package mssql import ( "database/sql" + "strings" "testing" ) +func TestBuildQueryParametersForCE(t *testing.T) { + type test struct { + name string + args []namedValue + expectedParams string + expectedError string + } + var outparam string + var tests = []test{ + { + "Single string", + []namedValue{ + {Name: "c1", Value: "somestring"}, + }, + `@c1 nvarchar(10)`, + "", + }, + { + "Input and Output params", + []namedValue{ + {Name: "", Ordinal: 0, Value: VarChar("somestring")}, + {Name: "c1", Value: 5}, + {Name: "pout", Value: sql.Out{Dest: outparam}}, + }, + `@p0 varchar(10), @c1 bigint, @pout nvarchar(max) output`, + "", + }, + } + s := &Stmt{} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual, err := s.buildParametersForColumnEncryption(tc.args) + if len(tc.expectedError) > 0 { + if err == nil || strings.Compare(err.Error(), tc.expectedError) != 0 { + t.Fatalf("buildParametersForColumnEncryption should have failed with %s. Got: %v", tc.expectedError, err) + } + } else if err != nil { + t.Fatalf("buildParametersForColumnEncryption failed with %s", err.Error()) + } + if strings.Compare(tc.expectedParams, actual) != 0 { + t.Fatalf("Incorrect parameters. Expected: %s. Got: %s ", tc.expectedParams, actual) + } + }) + } +} func TestSprocQueryForCE(t *testing.T) { type test struct { name string diff --git a/mssql.go b/mssql.go index a9316e8d..b5b36eb5 100644 --- a/mssql.go +++ b/mssql.go @@ -686,6 +686,9 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver if !s.c.connectionGood { return nil, driver.ErrBadConn } + if s.c.sess.alwaysEncrypted && len(args) > 0 { + args, err = s.encryptArgs(ctx, args) + } if err = s.sendQuery(ctx, args); err != nil { return nil, s.c.checkBadConn(ctx, err, true) } From ff797ce48c825a971792cfccf1e8bede4ce51748 Mon Sep 17 00:00:00 2001 From: davidshi Date: Fri, 30 Jun 2023 11:31:05 -0500 Subject: [PATCH 11/47] implement EncryptColumnEncryptionKey for local cert --- aecmk/localcert/keyprovider.go | 98 ++++++++++++++++++--- aecmk/localcert/keyprovider_windows_test.go | 16 ++++ 2 files changed, 101 insertions(+), 13 deletions(-) diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index 23453bdd..cb76205c 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -1,17 +1,23 @@ package localcert import ( + "crypto" + "crypto/rand" "crypto/rsa" "crypto/sha1" + "crypto/sha256" "crypto/x509" + "encoding/binary" "fmt" "io/ioutil" "os" "strconv" + "strings" "time" "github.com/microsoft/go-mssqldb/aecmk" ae "github.com/swisscom/mssql-always-encrypted/pkg" + "golang.org/x/text/encoding/unicode" pkcs "software.sslmate.com/src/go-pkcs12" ) @@ -52,21 +58,10 @@ func init() { // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. func (p *LocalCertProvider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { decryptedKey = nil - allowed := len(p.AllowedLocations) == 0 - if !allowed { - loop: - for _, l := range p.AllowedLocations { - if l == masterKeyPath { - allowed = true - break loop - } - } - } + pk, cert, allowed := p.tryLoadCertificate(masterKeyPath) if !allowed { return } - var cert *x509.Certificate - var pk interface{} switch p.name { case PfxKeyProviderName: pk, cert = p.loadLocalCertificate(masterKeyPath) @@ -87,6 +82,29 @@ func (p *LocalCertProvider) DecryptColumnEncryptionKey(masterKeyPath string, enc return } +func (p *LocalCertProvider) tryLoadCertificate(masterKeyPath string) (privateKey interface{}, cert *x509.Certificate, allowed bool) { + allowed = len(p.AllowedLocations) == 0 + if !allowed { + loop: + for _, l := range p.AllowedLocations { + if l == masterKeyPath { + allowed = true + break loop + } + } + } + if !allowed { + return + } + switch p.name { + case PfxKeyProviderName: + privateKey, cert = p.loadLocalCertificate(masterKeyPath) + case aecmk.CertificateStoreKeyProvider: + privateKey, cert = p.loadWindowsCertStoreCertificate(masterKeyPath) + } + return +} + func (p *LocalCertProvider) loadLocalCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { if f, err := os.Open(path); err == nil { pfxBytes, err := ioutil.ReadAll(f) @@ -112,7 +130,49 @@ func (p *LocalCertProvider) loadLocalCertificate(path string) (privateKey interf // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. func (p *LocalCertProvider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { - return nil + + validateEncryptionAlgorithm(encryptionAlgorithm) + validateKeyPathLength(masterKeyPath) + pk, cert, allowed := p.tryLoadCertificate(masterKeyPath) + if !allowed { + panic(fmt.Errorf("Key path not allowed for use in column key encryption")) + } + publicKey := cert.PublicKey.(*rsa.PublicKey) + keySizeInBytes := publicKey.Size() + + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder() + // Start with version byte == 1 + buf := []byte{byte(1)} + // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature + // version + keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath))) + if err != nil { + panic(fmt.Errorf("Unable to serialize key path %w", err)) + } + // keyPathLength + buf = binary.LittleEndian.AppendUint16(buf, uint16(len(keyPathBytes))) + + cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, cek, []byte{}) + if err != nil { + panic(fmt.Errorf("Unable to encrypt data %w", err)) + } + // ciphertextLength + buf = binary.LittleEndian.AppendUint16(buf, uint16(len(cipherText))) + // keypath + buf = append(buf, keyPathBytes...) + // ciphertext + buf = append(buf, cipherText...) + hash := sha256.Sum256(buf) + // signature is the signed hash of the current buf + sig, err := rsa.SignPKCS1v15(rand.Reader, pk.(*rsa.PrivateKey), crypto.SHA256, hash[:]) + if err != nil { + panic(err) + } + if len(sig) != keySizeInBytes { + panic("Signature length doesn't match certificate key size") + } + buf = append(buf, sig...) + return buf } // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key @@ -135,6 +195,18 @@ func (p *LocalCertProvider) KeyLifetime() *time.Duration { return nil } +func validateEncryptionAlgorithm(encryptionAlgorithm string) { + if !strings.EqualFold(encryptionAlgorithm, "RSA_OAEP") { + panic(fmt.Errorf("Unsupported encryption algorithm %s", encryptionAlgorithm)) + } +} + +func validateKeyPathLength(keyPath string) { + if len(keyPath) > 32767 { + panic(fmt.Errorf("Key path is too long. %d", len(keyPath))) + } +} + // InvalidCertificatePathError indicates the provided path could not be used to load a certificate type InvalidCertificatePathError struct { path string diff --git a/aecmk/localcert/keyprovider_windows_test.go b/aecmk/localcert/keyprovider_windows_test.go index 6e169564..95932f03 100644 --- a/aecmk/localcert/keyprovider_windows_test.go +++ b/aecmk/localcert/keyprovider_windows_test.go @@ -28,3 +28,19 @@ func TestLoadWindowsCertStoreCertificate(t *testing.T) { t.Fatalf("Wrong cert loaded: %s", cert.Subject.String()) } } + +func TestEncryptDecryptEncryptionKeyRoundTrip(t *testing.T) { + thumbprint, err := certs.ProvisionMasterKeyInCertStore() + if err != nil { + t.Fatal(err) + } + defer certs.DeleteMasterKeyCert(thumbprint) + bytesToEncrypt := []byte{1, 2, 3} + keyPath := "CurrentUser/My/" + thumbprint + provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*LocalCertProvider) + encryptedBytes := provider.EncryptColumnEncryptionKey(keyPath, "RSA_OAEP", bytesToEncrypt) + decryptedBytes := provider.DecryptColumnEncryptionKey(keyPath, "RSA_OAEP", encryptedBytes) + if len(decryptedBytes) != 3 || decryptedBytes[0] != 1 || decryptedBytes[1] != 2 || decryptedBytes[2] != 3 { + t.Fatalf("Encrypt/Decrypt did not roundtrip. encryptedBytes:%v, decryptedBytes: %v", encryptedBytes, decryptedBytes) + } +} From 2e3ec3fbd8d9a13f6a180b7d2aacf0ce7c643bb2 Mon Sep 17 00:00:00 2001 From: davidshi Date: Fri, 30 Jun 2023 14:41:45 -0500 Subject: [PATCH 12/47] fix query for param encryption data --- alwaysencrypted_windows_test.go | 46 +++++++++++++++++++++++++++------ encrypt.go | 6 +++-- mssql.go | 28 +++++++++++++++----- mssql_go19.go | 14 +++++++++- 4 files changed, 76 insertions(+), 18 deletions(-) diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index daf8fc56..28118ce5 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -1,10 +1,12 @@ package mssql import ( + "crypto/rand" "fmt" + "math/big" "testing" - _ "github.com/microsoft/go-mssqldb/aecmk/localcert" + "github.com/microsoft/go-mssqldb/aecmk/localcert" "github.com/microsoft/go-mssqldb/internal/certs" ) @@ -26,12 +28,28 @@ func TestAlwaysEncryptedE2E(t *testing.T) { t.Fatalf("Unable to create CMK: %s", err.Error()) } defer conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) - // TODO: Implement encryption and insert encrypted values into a table using custom CEK - // Currently this test only passes when run against a particular database whose - // columns are encrypted using a cert on a developer's machine. - rows, err := conn.Query("select top (1) col1, col2 from Table_1") + r, _ := rand.Int(rand.Reader, big.NewInt(1000)) + cekName := fmt.Sprintf("mssqlCek%d", r.Int64()) + encryptedCek := localcert.WindowsCertificateStoreKeyProvider.EncryptColumnEncryptionKey(certPath, "RSA_OAEP", []byte(certPath)) + createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek) + _, err = conn.Exec(createCek) if err != nil { - t.Fatalf("Unable to query encrypted columns: %s", err.Error()) + t.Fatalf("Unable to create CEK: %s", err.Error()) + } + defer conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName)) + _, _ = conn.Exec("DROP TABLE IF EXISTS mssqlAlwaysEncrypted") + _, err = conn.Exec(fmt.Sprintf(createEncryptedTable, cekName, cekName)) + if err != nil { + t.Fatalf("Failed to create encrypted table %s", err.Error()) + } + defer conn.Exec("DROP TABLE IF EXISTS mssqlAlwaysEncrypted") + _, err = conn.Exec("INSERT INTO mssqlAlwaysEncrypted VALUES (@p1, @p2)", int32(1), NChar("mycol2")) + if err != nil { + t.Fatalf("Failed to insert row in encrypted table %s", err.Error()) + } + rows, err := conn.Query("select top (1) col1, col2 from mssqlAlwaysEncrypted") + if err != nil { + t.Fatalf("Unable to query encrypted columns: %v", err.(Error).All) } if !rows.Next() { rows.Close() @@ -52,6 +70,18 @@ func TestAlwaysEncryptedE2E(t *testing.T) { } const ( - createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= 'MSSQL_CERTIFICATE_STORE', KEY_PATH='%s')` - dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` + createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= 'MSSQL_CERTIFICATE_STORE', KEY_PATH='%s')` + dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` + createColumnEncryptionKey = `CREATE COLUMN ENCRYPTION KEY [%s] WITH VALUES (COLUMN_MASTER_KEY = [%s], ALGORITHM = 'RSA_OAEP', ENCRYPTED_VALUE = 0x%x )` + dropColumnEncryptionKey = `DROP COLUMN ENCRYPTION KEY [%s]` + createEncryptedTable = `CREATE TABLE mssqlAlwaysEncrypted + (col1 int + ENCRYPTED WITH (ENCRYPTION_TYPE = RANDOMIZED, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]), + col2 nchar(10) COLLATE Latin1_General_BIN2 + ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]) + )` ) diff --git a/encrypt.go b/encrypt.go index b596062b..8853223b 100644 --- a/encrypt.go +++ b/encrypt.go @@ -42,13 +42,15 @@ type parameterEncData struct { // when Always Encrypted is turned on, we have to ask the server for metadata about how to encrypt input parameters. func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArgs []namedValue, err error) { q := Stmt{c: s.c, - paramCount: s.paramCount, - query: "sp_describe_parameter_encryption", + paramCount: s.paramCount, + query: "sp_describe_parameter_encryption", + skipEncryption: true, } newArgs, err := s.prepareEncryptionQuery(isProc(s.query), s.query, args) if err != nil { return } + // TODO: Consider not using recursion rows, err := q.queryContext(ctx, newArgs) if err != nil { return diff --git a/mssql.go b/mssql.go index b5b36eb5..0ad3d4a1 100644 --- a/mssql.go +++ b/mssql.go @@ -444,10 +444,11 @@ func (c *Conn) Close() error { } type Stmt struct { - c *Conn - query string - paramCount int - notifSub *queryNotifSub + c *Conn + query string + paramCount int + notifSub *queryNotifSub + skipEncryption bool } type queryNotifSub struct { @@ -471,7 +472,7 @@ func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) if c.processQueryText { query, paramCount = querytext.ParseParams(query) } - return &Stmt{c, query, paramCount, nil}, nil + return &Stmt{c, query, paramCount, nil, false}, nil } func (s *Stmt) Close() error { @@ -676,6 +677,10 @@ func convertOldArgs(args []driver.Value) []namedValue { return list } +func (s *Stmt) doEncryption() bool { + return !s.skipEncryption && s.c.sess.alwaysEncrypted +} + func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { defer s.c.clearOuts() @@ -686,9 +691,12 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver if !s.c.connectionGood { return nil, driver.ErrBadConn } - if s.c.sess.alwaysEncrypted && len(args) > 0 { + if s.doEncryption() && len(args) > 0 { args, err = s.encryptArgs(ctx, args) } + if err != nil { + return nil, err + } if err = s.sendQuery(ctx, args); err != nil { return nil, s.c.checkBadConn(ctx, err, true) } @@ -756,6 +764,12 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, if !s.c.connectionGood { return nil, driver.ErrBadConn } + if s.doEncryption() && len(args) > 0 { + args, err = s.encryptArgs(ctx, args) + } + if err != nil { + return nil, err + } if err = s.sendQuery(ctx, args); err != nil { return nil, s.c.checkBadConn(ctx, err, true) } @@ -1045,7 +1059,7 @@ func (c *Conn) Ping(ctx context.Context) error { if !c.connectionGood { return driver.ErrBadConn } - stmt := &Stmt{c, `select 1;`, 0, nil} + stmt := &Stmt{c, `select 1;`, 0, nil, true} _, err := stmt.ExecContext(ctx, nil) return err } diff --git a/mssql_go19.go b/mssql_go19.go index 688b5c5d..6359c39a 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -29,12 +29,18 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th var _ driver.NamedValueChecker = &Conn{} -// VarChar parameter types. +// VarChar is used to encode a string parameter as VarChar instead of a sized NVarChar type VarChar string +// NVarCharMax is used to encode a string parameter as NVarChar(max) instead of a sized NVarChar type NVarCharMax string + +// VarCharMax is used to encode a string parameter as VarChar(max) instead of a sized NVarChar type VarCharMax string +// NChar is used to encode a string parameter as NChar instead of a sized NVarChar +type NChar string + // DateTime1 encodes parameters to original DateTime SQL types. type DateTime1 time.Time @@ -51,6 +57,8 @@ func convertInputParameter(val interface{}) (interface{}, error) { return val, nil case VarCharMax: return val, nil + case NChar: + return val, nil case DateTime1: return val, nil case DateTimeOffset: @@ -144,6 +152,10 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { res.ti.TypeId = typeNVarChar res.buffer = str2ucs2(string(val)) res.ti.Size = 0 // currently zero forces nvarchar(max) + case NChar: + res.ti.TypeId = typeNChar + res.buffer = str2ucs2(string(val)) + res.ti.Size = len(res.buffer) case DateTime1: t := time.Time(val) res.ti.TypeId = typeDateTimeN From 0fcb7ea110f6de0cd5619c38bf5e7396a2358e55 Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 3 Jul 2023 15:36:55 -0500 Subject: [PATCH 13/47] add cipher data to parameters --- aecmk/keyprovider.go | 47 ++++++++++++-- alwaysencrypted_windows_test.go | 2 +- columnencryptionkey.go | 1 + encrypt.go | 108 ++++++++++++++++++++++++++++++-- encrypt_test.go | 4 ++ mssql.go | 32 ++++++++-- rpc.go | 21 +++++-- token.go | 5 +- 8 files changed, 200 insertions(+), 20 deletions(-) diff --git a/aecmk/keyprovider.go b/aecmk/keyprovider.go index 69f784df..7cdcb82c 100644 --- a/aecmk/keyprovider.go +++ b/aecmk/keyprovider.go @@ -2,6 +2,7 @@ package aecmk import ( "fmt" + "sync" "time" ) @@ -11,22 +12,60 @@ const ( CngKeyProvider = "MSSQL_CNG_STORE" AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT" JavaKeyProvider = "MSSQL_JAVA_KEYSTORE" + KeyEncryptionAlgorithm = "RSA_OAEP" ) // ColumnEncryptionKeyLifetime is the default lifetime of decrypted Column Encryption Keys in the global cache. // The default is 2 hours var ColumnEncryptionKeyLifetime time.Duration = 2 * time.Hour -type CekCacheEntry struct { +type cekCacheEntry struct { Expiry time.Time Key []byte } -type CekCache map[string]CekCacheEntry +type cekCache map[string]cekCacheEntry type CekProvider struct { Provider ColumnEncryptionKeyProvider - DecryptedKeys CekCache + decryptedKeys cekCache + mutex sync.Mutex +} + +func NewCekProvider(provider ColumnEncryptionKeyProvider) *CekProvider { + return &CekProvider{Provider: provider, decryptedKeys: make(cekCache), mutex: sync.Mutex{}} +} + +func (cp *CekProvider) GetDecryptedKey(keyPath string, encryptedBytes []byte) (decryptedKey []byte, err error) { + cp.mutex.Lock() + ev, cachedKey := cp.decryptedKeys[keyPath] + if cachedKey { + if ev.Expiry.Before(time.Now()) { + delete(cp.decryptedKeys, keyPath) + cachedKey = false + } else { + decryptedKey = ev.Key + } + } + // decrypting a key can take a while, so let multiple callers race + // Key providers can choose to optimize their own concurrency. + // For example - there's probably minimal value in serializing access to a local certificate, + // but there'd be high value in having a queue of waiters for decrypting a key stored in the cloud. + cp.mutex.Unlock() + if !cachedKey { + decryptedKey = cp.Provider.DecryptColumnEncryptionKey(keyPath, KeyEncryptionAlgorithm, encryptedBytes) + } + if !cachedKey { + duration := cp.Provider.KeyLifetime() + if duration == nil { + duration = &ColumnEncryptionKeyLifetime + } + expiry := time.Now().Add(*duration) + cp.mutex.Lock() + cp.decryptedKeys[keyPath] = cekCacheEntry{Expiry: expiry, Key: decryptedKey} + cp.mutex.Unlock() + } + return } // no synchronization on this map. Providers register during init. @@ -60,7 +99,7 @@ func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) erro if ok { return fmt.Errorf("CEK provider %s is already registered", name) } - globalCekProviderFactoryMap[name] = &CekProvider{Provider: provider, DecryptedKeys: CekCache{}} + globalCekProviderFactoryMap[name] = &CekProvider{Provider: provider, decryptedKeys: cekCache{}, mutex: sync.Mutex{}} return nil } diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index 28118ce5..2ec6fd92 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -30,7 +30,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) { defer conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) r, _ := rand.Int(rand.Reader, big.NewInt(1000)) cekName := fmt.Sprintf("mssqlCek%d", r.Int64()) - encryptedCek := localcert.WindowsCertificateStoreKeyProvider.EncryptColumnEncryptionKey(certPath, "RSA_OAEP", []byte(certPath)) + encryptedCek := localcert.WindowsCertificateStoreKeyProvider.EncryptColumnEncryptionKey(certPath, KeyEncryptionAlgorithm, []byte(certPath)) createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek) _, err = conn.Exec(createCek) if err != nil { diff --git a/columnencryptionkey.go b/columnencryptionkey.go index 257c1df0..1dd51068 100644 --- a/columnencryptionkey.go +++ b/columnencryptionkey.go @@ -6,6 +6,7 @@ const ( CngKeyProvider = "MSSQL_CNG_STORE" AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT" JavaKeyProvider = "MSSQL_JAVA_KEYSTORE" + KeyEncryptionAlgorithm = "RSA_OAEP" ) // cek ==> Column Encryption Key diff --git a/encrypt.go b/encrypt.go index 8853223b..8471010d 100644 --- a/encrypt.go +++ b/encrypt.go @@ -3,9 +3,15 @@ package mssql import ( "context" "database/sql/driver" + "encoding/binary" "fmt" "io" "strings" + + "github.com/microsoft/go-mssqldb/msdsn" + "github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/swisscom/mssql-always-encrypted/pkg/keys" ) type ColumnEncryptionType int @@ -13,7 +19,7 @@ type ColumnEncryptionType int var ( ColumnEncryptionPlainText ColumnEncryptionType = 0 ColumnEncryptionDeterministic ColumnEncryptionType = 1 - ColumnEncryptionRandomized ColumnEncryptionType = 1 + ColumnEncryptionRandomized ColumnEncryptionType = 2 ) type cekData struct { @@ -28,6 +34,7 @@ type cekData struct { algorithm string byEnclave bool cmkSignature string + decryptedValue []byte } type parameterEncData struct { @@ -39,7 +46,14 @@ type parameterEncData struct { ruleVersion int } +type paramMapEntry struct { + cek *cekData + p *parameterEncData +} + // when Always Encrypted is turned on, we have to ask the server for metadata about how to encrypt input parameters. +// This function stores the relevant encryption parameters in a copy of the args so they can be +// encrypted just before being sent to the server func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArgs []namedValue, err error) { q := Stmt{c: s.c, paramCount: s.paramCount, @@ -59,8 +73,33 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg if err != nil { return } - fmt.Printf("cekInfo: %v\nparamsInfo:%v\n", cekInfo, paramsInfo) - return args, nil + s.c.sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("cekInfo: %v\nparamsInfo:%v\n", cekInfo, paramsInfo)) + err = s.decryptCek(cekInfo) + if err != nil { + return + } + paramMap := make(map[string]paramMapEntry) + for _, p := range paramsInfo { + paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], &p} + } + encryptedArgs = make([]namedValue, len(args)) + for i, a := range args { + encryptedArgs[i] = a + name := "" + if len(a.Name) > 0 { + name = "@" + a.Name + } else { + name = fmt.Sprintf("@p%d", a.Ordinal) + } + info := paramMap[name] + + if info.p.encType == ColumnEncryptionPlainText { + continue + } + + encryptedArgs[i].encrypt = getEncryptor(info) + } + return encryptedArgs, nil } // returns the arguments to sp_describe_parameter_encryption @@ -93,6 +132,63 @@ func (s *Stmt) buildParametersForColumnEncryption(args []namedValue) (parameters return } +func (s *Stmt) decryptCek(cekInfo []*cekData) error { + for _, info := range cekInfo { + kp, ok := s.c.sess.aeSettings.keyProviders[info.cmkStoreName] + if !ok { + return fmt.Errorf("No provider found for key store %s", info.cmkStoreName) + } + dk, err := kp.GetDecryptedKey(info.cmkPath, info.encryptedValue) + if err != nil { + return err + } + info.decryptedValue = dk + } + return nil +} + +func getEncryptor(info paramMapEntry) valueEncryptor { + k := keys.NewAeadAes256CbcHmac256(info.cek.decryptedValue) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encryption.From(byte(info.p.encType)), byte(info.cek.version)) + // Metadata to append to an encrypted parameter. Doesn't include original typeinfo + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/619c43b6-9495-4a58-9e49-a4950db245b3 + // ParamCipherInfo = TYPE_INFO + // EncryptionAlgo (byte) + // [AlgoName] (b_varchar) unused, no custom algorithm + // EncryptionType (byte) + // DatabaseId (ulong) + // CekId (ulong) + // CekVersion (ulong) + // CekMDVersion (ulonglong) - really a byte array + // NormVersion (byte) + // algo+ enctype+ dbid+ keyid+ keyver= normversion + metadataLen := 1 + 1 + 4 + 4 + 4 + 1 + metadataLen += len(info.cek.metadataVersion) + metadata := make([]byte, metadataLen) + offset := 0 + // AEAD_AES_256_CBC_HMAC_SHA256 + metadata[offset] = byte(info.p.algorithm) + offset++ + metadata[offset] = byte(info.p.encType) + offset++ + binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.database_id)) + offset += 4 + binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.id)) + offset += 4 + binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.version)) + offset += 4 + copy(metadata[offset:], info.cek.metadataVersion) + offset += len(info.cek.metadataVersion) + metadata[offset] = byte(info.p.ruleVersion) + return func(b []byte) ([]byte, []byte, error) { + encryptedData, err := alg.Encrypt(b) + if err != nil { + return nil, nil, err + } + return encryptedData, metadata, nil + } +} + // Based on the .Net implementation at https://github.com/dotnet/SqlClient/blob/2b31810ce69b88d707450e2059ee8fbde63f774f/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs#L6040 func buildStoredProcedureStatementForColumnEncryption(sproc string, args []namedValue) string { b := new(strings.Builder) @@ -129,12 +225,12 @@ func appendPrefixedParameterName(b *strings.Builder, p string) { } } -func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []cekData, paramInfo []parameterEncData, err error) { - cekInfo = make([]cekData, 0) +func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, paramInfo []parameterEncData, err error) { + cekInfo = make([]*cekData, 0) values := make([]driver.Value, 9) qerr := rows.Next(values) for qerr == nil { - cekInfo = append(cekInfo, cekData{ordinal: int(values[0].(int64)), + cekInfo = append(cekInfo, &cekData{ordinal: int(values[0].(int64)), database_id: int(values[1].(int64)), id: int(values[2].(int64)), version: int(values[3].(int64)), diff --git a/encrypt_test.go b/encrypt_test.go index 7bd86450..7cd578b3 100644 --- a/encrypt_test.go +++ b/encrypt_test.go @@ -74,11 +74,13 @@ func TestSprocQueryForCE(t *testing.T) { "p1", 0, 5, + nil, }, { "@p2", 0, "val", + nil, }, }, "EXEC [myproc] @p1=@p1, @p2=@p2", @@ -94,11 +96,13 @@ func TestSprocQueryForCE(t *testing.T) { Dest: &out, In: false, }, + nil, }, { "pin", 1, "in", + nil, }, }, "EXEC [myproc] @pout=@pout OUTPUT, @pin=@pin", diff --git a/mssql.go b/mssql.go index 0ad3d4a1..dfc883e9 100644 --- a/mssql.go +++ b/mssql.go @@ -213,9 +213,9 @@ func (c *Connector) getDialer(p *msdsn.Config) Dialer { return createDialer(p) } -// RegisterCekProvider associated the given provider with the named key store. If an entry of the given name already exists, that entry is overwritten +// RegisterCekProvider associates the given provider with the named key store. If an entry of the given name already exists, that entry is overwritten func (c *Connector) RegisterCekProvider(name string, provider aecmk.ColumnEncryptionKeyProvider) { - c.keyProviders[name] = &aecmk.CekProvider{Provider: provider, DecryptedKeys: make(aecmk.CekCache)} + c.keyProviders[name] = aecmk.NewCekProvider(provider) } type Conn struct { @@ -654,16 +654,36 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, if isOutputValue(val.Value) { output = outputSuffix } + if val.encrypt != nil { + // Encrypted parameters have a few requirements: + // 1. Copy original typeinfo to a block after the data + // 2. Set the parameter type to varbinary(max) + // 3. Append the crypto metadata bytes + params[i+offset].tiOriginal = params[i+offset].ti + params[i+offset].Flags |= fEncrypted + encryptedBytes, metadata, err := val.encrypt(params[i+offset].buffer) + if err != nil { + return nil, nil, err + } + params[i+offset].cipherInfo = metadata + params[i+offset].ti.TypeId = typeBigVarBin + params[i+offset].ti.Buffer = encryptedBytes + params[i+offset].ti.Size = 0 + } decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output) } return params, decls, nil } +// Encrypts the input bytes. Returns the encrypted bytes followed by the encryption metadata to append to the packet. +type valueEncryptor func(bytes []byte) ([]byte, []byte, error) + type namedValue struct { Name string Ordinal int Value driver.Value + encrypt valueEncryptor } func convertOldArgs(args []driver.Value) []namedValue { @@ -1124,7 +1144,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv } list := make([]namedValue, len(args)) for i, nv := range args { - list[i] = namedValue(nv) + list[i] = namedValueFromDriverNamedValue(nv) } return s.queryContext(ctx, list) } @@ -1137,11 +1157,15 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive } list := make([]namedValue, len(args)) for i, nv := range args { - list[i] = namedValue(nv) + list[i] = namedValueFromDriverNamedValue(nv) } return s.exec(ctx, list) } +func namedValueFromDriverNamedValue(v driver.NamedValue) namedValue { + return namedValue{Name: v.Name, Ordinal: v.Ordinal, Value: v.Value, encrypt: nil} +} + // Rowsq implements the sqlexp messages model for Query and QueryContext // Theory: We could also implement the non-experimental model this way type Rowsq struct { diff --git a/rpc.go b/rpc.go index f7d4c00e..17a4e5f0 100644 --- a/rpc.go +++ b/rpc.go @@ -13,13 +13,16 @@ type procId struct { const ( fByRevValue = 1 fDefaultValue = 2 + fEncrypted = 8 ) type param struct { - Name string - Flags uint8 - ti typeInfo - buffer []byte + Name string + Flags uint8 + ti typeInfo + buffer []byte + tiOriginal typeInfo + cipherInfo []byte } var ( @@ -78,6 +81,16 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, if err != nil { return } + if (param.Flags & fEncrypted) == fEncrypted { + err = writeTypeInfo(buf, ¶m.tiOriginal) + if err != nil { + return + } + param.tiOriginal.Writer(buf, param.tiOriginal, param.buffer) + if _, err = buf.Write(param.cipherInfo); err != nil { + return + } + } } return buf.FinishPacket() } diff --git a/token.go b/token.go index 15f2fda1..a2bcc62b 100644 --- a/token.go +++ b/token.go @@ -827,7 +827,10 @@ func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{} if !ok { panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName)) } - cek := cekProvider.Provider.DecryptColumnEncryptionKey(cekValue.keyPath, cekValue.algorithmName, column.cryptoMeta.entry.cekValues[0].encryptedKey) + cek, err := cekProvider.GetDecryptedKey(cekValue.keyPath, column.cryptoMeta.entry.cekValues[0].encryptedKey) + if err != nil { + panic(err) + } k := keys.NewAeadAes256CbcHmac256(cek) alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion)) d, err := alg.Decrypt(columnContent.([]byte)) From 2e75557aec7031a761772b341e20cdf2a5274ec9 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 5 Jul 2023 13:50:17 -0500 Subject: [PATCH 14/47] copy swisscom code locally --- aecmk/localcert/keyprovider.go | 2 +- encrypt.go | 6 +- go.mod | 3 +- go.sum | 5 - .../mssql-always-encrypted/LICENSE.txt | 20 +++ .../swisscom/mssql-always-encrypted/README.md | 5 + .../aead_aes_256_cbc_hmac_sha256.go | 98 +++++++++++++ .../aead_aes_256_cbc_hmac_sha256_test.go | 37 +++++ .../pkg/algorithms/algorithm.go | 6 + .../pkg/alwaysencrypted.go | 79 ++++++++++ .../pkg/alwaysencrypted_test.go | 138 ++++++++++++++++++ .../pkg/crypto/aes_cbc_pkcs5.go | 68 +++++++++ .../pkg/crypto/utils.go | 12 ++ .../pkg/encryption/type.go | 37 +++++ .../pkg/keys/aead_aes_256_cbc_hmac_256.go | 51 +++++++ .../mssql-always-encrypted/pkg/keys/key.go | 5 + .../mssql-always-encrypted/pkg/utils/utf16.go | 18 +++ .../test/always-encrypted.pem | 28 ++++ .../test/always-encrypted_pub.pem | 19 +++ .../mssql-always-encrypted/test/cekv.key | Bin 0 -> 627 bytes .../test/column_value.enc | 2 + .../test/decrypted_key.key | Bin 0 -> 627 bytes token.go | 6 +- 23 files changed, 631 insertions(+), 14 deletions(-) create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/README.md create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index cb76205c..900f2111 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -16,7 +16,7 @@ import ( "time" "github.com/microsoft/go-mssqldb/aecmk" - ae "github.com/swisscom/mssql-always-encrypted/pkg" + ae "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg" "golang.org/x/text/encoding/unicode" pkcs "software.sslmate.com/src/go-pkcs12" ) diff --git a/encrypt.go b/encrypt.go index 8471010d..7af7c5d6 100644 --- a/encrypt.go +++ b/encrypt.go @@ -8,10 +8,10 @@ import ( "io" "strings" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" "github.com/microsoft/go-mssqldb/msdsn" - "github.com/swisscom/mssql-always-encrypted/pkg/algorithms" - "github.com/swisscom/mssql-always-encrypted/pkg/encryption" - "github.com/swisscom/mssql-always-encrypted/pkg/keys" ) type ColumnEncryptionType int diff --git a/go.mod b/go.mod index 57de1076..462a2650 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,9 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 github.com/golang-sql/sqlexp v0.1.0 github.com/jcmturner/gokrb5/v8 v8.4.4 - github.com/swisscom/mssql-always-encrypted v0.1.3 + github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.9.0 golang.org/x/sys v0.8.0 golang.org/x/text v0.9.0 software.sslmate.com/src/go-pkcs12 v0.2.0 ) - diff --git a/go.sum b/go.sum index 92daeded..6090bc1d 100644 --- a/go.sum +++ b/go.sum @@ -60,8 +60,6 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/swisscom/mssql-always-encrypted v0.1.3 h1:+Q7sa71G2taM4SmwyNfPIB1iB8750iKNJEJQvqtlB38= -github.com/swisscom/mssql-always-encrypted v0.1.3/go.mod h1:FlEWLI3+svdMFq2w7GVMvk7iVhwBEBi7E7llAHb4B20= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -104,7 +102,6 @@ golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= @@ -118,8 +115,6 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= -gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt b/internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt new file mode 100644 index 00000000..3ece719c --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt @@ -0,0 +1,20 @@ +Copyright (c) 2021 Swisscom (Switzerland) Ltd + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + diff --git a/internal/github.com/swisscom/mssql-always-encrypted/README.md b/internal/github.com/swisscom/mssql-always-encrypted/README.md new file mode 100644 index 00000000..c40de310 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/README.md @@ -0,0 +1,5 @@ +# mssql-always-encrypted + +A library to interact with MSSQL's Always Encrypted features. +This library mostly handles the crpyto part to facilitate +the integration with [go-mssql](https://github.com/denisenkom/go-mssqldb). \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go new file mode 100644 index 00000000..7ccab4db --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go @@ -0,0 +1,98 @@ +package algorithms + +import ( + "bytes" + "fmt" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" +) + +// https://tools.ietf.org/html/draft-mcgrew-aead-aes-cbc-hmac-sha2-05 +// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-TDS/%5bMS-TDS%5d.pdf + +var _ Algorithm = &AeadAes256CbcHmac256Algorithm{} + +type AeadAes256CbcHmac256Algorithm struct { + algorithmVersion byte + deterministic bool + blockSizeBytes int + keySizeBytes int + minimumCipherTextLengthBytesNoAuthTag int + minimumCipherTextLengthBytesWithAuthTag int + cek keys.AeadAes256CbcHmac256 + version []byte + versionSize []byte +} + +func NewAeadAes256CbcHmac256Algorithm(key keys.AeadAes256CbcHmac256, encType encryption.Type, algorithmVersion byte) AeadAes256CbcHmac256Algorithm { + const keySizeBytes = 256 / 8 + const blockSizeBytes = 16 + const minimumCipherTextLengthBytesNoAuthTag = 1 + 2*blockSizeBytes + const minimumCipherTextLengthBytesWithAuthTag = minimumCipherTextLengthBytesNoAuthTag + keySizeBytes + + a := AeadAes256CbcHmac256Algorithm{ + algorithmVersion: algorithmVersion, + deterministic: encType.Deterministic, + blockSizeBytes: blockSizeBytes, + keySizeBytes: keySizeBytes, + cek: key, + minimumCipherTextLengthBytesNoAuthTag: minimumCipherTextLengthBytesNoAuthTag, + minimumCipherTextLengthBytesWithAuthTag: minimumCipherTextLengthBytesWithAuthTag, + version: []byte{0x01}, + versionSize: []byte{1}, + } + + a.version[0] = algorithmVersion + return a +} + +func (a *AeadAes256CbcHmac256Algorithm) Encrypt(bytes []byte) ([]byte, error) { + panic("implement me") +} + +func (a *AeadAes256CbcHmac256Algorithm) Decrypt(ciphertext []byte) ([]byte, error) { + // This algorithm always has the auth tag! + minimumCiphertextLength := a.minimumCipherTextLengthBytesWithAuthTag + + if len(ciphertext) < minimumCiphertextLength { + return nil, fmt.Errorf("invalid ciphertext length: at least %v bytes expected", minimumCiphertextLength) + } + + idx := 0 + if ciphertext[idx] != a.algorithmVersion { + return nil, fmt.Errorf("invalid algorithm version used: %v found but %v expected", ciphertext[idx], + a.algorithmVersion) + } + + idx++ + authTag := ciphertext[idx : idx+a.keySizeBytes] + idx += a.keySizeBytes + + iv := ciphertext[idx : idx+a.blockSizeBytes] + idx += len(iv) + + realCiphertext := ciphertext[idx:] + ourAuthTag := a.prepareAuthTag(iv, realCiphertext) + + if bytes.Compare(ourAuthTag, authTag) != 0 { + return nil, fmt.Errorf("invalid auth tag") + } + + // decrypt + + aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv) + cleartext := aescdbc.Decrypt(realCiphertext) + + return cleartext, nil +} + +func (a *AeadAes256CbcHmac256Algorithm) prepareAuthTag(iv []byte, ciphertext []byte) []byte { + var input = make([]byte, 0) + input = append(input, a.algorithmVersion) + input = append(input, iv...) + input = append(input, ciphertext...) + input = append(input, a.versionSize...) + return crypto.Sha256Hmac(input, a.cek.MacKey()) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go new file mode 100644 index 00000000..ad59b292 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go @@ -0,0 +1,37 @@ +package algorithms_test + +import ( + "encoding/hex" + "testing" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" + "github.com/stretchr/testify/assert" +) + +func TestAeadAes256CbcHmac256Algorithm_Decrypt(t *testing.T) { + expectedResult, err := hex.DecodeString("3100320033003400350020002000200020002000") + if err != nil { + t.Fatal(err) + } + + cipherText, err := hex.DecodeString("0181c4b77e1c50583c5e83a20afd4c98ce5acb39a636f00247b3a4d78a8be319c840e6970541a66723583def227eb774b4234cff209443b0209b75309532b527bdf9b2dfb326b4428840532a20460d06d4") + if err != nil { + t.Fatal(err) + } + + rootKey, err := hex.DecodeString("0ff9e45335df3dec7be0649f741e6ea870e9d49d16fe4be7437ce22489f48ead") + if err != nil { + t.Fatal(err) + } + + key := keys.NewAeadAes256CbcHmac256(rootKey) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) + + result, err := alg.Decrypt(cipherText) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, expectedResult, result) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go new file mode 100644 index 00000000..48a751da --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go @@ -0,0 +1,6 @@ +package algorithms + +type Algorithm interface { + Encrypt([]byte) ([]byte, error) + Decrypt([]byte) ([]byte, error) +} \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go new file mode 100644 index 00000000..64ca57f6 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go @@ -0,0 +1,79 @@ +package alwaysencrypted + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/binary" + "unicode/utf16" +) + +type CEKV struct { + Version int + KeyPath string + Ciphertext []byte + SignedHash []byte + DataToSign []byte + + Key []byte +} + +func (c *CEKV) Verify(cert *x509.Certificate) bool { + sha256Sum := sha256.Sum256(c.DataToSign) + err := rsa.VerifyPKCS1v15(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, sha256Sum[:], c.SignedHash) + + return err == nil +} + +func (c *CEKV) Decrypt(private *rsa.PrivateKey) ([]byte, error) { + decryptedData, decryptErr := rsa.DecryptOAEP(sha1.New(), rand.Reader, private, c.Ciphertext, nil) + if decryptErr != nil { + return nil, decryptErr + } + + return decryptedData, nil +} + +func LoadCEKV(bytes []byte) CEKV { + idx := 0 + version := int(bytes[idx]) + idx++ + + keyPathLengthBytes := bytes[idx : idx+2] + keyPathLength := binary.LittleEndian.Uint16(keyPathLengthBytes) + idx += 2 + + cipherTextLengthBytes := bytes[idx : idx+2] + cipherTextLength := binary.LittleEndian.Uint16(cipherTextLengthBytes) + idx += 2 + + keyPathBytes := bytes[idx : idx+int(keyPathLength)] + idx += int(keyPathLength) + + var uint16Bytes []uint16 + for i := range keyPathBytes { + if i%2 == 0 { + continue + } + uint16Value := binary.LittleEndian.Uint16([]byte{keyPathBytes[i-1], keyPathBytes[i]}) + uint16Bytes = append(uint16Bytes, uint16Value) + } + keyPath := string(utf16.Decode(uint16Bytes)) + + cipherText := bytes[idx : idx+int(cipherTextLength)] + idx += int(cipherTextLength) + + dataToSign := bytes[0:idx] + signedHash := bytes[idx:] + + return CEKV{ + Version: version, + KeyPath: keyPath, + DataToSign: dataToSign, + Ciphertext: cipherText, + SignedHash: signedHash, + } +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go new file mode 100644 index 00000000..33e86125 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go @@ -0,0 +1,138 @@ +package alwaysencrypted + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + "testing" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" + "github.com/stretchr/testify/assert" + "golang.org/x/text/encoding/unicode" +) + +func TestLoadCEKV(t *testing.T) { + certFile, err := os.Open("../test/always-encrypted_pub.pem") + if err != nil { + t.Fatal(err) + } + + certBytes, err := ioutil.ReadAll(certFile) + if err != nil { + t.Fatal(err) + } + pemB, _ := pem.Decode(certBytes) + cert, err := x509.ParseCertificate(pemB.Bytes) + if err != nil { + t.Fatal(nil) + } + + cekvFile, err := os.Open("../test/cekv.key") + if err != nil { + t.Fatal(err) + } + cekvBytes, err := ioutil.ReadAll(cekvFile) + + cekv := LoadCEKV(cekvBytes) + assert.Equal(t, 1, cekv.Version) + assert.True(t, cekv.Verify(cert)) +} +func TestDecrypt(t *testing.T) { + certFile, err := os.Open("../test/always-encrypted.pem") + if err != nil { + t.Fatal(err) + } + + certBytes, err := ioutil.ReadAll(certFile) + if err != nil { + t.Fatal(err) + } + pemB, _ := pem.Decode(certBytes) + privKey, err := x509.ParsePKCS8PrivateKey(pemB.Bytes) + if err != nil { + t.Fatal(err) + } + + rsaPrivKey := privKey.(*rsa.PrivateKey) + + cekvFile, err := os.Open("../test/cekv.key") + if err != nil { + t.Fatal(err) + } + cekvBytes, err := ioutil.ReadAll(cekvFile) + + cekv := LoadCEKV(cekvBytes) + rootKey, err := cekv.Decrypt(rsaPrivKey) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "0ff9e45335df3dec7be0649f741e6ea870e9d49d16fe4be7437ce22489f48ead", fmt.Sprintf("%02x", rootKey)) + assert.Equal(t, 1, cekv.Version) + assert.NotNil(t, rootKey) + + columnBytesFile, err := os.Open("../test/column_value.enc") + if err != nil { + t.Fatal(err) + } + + columnBytes, err := ioutil.ReadAll(columnBytesFile) + if err != nil { + t.Fatal(err) + } + + key := keys.NewAeadAes256CbcHmac256(rootKey) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) + cleartext, err := alg.Decrypt(columnBytes) + + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) + decoder := enc.NewDecoder() + cleartextUtf8, err := decoder.Bytes(cleartext) + if err != nil { + t.Fatal(err) + } + fmt.Printf("column value: \"%02X\"", cleartextUtf8) + assert.Equal(t, "12345 ", string(cleartextUtf8)) +} +func TestDecryptCEK(t *testing.T) { + certFile, err := os.Open("../test/always-encrypted.pem") + if err != nil { + t.Fatal(err) + } + + certFileBytes, err := ioutil.ReadAll(certFile) + if err != nil { + t.Fatal(err) + } + + pemBlock, _ := pem.Decode(certFileBytes) + cert, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) + if err != nil { + t.Fatal(err) + } + + cekvFile, err := os.Open("../test/cekv.key") + if err != nil { + t.Fatal(err) + } + + cekvBytes, err := ioutil.ReadAll(cekvFile) + if err != nil { + t.Fatal(err) + } + + cekv := LoadCEKV(cekvBytes) + fmt.Printf("Cert: %v\n", cert) + + rsaKey := cert.(*rsa.PrivateKey) + + // RSA/ECB/OAEPWithSHA-1AndMGF1Padding + bytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, rsaKey, cekv.Ciphertext, nil) + fmt.Printf("Key: %02x\n", bytes) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go new file mode 100644 index 00000000..6562fca2 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go @@ -0,0 +1,68 @@ +package crypto + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" +) + +// Inspired by: https://gist.github.com/hothero/7d085573f5cb7cdb5801d7adcf66dcf3 + +type AESCbcPKCS5 struct { + key []byte + iv []byte + block cipher.Block +} + +func NewAESCbcPKCS5(key []byte, iv []byte) AESCbcPKCS5 { + a := AESCbcPKCS5{ + key: key, + iv: iv, + block: nil, + } + a.initCipher() + return a +} + +func (a AESCbcPKCS5) Encrypt(cleartext []byte) { + if a.block == nil { + a.initCipher() + } + + blockMode := cipher.NewCBCEncrypter(a.block, a.iv) + paddedCleartext := PKCS5Padding(cleartext, blockMode.BlockSize()) + var cipherText = make([]byte, 0) + blockMode.CryptBlocks(cipherText, paddedCleartext) +} + +func (a AESCbcPKCS5) Decrypt(ciphertext []byte) []byte { + if a.block == nil { + a.initCipher() + } + + blockMode := cipher.NewCBCDecrypter(a.block, a.iv) + var cleartext = make([]byte, len(ciphertext)) + blockMode.CryptBlocks(cleartext, ciphertext) + return PKCS5Trim(cleartext) +} + +func PKCS5Padding(inArr []byte, blockSize int) []byte { + padding := blockSize - len(inArr)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(inArr, padText...) +} + +func PKCS5Trim(inArr []byte) []byte { + padding := inArr[len(inArr)-1] + return inArr[:len(inArr)-int(padding)] +} + +func (a *AESCbcPKCS5) initCipher() { + block, err := aes.NewCipher(a.key) + if err != nil { + panic(fmt.Errorf("unable to create cipher: %v", err)) + } + + a.block = block +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go new file mode 100644 index 00000000..b8f9319f --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go @@ -0,0 +1,12 @@ +package crypto + +import ( + "crypto/hmac" + "crypto/sha256" +) + +func Sha256Hmac(input []byte, key []byte) []byte { + sha256Hmac := hmac.New(sha256.New, key) + sha256Hmac.Write(input) + return sha256Hmac.Sum(nil) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go new file mode 100644 index 00000000..b38cccd6 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go @@ -0,0 +1,37 @@ +package encryption + +type Type struct { + Deterministic bool + Name string + Value byte +} + +var Plaintext = Type{ + Deterministic: false, + Name: "Plaintext", + Value: 0, +} + +var Deterministic = Type{ + Deterministic: true, + Name: "Deterministic", + Value: 1, +} + +var Randomized = Type{ + Deterministic: false, + Name: "Randomized", + Value: 2, +} + +func From(encType byte) Type { + switch encType { + case 0: + return Plaintext + case 1: + return Deterministic + case 2: + return Randomized + } + return Plaintext +} \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go new file mode 100644 index 00000000..4c1dba15 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go @@ -0,0 +1,51 @@ +package keys + +import ( + "fmt" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils" +) + +var _ Key = &AeadAes256CbcHmac256{} + +type AeadAes256CbcHmac256 struct { + rootKey []byte + encryptionKey []byte + macKey []byte + ivKey []byte +} + +func NewAeadAes256CbcHmac256(rootKey []byte) AeadAes256CbcHmac256 { + const keySize = 256 + const encryptionKeySaltFormat = "Microsoft SQL Server cell encryption key with encryption algorithm:%v and key length:%v" + const macKeySaltFormat = "Microsoft SQL Server cell MAC key with encryption algorithm:%v and key length:%v" + const ivKeySaltFormat = "Microsoft SQL Server cell IV key with encryption algorithm:%v and key length:%v" + const algorithmName = "AEAD_AES_256_CBC_HMAC_SHA256" + + encryptionKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(encryptionKeySaltFormat, algorithmName, keySize)) + macKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(macKeySaltFormat, algorithmName, keySize)) + ivKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(ivKeySaltFormat, algorithmName, keySize)) + + return AeadAes256CbcHmac256{ + rootKey: rootKey, + encryptionKey: crypto.Sha256Hmac(encryptionKeySalt, rootKey), + macKey: crypto.Sha256Hmac(macKeySalt, rootKey), + ivKey: crypto.Sha256Hmac(ivKeySalt, rootKey)} +} + +func (a AeadAes256CbcHmac256) IvKey() []byte { + return a.ivKey +} + +func (a AeadAes256CbcHmac256) MacKey() []byte { + return a.macKey +} + +func (a AeadAes256CbcHmac256) EncryptionKey() []byte { + return a.encryptionKey +} + +func (a AeadAes256CbcHmac256) RootKey() []byte { + return a.rootKey +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go new file mode 100644 index 00000000..9e6e0161 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go @@ -0,0 +1,5 @@ +package keys + +type Key interface { + RootKey() []byte +} \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go new file mode 100644 index 00000000..52c2c792 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go @@ -0,0 +1,18 @@ +package utils + +import ( + "encoding/binary" + "unicode/utf16" +) + +func ConvertUTF16ToLittleEndianBytes(u []uint16) []byte { + b := make([]byte, 2*len(u)) + for index, value := range u { + binary.LittleEndian.PutUint16(b[index*2:], value) + } + return b +} + +func ProcessUTF16LE(inputString string) []byte { + return ConvertUTF16ToLittleEndianBytes(utf16.Encode([]rune(inputString))) +} \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem new file mode 100644 index 00000000..382ab002 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDFkCQfKsTM0UMw +pRA4t/kWYWTFUMJyEu5i4Zhw4yL0emFUcDdLqmwqrgR2eC1RFb9CU4UnLowVoq6H +vXlYuHnlCQio1QwLio80WoHBezdU/TQBBbpm5D2FxOLhzen2Puby+PQd+ZIBVuQu +I920g5O0wke86RmWmNdu0jaftwNqoIqc/TAqjYNKB2/CwnPnHwsCHJjIhCoSGlCa +WsQZSptSeqLQ87eaVfJypJpxG5FJ+bOXjFdgpXY3XOQoeR+xsXs2AKZ+eKOaSmb+ +Hg+pvMGBCXuSwBIAwPUxlCQSe2dfcXTkF+stadfH6EvVyIvK0G8RZ9N0Ow5vyRaU +95Bxc+2BAgMBAAECggEBAKJmz9qy/J3lc5ccSQ5m5SJpoz20GnNNbproGbjKbiSM +KVARAtN3X31iGRcNySq7dsJeB7niwJLUbSX2MjclRkZpO64Vm9Ys63U85ScYU67Q +iZxBii4kdxJse5jk/OtIX+7hiULOsh/Zvq7TGt/VvWi8v93hvAAY2hcmRHLcLbnK +li9DLnN3dIJoFh3y2OHlFfvFcX04wNmyfv04/FZKliGwrONkTN1YvEclU3XSjdrH +JM2977u+rB216Y1jiIObFceKj573hBAwS+gU2kx7g9Fpq9SvwszxmHMWtJQvJxg+ +7ClBeB8aSu1wSydm/0hfmwFNBH9c4BDVo3P1+K37PQUCgYEA8Lnceo9S4NOog5ri +taSVUqoHjruRU2tqFFi1wni+dw0m99kd5h8p9K0aXwvvjP8cmpK/ultSVZb9NzEz +zA5ZXXxT83QZOmq4FJCl31tjhcA/oidD139dCpe3RQ08ToClJgOuG8obS0hgy9Xt +sa16HgYP4aDerEgXR2fg3TWW1icCgYEA0hkt2FXFTh8L9z3nb/a8TNGBgVlafxcV +d4m1HhDoJ+GF8yscvUq7kn4xG2BHA5GNnUn0hIfrci/A0CXNGVOeUufgOUBKw39V +5Wq26ryElDcQ7CyJ36yH8/zQ4jgUOVo+R+jSO0+L4H1T/vP9F1ARtORb0/Ga5JFq +pxh6Q5VB0BcCgYEAh/2Hd1lGSapolUhHcLP0g0l4kYKWu5h/ydS/gYgymRC+BeAK +yvip/AZaUn1sq6tm3k+urjluztlIXQiXqVwl0fEtf+gDZIPrT/rTKdX36BROHm2u +HqxdxGEm8IRkoDh+k3YawqovNx1BSYWmDOzigtmL2TvG726ecAFX/7+JYZsCgYAf +kHTYyZoI8JUlogFBSvpjOB6Sxk/YRCmPefrh93xJcZJkRBffQHkJuze5ey9wE9AI +z3GS77CpyQ7YtrUnlu50Wi3PrB8PW/QVsYClp4jrk5JRSSe1mQAb4eGn+vDe5PXy +a8IZ8wt6wJl79kAR3o+qc5xwLR4uNMKnNAA6YxQuJQKBgQCIjo++s0i1pxf60CaL +2Mph/sDztdv0nZMPZzN0j2HGGJ21tKi3O+V+VoHHIs2YYjTsFu5Iwc7LONiGN+SF +38ojT7uWyY4Jz+9Sr4uYTJvWLc9G4BCkco3RNowLK8tb6TfewajWXeAzlz/Eafmj +nlUFODdXG+URQ5tpDjdCd6zbpQ== +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem new file mode 100644 index 00000000..b0b4a9e5 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDKjCCAhKgAwIBAgIQRlupjX13FaVC/c36tbVQxzANBgkqhkiG9w0BAQsFADAn +MSUwIwYDVQQDDBxBbHdheXMgRW5jcnlwdGVkIENlcnRpZmljYXRlMB4XDTIxMDEy +NjE1MDgyMloXDTIyMDEyNjE1MDgyMlowJzElMCMGA1UEAwwcQWx3YXlzIEVuY3J5 +cHRlZCBDZXJ0aWZpY2F0ZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AMWQJB8qxMzRQzClEDi3+RZhZMVQwnIS7mLhmHDjIvR6YVRwN0uqbCquBHZ4LVEV +v0JThScujBWiroe9eVi4eeUJCKjVDAuKjzRagcF7N1T9NAEFumbkPYXE4uHN6fY+ +5vL49B35kgFW5C4j3bSDk7TCR7zpGZaY127SNp+3A2qgipz9MCqNg0oHb8LCc+cf +CwIcmMiEKhIaUJpaxBlKm1J6otDzt5pV8nKkmnEbkUn5s5eMV2Cldjdc5Ch5H7Gx +ezYApn54o5pKZv4eD6m8wYEJe5LAEgDA9TGUJBJ7Z19xdOQX6y1p18foS9XIi8rQ +bxFn03Q7Dm/JFpT3kHFz7YECAwEAAaNSMFAwHwYDVR0lBBgwFgYIKwYBBQUIAgIG +CisGAQQBgjcKAwswHQYDVR0OBBYEFNQfS2liOJPsJuonIc0KPF4+CtFIMA4GA1Ud +DwEB/wQEAwIFIDANBgkqhkiG9w0BAQsFAAOCAQEAKMzuAfIv6uGxgx+SGgjDqk2O +oVdRul5xB/QlChdhzTrMwpIdul0+eLo46gqPdj/5kxWhQGNMuns+5/QrSfbaqAUz +ZWFsNAm+bhTBsgy9VSor3QUGedfQV3fP/8aZ/nvgLUe7PegmFBIiSALyjvCdayb5 +UZIxcBGQTmmpqGmL0hnRQwE2JvneOGEAiIIOTObCzgWyKhKuF2DWxinBtzyRlXfD +TV15+7v5kAdrjLevk57NOEshr0IDirD9auI61bqoxJZFyDqkdLZWED69pbCF8Ly5 +zbC8uUnDh3enxgmnUPXU/JZM1dbiPHZBxkUjVOoMYxycr0YgROJk7w5cfjrMYQ== +-----END CERTIFICATE----- diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key b/internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key new file mode 100644 index 0000000000000000000000000000000000000000..d26e9f9eedbdf4a3744f65c4a6b3939ec3ee01ed GIT binary patch literal 627 zcmZR~V_;xRW+-JS0>V^=Jcbe=yBNqSV$f&EWvB#_1`J6+z9oY>g9VV42qX=G{8WY% z1~VYb00<2kOc~OEe2~03Lo!fJDv)mq7BvK`G69Pk1Le&aJkq~zPu_lh`)s*?tmnV( z`Lc)C=4`A(#4&mQ!=In^O@4H$Xr|NkeHLM?&E?Y!g(?F+9lXzH1^G^rhY|n{(pvH2<*LH(I-wN9^{SckSwTSS7yWB+v8MhN$sT|Z}#xRJko zYQ`Q*(e7QF136p{mLD&ebI!bN+r-C@4SasC5pJ9nDsXUd>BBjbq<_8PTbktM#xs9z z(2bW_dWG_~1+l!^=7pR*`>r(!seJdb6LjJCjI>o@T9LMX7vsjADam)1w9mS~=!3!A z^JPL83pj)sy(NWxo}c(@){s1FHPc${3o9iqUfIx?{;bA%&AGm{ZejT!Y-a8@y!*4r zr~HNE&ThsBub5pDPVm$_N~gY0mbOTKU~6;bP~x*Arb&rs!?rEZ=sa6u7yV_mE4!=O zxuuQ!uUvAtW2AmXe`=F^(1a7hAJX_V&+<>PS!nLM`piMbYp(*BCLUCtf2i(%WnlMx JzsBF&WdU&e5_tds literal 0 HcmV?d00001 diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc b/internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc new file mode 100644 index 00000000..b3243a4c --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc @@ -0,0 +1,2 @@ +Ä·~PX<^ƒ¢ +ýL˜ÎZË9¦6ðG³¤×Š‹ãÈ@æ—A¦g#X=ï"~·t´#Lÿ ”C° ›u0•2µ'½ù²ß³&´Bˆ@S* F Ô \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key b/internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key new file mode 100644 index 0000000000000000000000000000000000000000..d26e9f9eedbdf4a3744f65c4a6b3939ec3ee01ed GIT binary patch literal 627 zcmZR~V_;xRW+-JS0>V^=Jcbe=yBNqSV$f&EWvB#_1`J6+z9oY>g9VV42qX=G{8WY% z1~VYb00<2kOc~OEe2~03Lo!fJDv)mq7BvK`G69Pk1Le&aJkq~zPu_lh`)s*?tmnV( z`Lc)C=4`A(#4&mQ!=In^O@4H$Xr|NkeHLM?&E?Y!g(?F+9lXzH1^G^rhY|n{(pvH2<*LH(I-wN9^{SckSwTSS7yWB+v8MhN$sT|Z}#xRJko zYQ`Q*(e7QF136p{mLD&ebI!bN+r-C@4SasC5pJ9nDsXUd>BBjbq<_8PTbktM#xs9z z(2bW_dWG_~1+l!^=7pR*`>r(!seJdb6LjJCjI>o@T9LMX7vsjADam)1w9mS~=!3!A z^JPL83pj)sy(NWxo}c(@){s1FHPc${3o9iqUfIx?{;bA%&AGm{ZejT!Y-a8@y!*4r zr~HNE&ThsBub5pDPVm$_N~gY0mbOTKU~6;bP~x*Arb&rs!?rEZ=sa6u7yV_mE4!=O zxuuQ!uUvAtW2AmXe`=F^(1a7hAJX_V&+<>PS!nLM`piMbYp(*BCLUCtf2i(%WnlMx JzsBF&WdU&e5_tds literal 0 HcmV?d00001 diff --git a/token.go b/token.go index a2bcc62b..58c3f5d9 100644 --- a/token.go +++ b/token.go @@ -11,10 +11,10 @@ import ( "strconv" "github.com/golang-sql/sqlexp" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" "github.com/microsoft/go-mssqldb/msdsn" - "github.com/swisscom/mssql-always-encrypted/pkg/algorithms" - "github.com/swisscom/mssql-always-encrypted/pkg/encryption" - "github.com/swisscom/mssql-always-encrypted/pkg/keys" "golang.org/x/text/encoding/unicode" ) From a98b1fd8ef8988107aeffc3688abf007282ae9e0 Mon Sep 17 00:00:00 2001 From: davidshi Date: Thu, 6 Jul 2023 14:29:53 -0500 Subject: [PATCH 15/47] implement Encrypt --- encrypt.go | 8 +++--- .../aead_aes_256_cbc_hmac_sha256.go | 26 +++++++++++++++++-- .../pkg/crypto/aes_cbc_pkcs5.go | 11 ++++---- mssql.go | 2 +- rpc.go | 1 - types.go | 2 +- 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/encrypt.go b/encrypt.go index 7af7c5d6..c45705ae 100644 --- a/encrypt.go +++ b/encrypt.go @@ -80,7 +80,7 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg } paramMap := make(map[string]paramMapEntry) for _, p := range paramsInfo { - paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], &p} + paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], p} } encryptedArgs = make([]namedValue, len(args)) for i, a := range args { @@ -225,7 +225,7 @@ func appendPrefixedParameterName(b *strings.Builder, p string) { } } -func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, paramInfo []parameterEncData, err error) { +func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, paramInfo []*parameterEncData, err error) { cekInfo = make([]*cekData, 0) values := make([]driver.Value, 9) qerr := rows.Next(values) @@ -255,10 +255,10 @@ func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, p if err != nil { return } - paramInfo = make([]parameterEncData, 0) + paramInfo = make([]*parameterEncData, 0) qerr = rows.Next(values[:6]) for qerr == nil { - paramInfo = append(paramInfo, parameterEncData{ordinal: int(values[0].(int64)), + paramInfo = append(paramInfo, ¶meterEncData{ordinal: int(values[0].(int64)), name: values[1].(string), algorithm: int(values[2].(int64)), encType: ColumnEncryptionType(values[3].(int64)), diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go index 7ccab4db..1f994332 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go @@ -2,6 +2,7 @@ package algorithms import ( "bytes" + "crypto/rand" "fmt" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto" @@ -48,8 +49,29 @@ func NewAeadAes256CbcHmac256Algorithm(key keys.AeadAes256CbcHmac256, encType enc return a } -func (a *AeadAes256CbcHmac256Algorithm) Encrypt(bytes []byte) ([]byte, error) { - panic("implement me") +func (a *AeadAes256CbcHmac256Algorithm) Encrypt(cleartext []byte) ([]byte, error) { + buf := make([]byte, 0) + var iv []byte + if a.deterministic { + iv = crypto.Sha256Hmac(cleartext, a.cek.IvKey()) + if len(iv) > a.blockSizeBytes { + iv = iv[:a.blockSizeBytes] + } + } else { + iv = make([]byte, a.blockSizeBytes) + _, err := rand.Read(iv) + if err != nil { + panic(err) + } + } + buf = append(buf, a.algorithmVersion) + authTag := a.prepareAuthTag(iv, cleartext) + buf = append(buf, authTag...) + buf = append(buf, iv...) + aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv) + ciphertext := aescdbc.Encrypt(cleartext) + buf = append(buf, ciphertext...) + return buf, nil } func (a *AeadAes256CbcHmac256Algorithm) Decrypt(ciphertext []byte) ([]byte, error) { diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go index 6562fca2..4ea2e5be 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go @@ -10,9 +10,9 @@ import ( // Inspired by: https://gist.github.com/hothero/7d085573f5cb7cdb5801d7adcf66dcf3 type AESCbcPKCS5 struct { - key []byte - iv []byte - block cipher.Block + key []byte + iv []byte + block cipher.Block } func NewAESCbcPKCS5(key []byte, iv []byte) AESCbcPKCS5 { @@ -25,15 +25,16 @@ func NewAESCbcPKCS5(key []byte, iv []byte) AESCbcPKCS5 { return a } -func (a AESCbcPKCS5) Encrypt(cleartext []byte) { +func (a AESCbcPKCS5) Encrypt(cleartext []byte) (cipherText []byte) { if a.block == nil { a.initCipher() } blockMode := cipher.NewCBCEncrypter(a.block, a.iv) paddedCleartext := PKCS5Padding(cleartext, blockMode.BlockSize()) - var cipherText = make([]byte, 0) + cipherText = make([]byte, len(paddedCleartext)) blockMode.CryptBlocks(cipherText, paddedCleartext) + return } func (a AESCbcPKCS5) Decrypt(ciphertext []byte) []byte { diff --git a/mssql.go b/mssql.go index dfc883e9..feac88d0 100644 --- a/mssql.go +++ b/mssql.go @@ -667,7 +667,7 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, } params[i+offset].cipherInfo = metadata params[i+offset].ti.TypeId = typeBigVarBin - params[i+offset].ti.Buffer = encryptedBytes + params[i+offset].buffer = encryptedBytes params[i+offset].ti.Size = 0 } decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output) diff --git a/rpc.go b/rpc.go index 17a4e5f0..8f1ef2b4 100644 --- a/rpc.go +++ b/rpc.go @@ -86,7 +86,6 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, if err != nil { return } - param.tiOriginal.Writer(buf, param.tiOriginal, param.buffer) if _, err = buf.Write(param.cipherInfo); err != nil { return } diff --git a/types.go b/types.go index 1fd25a0d..e0af243c 100644 --- a/types.go +++ b/types.go @@ -440,7 +440,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} default: badStreamPanicf("Invalid typeid") } - panic("shoulnd't get here") + panic("shouldn't get here") } func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { From 0dcb6026edcc1a1bac35180ce1c91f9a654a7cc6 Mon Sep 17 00:00:00 2001 From: davidshi Date: Thu, 6 Jul 2023 15:57:25 -0500 Subject: [PATCH 16/47] don't claim to support enclaves --- alwaysencrypted_windows_test.go | 2 +- tds.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index 2ec6fd92..f773b104 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -76,7 +76,7 @@ const ( dropColumnEncryptionKey = `DROP COLUMN ENCRYPTION KEY [%s]` createEncryptedTable = `CREATE TABLE mssqlAlwaysEncrypted (col1 int - ENCRYPTED WITH (ENCRYPTION_TYPE = RANDOMIZED, + ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', COLUMN_ENCRYPTION_KEY = [%s]), col2 nchar(10) COLLATE Latin1_General_BIN2 diff --git a/tds.go b/tds.go index 772c98de..6292b682 100644 --- a/tds.go +++ b/tds.go @@ -1370,5 +1370,5 @@ func (f *featureExtColumnEncryption) toBytes() []byte { with the additional ability to cache column encryption keys that are to be sent to the enclave and the ability to retry queries when the keys sent by the client do not match what is needed for the query to run. */ - return []byte{0x02} + return []byte{0x01} } From 2346b5d5f92ed25dea186342ddc4ad8e794b086b Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 10 Jul 2023 14:11:33 -0500 Subject: [PATCH 17/47] fix encrypt --- alwaysencrypted_windows_test.go | 27 ++++++++++------- encrypt.go | 10 +++---- .../aead_aes_256_cbc_hmac_sha256.go | 6 ++-- mssql.go | 4 ++- tds_test.go | 29 +++++++++++++++---- 5 files changed, 52 insertions(+), 24 deletions(-) diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index f773b104..cf868065 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -30,24 +30,27 @@ func TestAlwaysEncryptedE2E(t *testing.T) { defer conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) r, _ := rand.Int(rand.Reader, big.NewInt(1000)) cekName := fmt.Sprintf("mssqlCek%d", r.Int64()) - encryptedCek := localcert.WindowsCertificateStoreKeyProvider.EncryptColumnEncryptionKey(certPath, KeyEncryptionAlgorithm, []byte(certPath)) + tableName := fmt.Sprintf("mssqlAe%d", r.Int64()) + keyBytes := make([]byte, 32) + _, _ = rand.Read(keyBytes) + encryptedCek := localcert.WindowsCertificateStoreKeyProvider.EncryptColumnEncryptionKey(certPath, KeyEncryptionAlgorithm, keyBytes) createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek) _, err = conn.Exec(createCek) if err != nil { t.Fatalf("Unable to create CEK: %s", err.Error()) } defer conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName)) - _, _ = conn.Exec("DROP TABLE IF EXISTS mssqlAlwaysEncrypted") - _, err = conn.Exec(fmt.Sprintf(createEncryptedTable, cekName, cekName)) + _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) + _, err = conn.Exec(fmt.Sprintf(createEncryptedTable, tableName, cekName, cekName)) if err != nil { t.Fatalf("Failed to create encrypted table %s", err.Error()) } - defer conn.Exec("DROP TABLE IF EXISTS mssqlAlwaysEncrypted") - _, err = conn.Exec("INSERT INTO mssqlAlwaysEncrypted VALUES (@p1, @p2)", int32(1), NChar("mycol2")) + defer conn.Exec("DROP TABLE IF EXISTS " + tableName) + _, err = conn.Exec("INSERT INTO "+tableName+" VALUES (@p1, @p2)", int32(1), NChar("mycol2")) if err != nil { - t.Fatalf("Failed to insert row in encrypted table %s", err.Error()) + t.Logf("Failed to insert row in encrypted table %s", err.Error()) } - rows, err := conn.Query("select top (1) col1, col2 from mssqlAlwaysEncrypted") + rows, err := conn.Query("select top (1) col1, col2 from " + tableName) if err != nil { t.Fatalf("Unable to query encrypted columns: %v", err.(Error).All) } @@ -55,13 +58,17 @@ func TestAlwaysEncryptedE2E(t *testing.T) { rows.Close() t.Fatalf("rows.Next returned false") } - var col1 string - var col2 int32 + var col1 int32 + var col2 string err = rows.Scan(&col1, &col2) if err != nil { rows.Close() t.Fatalf("rows.Scan failed: %s", err.Error()) } + if col1 != 1 || col2 != "mycol2" { + rows.Close() + t.Fatalf("Got incorrect scan values %d and %s", col1, col2) + } rows.Close() err = rows.Err() if err != nil { @@ -74,7 +81,7 @@ const ( dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` createColumnEncryptionKey = `CREATE COLUMN ENCRYPTION KEY [%s] WITH VALUES (COLUMN_MASTER_KEY = [%s], ALGORITHM = 'RSA_OAEP', ENCRYPTED_VALUE = 0x%x )` dropColumnEncryptionKey = `DROP COLUMN ENCRYPTION KEY [%s]` - createEncryptedTable = `CREATE TABLE mssqlAlwaysEncrypted + createEncryptedTable = `CREATE TABLE %s (col1 int ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', diff --git a/encrypt.go b/encrypt.go index c45705ae..8377b313 100644 --- a/encrypt.go +++ b/encrypt.go @@ -11,7 +11,6 @@ import ( "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" - "github.com/microsoft/go-mssqldb/msdsn" ) type ColumnEncryptionType int @@ -73,7 +72,9 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg if err != nil { return } - s.c.sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("cekInfo: %v\nparamsInfo:%v\n", cekInfo, paramsInfo)) + if len(cekInfo) == 0 { + return args, nil + } err = s.decryptCek(cekInfo) if err != nil { return @@ -161,7 +162,7 @@ func getEncryptor(info paramMapEntry) valueEncryptor { // CekVersion (ulong) // CekMDVersion (ulonglong) - really a byte array // NormVersion (byte) - // algo+ enctype+ dbid+ keyid+ keyver= normversion + // algo+ enctype+ dbid+ keyid+ keyver+ normversion metadataLen := 1 + 1 + 4 + 4 + 4 + 1 metadataLen += len(info.cek.metadataVersion) metadata := make([]byte, metadataLen) @@ -245,9 +246,8 @@ func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, p if len(cekInfo) == 0 || qerr != io.EOF { if qerr != io.EOF { err = qerr - } else { - err = fmt.Errorf("No column encryption key rows were returned from sp_describe_parameter_encryption") } + // No encryption needed return } r := rows.(driver.RowsNextResultSet) diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go index 1f994332..cb4def7c 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go @@ -65,11 +65,11 @@ func (a *AeadAes256CbcHmac256Algorithm) Encrypt(cleartext []byte) ([]byte, error } } buf = append(buf, a.algorithmVersion) - authTag := a.prepareAuthTag(iv, cleartext) - buf = append(buf, authTag...) - buf = append(buf, iv...) aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv) ciphertext := aescdbc.Encrypt(cleartext) + authTag := a.prepareAuthTag(iv, ciphertext) + buf = append(buf, authTag...) + buf = append(buf, iv...) buf = append(buf, ciphertext...) return buf, nil } diff --git a/mssql.go b/mssql.go index feac88d0..b1dd71cc 100644 --- a/mssql.go +++ b/mssql.go @@ -654,6 +654,7 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, if isOutputValue(val.Value) { output = outputSuffix } + tiDecl := params[i+offset].ti if val.encrypt != nil { // Encrypted parameters have a few requirements: // 1. Copy original typeinfo to a block after the data @@ -670,7 +671,8 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, params[i+offset].buffer = encryptedBytes params[i+offset].ti.Size = 0 } - decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output) + + decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(tiDecl), output) } return params, decls, nil diff --git a/tds_test.go b/tds_test.go index 6b1c6481..5b540e8f 100644 --- a/tds_test.go +++ b/tds_test.go @@ -35,7 +35,7 @@ func TestConstantsDefined(t *testing.T) { // This test is just here to avoid complaints about unused code. // These constants are part of the spec but not yet used. for _, b := range []byte{ - featExtSESSIONRECOVERY, featExtCOLUMNENCRYPTION, featExtGLOBALTRANSACTIONS, + featExtSESSIONRECOVERY, featExtGLOBALTRANSACTIONS, featExtAZURESQLSUPPORT, featExtDATACLASSIFICATION, featExtUTF8SUPPORT, } { if b == 0 { @@ -131,7 +131,8 @@ func TestSendLoginWithFeatureExt(t *testing.T) { if err != nil { t.Error("sendLogin should succeed") } - ref := []byte{ + // featureext ordering is non-deterministic + ref1 := []byte{ 16, 1, 0, 0xe5, 0, 0, 1, 0, 0xdd, 0, 0, 0, 4, 0, 0, 116, 0, 16, 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, @@ -146,11 +147,29 @@ func TestSendLoginWithFeatureExt(t *testing.T) { 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, 0, 0, 2, 29, 0, 0, 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 4, 1, - 0, 0, 0, 2, 255} + 0, 0, 0, 1, 255} + ref2 := []byte{ + 16, 1, 0, 0xe5, 0, 0, 1, 0, 0xdd, 0, 0, 0, 4, 0, 0, 116, + 0, 16, 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, + 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, + 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, + 176, 0, 4, 0, 142, 0, 7, 0, 156, 0, 2, 0, 160, 0, 8, 0, + 18, 52, 86, 120, 144, 171, 176, 0, 0, 0, 176, 0, 0, 0, 176, 0, + 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, 0, 100, 0, 101, 0, + 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, 109, 0, + 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, + 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114, 0, 97, 0, + 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, + 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, 0, 0, 4, 1, + 0, 0, 0, 1, 2, 29, 0, 0, + 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, + 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 255} out := memBuf.Bytes() - if !bytes.Equal(ref, out) { + if !bytes.Equal(ref1, out) && !bytes.Equal(ref2, out) { t.Log("Expected:") - t.Log(hex.Dump(ref)) + t.Log(hex.Dump(ref1)) + t.Log("Or:") + t.Log(hex.Dump(ref2)) t.Log("Returned:") t.Log(hex.Dump(out)) t.Fatal("input output don't match") From c4bd2b1a9b7d7c1d0bc341bb5aef21d93dbe08bb Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 10 Jul 2023 14:40:12 -0500 Subject: [PATCH 18/47] close Rows when done --- encrypt.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/encrypt.go b/encrypt.go index 8377b313..da35e02a 100644 --- a/encrypt.go +++ b/encrypt.go @@ -69,6 +69,7 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg return } cekInfo, paramsInfo, err := processDescribeParameterEncryption(rows) + rows.Close() if err != nil { return } @@ -109,13 +110,12 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg // [ @params = ] N'parameters' // [ ;] func (s *Stmt) prepareEncryptionQuery(isProc bool, q string, args []namedValue) (newArgs []namedValue, err error) { + newArgs = make([]namedValue, 2) if isProc { - newArgs = make([]namedValue, 1) newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: buildStoredProcedureStatementForColumnEncryption(q, args)} - return + } else { + newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: q} } - newArgs = make([]namedValue, 2) - newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: q} params, err := s.buildParametersForColumnEncryption(args) if err != nil { return From 0d97e9e2039d515e061549e1f14ea8bd0ec39928 Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 10 Jul 2023 17:15:28 -0500 Subject: [PATCH 19/47] fix bulk copy --- bulkcopy.go | 4 ++++ internal/certs/certs.go | 1 + tds.go | 3 +-- token.go | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/bulkcopy.go b/bulkcopy.go index 97edb7be..15512a9e 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -250,6 +250,10 @@ func (b *Bulk) createColMetadata() []byte { buf.WriteByte(byte(tokenColMetadata)) // token binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count + // TODO: Write a valid CEK table if any parameters have cekTableEntry values + if b.cn.sess.alwaysEncrypted { + binary.Write(buf, binary.LittleEndian, uint16(0)) + } for i, col := range b.bulkColumns { if b.cn.sess.loginAck.TDSVersion >= verTDS72 { diff --git a/internal/certs/certs.go b/internal/certs/certs.go index 24e372a7..dfa9a969 100644 --- a/internal/certs/certs.go +++ b/internal/certs/certs.go @@ -9,6 +9,7 @@ import ( "github.com/Microsoft/go-winio/pkg/guid" ) +// TODO: Create a Linux equivalent. const ( createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 -HashAlgorithm 'SHA256' | select {$_.Thumbprint}` deleteUserCertScript = `Get-ChildItem Cert:\CurrentUser\My\%s | Remove-Item -DeleteKey` diff --git a/tds.go b/tds.go index 6292b682..dfa4a6aa 100644 --- a/tds.go +++ b/tds.go @@ -1159,9 +1159,8 @@ initiate_connection: buf: outbuf, logger: logger, logFlags: uint64(p.LogFlags), - aeSettings: &alwaysEncryptedSettings{keyProviders: make(aecmk.ColumnEncryptionKeyProviderMap)}, + aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()}, } - sess.aeSettings.keyProviders = aecmk.GetGlobalCekProviders() for i, p := range c.keyProviders { sess.aeSettings.keyProviders[i] = p diff --git a/token.go b/token.go index 58c3f5d9..8524a68e 100644 --- a/token.go +++ b/token.go @@ -1032,7 +1032,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS ch <- done if done.Status&doneCount != 0 { if sess.logFlags&logRows != 0 { - sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) + sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(Rows affected: %d)", done.RowCount)) } if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil { From 954472a494dfa3a0aa75ddc3c8da09b55c6c4ae6 Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 10 Jul 2023 19:12:39 -0500 Subject: [PATCH 20/47] fix return value --- encrypt.go | 7 ++++++- token.go | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/encrypt.go b/encrypt.go index da35e02a..1cb81714 100644 --- a/encrypt.go +++ b/encrypt.go @@ -59,6 +59,9 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg query: "sp_describe_parameter_encryption", skipEncryption: true, } + oldouts := s.c.outs + s.c.clearOuts() + newArgs, err := s.prepareEncryptionQuery(isProc(s.query), s.query, args) if err != nil { return @@ -66,10 +69,12 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg // TODO: Consider not using recursion rows, err := q.queryContext(ctx, newArgs) if err != nil { + s.c.outs = oldouts return } cekInfo, paramsInfo, err := processDescribeParameterEncryption(rows) rows.Close() + s.c.outs = oldouts if err != nil { return } @@ -95,7 +100,7 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg } info := paramMap[name] - if info.p.encType == ColumnEncryptionPlainText { + if info.p.encType == ColumnEncryptionPlainText || a.Value == nil { continue } diff --git a/token.go b/token.go index 8524a68e..694bd3e7 100644 --- a/token.go +++ b/token.go @@ -919,7 +919,7 @@ func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) { ti := getBaseTypeInfo(r, true) // UserType + Flags + TypeInfo var cryptoMetadata *cryptoMetadata = nil - if s.alwaysEncrypted { + if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted { cm := parseCryptoMetadata(r, nil) // CryptoMetadata cryptoMetadata = &cm } From fc4e1d88089caeb8d8fea0b94c9fca6f206286f0 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 11 Jul 2023 12:04:29 -0500 Subject: [PATCH 21/47] fix unnamed params to sprocs --- aecmk/localcert/keyprovider.go | 7 ++++--- appveyor.yml | 2 +- encrypt.go | 15 ++++++++++----- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index 900f2111..d2f8f201 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -7,7 +7,6 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/x509" - "encoding/binary" "fmt" "io/ioutil" "os" @@ -149,15 +148,17 @@ func (p *LocalCertProvider) EncryptColumnEncryptionKey(masterKeyPath string, enc if err != nil { panic(fmt.Errorf("Unable to serialize key path %w", err)) } + k := uint16(len(keyPathBytes)) // keyPathLength - buf = binary.LittleEndian.AppendUint16(buf, uint16(len(keyPathBytes))) + buf = append(buf, byte(k), byte(k>>8)) cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, cek, []byte{}) if err != nil { panic(fmt.Errorf("Unable to encrypt data %w", err)) } + l := uint16(len(cipherText)) // ciphertextLength - buf = binary.LittleEndian.AppendUint16(buf, uint16(len(cipherText))) + buf = append(buf, byte(l), byte(l>>8)) // keypath buf = append(buf, keyPathBytes...) // ciphertext diff --git a/appveyor.yml b/appveyor.yml index dafa9729..4b90a402 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -70,7 +70,7 @@ install: - go get -u github.com/golang-sql/civil - go get -u github.com/golang-sql/sqlexp - go get -u golang.org/x/crypto/md4 - + - go get -u golang.org/x/text/encoding/unicode build_script: - go build diff --git a/encrypt.go b/encrypt.go index 1cb81714..86add1e8 100644 --- a/encrypt.go +++ b/encrypt.go @@ -61,12 +61,11 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg } oldouts := s.c.outs s.c.clearOuts() - newArgs, err := s.prepareEncryptionQuery(isProc(s.query), s.query, args) if err != nil { return } - // TODO: Consider not using recursion + // TODO: Consider not using recursion. rows, err := q.queryContext(ctx, newArgs) if err != nil { s.c.outs = oldouts @@ -212,9 +211,15 @@ func buildStoredProcedureStatementForColumnEncryption(sproc string, args []named } first = false b.WriteRune(' ') - appendPrefixedParameterName(b, a.Name) - b.WriteRune('=') - appendPrefixedParameterName(b, a.Name) + name := a.Name + if len(name) == 0 { + name = fmt.Sprintf("@p%d", a.Ordinal) + } + appendPrefixedParameterName(b, name) + if len(a.Name) > 0 { + b.WriteRune('=') + appendPrefixedParameterName(b, a.Name) + } if isOutputValue(a.Value) { b.WriteString(" OUTPUT") } From 9c6c679939162787ace075345fc47ccd6127c1e6 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 11 Jul 2023 13:38:09 -0500 Subject: [PATCH 22/47] update readme --- README.md | 49 ++++++++++++++++++- .../pkg/encryption/type.go | 2 +- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 03162f3e..75ea0a20 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ Other supported formats are listed below. * `Workstation ID` - The workstation name (default is the host name) * `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. * `protocol` - forces use of a protocol. Make sure the corresponding package is imported. +* `columnencryption` or `column encryption setting` - a boolean value indicating whether Always Encrypted should be enabled on the connection. ### Connection parameters for namedpipe package * `pipe` - If set, no Browser query is made and named pipe used will be `\\\pipe\` @@ -374,8 +375,51 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na // Note: Mismatched data types on table and parameter may cause long running queries ``` +## Using Always Encrypted + +The protocol and cryptography details for AE are [detailed elsewhere](https://learn.microsoft.com/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver16). + +### Enablement + +To enable AE on a connection, set the `ColumnEncryption` value to true on a config or pass `columnencryption=true` in the connection string. + +Decryption and encryption won't succeed, however, without also including a decryption key provider. To avoid code size impacts on non-AE applications, key providers are not included by default. + +Include the local certificate providers: + +```go + import ( + "github.com/microsoft/go-mssqldb/aecmk/localcert" + ) + ``` + +You can also instantiate a key provider directly in code and hand it to a `Connector` instance. + +```go +c := mssql.NewConnectorConfig(myconfig) +c.RegisterCekProvider(providerName, MyProviderType{}) +``` + +### Decryption + +If the correct key provider is included in your application, decryption of encrypted cells happens automatically with no extra server round trips. + +### Encryption + +Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`. + +### Local certificate AE key provider + +Key provider configuration is managed separately without any properties in the connection string. +The `pfx` provider exposes its instance as the variable `PfxKeyProvider`. You can give it passwords for certificates using `SetCertificatePassword(pathToCertificate, path)`. Use an empty string or `"*"` as the path to use the same password for all certificates. + +The `MSSQL_CERTIFICATE_STORE` provider exposes its instance as the variable `WindowsCertificateStoreKeyProvider`. + +Both providers can be constrained to an allowed list of encryption key paths by appending paths to `provider.AllowedLocations`. + ## Important Notes + * [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should not be used with this driver (or SQL Server) due to how the TDS protocol works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) @@ -406,7 +450,9 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na * A `namedpipe` package to support connections using named pipes (np:) on Windows * A `sharedmemory` package to support connections using shared memory (lpc:) on Windows * Dedicated Administrator Connection (DAC) is supported using `admin` protocol - +* Always Encrypted + - `MSSQL_CERTIFICATE_STORE` provider on Windows + - `pfx` provider on Linux and Windows ## Tests `go test` is used for testing. A running instance of MSSQL server is required. @@ -446,6 +492,7 @@ To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. More information: +* Bulk copy does not yet support encrypting column values using Always Encrypted. Tracked in [#127](https://github.com/microsoft/go-mssqldb/issues/127) # Contributing This project is a fork of [https://github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) and welcomes new and previous contributors. For more informaton on contributing to this project, please see [Contributing](./CONTRIBUTING.md). diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go index b38cccd6..a46dc3d7 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go @@ -34,4 +34,4 @@ func From(encType byte) Type { return Randomized } return Plaintext -} \ No newline at end of file +} From 45896d34abdc572ff138aa2499e26d29a8ffb1e7 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 11 Jul 2023 14:09:39 -0500 Subject: [PATCH 23/47] Remove allocations from unmarshalRSA --- internal/certs/certs_windows.go | 91 +++++++++++---------------------- 1 file changed, 31 insertions(+), 60 deletions(-) diff --git a/internal/certs/certs_windows.go b/internal/certs/certs_windows.go index 757db289..ea786174 100644 --- a/internal/certs/certs_windows.go +++ b/internal/certs/certs_windows.go @@ -1,10 +1,8 @@ package certs import ( - "bytes" "crypto/rsa" "crypto/x509" - "encoding/binary" "errors" "fmt" "math/big" @@ -124,23 +122,20 @@ func nCryptExportKey(kh windows.Handle, blobType string) ([]byte, error) { return buf, nil } -// TODO: See if we can rewrite this to avoid copying the data from buf twice per field +type RSA_HEADER struct { + Magic uint32 + BitLength uint32 + PublicExpSize uint32 + ModulusSize uint32 + Prime1Size uint32 + Prime2Size uint32 +} + func unmarshalRSA(buf []byte) (*rsa.PrivateKey, error) { // BCRYPT_RSA_BLOB -- https://learn.microsoft.com/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob - header := struct { - Magic uint32 - BitLength uint32 - PublicExpSize uint32 - ModulusSize uint32 - Prime1Size uint32 - Prime2Size uint32 - }{} - - r := bytes.NewReader(buf) - if err := binary.Read(r, binary.LittleEndian, &header); err != nil { - return nil, err - } - + cbHeader := uint32(unsafe.Sizeof(RSA_HEADER{})) + header := (*(*RSA_HEADER)(unsafe.Pointer(&buf[0]))) + buf = buf[cbHeader:] if header.Magic != 0x33415352 { // "RSA3" BCRYPT_RSAFULLPRIVATE_MAGIC return nil, fmt.Errorf("invalid header magic %x", header.Magic) } @@ -149,54 +144,30 @@ func unmarshalRSA(buf []byte) (*rsa.PrivateKey, error) { return nil, fmt.Errorf("unsupported public exponent size (%d bits)", header.PublicExpSize*8) } - // the exponent is in BigEndian format, so read the data into the right place in the buffer - exp := make([]byte, 8) - n, err := r.Read(exp[8-header.PublicExpSize:]) - - if err != nil { - return nil, fmt.Errorf("failed to read public exponent %w", err) - } - - if n != int(header.PublicExpSize) { - return nil, fmt.Errorf("failed to read correct public exponent size, read %d expected %d", n, int(header.PublicExpSize)) - } - - mod := make([]byte, header.ModulusSize) - n, err = r.Read(mod) - - if err != nil { - return nil, fmt.Errorf("failed to read modulus %w", err) - } - - if n != int(header.ModulusSize) { - return nil, fmt.Errorf("failed to read correct modulus size, read %d expected %d", n, int(header.ModulusSize)) + consumeBigInt := func(size uint32) *big.Int { + b := buf[:size] + buf = buf[size:] + return new(big.Int).SetBytes(b) } + E := consumeBigInt(header.PublicExpSize) + N := consumeBigInt(header.ModulusSize) + P := consumeBigInt(header.Prime1Size) + Q := consumeBigInt(header.Prime2Size) + Dp := consumeBigInt(header.Prime1Size) + Dq := consumeBigInt(header.Prime2Size) + Qinv := consumeBigInt(header.Prime1Size) + D := consumeBigInt(header.ModulusSize) pk := &rsa.PrivateKey{ PublicKey: rsa.PublicKey{ - N: new(big.Int).SetBytes(mod), - E: int(binary.BigEndian.Uint64(exp)), + N: N, + E: int(E.Int64()), + }, + D: D, + Primes: []*big.Int{P, Q}, + Precomputed: rsa.PrecomputedValues{Dp: Dp, + Dq: Dq, Qinv: Qinv, }, - D: new(big.Int), - Primes: make([]*big.Int, 2), - } - prime := make([]byte, header.Prime1Size) - n, err = r.Read(prime) - if err != nil { - return nil, fmt.Errorf("failed to read prime1 %w", err) - } - pk.Primes[0] = new(big.Int).SetBytes(prime) - prime = make([]byte, header.Prime2Size) - n, err = r.Read(prime) - if err != nil { - return nil, fmt.Errorf("failed to read prime2 %w", err) - } - pk.Primes[1] = new(big.Int).SetBytes(prime) - expBytes := make([]byte, 2*header.Prime1Size+header.Prime2Size+header.ModulusSize) - n, err = r.Read(expBytes) - if err != nil { - return nil, fmt.Errorf("Unable to read PrivateExponent %w", err) } - pk.D = new(big.Int).SetBytes(expBytes[2*header.Prime1Size+header.Prime2Size:]) return pk, nil } From 0dc231bc2074b6a40bbb4797778de2e244d08d5a Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 11 Jul 2023 17:23:06 -0500 Subject: [PATCH 24/47] remove go-winio dependency --- go.mod | 1 - go.sum | 4 ---- internal/certs/certs.go | 10 ++++------ 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index 462a2650..1f42165a 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.13 require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 - github.com/Microsoft/go-winio v0.6.1 github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 github.com/golang-sql/sqlexp v0.1.0 github.com/jcmturner/gokrb5/v8 v8.4.4 diff --git a/go.sum b/go.sum index 6090bc1d..c99a24c5 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInm github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= -github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= -github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -50,7 +48,6 @@ github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -89,7 +86,6 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/certs/certs.go b/internal/certs/certs.go index dfa9a969..8646266a 100644 --- a/internal/certs/certs.go +++ b/internal/certs/certs.go @@ -3,10 +3,11 @@ package certs import ( "bytes" "fmt" + "math/big" "os/exec" "strings" - "github.com/Microsoft/go-winio/pkg/guid" + "crypto/rand" ) // TODO: Create a Linux equivalent. @@ -16,11 +17,8 @@ const ( ) func ProvisionMasterKeyInCertStore() (thumbprint string, err error) { - var g guid.GUID - if g, err = guid.NewV4(); err != nil { - return - } - subject := fmt.Sprintf(`gomssqltest-%s`, g.String()) + x, _ := rand.Int(rand.Reader, big.NewInt(50000)) + subject := fmt.Sprintf(`gomssqltest-%d`, x) cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(createUserCertScript, subject)) buf := &memoryBuffer{buf: new(bytes.Buffer)} From 83e98f16e41517922aaf068118b13f2dfa58978d Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 12 Jul 2023 17:04:05 -0500 Subject: [PATCH 25/47] fix Scan to use correct data types --- alwaysencrypted_windows_test.go | 7 +++++ mssql.go | 16 +++++----- tds.go | 7 +++++ token.go | 19 +++++++++++- types.go | 54 ++++++++++++++++++--------------- 5 files changed, 70 insertions(+), 33 deletions(-) diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index cf868065..0c6e0a58 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -58,6 +58,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) { rows.Close() t.Fatalf("rows.Next returned false") } + cols, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("rows.ColumnTypes failed %s", err.Error()) + } + if cols[0].DatabaseTypeName() != "INT" { + t.Fatalf("Got wrong type name for intcol %s", cols[0].DatabaseTypeName()) + } var col1 int32 var col2 string err = rows.Scan(&col1, &col2) diff --git a/mssql.go b/mssql.go index b1dd71cc..34d02653 100644 --- a/mssql.go +++ b/mssql.go @@ -910,7 +910,7 @@ func (rc *Rows) NextResultSet() error { // the value type that can be used to scan types into. For example, the database // column type "bigint" this should return "reflect.TypeOf(int64(0))". func (r *Rows) ColumnTypeScanType(index int) reflect.Type { - return makeGoLangScanType(r.cols[index].ti) + return makeGoLangScanType(r.cols[index].originalTypeInfo()) } // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the @@ -919,7 +919,7 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type { // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", // "TIMESTAMP". func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - return makeGoLangTypeName(r.cols[index].ti) + return makeGoLangTypeName(r.cols[index].originalTypeInfo()) } // RowsColumnTypeLength may be implemented by Rows. It should return the length @@ -935,7 +935,7 @@ func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { // int (0, false) // bytea(30) (30, true) func (r *Rows) ColumnTypeLength(index int) (int64, bool) { - return makeGoLangTypeLength(r.cols[index].ti) + return makeGoLangTypeLength(r.cols[index].originalTypeInfo()) } // It should return @@ -946,7 +946,7 @@ func (r *Rows) ColumnTypeLength(index int) (int64, bool) { // int (0, 0, false) // decimal (math.MaxInt64, math.MaxInt64, true) func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { - return makeGoLangTypePrecisionScale(r.cols[index].ti) + return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo()) } // The nullable value should @@ -1358,7 +1358,7 @@ scan: // the value type that can be used to scan types into. For example, the database // column type "bigint" this should return "reflect.TypeOf(int64(0))". func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type { - return makeGoLangScanType(r.cols[index].ti) + return makeGoLangScanType(r.cols[index].originalTypeInfo()) } // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the @@ -1367,7 +1367,7 @@ func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type { // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", // "TIMESTAMP". func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string { - return makeGoLangTypeName(r.cols[index].ti) + return makeGoLangTypeName(r.cols[index].originalTypeInfo()) } // RowsColumnTypeLength may be implemented by Rows. It should return the length @@ -1383,7 +1383,7 @@ func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string { // int (0, false) // bytea(30) (30, true) func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { - return makeGoLangTypeLength(r.cols[index].ti) + return makeGoLangTypeLength(r.cols[index].originalTypeInfo()) } // It should return @@ -1394,7 +1394,7 @@ func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { // int (0, 0, false) // decimal (math.MaxInt64, math.MaxInt64, true) func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) { - return makeGoLangTypePrecisionScale(r.cols[index].ti) + return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo()) } // The nullable value should diff --git a/tds.go b/tds.go index dfa4a6aa..891630c6 100644 --- a/tds.go +++ b/tds.go @@ -201,6 +201,13 @@ func isEncryptedFlag(flags uint16) bool { return colFlagEncrypted == (flags & colFlagEncrypted) } +func (c columnStruct) originalTypeInfo() typeInfo { + if c.isEncrypted() { + return c.cryptoMeta.typeInfo + } + return c.ti +} + type keySlice []uint8 func (p keySlice) Len() int { return len(p) } diff --git a/token.go b/token.go index 694bd3e7..323ddbd7 100644 --- a/token.go +++ b/token.go @@ -821,7 +821,9 @@ func (R RWCBuffer) Close() error { func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer { encType := encryption.From(column.cryptoMeta.encType) cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] - s.logger.Log(context.Background(), msdsn.LogMessages, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName)) + if (s.logFlags & uint64(msdsn.LogDebug)) == uint64(msdsn.LogDebug) { + s.logger.Log(context.Background(), msdsn.LogDebug, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName)) + } cekProvider, ok := s.aeSettings.keyProviders[cekValue.keyStoreName] if !ok { @@ -838,6 +840,10 @@ func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{} panic(err) } + // Decrypt returns a minimum of 8 bytes so truncate to the actual data size + if column.cryptoMeta.typeInfo.Size > 0 && column.cryptoMeta.typeInfo.Size < len(d) { + d = d[:column.cryptoMeta.typeInfo.Size] + } var newBuff []byte newBuff = append(newBuff, d...) @@ -936,6 +942,17 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if sess.logFlags&logErrors != 0 { sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Intercepted panic %v", err)) } + if outs.msgq != nil { + var derr error + switch e := err.(type) { + case error: + derr = e + default: + derr = fmt.Errorf("Unhandled session error %v", e) + } + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: derr}) + + } ch <- err } close(ch) diff --git a/types.go b/types.go index e0af243c..24cc4077 100644 --- a/types.go +++ b/types.go @@ -671,40 +671,46 @@ func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} // partially length prefixed stream // http://msdn.microsoft.com/en-us/library/dd340469.aspx func readPLPType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { - size := r.uint64() - var buf *bytes.Buffer - switch size { - case _PLP_NULL: - // null - return nil - case _UNKNOWN_PLP_LEN: - // size unknown - buf = bytes.NewBuffer(make([]byte, 0, 1000)) - default: - buf = bytes.NewBuffer(make([]byte, 0, size)) - } - for { - chunksize := r.uint32() - if chunksize == 0 { - break + var bytesToDecode []byte + if c == nil { + size := r.uint64() + var buf *bytes.Buffer + switch size { + case _PLP_NULL: + // null + return nil + case _UNKNOWN_PLP_LEN: + // size unknown + buf = bytes.NewBuffer(make([]byte, 0, 1000)) + default: + buf = bytes.NewBuffer(make([]byte, 0, size)) } - if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil { - badStreamPanicf("Reading PLP type failed: %s", err.Error()) + for { + chunksize := r.uint32() + if chunksize == 0 { + break + } + if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil { + badStreamPanicf("Reading PLP type failed: %s", err.Error()) + } } + bytesToDecode = buf.Bytes() + } else { + bytesToDecode = r.rbuf } switch ti.TypeId { case typeXml: - return decodeXml(*ti, buf.Bytes()) + return decodeXml(*ti, bytesToDecode) case typeBigVarChar, typeBigChar, typeText: - return decodeChar(ti.Collation, buf.Bytes()) + return decodeChar(ti.Collation, bytesToDecode) case typeBigVarBin, typeBigBinary, typeImage: - return buf.Bytes() + return bytesToDecode case typeNVarChar, typeNChar, typeNText: - return decodeNChar(buf.Bytes()) + return decodeNChar(bytesToDecode) case typeUdt: - return decodeUdt(*ti, buf.Bytes()) + return decodeUdt(*ti, bytesToDecode) } - panic("shoulnd't get here") + panic("shouldn't get here") } func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { From 039dc2848acc901f31ac4ab4738d489a2c4bb22d Mon Sep 17 00:00:00 2001 From: davidshi Date: Thu, 13 Jul 2023 22:33:48 -0500 Subject: [PATCH 26/47] fix encryption of more types --- README.md | 5 ++ alwaysencrypted_windows_test.go | 140 +++++++++++++++++++++++++++++--- encrypt.go | 6 +- mssql.go | 10 ++- mssql_go19.go | 8 +- 5 files changed, 153 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 75ea0a20..5493a128 100644 --- a/README.md +++ b/README.md @@ -408,6 +408,11 @@ If the correct key provider is included in your application, decryption of encry Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`. +*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`. +https://github.com/microsoft/go-mssqldb/issues/129 +https://github.com/microsoft/go-mssqldb/issues/130 + + ### Local certificate AE key provider Key provider configuration is managed separately without any properties in the connection string. diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index 0c6e0a58..69462f7e 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -2,14 +2,47 @@ package mssql import ( "crypto/rand" + "database/sql" "fmt" "math/big" + "strings" "testing" + "time" + "github.com/golang-sql/civil" "github.com/microsoft/go-mssqldb/aecmk/localcert" "github.com/microsoft/go-mssqldb/internal/certs" ) +// Define phrases for create table for each enryptable data type along with sample data for insertion and validation +type aeColumnInfo struct { + queryPhrase string + sqlDataType string + encType ColumnEncryptionType + sampleValue interface{} +} + +var encryptableColumns = []aeColumnInfo{ + {"int", "INT", ColumnEncryptionDeterministic, int32(1)}, + {"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")}, + {"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)}, + {"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)}, + {"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)}, + // We can't use fractional float/real values due to rounding errors in the round trip + {"real", "REAL", ColumnEncryptionDeterministic, float32(5)}, + {"float", "FLOAT", ColumnEncryptionRandomized, float64(6)}, + {"varbinary(10)", "VARBINARY", ColumnEncryptionDeterministic, []byte{1, 2, 3, 4}}, + // TODO: Varchar support requires proper selection of a collation and conversion + // {"varchar(10) COLLATE Latin1_General_BIN2", "VARCHAR", ColumnEncryptionRandomized, VarChar("varcharval")}, + {"nvarchar(30)", "NVARCHAR", ColumnEncryptionRandomized, "nvarcharval"}, + {"bit", "BIT", ColumnEncryptionDeterministic, true}, + {"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, time.Now()}, + {"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(time.Now())}, + {"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")}, + // TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that. + // {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, +} + func TestAlwaysEncryptedE2E(t *testing.T) { params := testConnParams(t) if !params.ColumnEncryption { @@ -41,16 +74,47 @@ func TestAlwaysEncryptedE2E(t *testing.T) { } defer conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName)) _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) - _, err = conn.Exec(fmt.Sprintf(createEncryptedTable, tableName, cekName, cekName)) + query := new(strings.Builder) + insert := new(strings.Builder) + sel := new(strings.Builder) + _, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName)) + _, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName)) + _, _ = sel.WriteString("select top(1) ") + insertArgs := make([]interface{}, len(encryptableColumns)+1) + for i, ec := range encryptableColumns { + encType := "RANDOMIZED" + null := "" + _, ok := ec.sampleValue.(sql.NullInt32) + if ok { + null = "NULL" + } + if ec.encType == ColumnEncryptionDeterministic { + encType = "DETERMINISTIC" + } + _, _ = query.WriteString(fmt.Sprintf(`col%d %s ENCRYPTED WITH (ENCRYPTION_TYPE = %s, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]) %s, + `, i, ec.queryPhrase, encType, cekName, null)) + + insertArgs[i] = ec.sampleValue + insert.WriteString(fmt.Sprintf("@p%d,", i+1)) + sel.WriteString(fmt.Sprintf("col%d,", i)) + } + _, _ = query.WriteString("unencryptedcolumn nvarchar(100)") + _, _ = query.WriteString(")") + insertArgs[len(encryptableColumns)] = "unencryptedvalue" + insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1)) + sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName)) + _, err = conn.Exec(query.String()) if err != nil { t.Fatalf("Failed to create encrypted table %s", err.Error()) } defer conn.Exec("DROP TABLE IF EXISTS " + tableName) - _, err = conn.Exec("INSERT INTO "+tableName+" VALUES (@p1, @p2)", int32(1), NChar("mycol2")) + _, err = conn.Exec(insert.String(), insertArgs...) if err != nil { - t.Logf("Failed to insert row in encrypted table %s", err.Error()) + t.Fatalf("Failed to insert row in encrypted table %s", err.Error()) } - rows, err := conn.Query("select top (1) col1, col2 from " + tableName) + rows, err := conn.Query(sel.String()) if err != nil { t.Fatalf("Unable to query encrypted columns: %v", err.(Error).All) } @@ -62,19 +126,48 @@ func TestAlwaysEncryptedE2E(t *testing.T) { if err != nil { t.Fatalf("rows.ColumnTypes failed %s", err.Error()) } - if cols[0].DatabaseTypeName() != "INT" { - t.Fatalf("Got wrong type name for intcol %s", cols[0].DatabaseTypeName()) + for i := range encryptableColumns { + + if cols[i].DatabaseTypeName() != encryptableColumns[i].sqlDataType { + t.Fatalf("Got wrong type name for col%d. Expected: %s, Got:%s", i, encryptableColumns[i].sqlDataType, cols[i].DatabaseTypeName()) + } + } + + var unencryptedColumnValue string + scanValues := make([]interface{}, len(encryptableColumns)+1) + for v := range scanValues { + if v < len(encryptableColumns) { + scanValues[v] = new(interface{}) + } } - var col1 int32 - var col2 string - err = rows.Scan(&col1, &col2) + scanValues[len(encryptableColumns)] = &unencryptedColumnValue + err = rows.Scan(scanValues...) if err != nil { rows.Close() t.Fatalf("rows.Scan failed: %s", err.Error()) } - if col1 != 1 || col2 != "mycol2" { - rows.Close() - t.Fatalf("Got incorrect scan values %d and %s", col1, col2) + + for i := range encryptableColumns { + var strVal string + var expectedStrVal string + if encryptableColumns[i].sampleValue == nil { + expectedStrVal = "NULL" + } else { + expectedStrVal = comparisonValueFromObject(encryptableColumns[i].sampleValue) + } + rawVal := scanValues[i].(*interface{}) + + if rawVal == nil { + strVal = "NULL" + } else { + strVal = comparisonValueFromObject(*rawVal) + } + if expectedStrVal != strVal { + t.Fatalf("Incorrect value for col%d. Expected:%s, Got:%s", i, expectedStrVal, strVal) + } + } + if unencryptedColumnValue != "unencryptedvalue" { + t.Fatalf("Got wrong value for unencrypted column: %s", unencryptedColumnValue) } rows.Close() err = rows.Err() @@ -83,6 +176,29 @@ func TestAlwaysEncryptedE2E(t *testing.T) { } } +func comparisonValueFromObject(object interface{}) string { + switch v := object.(type) { + case []byte: + { + return string(v) + } + case string: + return v + case time.Time: + return civil.DateTimeOf(v).String() + //return v.Format(time.RFC3339) + case fmt.Stringer: + return v.String() + case bool: + if v == true { + return "1" + } + return "0" + default: + return fmt.Sprintf("%v", v) + } +} + const ( createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= 'MSSQL_CERTIFICATE_STORE', KEY_PATH='%s')` dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` diff --git a/encrypt.go b/encrypt.go index 86add1e8..91dca378 100644 --- a/encrypt.go +++ b/encrypt.go @@ -86,7 +86,11 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg } paramMap := make(map[string]paramMapEntry) for _, p := range paramsInfo { - paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], p} + if p.encType == ColumnEncryptionPlainText { + paramMap[p.name] = paramMapEntry{nil, p} + } else { + paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], p} + } } encryptedArgs = make([]namedValue, len(args)) for i, a := range args { diff --git a/mssql.go b/mssql.go index 34d02653..3af20bc3 100644 --- a/mssql.go +++ b/mssql.go @@ -1012,12 +1012,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.ti.TypeId = typeIntN res.ti.Size = 8 res.buffer = []byte{} - + case byte: + res.ti.TypeId = typeIntN + res.buffer = []byte{val} + res.ti.Size = 1 case float64: res.ti.TypeId = typeFltN res.ti.Size = 8 res.buffer = make([]byte, 8) binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val)) + case float32: + res.ti.TypeId = typeFltN + res.ti.Size = 4 + res.buffer = make([]byte, 4) + binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(val)) case sql.NullFloat64: // only null values should be getting here res.ti.TypeId = typeFltN diff --git a/mssql_go19.go b/mssql_go19.go index 6359c39a..b0285eef 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -51,6 +51,8 @@ func convertInputParameter(val interface{}) (interface{}, error) { switch v := val.(type) { case int, int16, int32, int64, int8: return val, nil + case byte: + return val, nil case VarChar: return val, nil case NVarCharMax: @@ -69,8 +71,10 @@ func convertInputParameter(val interface{}) (interface{}, error) { return val, nil case civil.Time: return val, nil - // case *apd.Decimal: - // return nil + // case *apd.Decimal: + // return nil + case float32: + return val, nil default: return driver.DefaultParameterConverter.ConvertValue(v) } From 83edee4efb731d0313c549ce3d573cdb2176c40c Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 11:04:26 -0500 Subject: [PATCH 27/47] try to fix appveyor build --- appveyor.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/appveyor.yml b/appveyor.yml index 4b90a402..a43b8b4f 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,6 +71,7 @@ install: - go get -u github.com/golang-sql/sqlexp - go get -u golang.org/x/crypto/md4 - go get -u golang.org/x/text/encoding/unicode + - go get -u golang.org/x/sys/windows build_script: - go build From 0ae9b2dab51d170e14225899b1e4b5a1c670672d Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 11:09:55 -0500 Subject: [PATCH 28/47] mute test --- .../swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go index 33e86125..093ec64e 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go @@ -97,7 +97,7 @@ func TestDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - fmt.Printf("column value: \"%02X\"", cleartextUtf8) + t.Logf("column value: \"%02X\"", cleartextUtf8) assert.Equal(t, "12345 ", string(cleartextUtf8)) } func TestDecryptCEK(t *testing.T) { From 689434d57ce04ee4b71ed7b0fe6221a603196a42 Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 14:32:55 -0500 Subject: [PATCH 29/47] make cert store provider go1.17+ --- aecmk/localcert/keyprovider_windows.go | 46 ++-------------- aecmk/localcert/keyprovider_windows_go117.go | 53 +++++++++++++++++++ ...t.go => keyprovider_windows_go117_test.go} | 0 3 files changed, 57 insertions(+), 42 deletions(-) create mode 100644 aecmk/localcert/keyprovider_windows_go117.go rename aecmk/localcert/{keyprovider_windows_test.go => keyprovider_windows_go117_test.go} (100%) diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index cf03a397..144cb5c6 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -1,51 +1,13 @@ +//go:build !go1.17 + package localcert import ( "crypto/x509" "fmt" - "strings" - "unsafe" - - "github.com/microsoft/go-mssqldb/aecmk" - "github.com/microsoft/go-mssqldb/internal/certs" - "golang.org/x/sys/windows" ) -var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} - -func init() { - aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) -} - func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { - privateKey = nil - cert = nil - pathParts := strings.Split(path, `/`) - if len(pathParts) != 3 { - panic(invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments"))) - } - - var storeId uint32 - switch strings.ToLower(pathParts[0]) { - case "localmachine": - storeId = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE - case "currentuser": - storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER - default: - panic(invalidCertificatePath(path, fmt.Errorf("Unknown certificate store"))) - } - system, err := windows.UTF16PtrFromString(pathParts[1]) - if err != nil { - panic(err) - } - h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, - windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, - 0, - storeId, uintptr(unsafe.Pointer(system))) - if err != nil { - panic(err) - } - defer windows.CertCloseStore(h, 0) - signature := thumbprintToByteArray(pathParts[2]) - return certs.FindCertBySignatureHash(h, signature) + panic(fmt.Errorf("Windows cert store not supported until Go 1.17")) + return } diff --git a/aecmk/localcert/keyprovider_windows_go117.go b/aecmk/localcert/keyprovider_windows_go117.go new file mode 100644 index 00000000..7a67073b --- /dev/null +++ b/aecmk/localcert/keyprovider_windows_go117.go @@ -0,0 +1,53 @@ +//go:build go1.17 + +package localcert + +import ( + "crypto/x509" + "fmt" + "strings" + "unsafe" + + "github.com/microsoft/go-mssqldb/aecmk" + "github.com/microsoft/go-mssqldb/internal/certs" + "golang.org/x/sys/windows" +) + +var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} + +func init() { + aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) +} + +func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + privateKey = nil + cert = nil + pathParts := strings.Split(path, `/`) + if len(pathParts) != 3 { + panic(invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments"))) + } + + var storeId uint32 + switch strings.ToLower(pathParts[0]) { + case "localmachine": + storeId = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE + case "currentuser": + storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER + default: + panic(invalidCertificatePath(path, fmt.Errorf("Unknown certificate store"))) + } + system, err := windows.UTF16PtrFromString(pathParts[1]) + if err != nil { + panic(err) + } + h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, + windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, + 0, + storeId, uintptr(unsafe.Pointer(system))) + if err != nil { + panic(err) + } + defer windows.CertCloseStore(h, 0) + signature := thumbprintToByteArray(pathParts[2]) + return certs.FindCertBySignatureHash(h, signature) +} diff --git a/aecmk/localcert/keyprovider_windows_test.go b/aecmk/localcert/keyprovider_windows_go117_test.go similarity index 100% rename from aecmk/localcert/keyprovider_windows_test.go rename to aecmk/localcert/keyprovider_windows_go117_test.go From 396a5dde3b495a907e2860ba0032f56abc583fef Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 16:33:13 -0500 Subject: [PATCH 30/47] fix build directives --- ...yprovider_windows_go117.go => keyprovider_go117_windows.go} | 3 ++- ...windows_go117_test.go => keyprovider_go117_windows_test.go} | 3 +++ aecmk/localcert/keyprovider_windows.go | 1 + internal/certs/certs_windows.go | 3 +++ 4 files changed, 9 insertions(+), 1 deletion(-) rename aecmk/localcert/{keyprovider_windows_go117.go => keyprovider_go117_windows.go} (97%) rename aecmk/localcert/{keyprovider_windows_go117_test.go => keyprovider_go117_windows_test.go} (98%) diff --git a/aecmk/localcert/keyprovider_windows_go117.go b/aecmk/localcert/keyprovider_go117_windows.go similarity index 97% rename from aecmk/localcert/keyprovider_windows_go117.go rename to aecmk/localcert/keyprovider_go117_windows.go index 7a67073b..ebd63a40 100644 --- a/aecmk/localcert/keyprovider_windows_go117.go +++ b/aecmk/localcert/keyprovider_go117_windows.go @@ -1,4 +1,5 @@ -//go:build go1.17 +//go:build 1.17 +// +build 1.17 package localcert diff --git a/aecmk/localcert/keyprovider_windows_go117_test.go b/aecmk/localcert/keyprovider_go117_windows_test.go similarity index 98% rename from aecmk/localcert/keyprovider_windows_go117_test.go rename to aecmk/localcert/keyprovider_go117_windows_test.go index 95932f03..834f6de0 100644 --- a/aecmk/localcert/keyprovider_windows_go117_test.go +++ b/aecmk/localcert/keyprovider_go117_windows_test.go @@ -1,3 +1,6 @@ +//go:build 1.17 +// +build 1.17 + package localcert import ( diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index 144cb5c6..1d6bdee9 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -1,4 +1,5 @@ //go:build !go1.17 +// +build !go1.17 package localcert diff --git a/internal/certs/certs_windows.go b/internal/certs/certs_windows.go index ea786174..5577dd1a 100644 --- a/internal/certs/certs_windows.go +++ b/internal/certs/certs_windows.go @@ -1,3 +1,6 @@ +//go:build go1.17 +// +build go1.17 + package certs import ( From b4ab997711ae3d96434539eb17f762cb25e3a9d7 Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 16:36:18 -0500 Subject: [PATCH 31/47] rename files for clarity --- aecmk/localcert/keyprovider_go117_windows.go | 54 ------------------- .../localcert/keyprovider_prego117_windows.go | 14 +++++ aecmk/localcert/keyprovider_windows.go | 48 +++++++++++++++-- 3 files changed, 58 insertions(+), 58 deletions(-) delete mode 100644 aecmk/localcert/keyprovider_go117_windows.go create mode 100644 aecmk/localcert/keyprovider_prego117_windows.go diff --git a/aecmk/localcert/keyprovider_go117_windows.go b/aecmk/localcert/keyprovider_go117_windows.go deleted file mode 100644 index ebd63a40..00000000 --- a/aecmk/localcert/keyprovider_go117_windows.go +++ /dev/null @@ -1,54 +0,0 @@ -//go:build 1.17 -// +build 1.17 - -package localcert - -import ( - "crypto/x509" - "fmt" - "strings" - "unsafe" - - "github.com/microsoft/go-mssqldb/aecmk" - "github.com/microsoft/go-mssqldb/internal/certs" - "golang.org/x/sys/windows" -) - -var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} - -func init() { - aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) -} - -func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { - privateKey = nil - cert = nil - pathParts := strings.Split(path, `/`) - if len(pathParts) != 3 { - panic(invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments"))) - } - - var storeId uint32 - switch strings.ToLower(pathParts[0]) { - case "localmachine": - storeId = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE - case "currentuser": - storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER - default: - panic(invalidCertificatePath(path, fmt.Errorf("Unknown certificate store"))) - } - system, err := windows.UTF16PtrFromString(pathParts[1]) - if err != nil { - panic(err) - } - h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, - windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, - 0, - storeId, uintptr(unsafe.Pointer(system))) - if err != nil { - panic(err) - } - defer windows.CertCloseStore(h, 0) - signature := thumbprintToByteArray(pathParts[2]) - return certs.FindCertBySignatureHash(h, signature) -} diff --git a/aecmk/localcert/keyprovider_prego117_windows.go b/aecmk/localcert/keyprovider_prego117_windows.go new file mode 100644 index 00000000..1d6bdee9 --- /dev/null +++ b/aecmk/localcert/keyprovider_prego117_windows.go @@ -0,0 +1,14 @@ +//go:build !go1.17 +// +build !go1.17 + +package localcert + +import ( + "crypto/x509" + "fmt" +) + +func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + panic(fmt.Errorf("Windows cert store not supported until Go 1.17")) + return +} diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index 1d6bdee9..ebd63a40 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -1,14 +1,54 @@ -//go:build !go1.17 -// +build !go1.17 +//go:build 1.17 +// +build 1.17 package localcert import ( "crypto/x509" "fmt" + "strings" + "unsafe" + + "github.com/microsoft/go-mssqldb/aecmk" + "github.com/microsoft/go-mssqldb/internal/certs" + "golang.org/x/sys/windows" ) +var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} + +func init() { + aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) +} + func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { - panic(fmt.Errorf("Windows cert store not supported until Go 1.17")) - return + privateKey = nil + cert = nil + pathParts := strings.Split(path, `/`) + if len(pathParts) != 3 { + panic(invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments"))) + } + + var storeId uint32 + switch strings.ToLower(pathParts[0]) { + case "localmachine": + storeId = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE + case "currentuser": + storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER + default: + panic(invalidCertificatePath(path, fmt.Errorf("Unknown certificate store"))) + } + system, err := windows.UTF16PtrFromString(pathParts[1]) + if err != nil { + panic(err) + } + h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, + windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, + 0, + storeId, uintptr(unsafe.Pointer(system))) + if err != nil { + panic(err) + } + defer windows.CertCloseStore(h, 0) + signature := thumbprintToByteArray(pathParts[2]) + return certs.FindCertBySignatureHash(h, signature) } From 25a5ebf71a5c27e546d8c8411b8dd5667f50fae6 Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 16:58:46 -0500 Subject: [PATCH 32/47] fix typo --- aecmk/localcert/keyprovider_windows.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index ebd63a40..cce35f29 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -1,5 +1,5 @@ -//go:build 1.17 -// +build 1.17 +//go:build go1.17 +// +build go1.17 package localcert From 643b7f1b61e010b4ab7ffcc12b6114211622600e Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 17:12:42 -0500 Subject: [PATCH 33/47] fix test file directive --- aecmk/localcert/keyprovider_go117_windows_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aecmk/localcert/keyprovider_go117_windows_test.go b/aecmk/localcert/keyprovider_go117_windows_test.go index 834f6de0..28d2d2fa 100644 --- a/aecmk/localcert/keyprovider_go117_windows_test.go +++ b/aecmk/localcert/keyprovider_go117_windows_test.go @@ -1,5 +1,5 @@ -//go:build 1.17 -// +build 1.17 +//go:build go1.17 +// +build go1.17 package localcert From 2c946e455b8508a4cb2d8232a4291e73f6d85dad Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 17 Jul 2023 17:21:05 -0500 Subject: [PATCH 34/47] skip windows package get in pipeline --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index a43b8b4f..0347ab96 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,7 +71,7 @@ install: - go get -u github.com/golang-sql/sqlexp - go get -u golang.org/x/crypto/md4 - go get -u golang.org/x/text/encoding/unicode - - go get -u golang.org/x/sys/windows + build_script: - go build From 522314e8aa5e5f9ee109686617a5e2b23088eb88 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 18 Jul 2023 09:23:54 -0500 Subject: [PATCH 35/47] fix build --- appveyor.yml | 1 + encrypt_test.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 0347ab96..87993bd6 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,6 +71,7 @@ install: - go get -u github.com/golang-sql/sqlexp - go get -u golang.org/x/crypto/md4 - go get -u golang.org/x/text/encoding/unicode + - go get -u software.sslmate.com/src/go-pkcs12 build_script: - go build diff --git a/encrypt_test.go b/encrypt_test.go index 7cd578b3..6e54a0a9 100644 --- a/encrypt_test.go +++ b/encrypt_test.go @@ -27,7 +27,7 @@ func TestBuildQueryParametersForCE(t *testing.T) { "Input and Output params", []namedValue{ {Name: "", Ordinal: 0, Value: VarChar("somestring")}, - {Name: "c1", Value: 5}, + {Name: "c1", Value: int64(5)}, {Name: "pout", Value: sql.Out{Dest: outparam}}, }, `@p0 varchar(10), @c1 bigint, @pout nvarchar(max) output`, From 5c0dfc9144737598cdaee7569e4a714e48ee2286 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 18 Jul 2023 17:58:35 -0500 Subject: [PATCH 36/47] fix build breaks --- aecmk/localcert/keyprovider.go | 3 +++ aecmk/localcert/keyprovider_darwin.go | 3 +++ aecmk/localcert/keyprovider_linux.go | 3 +++ aecmk/localcert/keyprovider_prego117_windows.go | 14 -------------- aecmk/localcert/keyprovider_test.go | 3 +++ alwaysencrypted_windows_test.go | 3 +++ 6 files changed, 15 insertions(+), 14 deletions(-) delete mode 100644 aecmk/localcert/keyprovider_prego117_windows.go diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index d2f8f201..d189a8ed 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -1,3 +1,6 @@ +//go:build go1.17 +// +build go1.17 + package localcert import ( diff --git a/aecmk/localcert/keyprovider_darwin.go b/aecmk/localcert/keyprovider_darwin.go index a3a0e7d6..03943489 100644 --- a/aecmk/localcert/keyprovider_darwin.go +++ b/aecmk/localcert/keyprovider_darwin.go @@ -1,3 +1,6 @@ +//go:build go1.17 +// +build go1.17 + package localcert import ( diff --git a/aecmk/localcert/keyprovider_linux.go b/aecmk/localcert/keyprovider_linux.go index a3a0e7d6..03943489 100644 --- a/aecmk/localcert/keyprovider_linux.go +++ b/aecmk/localcert/keyprovider_linux.go @@ -1,3 +1,6 @@ +//go:build go1.17 +// +build go1.17 + package localcert import ( diff --git a/aecmk/localcert/keyprovider_prego117_windows.go b/aecmk/localcert/keyprovider_prego117_windows.go deleted file mode 100644 index 1d6bdee9..00000000 --- a/aecmk/localcert/keyprovider_prego117_windows.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !go1.17 -// +build !go1.17 - -package localcert - -import ( - "crypto/x509" - "fmt" -) - -func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { - panic(fmt.Errorf("Windows cert store not supported until Go 1.17")) - return -} diff --git a/aecmk/localcert/keyprovider_test.go b/aecmk/localcert/keyprovider_test.go index c02354af..8b4237cf 100644 --- a/aecmk/localcert/keyprovider_test.go +++ b/aecmk/localcert/keyprovider_test.go @@ -1,3 +1,6 @@ +//go:build go1.17 +// +build go1.17 + package localcert import ( diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go index 69462f7e..20f95c73 100644 --- a/alwaysencrypted_windows_test.go +++ b/alwaysencrypted_windows_test.go @@ -1,3 +1,6 @@ +//go:build go1.17 +// +build go1.17 + package mssql import ( From 44af0a549385694f690c2b2a5b3b1b5e9578fdd9 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 10:27:00 -0500 Subject: [PATCH 37/47] update dependencies and min Go version --- .github/workflows/pr-validation.yml | 4 +- .pipelines/TestSql2017.yml | 2 +- aecmk/localcert/keyprovider.go | 2 +- appveyor.yml | 40 +++++-------------- go.mod | 3 +- go.sum | 6 --- .../aead_aes_256_cbc_hmac_sha256.go | 6 +-- tds_test.go | 3 ++ 8 files changed, 22 insertions(+), 44 deletions(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 8c3a92cb..7963a660 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -10,8 +10,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ['1.16','1.17', '1.18'] - sqlImage: ['2017-latest','2019-latest'] + go: ['1.19','1.20'] + sqlImage: ['2019-latest','2022-latest'] steps: - uses: actions/checkout@v2 - name: Setup go diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 046e3e98..9bdb5303 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -8,7 +8,7 @@ variables: steps: - task: GoTool@0 inputs: - version: '1.16.5' + version: '1.19' - task: Go@0 displayName: 'Go: get sources' inputs: diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index d189a8ed..ec18419c 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -19,8 +19,8 @@ import ( "github.com/microsoft/go-mssqldb/aecmk" ae "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg" + pkcs "golang.org/x/crypto/pkcs12" "golang.org/x/text/encoding/unicode" - pkcs "software.sslmate.com/src/go-pkcs12" ) const ( diff --git a/appveyor.yml b/appveyor.yml index 87993bd6..0a96e1ad 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -11,52 +11,34 @@ environment: SQLUSER: sa SQLPASSWORD: Password12! DATABASE: test - GOVERSION: 113 + GOVERSION: 116 + COLUMNENCRYPTION: + # Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only + APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 RACE: -race -cpu 4 TAGS: matrix: - - GOVERSION: 110 - SQLINSTANCE: SQL2017 - - GOVERSION: 111 - SQLINSTANCE: SQL2017 - - GOVERSION: 112 - SQLINSTANCE: SQL2017 - SQLINSTANCE: SQL2017 - SQLINSTANCE: SQL2016 - SQLINSTANCE: SQL2014 - SQLINSTANCE: SQL2012SP1 - SQLINSTANCE: SQL2008R2SP2 - - # Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 114 - SQLINSTANCE: SQL2019 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 115 - SQLINSTANCE: SQL2019 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 115 - SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 116 + - GOVERSION: 117 SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 117 - SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 118 + - GOVERSION: 118 SQLINSTANCE: SQL2017 + - GOVERSION: 120 + SQLINSTANCE: SQL2019 + COLUMNENCRYPTION: 1 # Cover 32bit and named pipes protocol - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 118-x86 + - GOVERSION: 119-x86 SQLINSTANCE: SQL2017 GOARCH: 386 RACE: PROTOCOL: np TAGS: -tags np # Cover SSPI and lpc protocol - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 118 + - GOVERSION: 120 SQLINSTANCE: SQL2019 PROTOCOL: lpc TAGS: -tags sm diff --git a/go.mod b/go.mod index 1f42165a..4c3ea17a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/microsoft/go-mssqldb -go 1.13 +go 1.16 require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 @@ -12,5 +12,4 @@ require ( golang.org/x/crypto v0.9.0 golang.org/x/sys v0.8.0 golang.org/x/text v0.9.0 - software.sslmate.com/src/go-pkcs12 v0.2.0 ) diff --git a/go.sum b/go.sum index c99a24c5..69dce41d 100644 --- a/go.sum +++ b/go.sum @@ -60,7 +60,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= @@ -70,7 +69,6 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -82,7 +80,6 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -98,7 +95,6 @@ golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= @@ -119,5 +115,3 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -software.sslmate.com/src/go-pkcs12 v0.2.0 h1:nlFkj7bTysH6VkC4fGphtjXRbezREPgrHuJG20hBGPE= -software.sslmate.com/src/go-pkcs12 v0.2.0/go.mod h1:23rNcYsMabIc1otwLpTkCCPwUq6kQsTyowttG/as0kQ= diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go index cb4def7c..d4267ab3 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go @@ -1,8 +1,8 @@ package algorithms import ( - "bytes" "crypto/rand" + "crypto/subtle" "fmt" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto" @@ -98,12 +98,12 @@ func (a *AeadAes256CbcHmac256Algorithm) Decrypt(ciphertext []byte) ([]byte, erro realCiphertext := ciphertext[idx:] ourAuthTag := a.prepareAuthTag(iv, realCiphertext) - if bytes.Compare(ourAuthTag, authTag) != 0 { + // bytes.Compare is subject to timing attacks + if subtle.ConstantTimeCompare(ourAuthTag, authTag) != 1 { return nil, fmt.Errorf("invalid auth tag") } // decrypt - aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv) cleartext := aescdbc.Decrypt(realCiphertext) diff --git a/tds_test.go b/tds_test.go index 5b540e8f..d6103475 100644 --- a/tds_test.go +++ b/tds_test.go @@ -326,6 +326,9 @@ func GetConnParams() (*msdsn.Config, error) { if os.Getenv("PIPE") != "" { c.Parameters["pipe"] = os.Getenv("PIPE") } + if os.Getenv("COLUMNENCRYPTION") != "" { + c.ColumnEncryption = true + } return c, nil } // try loading connection string from file From 3b17cd60da9d70670d5642e03ee8349ef9d34c2c Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 10:57:41 -0500 Subject: [PATCH 38/47] update appveyor --- appveyor.yml | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 0a96e1ad..5085db33 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -13,16 +13,20 @@ environment: DATABASE: test GOVERSION: 116 COLUMNENCRYPTION: - # Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 RACE: -race -cpu 4 + CGO_ENABLED: 1 TAGS: matrix: - SQLINSTANCE: SQL2017 - - SQLINSTANCE: SQL2016 - - SQLINSTANCE: SQL2014 - - SQLINSTANCE: SQL2012SP1 - - SQLINSTANCE: SQL2008R2SP2 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017 + SQLINSTANCE: SQL2016 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017 + SQLINSTANCE: SQL2014 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 + SQLINSTANCE: SQL2012SP1 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 + SQLINSTANCE: SQL2008R2SP2 - GOVERSION: 117 SQLINSTANCE: SQL2017 - GOVERSION: 118 @@ -53,7 +57,6 @@ install: - go get -u github.com/golang-sql/sqlexp - go get -u golang.org/x/crypto/md4 - go get -u golang.org/x/text/encoding/unicode - - go get -u software.sslmate.com/src/go-pkcs12 build_script: - go build From 64a993f7963a4af6e1279952195bb8439b89e3ab Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 11:22:09 -0500 Subject: [PATCH 39/47] try older appveyor image --- appveyor.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 5085db33..03892e9e 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -19,13 +19,13 @@ environment: TAGS: matrix: - SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 SQLINSTANCE: SQL2016 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017 - SQLINSTANCE: SQL2014 - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 + SQLINSTANCE: SQL2014 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2013 SQLINSTANCE: SQL2012SP1 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2013 SQLINSTANCE: SQL2008R2SP2 - GOVERSION: 117 SQLINSTANCE: SQL2017 From 4c42fbf4123f5e313a4f90ec2990a3778867cdef Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 11:45:17 -0500 Subject: [PATCH 40/47] no race on go 1.20 --- appveyor.yml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 03892e9e..2c2bcc43 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -15,23 +15,23 @@ environment: COLUMNENCRYPTION: APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 RACE: -race -cpu 4 - CGO_ENABLED: 1 TAGS: matrix: - SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 + - APPVEYOR_BUILD_WORKER_IMAGE: SQLINSTANCE: SQL2016 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 + - APPVEYOR_BUILD_WORKER_IMAGE: SQLINSTANCE: SQL2014 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2013 + - APPVEYOR_BUILD_WORKER_IMAGE: SQLINSTANCE: SQL2012SP1 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2013 + - APPVEYOR_BUILD_WORKER_IMAGE: SQLINSTANCE: SQL2008R2SP2 - GOVERSION: 117 SQLINSTANCE: SQL2017 - GOVERSION: 118 SQLINSTANCE: SQL2017 - GOVERSION: 120 + RACE: SQLINSTANCE: SQL2019 COLUMNENCRYPTION: 1 # Cover 32bit and named pipes protocol @@ -43,6 +43,7 @@ environment: TAGS: -tags np # Cover SSPI and lpc protocol - GOVERSION: 120 + RACE: SQLINSTANCE: SQL2019 PROTOCOL: lpc TAGS: -tags sm From 8e171eb9a06469cd6eac819a3f330c8fcf5b9418 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 12:44:01 -0500 Subject: [PATCH 41/47] update reviewdog --- .github/workflows/reviewdog.yml | 4 ++-- README.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index 97b897a4..143ab8b1 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,9 +6,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: golangci-lint - uses: reviewdog/action-golangci-lint@v1 + uses: reviewdog/action-golangci-lint@v2 with: level: warning reporter: github-pr-review diff --git a/README.md b/README.md index 120aa48b..b424dd87 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ## Install -Requires Go 1.10 or above. +Requires Go 1.16 or above. Install with `go install github.com/microsoft/go-mssqldb@latest`. From e2e907a669d213500124935b8d715e647e2cbe81 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 12:55:09 -0500 Subject: [PATCH 42/47] fix linter warnings --- aecmk/localcert/keyprovider.go | 27 ++++++++++--------- aecmk/localcert/keyprovider_darwin.go | 2 +- .../keyprovider_go117_windows_test.go | 4 +-- aecmk/localcert/keyprovider_linux.go | 2 +- aecmk/localcert/keyprovider_windows.go | 9 ++++--- .../mssql-always-encrypted/pkg/keys/key.go | 2 +- .../mssql-always-encrypted/pkg/utils/utf16.go | 2 +- tds_test.go | 4 +-- 8 files changed, 29 insertions(+), 23 deletions(-) diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index ec18419c..98d1ca10 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -28,11 +28,11 @@ const ( wildcard = "*" ) -// LocalCertProvider uses local certificates to decrypt CEKs +// Provider uses local certificates to decrypt CEKs // It supports both 'MSSQL_CERTIFICATE_STORE' and 'pfx' key stores. // MSSQL_CERTIFICATE_STORE key paths are of the form `storename/storepath/thumbprint` and only supported on Windows clients. // pfx key paths are absolute file system paths that are operating system dependent. -type LocalCertProvider struct { +type Provider struct { // Name identifies which key store the provider supports. name string // AllowedLocations constrains which locations the provider will use to find certificates. If empty, all locations are allowed. @@ -43,22 +43,25 @@ type LocalCertProvider struct { // SetCertificatePassword stores the password associated with the certificate at the given location. // If location is empty the given password applies to all certificates that have not been explicitly assigned a value. -func (p LocalCertProvider) SetCertificatePassword(location string, password string) { +func (p Provider) SetCertificatePassword(location string, password string) { if location == "" { location = wildcard } p.passwords[location] = password } -var PfxKeyProvider = LocalCertProvider{name: PfxKeyProviderName, passwords: make(map[string]string), AllowedLocations: make([]string, 0)} +var PfxKeyProvider = Provider{name: PfxKeyProviderName, passwords: make(map[string]string), AllowedLocations: make([]string, 0)} func init() { - aecmk.RegisterCekProvider("pfx", &PfxKeyProvider) + err := aecmk.RegisterCekProvider("pfx", &PfxKeyProvider) + if err != nil { + panic(err) + } } // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. -func (p *LocalCertProvider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { +func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { decryptedKey = nil pk, cert, allowed := p.tryLoadCertificate(masterKeyPath) if !allowed { @@ -84,7 +87,7 @@ func (p *LocalCertProvider) DecryptColumnEncryptionKey(masterKeyPath string, enc return } -func (p *LocalCertProvider) tryLoadCertificate(masterKeyPath string) (privateKey interface{}, cert *x509.Certificate, allowed bool) { +func (p *Provider) tryLoadCertificate(masterKeyPath string) (privateKey interface{}, cert *x509.Certificate, allowed bool) { allowed = len(p.AllowedLocations) == 0 if !allowed { loop: @@ -107,7 +110,7 @@ func (p *LocalCertProvider) tryLoadCertificate(masterKeyPath string) (privateKey return } -func (p *LocalCertProvider) loadLocalCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { +func (p *Provider) loadLocalCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { if f, err := os.Open(path); err == nil { pfxBytes, err := ioutil.ReadAll(f) if err != nil { @@ -131,7 +134,7 @@ func (p *LocalCertProvider) loadLocalCertificate(path string) (privateKey interf } // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. -func (p *LocalCertProvider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { +func (p *Provider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { validateEncryptionAlgorithm(encryptionAlgorithm) validateKeyPathLength(masterKeyPath) @@ -182,20 +185,20 @@ func (p *LocalCertProvider) EncryptColumnEncryptionKey(masterKeyPath string, enc // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. -func (p *LocalCertProvider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte { +func (p *Provider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte { return nil } // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key // with the specified key path and the specified enclave behavior. Return nil if not supported. -func (p *LocalCertProvider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool { +func (p *Provider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool { return nil } // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. // If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. // If it returns zero, the keys will not be cached. -func (p *LocalCertProvider) KeyLifetime() *time.Duration { +func (p *Provider) KeyLifetime() *time.Duration { return nil } diff --git a/aecmk/localcert/keyprovider_darwin.go b/aecmk/localcert/keyprovider_darwin.go index 03943489..c3a7564a 100644 --- a/aecmk/localcert/keyprovider_darwin.go +++ b/aecmk/localcert/keyprovider_darwin.go @@ -8,7 +8,7 @@ import ( "fmt" ) -func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { panic(fmt.Errorf("Windows cert store not supported on this OS")) return } diff --git a/aecmk/localcert/keyprovider_go117_windows_test.go b/aecmk/localcert/keyprovider_go117_windows_test.go index 28d2d2fa..9212d4b8 100644 --- a/aecmk/localcert/keyprovider_go117_windows_test.go +++ b/aecmk/localcert/keyprovider_go117_windows_test.go @@ -18,7 +18,7 @@ func TestLoadWindowsCertStoreCertificate(t *testing.T) { t.Fatal(err) } defer certs.DeleteMasterKeyCert(thumbprint) - provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*LocalCertProvider) + provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) pk, cert := provider.loadWindowsCertStoreCertificate("CurrentUser/My/" + thumbprint) switch z := pk.(type) { case *rsa.PrivateKey: @@ -40,7 +40,7 @@ func TestEncryptDecryptEncryptionKeyRoundTrip(t *testing.T) { defer certs.DeleteMasterKeyCert(thumbprint) bytesToEncrypt := []byte{1, 2, 3} keyPath := "CurrentUser/My/" + thumbprint - provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*LocalCertProvider) + provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) encryptedBytes := provider.EncryptColumnEncryptionKey(keyPath, "RSA_OAEP", bytesToEncrypt) decryptedBytes := provider.DecryptColumnEncryptionKey(keyPath, "RSA_OAEP", encryptedBytes) if len(decryptedBytes) != 3 || decryptedBytes[0] != 1 || decryptedBytes[1] != 2 || decryptedBytes[2] != 3 { diff --git a/aecmk/localcert/keyprovider_linux.go b/aecmk/localcert/keyprovider_linux.go index 03943489..c3a7564a 100644 --- a/aecmk/localcert/keyprovider_linux.go +++ b/aecmk/localcert/keyprovider_linux.go @@ -8,7 +8,7 @@ import ( "fmt" ) -func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { panic(fmt.Errorf("Windows cert store not supported on this OS")) return } diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index cce35f29..25c7fa20 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -14,13 +14,16 @@ import ( "golang.org/x/sys/windows" ) -var WindowsCertificateStoreKeyProvider = LocalCertProvider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} +var WindowsCertificateStoreKeyProvider = Provider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} func init() { - aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) + err := aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) + if err != nil { + panic(err) + } } -func (p *LocalCertProvider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { privateKey = nil cert = nil pathParts := strings.Split(path, `/`) diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go index 9e6e0161..f778e902 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go @@ -2,4 +2,4 @@ package keys type Key interface { RootKey() []byte -} \ No newline at end of file +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go index 52c2c792..4eb13390 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go @@ -15,4 +15,4 @@ func ConvertUTF16ToLittleEndianBytes(u []uint16) []byte { func ProcessUTF16LE(inputString string) []byte { return ConvertUTF16ToLittleEndianBytes(utf16.Encode([]rune(inputString))) -} \ No newline at end of file +} diff --git a/tds_test.go b/tds_test.go index 1904f1d2..daabd714 100644 --- a/tds_test.go +++ b/tds_test.go @@ -122,11 +122,11 @@ func TestSendLoginWithFeatureExt(t *testing.T) { Database: "database", ClientLCID: 0x204, } - login.FeatureExt.Add(&featureExtFedAuth{ + _ = login.FeatureExt.Add(&featureExtFedAuth{ FedAuthLibrary: FedAuthLibrarySecurityToken, FedAuthToken: "fedauthtoken", }) - login.FeatureExt.Add(&featureExtColumnEncryption{}) + _ = login.FeatureExt.Add(&featureExtColumnEncryption{}) err := sendLogin(buf, &login) if err != nil { t.Error("sendLogin should succeed") From a66a566cd2c3522a5f01da2ccf188e797c4d0d74 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 13:30:14 -0500 Subject: [PATCH 43/47] more linter fixes --- aecmk/localcert/keyprovider.go | 8 -------- encrypt.go | 6 +++--- .../mssql-always-encrypted/pkg/algorithms/algorithm.go | 2 +- .../mssql-always-encrypted/pkg/alwaysencrypted_test.go | 4 +++- 4 files changed, 7 insertions(+), 13 deletions(-) diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index 98d1ca10..24c1bcae 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -67,14 +67,6 @@ func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAl if !allowed { return } - switch p.name { - case PfxKeyProviderName: - pk, cert = p.loadLocalCertificate(masterKeyPath) - case aecmk.CertificateStoreKeyProvider: - pk, cert = p.loadWindowsCertStoreCertificate(masterKeyPath) - default: - return - } cekv := ae.LoadCEKV(encryptedCek) if !cekv.Verify(cert) { panic(fmt.Errorf("Invalid certificate provided for decryption. Key Store Path: %s. <%s>-<%v>", masterKeyPath, cekv.KeyPath, fmt.Sprintf("%02x", sha1.Sum(cert.Raw)))) diff --git a/encrypt.go b/encrypt.go index 91dca378..0e04837a 100644 --- a/encrypt.go +++ b/encrypt.go @@ -31,9 +31,9 @@ type cekData struct { cmkStoreName string cmkPath string algorithm string - byEnclave bool - cmkSignature string - decryptedValue []byte + //byEnclave bool + //cmkSignature string + decryptedValue []byte } type parameterEncData struct { diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go index 48a751da..ea1ca6b8 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go @@ -3,4 +3,4 @@ package algorithms type Algorithm interface { Encrypt([]byte) ([]byte, error) Decrypt([]byte) ([]byte, error) -} \ No newline at end of file +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go index 093ec64e..efc6525d 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go @@ -39,7 +39,9 @@ func TestLoadCEKV(t *testing.T) { t.Fatal(err) } cekvBytes, err := ioutil.ReadAll(cekvFile) - + if err != nil { + t.Fatal(err) + } cekv := LoadCEKV(cekvBytes) assert.Equal(t, 1, cekv.Version) assert.True(t, cekv.Verify(cert)) From cdaec238fe62dea52beb83456fd18e25372bea81 Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 13:33:36 -0500 Subject: [PATCH 44/47] check err in test --- .../mssql-always-encrypted/pkg/alwaysencrypted_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go index efc6525d..d44cee84 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go @@ -69,7 +69,9 @@ func TestDecrypt(t *testing.T) { t.Fatal(err) } cekvBytes, err := ioutil.ReadAll(cekvFile) - + if err != nil { + t.Fatal(err) + } cekv := LoadCEKV(cekvBytes) rootKey, err := cekv.Decrypt(rsaPrivKey) if err != nil { From 9e0b61e445475addccacff57b860bf6c1a5abc0a Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 19 Jul 2023 16:19:12 -0500 Subject: [PATCH 45/47] remove old SQL versions from PR build --- appveyor.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 2c2bcc43..fdeeedf3 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -18,14 +18,6 @@ environment: TAGS: matrix: - SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: - SQLINSTANCE: SQL2016 - - APPVEYOR_BUILD_WORKER_IMAGE: - SQLINSTANCE: SQL2014 - - APPVEYOR_BUILD_WORKER_IMAGE: - SQLINSTANCE: SQL2012SP1 - - APPVEYOR_BUILD_WORKER_IMAGE: - SQLINSTANCE: SQL2008R2SP2 - GOVERSION: 117 SQLINSTANCE: SQL2017 - GOVERSION: 118 From 2a41c82b467a5d900f8149ebdd202138efbb4cfc Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 24 Jul 2023 18:20:22 -0500 Subject: [PATCH 46/47] check err in test --- .../mssql-always-encrypted/pkg/alwaysencrypted_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go index d44cee84..860f3e02 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go @@ -94,7 +94,9 @@ func TestDecrypt(t *testing.T) { key := keys.NewAeadAes256CbcHmac256(rootKey) alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) cleartext, err := alg.Decrypt(columnBytes) - + if err != nil { + t.Fatalf("Decrypt failed! %s", err.Error()) + } enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) decoder := enc.NewDecoder() cleartextUtf8, err := decoder.Bytes(cleartext) From c53676e1151cc93e05d7b00cdd4304d290898839 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 25 Jul 2023 09:09:09 -0500 Subject: [PATCH 47/47] fix unit tests --- .../pkg/alwaysencrypted_test.go | 86 ++++++------------- 1 file changed, 24 insertions(+), 62 deletions(-) diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go index 860f3e02..11e4c237 100644 --- a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go @@ -20,125 +20,87 @@ import ( func TestLoadCEKV(t *testing.T) { certFile, err := os.Open("../test/always-encrypted_pub.pem") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) certBytes, err := ioutil.ReadAll(certFile) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) pemB, _ := pem.Decode(certBytes) cert, err := x509.ParseCertificate(pemB.Bytes) - if err != nil { - t.Fatal(nil) - } + assert.NoError(t, err) cekvFile, err := os.Open("../test/cekv.key") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cekvBytes, err := ioutil.ReadAll(cekvFile) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cekv := LoadCEKV(cekvBytes) assert.Equal(t, 1, cekv.Version) assert.True(t, cekv.Verify(cert)) } func TestDecrypt(t *testing.T) { certFile, err := os.Open("../test/always-encrypted.pem") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) certBytes, err := ioutil.ReadAll(certFile) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) pemB, _ := pem.Decode(certBytes) privKey, err := x509.ParsePKCS8PrivateKey(pemB.Bytes) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) rsaPrivKey := privKey.(*rsa.PrivateKey) cekvFile, err := os.Open("../test/cekv.key") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cekvBytes, err := ioutil.ReadAll(cekvFile) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cekv := LoadCEKV(cekvBytes) rootKey, err := cekv.Decrypt(rsaPrivKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) assert.Equal(t, "0ff9e45335df3dec7be0649f741e6ea870e9d49d16fe4be7437ce22489f48ead", fmt.Sprintf("%02x", rootKey)) assert.Equal(t, 1, cekv.Version) assert.NotNil(t, rootKey) columnBytesFile, err := os.Open("../test/column_value.enc") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) columnBytes, err := ioutil.ReadAll(columnBytesFile) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) key := keys.NewAeadAes256CbcHmac256(rootKey) alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) cleartext, err := alg.Decrypt(columnBytes) - if err != nil { - t.Fatalf("Decrypt failed! %s", err.Error()) - } + assert.NoErrorf(t, err, "Decrypt failed! %v", err) + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) decoder := enc.NewDecoder() cleartextUtf8, err := decoder.Bytes(cleartext) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) t.Logf("column value: \"%02X\"", cleartextUtf8) assert.Equal(t, "12345 ", string(cleartextUtf8)) } func TestDecryptCEK(t *testing.T) { certFile, err := os.Open("../test/always-encrypted.pem") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) certFileBytes, err := ioutil.ReadAll(certFile) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) pemBlock, _ := pem.Decode(certFileBytes) cert, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cekvFile, err := os.Open("../test/cekv.key") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cekvBytes, err := ioutil.ReadAll(cekvFile) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cekv := LoadCEKV(cekvBytes) - fmt.Printf("Cert: %v\n", cert) + t.Logf("Cert: %v\n", cert) rsaKey := cert.(*rsa.PrivateKey) // RSA/ECB/OAEPWithSHA-1AndMGF1Padding bytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, rsaKey, cekv.Ciphertext, nil) - fmt.Printf("Key: %02x\n", bytes) + assert.NoError(t, err) + t.Logf("Key: %02x\n", bytes) }