diff --git a/sqlite3.go b/sqlite3.go index e6b8c166..281bd485 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -2182,26 +2182,90 @@ func (s *SQLiteStmt) NumInput() int { var placeHolder = []byte{0} +func hasNamedArgs(args []driver.NamedValue) bool { + for _, v := range args { + if v.Name != "" { + return true + } + } + return false +} + func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { return s.c.lastError() } + if hasNamedArgs(args) { + return s.bindIndices(args) + } + + for _, arg := range args { + n := C.int(arg.Ordinal) + switch v := arg.Value.(type) { + case nil: + rv = C.sqlite3_bind_null(s.s, n) + case string: + p := stringData(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v))) + case int64: + rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) + case bool: + val := 0 + if v { + val = 1 + } + rv = C.sqlite3_bind_int(s.s, n, C.int(val)) + case float64: + rv = C.sqlite3_bind_double(s.s, n, C.double(v)) + case []byte: + if v == nil { + rv = C.sqlite3_bind_null(s.s, n) + } else { + ln := len(v) + if ln == 0 { + v = placeHolder + } + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) + } + case time.Time: + ts := v.Format(SQLiteTimestampFormats[0]) + p := stringData(ts) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts))) + } + if rv != C.SQLITE_OK { + return s.c.lastError() + } + } + return nil +} + +func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error { + // Find the longest named parameter name. + n := 0 + for _, v := range args { + if m := len(v.Name); m > n { + n = m + } + } + buf := make([]byte, 0, n+2) // +2 for placeholder and null terminator + bindIndices := make([][3]int, len(args)) - prefixes := []string{":", "@", "$"} for i, v := range args { bindIndices[i][0] = args[i].Ordinal if v.Name != "" { - for j := range prefixes { - cname := C.CString(prefixes[j] + v.Name) - bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) - C.free(unsafe.Pointer(cname)) + for j, c := range []byte{':', '@', '$'} { + buf = append(buf[:0], c) + buf = append(buf, v.Name...) + buf = append(buf, 0) + bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, (*C.char)(unsafe.Pointer(&buf[0])))) } args[i].Ordinal = bindIndices[i][0] } } + var rv C.int for i, arg := range args { for j := range bindIndices[i] { if bindIndices[i][j] == 0 { @@ -2212,20 +2276,16 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { case nil: rv = C.sqlite3_bind_null(s.s, n) case string: - if len(v) == 0 { - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) - } else { - b := []byte(v) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) - } + p := stringData(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v))) case int64: rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) case bool: + val := 0 if v { - rv = C.sqlite3_bind_int(s.s, n, 1) - } else { - rv = C.sqlite3_bind_int(s.s, n, 0) + val = 1 } + rv = C.sqlite3_bind_int(s.s, n, C.int(val)) case float64: rv = C.sqlite3_bind_double(s.s, n, C.double(v)) case []byte: @@ -2239,8 +2299,9 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) } case time.Time: - b := []byte(v.Format(SQLiteTimestampFormats[0])) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + ts := v.Format(SQLiteTimestampFormats[0]) + p := stringData(ts) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts))) } if rv != C.SQLITE_OK { return s.c.lastError()