diff --git a/sqlite3.go b/sqlite3.go index cbde900c..161eb220 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -400,14 +400,18 @@ func (c *SQLiteConn) AutoCommit() bool { } func (c *SQLiteConn) lastError() error { - rv := C.sqlite3_errcode(c.db) + return lastError(c.db) +} + +func lastError(db *C.sqlite3) error { + rv := C.sqlite3_errcode(db) if rv == C.SQLITE_OK { return nil } return Error{ Code: ErrNo(rv), - ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)), - err: C.GoString(C.sqlite3_errmsg(c.db)), + ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)), + err: C.GoString(C.sqlite3_errmsg(db)), } } @@ -537,6 +541,8 @@ func errorString(err Error) string { // _txlock=XXX // Specify locking behavior for transactions. XXX can be "immediate", // "deferred", "exclusive". +// _foreign_keys=X +// Enable or disable enforcement of foreign keys. X can be 1 or 0. func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if C.sqlite3_threadsafe() == 0 { return nil, errors.New("sqlite library was not compiled for thread-safe operation") @@ -545,6 +551,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { var loc *time.Location txlock := "BEGIN" busyTimeout := 5000 + foreignKeys := -1 pos := strings.IndexRune(dsn, '?') if pos >= 1 { params, err := url.ParseQuery(dsn[pos+1:]) @@ -587,6 +594,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { } } + // _foreign_keys + if val := params.Get("_foreign_keys"); val != "" { + switch val { + case "1": + foreignKeys = 1 + case "0": + foreignKeys = 0 + default: + return nil, fmt.Errorf("Invalid _foreign_keys: %v", val) + } + } + if !strings.HasPrefix(dsn, "file:") { dsn = dsn[:pos] } @@ -612,6 +631,27 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, Error{Code: ErrNo(rv)} } + exec := func(s string) error { + cs := C.CString(s) + rv := C.sqlite3_exec(db, cs, nil, nil, nil) + C.free(unsafe.Pointer(cs)) + if rv != C.SQLITE_OK { + return lastError(db) + } + return nil + } + if foreignKeys == 0 { + if err := exec("PRAGMA foreign_keys = OFF;"); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + } else if foreignKeys == 1 { + if err := exec("PRAGMA foreign_keys = ON;"); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + } + conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} if len(d.Extensions) > 0 { diff --git a/sqlite3_test.go b/sqlite3_test.go index e844f82e..03b678d3 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -107,6 +107,35 @@ func TestReadonly(t *testing.T) { } } +func TestForeignKeys(t *testing.T) { + cases := map[string]bool{ + "?_foreign_keys=1": true, + "?_foreign_keys=0": false, + } + for option, want := range cases { + fname := TempFilename(t) + uri := "file:" + fname + option + db, err := sql.Open("sqlite3", uri) + if err != nil { + os.Remove(fname) + t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) + continue + } + var enabled bool + err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled) + db.Close() + os.Remove(fname) + if err != nil { + t.Errorf("query foreign_keys for %s: %v", uri, err) + continue + } + if enabled != want { + t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want) + continue + } + } +} + func TestClose(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename)