@@ -9,12 +9,61 @@ package sqlite3
9
9
10
10
import (
11
11
"database/sql"
12
+ "fmt"
12
13
"reflect"
13
14
"regexp"
14
15
"strings"
15
16
"testing"
16
17
)
17
18
19
+ func TestInvalidFunctionRegistration (t * testing.T ) {
20
+ afn := "func"
21
+ zeroArgsFn := func (a bool ) {}
22
+ nonErrorArgsFn := func (a bool ) (int , int ) { return 0 , 0 }
23
+
24
+ sql .Register (fmt .Sprintf ("sqlite3-%s-afn" , t .Name ()), & SQLiteDriver {
25
+ ConnectHook : func (conn * SQLiteConn ) error {
26
+ if err := conn .RegisterFunc ("afn" , afn , true ); err != nil {
27
+ return err
28
+ }
29
+
30
+ return nil
31
+ },
32
+ })
33
+
34
+ sql .Register (fmt .Sprintf ("sqlite3-%s-zeroArgsFn" , t .Name ()), & SQLiteDriver {
35
+ ConnectHook : func (conn * SQLiteConn ) error {
36
+ if err := conn .RegisterFunc ("zeroArgsFn" , zeroArgsFn , true ); err != nil {
37
+ return err
38
+ }
39
+
40
+ return nil
41
+ },
42
+ })
43
+
44
+ sql .Register (fmt .Sprintf ("sqlite3-%s-nonErrorArgsFn" , t .Name ()), & SQLiteDriver {
45
+ ConnectHook : func (conn * SQLiteConn ) error {
46
+ if err := conn .RegisterFunc ("nonErrorArgsFn" , nonErrorArgsFn , true ); err != nil {
47
+ return err
48
+ }
49
+
50
+ return nil
51
+ },
52
+ })
53
+
54
+ for _ , s := range []string {"sqlite3-%s-afn" , "sqlite3-%s-zeroArgsFn" , "sqlite3-%s-nonErrorArgsFn" } {
55
+ db , err := sql .Open (fmt .Sprintf (s , t .Name ()), ":memory:" )
56
+ if err != nil {
57
+ t .Fatal ("Failed to open database:" , err )
58
+ }
59
+ defer db .Close ()
60
+
61
+ if err := db .Ping (); err == nil {
62
+ t .Fatal ("Expected error from RegisterFunc" )
63
+ }
64
+ }
65
+ }
66
+
18
67
func TestFunctionRegistration (t * testing.T ) {
19
68
addi8_16_32 := func (a int8 , b int16 ) int32 { return int32 (a ) + int32 (b ) }
20
69
addi64 := func (a , b int64 ) int64 { return a + b }
0 commit comments