Skip to content

Commit

Permalink
[Feature] new feature abuut MySQL (datetime, timestamp) parse to time…
Browse files Browse the repository at this point in the history
….Time
  • Loading branch information
VarusHsu committed Sep 28, 2023
1 parent 12fc7d2 commit 96ed1e5
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 17 deletions.
45 changes: 41 additions & 4 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"strconv"
"time"
)

// Cursor is the interface of a row cursor.
Expand Down Expand Up @@ -35,6 +36,10 @@ func preparePointers(val reflect.Value, scans *[]interface{}) error {
*scans = append(*scans, addr.Interface())
}
case reflect.Struct:
if val.Type() == reflect.TypeOf(time.Time{}) {
*scans = append(*scans, val.Addr().Interface())
return nil
}
for j := 0; j < val.NumField(); j++ {
field := val.Field(j)
if field.Kind() == reflect.Interface {
Expand Down Expand Up @@ -106,16 +111,27 @@ func (c cursor) Scan(dest ...interface{}) error {

pbs := make(map[int]*bool)
ppbs := make(map[int]**bool)
pts := make(map[int]*time.Time)
ppts := make(map[int]**time.Time)

for i, scan := range scans {
if pb, ok := scan.(*bool); ok {
switch scan.(type) {
case *bool:
var s []uint8
scans[i] = &s
pbs[i] = pb
} else if ppb, ok := scan.(**bool); ok {
pbs[i] = scan.(*bool)
case **bool:
var s *[]uint8
scans[i] = &s
ppbs[i] = ppb
ppbs[i] = scan.(**bool)
case *time.Time:
var s string
scans[i] = &s
pts[i] = scan.(*time.Time)
case **time.Time:
var s *string
scans[i] = &s
ppts[i] = scan.(**time.Time)
}
}

Expand Down Expand Up @@ -145,6 +161,27 @@ func (c cursor) Scan(dest ...interface{}) error {
*ppb = &b
}
}
for i, pt := range pts {
if *(scans[i].(*string)) == "" {
return fmt.Errorf("field %d is null", i)
}
t, err := time.Parse("2006-01-02 15:04:05", *(scans[i].(*string)))
if err != nil {
return err
}
*pt = t
}
for i, ppt := range ppts {
if *(scans[i].(**string)) == nil {
*ppt = nil
} else {
t, err := time.Parse("2006-01-02 15:04:05", **(scans[i].(**string)))
if err != nil {
return err
}
*ppt = &t
}
}

return err
}
Expand Down
29 changes: 25 additions & 4 deletions cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"strconv"
"testing"
"time"
)

type mockDriver struct{}
Expand All @@ -30,7 +31,7 @@ type mockRows struct {
}

func (m mockRows) Columns() []string {
return []string{"a", "b", "c", "d", "e", "f", "g"}[:m.columnCount]
return []string{"a", "b", "c", "d", "e", "f", "g", "h", "j", "k", "l"}[:m.columnCount]
}

func (m mockRows) Close() error {
Expand Down Expand Up @@ -58,6 +59,14 @@ func (m *mockRows) Next(dest []driver.Value) error {
dest[i] = dest[0]
case 6:
dest[i] = nil
case 7:
dest[i] = "2023-09-06 18:37:46.828"
case 8:
dest[i] = "2023-09-06 18:37:46.828"
case 9:
dest[i] = "2023-09-06 18:37:46"
case 10:
dest[i] = "2023-09-06 18:37:46"
}
}
return nil
Expand Down Expand Up @@ -98,12 +107,20 @@ func TestCursor(t *testing.T) {
var f ****int // deep pointer
var g *int // always null

var h *time.Time
var j time.Time
var k *time.Time
var l time.Time
tmh, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828")
tmj, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828")
tmk, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46")
tml, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46")
for i := 1; i <= 10; i++ {
if !cursor.Next() {
t.Error()
}
g = &i
if err := cursor.Scan(&a, &b, &cde, &f, &g); err != nil {
if err := cursor.Scan(&a, &b, &cde, &f, &g, &h, &j, &k, &l); err != nil {
t.Errorf("%v", err)
}
if a != i ||
Expand All @@ -112,7 +129,11 @@ func TestCursor(t *testing.T) {
cde.DE.D != (i%2 == 1) ||
cde.DE.E != cde.DE.D ||
****f != i ||
g != nil {
g != nil ||
*h != tmh ||
j != tmj ||
*k != tmk ||
l != tml {
t.Error(a, b, cde.C, cde.DE.D, cde.DE.E, ****f, g)
}
if err := cursor.Scan(); err != nil {
Expand All @@ -123,7 +144,7 @@ func TestCursor(t *testing.T) {
var b ****bool
var p *string
var bs []byte
if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p); err != nil {
if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p, &h, &j, &k, &l); err != nil {
t.Error(err)
}
if ****b != (i%2 == 1) ||
Expand Down
2 changes: 1 addition & 1 deletion database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *mockConn) Begin() (driver.Tx, error) {
}

var sharedMockConn = &mockConn{
columnCount: 7,
columnCount: 11,
rowCount: 10,
}

Expand Down
45 changes: 37 additions & 8 deletions generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string,
case "float", "double", "decimal", "real":
goType = "float64"
fieldClass = "NumberField"
case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "datetime", "date", "time", "timestamp", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid":
case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "date", "time", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid":
goType = "string"
fieldClass = "StringField"
case "year":
Expand All @@ -103,6 +103,10 @@ func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string,
// TODO: Switch to specific type instead of interface.
goType = "[]interface{}"
fieldClass = "ArrayField"
case "datetime", "timestamp":
// todo import
goType = "time.Time"
fieldClass = "DateField"
case "geometry", "point", "linestring", "polygon", "multipoint", "multilinestring", "multipolygon", "geometrycollection":
goType = "sqlingo.WellKnownBinary"
fieldClass = "WellKnownBinaryField"
Expand Down Expand Up @@ -174,10 +178,38 @@ func Generate(driverName string, exampleDataSourceName string) (string, error) {
return "", errors.New("no database selected")
}

if len(options.tableNames) == 0 {
options.tableNames, err = schemaFetcher.GetTableNames()
if err != nil {
return "", err
}
}

needImportTime := false
for _, tableName := range options.tableNames {
fieldDescriptors, err := schemaFetcher.GetFieldDescriptors(tableName)
if err != nil {
return "", err
}
for _, fieldDescriptor := range fieldDescriptors {
if fieldDescriptor.Type == "datetime" || fieldDescriptor.Type == "timestamp" {
needImportTime = true
break
}
}
}

code := "// This file is generated by sqlingo (https://github.com/lqs/sqlingo)\n"
code += "// DO NOT EDIT.\n\n"
code += "package " + ensureIdentifier(dbName) + "_dsl\n"
code += "import \"github.com/lqs/sqlingo\"\n\n"
if needImportTime {
code += "import (\n"
code += "\t\"time\"\n"
code += "\t\"github.com/lqs/sqlingo\"\n"
code += ")\n\n"
} else {
code += "import \"github.com/lqs/sqlingo\"\n\n"
}

code += "type sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame uint32\n\n"

Expand Down Expand Up @@ -205,12 +237,9 @@ func Generate(driverName string, exampleDataSourceName string) (string, error) {
code += "\tsqlingo.ArrayField\n"
code += "}\n\n"

if len(options.tableNames) == 0 {
options.tableNames, err = schemaFetcher.GetTableNames()
if err != nil {
return "", err
}
}
code += "type dateField interface {\n"
code += "\tsqlingo.DateField\n"
code += "}\n\n"

var wg sync.WaitGroup

Expand Down

0 comments on commit 96ed1e5

Please sign in to comment.