From 57fe2b77e40e416d79a7135af8b0c35d06a646c5 Mon Sep 17 00:00:00 2001 From: Phil Eaton Date: Fri, 13 May 2022 19:49:23 +0000 Subject: [PATCH 1/4] Support returning any from callbacks Add test Fix test --- callback.go | 14 ++++++++++++++ callback_test.go | 12 ++++++++++++ 2 files changed, 26 insertions(+) diff --git a/callback.go b/callback.go index b020fe37..808160d4 100644 --- a/callback.go +++ b/callback.go @@ -353,6 +353,15 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error { return nil } +func callbackRetAny(ctx *C.sqlite3_context, v reflect.Value) error { + cb, err := callbackRet(v.Elem().Type()) + if err != nil { + return err + } + + return cb(ctx, v.Elem()) +} + func callbackRet(typ reflect.Type) (callbackRetConverter, error) { switch typ.Kind() { case reflect.Interface: @@ -360,6 +369,11 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) { if typ.Implements(errorInterface) { return callbackRetNil, nil } + + if typ.NumMethod() == 0 { + return callbackRetAny, nil + } + fallthrough case reflect.Slice: if typ.Elem().Kind() != reflect.Uint8 { diff --git a/callback_test.go b/callback_test.go index 714ed607..b09122ae 100644 --- a/callback_test.go +++ b/callback_test.go @@ -102,3 +102,15 @@ func TestCallbackConverters(t *testing.T) { } } } + +func TestCallbackReturnAny(t *testing.T) { + udf := func() interface{} { + return 1 + } + + typ := reflect.TypeOf(udf) + _, err := callbackRet(typ.Out(0)) + if err != nil { + t.Errorf("Expected valid callback for any return type, got: %s", err) + } +} From 46e57d52dc00fb4f7a57300ebdea4ac13062498b Mon Sep 17 00:00:00 2001 From: Phil Eaton Date: Fri, 13 May 2022 21:08:35 +0000 Subject: [PATCH 2/4] Rename to generic --- callback.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callback.go b/callback.go index 808160d4..7db9af0d 100644 --- a/callback.go +++ b/callback.go @@ -353,7 +353,7 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error { return nil } -func callbackRetAny(ctx *C.sqlite3_context, v reflect.Value) error { +func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error { cb, err := callbackRet(v.Elem().Type()) if err != nil { return err @@ -371,7 +371,7 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) { } if typ.NumMethod() == 0 { - return callbackRetAny, nil + return callbackRetGeneric, nil } fallthrough From 6a893017960af965a6ad3bca7c975e18a3513489 Mon Sep 17 00:00:00 2001 From: Phil Eaton Date: Fri, 13 May 2022 21:09:23 +0000 Subject: [PATCH 3/4] Add generic return test --- sqlite3_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/sqlite3_test.go b/sqlite3_test.go index c86aba4b..9ee87e7e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1449,6 +1449,63 @@ func TestAggregatorRegistration(t *testing.T) { } } +type mode struct { + counts map[interface{}]int + top interface{} + topCount int +} + +func newMode() *mode { + return &mode{ + counts: map[interface{}]int{}, + } +} + +func (m *mode) Step(x interface{}) { + m.counts[x]++ + c := m.counts[x] + if c > m.topCount { + m.top = x + m.topCount = c + } +} + +func (m *mode) Done() interface{} { + return m.top +} + +func TestAggregatorRegistration_GenericReturn(t *testing.T) { + sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return conn.RegisterAggregator("mode", newMode, true) + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + var mode int + err = db.QueryRow("select mode(profits) from foo").Scan(&mode) + if err != nil { + t.Fatal("MODE query error:", err) + } + + if mode != 20 { + t.Fatal("Got incorrect mode. Wanted 20, got: ", mode) + } +} + func rot13(r rune) rune { switch { case r >= 'A' && r <= 'Z': From 12637a65d5d7eb85f8420dd735745430c26e96c2 Mon Sep 17 00:00:00 2001 From: Phil Eaton Date: Fri, 13 May 2022 21:32:03 +0000 Subject: [PATCH 4/4] Check isnil --- callback.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/callback.go b/callback.go index 7db9af0d..d3056910 100644 --- a/callback.go +++ b/callback.go @@ -354,7 +354,12 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error { } func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error { - cb, err := callbackRet(v.Elem().Type()) + if v.IsNil() { + C.sqlite3_result_null(ctx) + return nil + } + + cb, err := callbackRet(v.Elem().Type()) if err != nil { return err }