Skip to content

Add _foreign_keys connection parameter #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,18 @@ func (c *SQLiteConn) AutoCommit() bool {
}

func (c *SQLiteConn) lastError() error {
rv := C.sqlite3_errcode(c.db)
return lastError(c.db)
}

func lastError(db *C.sqlite3) error {
rv := C.sqlite3_errcode(db)
if rv == C.SQLITE_OK {
return nil
}
return Error{
Code: ErrNo(rv),
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
err: C.GoString(C.sqlite3_errmsg(c.db)),
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)),
err: C.GoString(C.sqlite3_errmsg(db)),
}
}

Expand Down Expand Up @@ -537,6 +541,8 @@ func errorString(err Error) string {
// _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive".
// _foreign_keys=X
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
Expand All @@ -545,6 +551,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
var loc *time.Location
txlock := "BEGIN"
busyTimeout := 5000
foreignKeys := -1
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
Expand Down Expand Up @@ -587,6 +594,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
}
}

// _foreign_keys
if val := params.Get("_foreign_keys"); val != "" {
switch val {
case "1":
foreignKeys = 1
case "0":
foreignKeys = 0
default:
return nil, fmt.Errorf("Invalid _foreign_keys: %v", val)
}
}

if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
Expand All @@ -612,6 +631,27 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, Error{Code: ErrNo(rv)}
}

exec := func(s string) error {
cs := C.CString(s)
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
C.free(unsafe.Pointer(cs))
if rv != C.SQLITE_OK {
return lastError(db)
}
return nil
}
if foreignKeys == 0 {
if err := exec("PRAGMA foreign_keys = OFF;"); err != nil {
C.sqlite3_close_v2(db)
return nil, err
}
} else if foreignKeys == 1 {
if err := exec("PRAGMA foreign_keys = ON;"); err != nil {
C.sqlite3_close_v2(db)
return nil, err
}
}

conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}

if len(d.Extensions) > 0 {
Expand Down
29 changes: 29 additions & 0 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,35 @@ func TestReadonly(t *testing.T) {
}
}

func TestForeignKeys(t *testing.T) {
cases := map[string]bool{
"?_foreign_keys=1": true,
"?_foreign_keys=0": false,
}
for option, want := range cases {
fname := TempFilename(t)
uri := "file:" + fname + option
db, err := sql.Open("sqlite3", uri)
if err != nil {
os.Remove(fname)
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
continue
}
var enabled bool
err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
db.Close()
os.Remove(fname)
if err != nil {
t.Errorf("query foreign_keys for %s: %v", uri, err)
continue
}
if enabled != want {
t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
continue
}
}
}

func TestClose(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
Expand Down