From 80639541f33e3c92edf68fa6136ac69c140d53a1 Mon Sep 17 00:00:00 2001 From: Andrea Magnetto Date: Wed, 24 Jul 2024 14:07:55 +0200 Subject: [PATCH] Fix GUID conversion --- alwaysencrypted_test.go | 2 +- bulkcopy.go | 2 +- bulkcopy_test.go | 12 ++++- msdsn/conn_str.go | 25 +++++++++++ msdsn/conn_str_test.go | 5 ++- mssql.go | 2 +- mssql_go19.go | 2 +- queries_test.go | 22 ++++++--- rpc.go | 8 ++-- session.go | 1 + tds.go | 1 + tds_go110_test.go | 14 ++++-- tds_go110pre_test.go | 9 +++- tds_test.go | 6 +++ token.go | 12 ++--- tvp_go19.go | 6 ++- tvp_go19_db_test.go | 37 ++++++++++++--- tvp_go19_test.go | 13 +++++- types.go | 99 +++++++++++++++++++++++++++++++++-------- 19 files changed, 222 insertions(+), 56 deletions(-) diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index 518055b7..9ff7ff86 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -213,7 +213,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) { func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnEncryptionKeyProvider, sel string, insert string, insertArgs []interface{}) { t.Helper() testProvider := &testKeyProvider{fallback: provider} - connector, _ := getTestConnector(t) + connector, _ := getTestConnector(t, false /*guidConversion*/) connector.RegisterCekProvider(name, testProvider) conn := sql.OpenDB(connector) defer conn.Close() diff --git a/bulkcopy.go b/bulkcopy.go index 3008359f..be8398b1 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -264,7 +264,7 @@ func (b *Bulk) createColMetadata() []byte { } binary.Write(buf, binary.LittleEndian, uint16(col.Flags)) - writeTypeInfo(buf, &b.bulkColumns[i].ti, false) + writeTypeInfo(buf, &b.bulkColumns[i].ti, false, b.cn.sess.encoding) if col.ti.TypeId == typeNText || col.ti.TypeId == typeText || diff --git a/bulkcopy_test.go b/bulkcopy_test.go index 2dc6c6ae..559f4f41 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -111,9 +111,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) { } } -func TestBulkcopy(t *testing.T) { +func testBulkcopy(t *testing.T, guidConversion bool) { // TDS level Bulk Insert is not supported on Azure SQL Server. - if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") { + if dsn := makeConnStrSettingGuidConversion(t, guidConversion); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") { t.Skip("TDS level bulk copy is not supported on Azure SQL Server") } type testValue struct { @@ -300,6 +300,14 @@ func TestBulkcopy(t *testing.T) { } } +func TestBulkcopyWithGuidConversion(t *testing.T) { + testBulkcopy(t, true /*guidConversion*/) +} + +func TestBulkcopyWithoutGuidConversion(t *testing.T) { + testBulkcopy(t, false /*guidConversion*/) +} + func compareValue(a interface{}, expected interface{}) bool { if got, ok := a.([]uint8); ok { if _, ok := expected.([]uint8); !ok { diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 04653c27..a4e31506 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -84,8 +84,14 @@ const ( Pipe = "pipe" MultiSubnetFailover = "multisubnetfailover" NoTraceID = "notraceid" + GuidConversion = "guid conversion" ) +type EncodeParameters struct { + // Properly convert GUIDs, using correct byte endianness + GuidConversion bool +} + type Config struct { Port uint64 Host string @@ -141,6 +147,8 @@ type Config struct { // When true, no connection id or trace id value is sent in the prelogin packet. // Some cloud servers may block connections that lack such values. NoTraceID bool + // Parameters related to type encoding + Encoding EncodeParameters } func readDERFile(filename string) ([]byte, error) { @@ -525,6 +533,20 @@ func Parse(dsn string) (Config, error) { p.NoTraceID = notraceid } } + + guidConversion, ok := params[GuidConversion] + if ok { + var err error + p.Encoding.GuidConversion, err = strconv.ParseBool(guidConversion) + if err != nil { + f := "invalid guid conversion '%s': %s" + return p, fmt.Errorf(f, guidConversion, err.Error()) + } + } else { + // set to false for backward compatibility + p.Encoding.GuidConversion = false + } + return p, nil } @@ -585,6 +607,9 @@ func (p Config) URL() *url.URL { if p.ColumnEncryption { q.Add("columnencryption", "true") } + + q.Add(GuidConversion, strconv.FormatBool(p.Encoding.GuidConversion)) + if len(q) > 0 { res.RawQuery = q.Encode() } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 232845df..80db9c85 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -190,7 +190,10 @@ func TestValidConnectionString(t *testing.T) { return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption }}, {"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool { - return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption + return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && !p.Encoding.GuidConversion + }}, + {"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1&guid+conversion=true", func(p Config) bool { + return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && p.Encoding.GuidConversion }}, } for _, ts := range connStrings { diff --git a/mssql.go b/mssql.go index 62aaa437..58be982a 100644 --- a/mssql.go +++ b/mssql.go @@ -554,7 +554,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) { params[0] = makeStrParam(s.query) params[1] = makeStrParam(strings.Join(decls, ",")) } - if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil { + if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset, conn.sess.encoding); err != nil { conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send Rpc with %v", err) conn.connectionGood = false return fmt.Errorf("failed to send RPC: %v", err) diff --git a/mssql_go19.go b/mssql_go19.go index 6435f67e..9b55d1b9 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -206,7 +206,7 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { err = errCalTypes return } - res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes) + res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes, s.c.sess.encoding) if err != nil { return } diff --git a/queries_test.go b/queries_test.go index 533d3c4e..25b0b58e 100644 --- a/queries_test.go +++ b/queries_test.go @@ -27,8 +27,8 @@ func driverWithProcess(t *testing.T, tl Logger) *Driver { } } -func TestSelect(t *testing.T) { - conn, logger := open(t) +func testSelect(t *testing.T, guidConversion bool) { + conn, logger := openSettingGuidConversion(t, guidConversion) defer conn.Close() defer logger.StopLogging() @@ -39,6 +39,10 @@ func TestSelect(t *testing.T) { } longstr := strings.Repeat("x", 10000) + expectedGuid := []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF} + if guidConversion { + expectedGuid = []byte{0xFF, 0x19, 0x96, 0x6F, 0x86, 0x8B, 0x11, 0xD0, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF} + } values := []testStruct{ {"1", int64(1)}, @@ -83,8 +87,7 @@ func TestSelect(t *testing.T) { {"cast('2079-06-06T23:59:00' as smalldatetime)", time.Date(2079, 6, 6, 23, 59, 0, 0, time.UTC)}, {"cast(NULL as smalldatetime)", nil}, - {"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)", - []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}}, + {"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)", expectedGuid}, {"cast(NULL as uniqueidentifier)", nil}, {"cast(0x1234 as varbinary(2))", []byte{0x12, 0x34}}, {"cast(N'abc' as nvarchar(max))", "abc"}, @@ -114,8 +117,7 @@ func TestSelect(t *testing.T) { {"cast(cast(N'chào' as nvarchar(max)) collate Vietnamese_CI_AI as varchar(max))", "chào"}, // cp1258 {fmt.Sprintf("cast(N'%s' as nvarchar(max))", longstr), longstr}, {"cast(NULL as sql_variant)", nil}, - {"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)", - []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}}, + {"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)", expectedGuid}, {"cast(cast(1 as bit) as sql_variant)", true}, {"cast(cast(10 as tinyint) as sql_variant)", int64(10)}, {"cast(cast(-10 as smallint) as sql_variant)", int64(-10)}, @@ -214,6 +216,14 @@ func TestSelect(t *testing.T) { }) } +func TestSelectWithGuidConversion(t *testing.T) { + testSelect(t, true /*guidConversion*/) +} + +func TestSelectWithoutGuidConversion(t *testing.T) { + testSelect(t, false /*guidConversion*/) +} + func TestSelectDateTimeOffset(t *testing.T) { type testStruct struct { sql string diff --git a/rpc.go b/rpc.go index afda1309..527fd7d7 100644 --- a/rpc.go +++ b/rpc.go @@ -2,6 +2,8 @@ package mssql import ( "encoding/binary" + + "github.com/microsoft/go-mssqldb/msdsn" ) type procId struct { @@ -43,7 +45,7 @@ var ( ) // http://msdn.microsoft.com/en-us/library/dd357576.aspx -func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) (err error) { +func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool, encoding msdsn.EncodeParameters) (err error) { buf.BeginPacket(packRPCRequest, resetSession) writeAllHeaders(buf, headers) if len(proc.name) == 0 { @@ -73,7 +75,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil { return } - err = writeTypeInfo(buf, ¶m.ti, (param.Flags&fByRevValue) != 0) + err = writeTypeInfo(buf, ¶m.ti, (param.Flags&fByRevValue) != 0, encoding) if err != nil { return } @@ -82,7 +84,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, return } if (param.Flags & fEncrypted) == fEncrypted { - err = writeTypeInfo(buf, ¶m.tiOriginal, false) + err = writeTypeInfo(buf, ¶m.tiOriginal, false, encoding) if err != nil { return } diff --git a/session.go b/session.go index 33b06d89..ff4839ad 100644 --- a/session.go +++ b/session.go @@ -15,6 +15,7 @@ func newSession(outbuf *tdsBuffer, logger ContextLogger, p msdsn.Config) *tdsSes logger: logger, logFlags: uint64(p.LogFlags), aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()}, + encoding: p.Encoding, } _ = sess.activityid.Scan(p.ActivityID) // generating a guid has a small chance of failure. Make a best effort diff --git a/tds.go b/tds.go index 852c4d9f..9ddc2ce7 100644 --- a/tds.go +++ b/tds.go @@ -175,6 +175,7 @@ type tdsSession struct { aeSettings *alwaysEncryptedSettings connid UniqueIdentifier activityid UniqueIdentifier + encoding msdsn.EncodeParameters } type alwaysEncryptedSettings struct { diff --git a/tds_go110_test.go b/tds_go110_test.go index 76ecfc66..58eb4741 100644 --- a/tds_go110_test.go +++ b/tds_go110_test.go @@ -8,16 +8,22 @@ import ( "testing" ) -func open(t testing.TB) (*sql.DB, *testLogger) { - connector, logger := getTestConnector(t) +func openSettingGuidConversion(t testing.TB, guidConversion bool) (*sql.DB, *testLogger) { + connector, logger := getTestConnector(t, guidConversion) conn := sql.OpenDB(connector) return conn, logger } -func getTestConnector(t testing.TB) (*Connector, *testLogger) { +func open(t testing.TB) (*sql.DB, *testLogger) { + return openSettingGuidConversion(t, false /*guidConversion*/) +} + +func getTestConnector(t testing.TB, guidConversion bool) (*Connector, *testLogger) { tl := testLogger{t: t} SetLogger(&tl) - connector, err := NewConnector(makeConnStr(t).String()) + + connectionString := makeConnStrSettingGuidConversion(t, guidConversion).String() + connector, err := NewConnector(connectionString) if err != nil { t.Error("Open connection failed:", err.Error()) return nil, &tl diff --git a/tds_go110pre_test.go b/tds_go110pre_test.go index 6a42d3ae..2d7ff4e9 100644 --- a/tds_go110pre_test.go +++ b/tds_go110pre_test.go @@ -1,3 +1,4 @@ +//go:build !go1.10 // +build !go1.10 package mssql @@ -7,14 +8,18 @@ import ( "testing" ) -func open(t *testing.T) (*sql.DB, *testLogger) { +func openSettingGuidConversion(t *testing.T, guidConversion bool) (*sql.DB, *testLogger) { tl := testLogger{t: t} SetLogger(&tl) checkConnStr(t) - conn, err := sql.Open("sqlserver", makeConnStr(t).String()) + conn, err := sql.Open("sqlserver", makeConnStrSettingGuidConversion(t, guidConversion).String()) if err != nil { t.Error("Open connection failed:", err.Error()) return nil, &tl } return conn, &tl } + +func open(t *testing.T) (*sql.DB, *testLogger) { + return openSettingGuidConversion(t, false /*guidConversion*/) +} diff --git a/tds_test.go b/tds_test.go index 499ee084..38372804 100644 --- a/tds_test.go +++ b/tds_test.go @@ -364,6 +364,12 @@ func makeConnStr(t testing.TB) *url.URL { return testConnParams(t).URL() } +func makeConnStrSettingGuidConversion(t testing.TB, guidConversion bool) *url.URL { + config := testConnParams(t) + config.Encoding.GuidConversion = guidConversion + return config.URL() +} + type testLogger struct { t testing.TB mu sync.Mutex diff --git a/token.go b/token.go index 04a3e421..8926ca58 100644 --- a/token.go +++ b/token.go @@ -610,7 +610,7 @@ func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { for i := range columns { column := &columns[i] baseTi := getBaseTypeInfo(r, true) - typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta) + typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta, s.encoding) typeInfo.UserType = baseTi.UserType typeInfo.Flags = baseTi.Flags typeInfo.TypeId = baseTi.TypeId @@ -621,7 +621,7 @@ func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { if column.isEncrypted() && s.alwaysEncrypted { // Read Crypto Metadata - cryptoMeta := parseCryptoMetadata(r, cekTable) + cryptoMeta := parseCryptoMetadata(r, cekTable, s.encoding) cryptoMeta.typeInfo.Flags = baseTi.Flags column.cryptoMeta = &cryptoMeta } else { @@ -657,14 +657,14 @@ type cryptoMetadata struct { typeInfo typeInfo } -func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata { +func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable, encoding msdsn.EncodeParameters) cryptoMetadata { ordinal := uint16(0) if cekTable != nil { ordinal = r.uint16() } typeInfo := getBaseTypeInfo(r, false) - ti := readTypeInfo(r, typeInfo.TypeId, nil) + ti := readTypeInfo(r, typeInfo.TypeId, nil, encoding) ti.UserType = typeInfo.UserType ti.Flags = typeInfo.Flags ti.TypeId = typeInfo.TypeId @@ -929,11 +929,11 @@ func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) { var cryptoMetadata *cryptoMetadata = nil if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted { - cm := parseCryptoMetadata(r, nil) // CryptoMetadata + cm := parseCryptoMetadata(r, nil, s.encoding) // CryptoMetadata cryptoMetadata = &cm } - ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata) + ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata, s.encoding) nv.Value = ti2.Reader(&ti2, r, cryptoMetadata) return diff --git a/tvp_go19.go b/tvp_go19.go index cc5dbfe4..5dba1abb 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -12,6 +12,8 @@ import ( "reflect" "strings" "time" + + "github.com/microsoft/go-mssqldb/msdsn" ) const ( @@ -62,7 +64,7 @@ func (tvp TVP) check() error { return nil } -func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int) ([]byte, error) { +func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int, encoding msdsn.EncodeParameters) ([]byte, error) { if len(columnStr) != len(tvpFieldIndexes) { return nil, ErrorWrongTyping } @@ -80,7 +82,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd for i, column := range columnStr { binary.Write(buf, binary.LittleEndian, column.UserType) binary.Write(buf, binary.LittleEndian, column.Flags) - writeTypeInfo(buf, &columnStr[i].ti, false) + writeTypeInfo(buf, &columnStr[i].ti, false, encoding) writeBVarChar(buf, "") } // The returned error is always nil diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index da80c19c..8b5b8291 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -416,13 +416,22 @@ func TestTVPGoSQLTypes(t *testing.T) { } } -func TestTVP(t *testing.T) { +func getNullableUniqueIdentifier(guidConversion bool) *UniqueIdentifier { + result := UniqueIdentifier{} + if !guidConversion { + // if guidConversion is enabled nullable UUIDs can only be accessed using NullUniqueIdentifier + result = UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} + } + return &result +} + +func testTVP(t *testing.T, guidConversion bool) { checkConnStr(t) tl := testLogger{t: t} defer tl.StopLogging() SetLogger(&tl) - c := makeConnStr(t).String() + c := makeConnStrSettingGuidConversion(t, guidConversion).String() db, err := sql.Open("sqlserver", c) if err != nil { t.Fatalf("failed to open driver sqlserver") @@ -571,7 +580,7 @@ func TestTVP(t *testing.T) { PBinary: nil, PVarcharNull: &varcharNull, PNvarcharNull: &nvarchar, - PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PIDNull: getNullableUniqueIdentifier(guidConversion), PTinyintNull: &i8, PSmallintNull: &i16, PIntNull: &i32, @@ -699,13 +708,21 @@ func TestTVP(t *testing.T) { } } -func TestTVP_WithTag(t *testing.T) { +func TestTVP_WithGuidConversion(t *testing.T) { + testTVP(t, true /*guidConversion*/) +} + +func TestTVP_WithoutGuidConversion(t *testing.T) { + testTVP(t, false /*guidConversion*/) +} + +func testTVP_WithTag(t *testing.T, guidConversion bool) { checkConnStr(t) tl := testLogger{t: t} defer tl.StopLogging() SetLogger(&tl) - db, err := sql.Open("sqlserver", makeConnStr(t).String()) + db, err := sql.Open("sqlserver", makeConnStrSettingGuidConversion(t, guidConversion).String()) if err != nil { t.Fatalf("failed to open driver sqlserver") } @@ -884,7 +901,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: nil, PVarcharNull: &varcharNull, PNvarcharNull: &nvarchar, - PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PIDNull: getNullableUniqueIdentifier(guidConversion), PTinyintNull: &i8, PSmallintNull: &i16, PIntNull: &i32, @@ -1114,6 +1131,14 @@ func TestTVPSchema(t *testing.T) { log.Println(tvpResult) } +func TestTVP_WithTagAndGuidConversion(t *testing.T) { + testTVP_WithTag(t, true /*guidConversion*/) +} + +func TestTVP_WithTagWithoutGuidConversion(t *testing.T) { + testTVP_WithTag(t, false /*guidConversion*/) +} + func TestTVPObject(t *testing.T) { checkConnStr(t) tl := testLogger{t: t} diff --git a/tvp_go19_test.go b/tvp_go19_test.go index 5a990a74..958546b4 100644 --- a/tvp_go19_test.go +++ b/tvp_go19_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" "time" + + "github.com/microsoft/go-mssqldb/msdsn" ) type TestFields struct { @@ -503,7 +505,7 @@ func Test_getSchemeAndName(t *testing.T) { } } -func TestTVP_encode(t *testing.T) { +func testTVP_encode(t *testing.T, guidConversion bool) { type fields struct { TypeName string Value interface{} @@ -566,7 +568,7 @@ func TestTVP_encode(t *testing.T) { TypeName: tt.fields.TypeName, Value: tt.fields.Value, } - got, err := tvp.encode(tt.args.schema, tt.args.name, tt.args.columnStr, tt.args.tvpFieldIndexes) + got, err := tvp.encode(tt.args.schema, tt.args.name, tt.args.columnStr, tt.args.tvpFieldIndexes, msdsn.EncodeParameters{GuidConversion: guidConversion}) if (err != nil) != tt.wantErr { t.Errorf("TVP.encode() error = %v, wantErr %v", err, tt.wantErr) return @@ -577,3 +579,10 @@ func TestTVP_encode(t *testing.T) { }) } } +func TestTVP_encode_WithGuidConversion(t *testing.T) { + testTVP_encode(t, true /*guidConversion*/) +} + +func TestTVP_encode_WithoutGuidConversion(t *testing.T) { + testTVP_encode(t, false /*guidConversion*/) +} diff --git a/types.go b/types.go index 8f5ad9b9..41be1134 100644 --- a/types.go +++ b/types.go @@ -13,6 +13,7 @@ import ( "github.com/microsoft/go-mssqldb/internal/cp" "github.com/microsoft/go-mssqldb/internal/decimal" + "github.com/microsoft/go-mssqldb/msdsn" ) // fixed-length data types @@ -122,7 +123,7 @@ type xmlInfo struct { XmlSchemaCollection string } -func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) { +func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata, encoding msdsn.EncodeParameters) (res typeInfo) { res.TypeId = typeId switch typeId { case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4, @@ -143,13 +144,13 @@ func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) { res.Reader = readFixedType res.Buffer = make([]byte, res.Size) default: // all others are VARLENTYPE - readVarLen(&res, r, c) + readVarLen(&res, r, c, encoding) } return } // https://msdn.microsoft.com/en-us/library/dd358284.aspx -func writeTypeInfo(w io.Writer, ti *typeInfo, out bool) (err error) { +func writeTypeInfo(w io.Writer, ti *typeInfo, out bool, encoding msdsn.EncodeParameters) (err error) { err = binary.Write(w, binary.LittleEndian, ti.TypeId) if err != nil { return @@ -163,7 +164,7 @@ func writeTypeInfo(w io.Writer, ti *typeInfo, out bool) (err error) { case typeTvp: ti.Writer = writeFixedType default: // all others are VARLENTYPE - err = writeVarLen(w, ti, out) + err = writeVarLen(w, ti, out, encoding) if err != nil { return } @@ -177,7 +178,7 @@ func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) { } // https://msdn.microsoft.com/en-us/library/dd358341.aspx -func writeVarLen(w io.Writer, ti *typeInfo, out bool) (err error) { +func writeVarLen(w io.Writer, ti *typeInfo, out bool, encoding msdsn.EncodeParameters) (err error) { switch ti.TypeId { case typeDateN: @@ -194,7 +195,7 @@ func writeVarLen(w io.Writer, ti *typeInfo, out bool) (err error) { // byle len types if ti.Size > 0xff { - panic("Invalid size for BYLELEN_TYPE") + panic("Invalid size for BYTELEN_TYPE") } if err = binary.Write(w, binary.LittleEndian, uint8(ti.Size)); err != nil { return @@ -213,12 +214,16 @@ func writeVarLen(w io.Writer, ti *typeInfo, out bool) (err error) { ti.Writer = writeByteLenType case typeGuid: if !(ti.Size == 0x10 || ti.Size == 0x00) { - panic("Invalid size for BYLELEN_TYPE") + panic("Invalid size for UNIQUEIDENTIFIER") } if err = binary.Write(w, binary.LittleEndian, uint8(ti.Size)); err != nil { return } - ti.Writer = writeByteLenType + if encoding.GuidConversion { + ti.Writer = writeGuidTypeWithConversion + } else { + ti.Writer = writeGuidTypeWithoutConversion + } case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar, typeNVarChar, typeNChar, typeXml, typeUdt: @@ -352,7 +357,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { panic("shoulnd't get here") } -func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { +func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata, encoding msdsn.EncodeParameters) interface{} { var size byte if c != nil { size = byte(r.rsize) @@ -377,7 +382,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} case typeDateTimeOffsetN: return decodeDateTimeOffset(ti.Scale, buf) case typeGuid: - return decodeGuid(buf) + return decodeGuid(buf, encoding) case typeIntN: switch len(buf) { case 1: @@ -444,6 +449,14 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} panic("shouldn't get here") } +func readByteLenTypeWithGuidConversion(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + return readByteLenType(ti, r, c, msdsn.EncodeParameters{GuidConversion: true}) +} + +func readByteLenTypeWithoutGuidConversion(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + return readByteLenType(ti, r, c, msdsn.EncodeParameters{GuidConversion: false}) +} + func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { if ti.Size > 0xff { panic("Invalid size for BYTELEN_TYPE") @@ -456,6 +469,35 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } +func writeGuidType(w io.Writer, ti typeInfo, buf []byte, encoding msdsn.EncodeParameters) (err error) { + if !(ti.Size == 0x10 || ti.Size == 0x00) { + panic("Invalid size for UNIQUEIDENTIFIER") + } + err = binary.Write(w, binary.LittleEndian, uint8(len(buf))) + if err != nil { + return + } + if ti.Size == 0x10 { + res := make([]byte, 0x10) + copy(res, buf) + if encoding.GuidConversion { + binary.BigEndian.PutUint32(res[0:4], binary.LittleEndian.Uint32(res[0:4])) + binary.BigEndian.PutUint16(res[4:6], binary.LittleEndian.Uint16(res[4:6])) + binary.BigEndian.PutUint16(res[6:8], binary.LittleEndian.Uint16(res[6:8])) + } + _, err = w.Write(res) + } + return +} + +func writeGuidTypeWithConversion(w io.Writer, ti typeInfo, buf []byte) (err error) { + return writeGuidType(w, ti, buf, msdsn.EncodeParameters{GuidConversion: true}) +} + +func writeGuidTypeWithoutConversion(w io.Writer, ti typeInfo, buf []byte) (err error) { + return writeGuidType(w, ti, buf, msdsn.EncodeParameters{GuidConversion: false}) +} + func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { var size uint16 if c != nil { @@ -579,7 +621,7 @@ func writeCollation(w io.Writer, col cp.Collation) (err error) { // reads variant value // http://msdn.microsoft.com/en-us/library/dd303302.aspx -func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { +func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata, encoding msdsn.EncodeParameters) interface{} { size := r.int32() if size == 0 { return nil @@ -590,7 +632,7 @@ func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} case typeGuid: buf := make([]byte, size-2-propbytes) r.ReadFull(buf) - return buf + return decodeGuid(buf, encoding) case typeBit: return r.byte() != 0 case typeInt1: @@ -669,6 +711,14 @@ func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} panic("shoulnd't get here") } +func readVariantTypeWithGuidConversion(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + return readVariantType(ti, r, c, msdsn.EncodeParameters{GuidConversion: true}) +} + +func readVariantTypeWithoutGuidConversion(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + return readVariantType(ti, r, c, msdsn.EncodeParameters{GuidConversion: false}) +} + // partially length prefixed stream // http://msdn.microsoft.com/en-us/library/dd340469.aspx func readPLPType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { @@ -738,11 +788,11 @@ func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { } } -func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) { +func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata, encoding msdsn.EncodeParameters) { switch ti.TypeId { case typeDateN: ti.Size = 3 - ti.Reader = readByteLenType + ti.Reader = readByteLenTypeWithoutGuidConversion ti.Buffer = make([]byte, ti.Size) case typeTimeN, typeDateTime2N, typeDateTimeOffsetN: ti.Scale = r.byte() @@ -762,7 +812,7 @@ func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) { case typeDateTimeOffsetN: ti.Size += 5 } - ti.Reader = readByteLenType + ti.Reader = readByteLenTypeWithoutGuidConversion ti.Buffer = make([]byte, ti.Size) case typeGuid, typeIntN, typeDecimal, typeNumeric, typeBitN, typeDecimalN, typeNumericN, typeFltN, @@ -776,7 +826,11 @@ func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) { ti.Prec = r.byte() ti.Scale = r.byte() } - ti.Reader = readByteLenType + if encoding.GuidConversion { + ti.Reader = readByteLenTypeWithGuidConversion + } else { + ti.Reader = readByteLenTypeWithoutGuidConversion + } case typeXml: ti.XmlInfo.SchemaPresent = r.byte() if ti.XmlInfo.SchemaPresent != 0 { @@ -831,7 +885,11 @@ func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) { } ti.Reader = readLongLenType case typeVariant: - ti.Reader = readVariantType + if encoding.GuidConversion { + ti.Reader = readVariantTypeWithGuidConversion + } else { + ti.Reader = readVariantTypeWithoutGuidConversion + } } default: badStreamPanicf("Invalid type %d", ti.TypeId) @@ -855,9 +913,14 @@ func decodeMoney4(buf []byte) []byte { return decimal.ScaleBytes(strconv.FormatInt(int64(money), 10), 4) } -func decodeGuid(buf []byte) []byte { +func decodeGuid(buf []byte, encoding msdsn.EncodeParameters) []byte { res := make([]byte, 16) copy(res, buf) + if encoding.GuidConversion { + binary.LittleEndian.PutUint32(res[0:4], binary.BigEndian.Uint32(res[0:4])) + binary.LittleEndian.PutUint16(res[4:6], binary.BigEndian.Uint16(res[4:6])) + binary.LittleEndian.PutUint16(res[6:8], binary.BigEndian.Uint16(res[6:8])) + } return res }