From 96ed1e543e4cf92e060a2667668335192180816c Mon Sep 17 00:00:00 2001 From: XuShuo Date: Thu, 28 Sep 2023 19:06:55 +0800 Subject: [PATCH] [Feature] new feature abuut MySQL (datetime, timestamp) parse to time.Time --- cursor.go | 45 ++++++++++++++++++++++++++++++++++++++---- cursor_test.go | 29 +++++++++++++++++++++++---- database_test.go | 2 +- generator/generator.go | 45 ++++++++++++++++++++++++++++++++++-------- 4 files changed, 104 insertions(+), 17 deletions(-) diff --git a/cursor.go b/cursor.go index 059c264..b94f872 100644 --- a/cursor.go +++ b/cursor.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strconv" + "time" ) // Cursor is the interface of a row cursor. @@ -35,6 +36,10 @@ func preparePointers(val reflect.Value, scans *[]interface{}) error { *scans = append(*scans, addr.Interface()) } case reflect.Struct: + if val.Type() == reflect.TypeOf(time.Time{}) { + *scans = append(*scans, val.Addr().Interface()) + return nil + } for j := 0; j < val.NumField(); j++ { field := val.Field(j) if field.Kind() == reflect.Interface { @@ -106,16 +111,27 @@ func (c cursor) Scan(dest ...interface{}) error { pbs := make(map[int]*bool) ppbs := make(map[int]**bool) + pts := make(map[int]*time.Time) + ppts := make(map[int]**time.Time) for i, scan := range scans { - if pb, ok := scan.(*bool); ok { + switch scan.(type) { + case *bool: var s []uint8 scans[i] = &s - pbs[i] = pb - } else if ppb, ok := scan.(**bool); ok { + pbs[i] = scan.(*bool) + case **bool: var s *[]uint8 scans[i] = &s - ppbs[i] = ppb + ppbs[i] = scan.(**bool) + case *time.Time: + var s string + scans[i] = &s + pts[i] = scan.(*time.Time) + case **time.Time: + var s *string + scans[i] = &s + ppts[i] = scan.(**time.Time) } } @@ -145,6 +161,27 @@ func (c cursor) Scan(dest ...interface{}) error { *ppb = &b } } + for i, pt := range pts { + if *(scans[i].(*string)) == "" { + return fmt.Errorf("field %d is null", i) + } + t, err := time.Parse("2006-01-02 15:04:05", *(scans[i].(*string))) + if err != nil { + return err + } + *pt = t + } + for i, ppt := range ppts { + if *(scans[i].(**string)) == nil { + *ppt = nil + } else { + t, err := time.Parse("2006-01-02 15:04:05", **(scans[i].(**string))) + if err != nil { + return err + } + *ppt = &t + } + } return err } diff --git a/cursor_test.go b/cursor_test.go index c5a0149..d853ab8 100644 --- a/cursor_test.go +++ b/cursor_test.go @@ -5,6 +5,7 @@ import ( "io" "strconv" "testing" + "time" ) type mockDriver struct{} @@ -30,7 +31,7 @@ type mockRows struct { } func (m mockRows) Columns() []string { - return []string{"a", "b", "c", "d", "e", "f", "g"}[:m.columnCount] + return []string{"a", "b", "c", "d", "e", "f", "g", "h", "j", "k", "l"}[:m.columnCount] } func (m mockRows) Close() error { @@ -58,6 +59,14 @@ func (m *mockRows) Next(dest []driver.Value) error { dest[i] = dest[0] case 6: dest[i] = nil + case 7: + dest[i] = "2023-09-06 18:37:46.828" + case 8: + dest[i] = "2023-09-06 18:37:46.828" + case 9: + dest[i] = "2023-09-06 18:37:46" + case 10: + dest[i] = "2023-09-06 18:37:46" } } return nil @@ -98,12 +107,20 @@ func TestCursor(t *testing.T) { var f ****int // deep pointer var g *int // always null + var h *time.Time + var j time.Time + var k *time.Time + var l time.Time + tmh, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828") + tmj, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828") + tmk, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46") + tml, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46") for i := 1; i <= 10; i++ { if !cursor.Next() { t.Error() } g = &i - if err := cursor.Scan(&a, &b, &cde, &f, &g); err != nil { + if err := cursor.Scan(&a, &b, &cde, &f, &g, &h, &j, &k, &l); err != nil { t.Errorf("%v", err) } if a != i || @@ -112,7 +129,11 @@ func TestCursor(t *testing.T) { cde.DE.D != (i%2 == 1) || cde.DE.E != cde.DE.D || ****f != i || - g != nil { + g != nil || + *h != tmh || + j != tmj || + *k != tmk || + l != tml { t.Error(a, b, cde.C, cde.DE.D, cde.DE.E, ****f, g) } if err := cursor.Scan(); err != nil { @@ -123,7 +144,7 @@ func TestCursor(t *testing.T) { var b ****bool var p *string var bs []byte - if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p); err != nil { + if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p, &h, &j, &k, &l); err != nil { t.Error(err) } if ****b != (i%2 == 1) || diff --git a/database_test.go b/database_test.go index 37bf183..bf4364a 100644 --- a/database_test.go +++ b/database_test.go @@ -32,7 +32,7 @@ func (m *mockConn) Begin() (driver.Tx, error) { } var sharedMockConn = &mockConn{ - columnCount: 7, + columnCount: 11, rowCount: 10, } diff --git a/generator/generator.go b/generator/generator.go index ed7b748..e8cb013 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -88,7 +88,7 @@ func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string, case "float", "double", "decimal", "real": goType = "float64" fieldClass = "NumberField" - case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "datetime", "date", "time", "timestamp", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid": + case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "date", "time", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid": goType = "string" fieldClass = "StringField" case "year": @@ -103,6 +103,10 @@ func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string, // TODO: Switch to specific type instead of interface. goType = "[]interface{}" fieldClass = "ArrayField" + case "datetime", "timestamp": + // todo import + goType = "time.Time" + fieldClass = "DateField" case "geometry", "point", "linestring", "polygon", "multipoint", "multilinestring", "multipolygon", "geometrycollection": goType = "sqlingo.WellKnownBinary" fieldClass = "WellKnownBinaryField" @@ -174,10 +178,38 @@ func Generate(driverName string, exampleDataSourceName string) (string, error) { return "", errors.New("no database selected") } + if len(options.tableNames) == 0 { + options.tableNames, err = schemaFetcher.GetTableNames() + if err != nil { + return "", err + } + } + + needImportTime := false + for _, tableName := range options.tableNames { + fieldDescriptors, err := schemaFetcher.GetFieldDescriptors(tableName) + if err != nil { + return "", err + } + for _, fieldDescriptor := range fieldDescriptors { + if fieldDescriptor.Type == "datetime" || fieldDescriptor.Type == "timestamp" { + needImportTime = true + break + } + } + } + code := "// This file is generated by sqlingo (https://github.com/lqs/sqlingo)\n" code += "// DO NOT EDIT.\n\n" code += "package " + ensureIdentifier(dbName) + "_dsl\n" - code += "import \"github.com/lqs/sqlingo\"\n\n" + if needImportTime { + code += "import (\n" + code += "\t\"time\"\n" + code += "\t\"github.com/lqs/sqlingo\"\n" + code += ")\n\n" + } else { + code += "import \"github.com/lqs/sqlingo\"\n\n" + } code += "type sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame uint32\n\n" @@ -205,12 +237,9 @@ func Generate(driverName string, exampleDataSourceName string) (string, error) { code += "\tsqlingo.ArrayField\n" code += "}\n\n" - if len(options.tableNames) == 0 { - options.tableNames, err = schemaFetcher.GetTableNames() - if err != nil { - return "", err - } - } + code += "type dateField interface {\n" + code += "\tsqlingo.DateField\n" + code += "}\n\n" var wg sync.WaitGroup