Skip to content

Commit 844d7fd

Browse files
committed
Fix context cancellation racy handling
[why] Context cancellation goroutine is not in sync with Next() method lifetime. It leads to sql.ErrNoRows instead of context.Canceled often (easy to reproduce). It leads to interruption of next query executed on same connection (harder to reproduce). [how] Do query in goroutine, wait when interruption done. [testing] Add unit test that reproduces error cases.
1 parent d3c6909 commit 844d7fd

File tree

2 files changed

+136
-36
lines changed

2 files changed

+136
-36
lines changed

sqlite3.go

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ type SQLiteRows struct {
328328
decltype []string
329329
cls bool
330330
closed bool
331-
done chan struct{}
331+
ctx context.Context // no better alternative to pass context into Next() method
332332
}
333333

334334
type functionInfo struct {
@@ -1846,22 +1846,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
18461846
decltype: nil,
18471847
cls: s.cls,
18481848
closed: false,
1849-
done: make(chan struct{}),
1850-
}
1851-
1852-
if ctxdone := ctx.Done(); ctxdone != nil {
1853-
go func(db *C.sqlite3) {
1854-
select {
1855-
case <-ctxdone:
1856-
select {
1857-
case <-rows.done:
1858-
default:
1859-
C.sqlite3_interrupt(db)
1860-
rows.Close()
1861-
}
1862-
case <-rows.done:
1863-
}
1864-
}(s.c.db)
1849+
ctx: ctx,
18651850
}
18661851

18671852
return rows, nil
@@ -1890,28 +1875,40 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
18901875
}
18911876

18921877
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
1878+
if ctx.Done() == nil {
1879+
return s.execSync(args)
1880+
}
1881+
1882+
type result struct {
1883+
r driver.Result
1884+
err error
1885+
}
1886+
resultCh := make(chan result)
1887+
go func() {
1888+
r, err := s.execSync(args)
1889+
resultCh <- result{r, err}
1890+
}()
1891+
select {
1892+
case rv := <- resultCh:
1893+
return rv.r, rv.err
1894+
case <-ctx.Done():
1895+
select {
1896+
case <-resultCh: // no need to interrupt
1897+
default:
1898+
C.sqlite3_interrupt(s.c.db)
1899+
<-resultCh // ensure goroutine completed
1900+
}
1901+
return nil, ctx.Err()
1902+
}
1903+
}
1904+
1905+
func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
18931906
if err := s.bind(args); err != nil {
18941907
C.sqlite3_reset(s.s)
18951908
C.sqlite3_clear_bindings(s.s)
18961909
return nil, err
18971910
}
18981911

1899-
if ctxdone := ctx.Done(); ctxdone != nil {
1900-
done := make(chan struct{})
1901-
defer close(done)
1902-
go func(db *C.sqlite3) {
1903-
select {
1904-
case <-done:
1905-
case <-ctxdone:
1906-
select {
1907-
case <-done:
1908-
default:
1909-
C.sqlite3_interrupt(db)
1910-
}
1911-
}
1912-
}(s.c.db)
1913-
}
1914-
19151912
var rowid, changes C.longlong
19161913
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
19171914
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -1932,9 +1929,6 @@ func (rc *SQLiteRows) Close() error {
19321929
return nil
19331930
}
19341931
rc.closed = true
1935-
if rc.done != nil {
1936-
close(rc.done)
1937-
}
19381932
if rc.cls {
19391933
rc.s.mu.Unlock()
19401934
return rc.s.Close()
@@ -1980,8 +1974,31 @@ func (rc *SQLiteRows) DeclTypes() []string {
19801974

19811975
// Next move cursor to next.
19821976
func (rc *SQLiteRows) Next(dest []driver.Value) error {
1977+
if rc.ctx.Done() == nil {
1978+
return rc.nextSync(dest)
1979+
}
1980+
resultCh := make(chan error)
1981+
go func() {
1982+
resultCh <- rc.nextSync(dest)
1983+
}()
1984+
select {
1985+
case err := <- resultCh:
1986+
return err
1987+
case <-rc.ctx.Done():
1988+
select {
1989+
case <-resultCh: // no need to interrupt
1990+
default:
1991+
C.sqlite3_interrupt(rc.s.c.db)
1992+
<-resultCh // ensure goroutine completed
1993+
}
1994+
return rc.ctx.Err()
1995+
}
1996+
}
1997+
1998+
func (rc *SQLiteRows) nextSync(dest []driver.Value) error {
19831999
rc.s.mu.Lock()
19842000
defer rc.s.mu.Unlock()
2001+
19852002
if rc.s.closed {
19862003
return io.EOF
19872004
}

sqlite3_go18_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"io/ioutil"
1515
"math/rand"
1616
"os"
17+
"sync"
1718
"testing"
1819
"time"
1920
)
@@ -135,6 +136,88 @@ func TestShortTimeout(t *testing.T) {
135136
}
136137
}
137138

139+
func TestQueryRowContextCancel(t *testing.T) {
140+
srcTempFilename := TempFilename(t)
141+
defer os.Remove(srcTempFilename)
142+
143+
db, err := sql.Open("sqlite3", srcTempFilename)
144+
if err != nil {
145+
t.Fatal(err)
146+
}
147+
defer db.Close()
148+
initDatabase(t, db, 100)
149+
150+
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
151+
var keyID string
152+
unexpectedErrors := make(map[string]int)
153+
for i := 0; i < 10000; i++ {
154+
ctx, cancel := context.WithCancel(context.Background())
155+
row := db.QueryRowContext(ctx, query)
156+
157+
cancel()
158+
// it is fine to get "nil" as context cancellation can be handled with delay
159+
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
160+
unexpectedErrors[err.Error()]++
161+
}
162+
}
163+
for errText, count := range unexpectedErrors {
164+
t.Error(errText, count)
165+
}
166+
}
167+
168+
func TestQueryRowContextCancelParallel(t *testing.T) {
169+
srcTempFilename := TempFilename(t)
170+
defer os.Remove(srcTempFilename)
171+
172+
db, err := sql.Open("sqlite3", srcTempFilename)
173+
if err != nil {
174+
t.Fatal(err)
175+
}
176+
db.SetMaxOpenConns(10)
177+
db.SetMaxIdleConns(5)
178+
179+
defer db.Close()
180+
initDatabase(t, db, 100)
181+
182+
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
183+
wg := sync.WaitGroup{}
184+
defer wg.Wait()
185+
186+
testCtx, cancel := context.WithCancel(context.Background())
187+
defer cancel()
188+
189+
for i := 0; i < 50; i++ {
190+
wg.Add(1)
191+
go func() {
192+
defer wg.Done()
193+
194+
var keyID string
195+
for {
196+
select {
197+
case <-testCtx.Done():
198+
return
199+
default:
200+
}
201+
ctx, cancel := context.WithCancel(context.Background())
202+
row := db.QueryRowContext(ctx, query)
203+
204+
cancel()
205+
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
206+
}
207+
}()
208+
}
209+
210+
var keyID string
211+
for i := 0; i < 10000; i++ {
212+
// NOTE: testCtx is not cancelled during query execution
213+
row := db.QueryRowContext(testCtx, query)
214+
215+
if err := row.Scan(&keyID); err != nil {
216+
t.Fatal(i, err)
217+
}
218+
}
219+
}
220+
138221
func TestExecCancel(t *testing.T) {
139222
db, err := sql.Open("sqlite3", ":memory:")
140223
if err != nil {

0 commit comments

Comments
 (0)