Skip to content

Implement QueryRow and Exec methods of sql driver interface #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions chdb/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,60 @@ func init() {
sql.Register("chdb", Driver{})
}

// Row is the result of calling [DB.QueryRow] to select a single row.
type singleRow struct {
// One of these two will be non-nil:
err error // deferred error for easy chaining
rows driver.Rows
}

// Scan copies the columns from the matched row into the values
// pointed at by dest. See the documentation on [Rows.Scan] for details.
// If more than one row matches the query,
// Scan uses the first row and discards the rest. If no row matches
// the query, Scan returns [ErrNoRows].
func (r *singleRow) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
vals := make([]driver.Value, 0)
for _, v := range dest {
vals = append(vals, v)
}
err := r.rows.Next(vals)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return r.rows.Close()
}

// Err provides a way for wrapping packages to check for
// query errors without calling [Row.Scan].
// Err returns the error, if any, that was encountered while running the query.
// If this error is not nil, this error will also be returned from [Row.Scan].
func (r *singleRow) Err() error {
return r.err
}

type execResult struct {
err error
}

func (e *execResult) LastInsertId() (int64, error) {
if e.err != nil {
return 0, e.err
}
return -1, fmt.Errorf("does not support LastInsertId")

}
func (e *execResult) RowsAffected() (int64, error) {
if e.err != nil {
return 0, e.err
}
return -1, fmt.Errorf("does not support RowsAffected")
}

type queryHandle func(string, ...string) (*chdbstable.LocalResult, error)

type connector struct {
Expand Down Expand Up @@ -192,6 +246,18 @@ type conn struct {
QueryFun queryHandle
}

func prepareValues(values []driver.Value) []driver.NamedValue {
namedValues := make([]driver.NamedValue, len(values))
for i, value := range values {
namedValues[i] = driver.NamedValue{
// nb: Name field is optional
Ordinal: i,
Value: value,
}
}
return namedValues
}

func (c *conn) Close() error {
return nil
}
Expand All @@ -204,15 +270,39 @@ func (c *conn) SetupQueryFun() {
}

func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) {
namedValues := make([]driver.NamedValue, len(values))
for i, value := range values {
namedValues[i] = driver.NamedValue{
// nb: Name field is optional
Ordinal: i,
Value: value,
return c.QueryContext(context.Background(), query, prepareValues(values))
}

func (c *conn) QueryRow(query string, values []driver.Value) *singleRow {
return c.QueryRowContext(context.Background(), query, values)
}

func (c *conn) Exec(query string, values []driver.Value) (sql.Result, error) {
return c.ExecContext(context.Background(), query, prepareValues(values))
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
_, err := c.QueryContext(ctx, query, args)
if err != nil && err.Error() != "result is nil" {
return nil, err
}
return &execResult{
err: nil,
}, nil
}

func (c *conn) QueryRowContext(ctx context.Context, query string, values []driver.Value) *singleRow {

v, err := c.QueryContext(ctx, query, prepareValues(values))
if err != nil {
return &singleRow{
err: err,
rows: nil,
}
}
return c.QueryContext(context.Background(), query, namedValues)
return &singleRow{
rows: v,
}
}

func (c *conn) compileArguments(query string, args []driver.NamedValue) (string, error) {
Expand Down
90 changes: 90 additions & 0 deletions chdb/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,93 @@ func TestDbWithSession(t *testing.T) {
count++
}
}

func TestQueryRow(t *testing.T) {
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
if err != nil {
t.Fatalf("create temp directory fail, err: %s", err)
}
defer os.RemoveAll(sessionDir)
session, err := chdb.NewSession(sessionDir)
if err != nil {
t.Fatalf("new session fail, err: %s", err)
}
defer session.Cleanup()

session.Query("USE testdb; INSERT INTO testtable VALUES (1), (2), (3);")

ret, err := session.Query("SELECT * FROM testtable;")
if err != nil {
t.Fatalf("Query fail, err: %s", err)
}
if string(ret.Buf()) != "1\n2\n3\n" {
t.Errorf("Query result should be 1\n2\n3\n, got %s", string(ret.Buf()))
}
db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir))
if err != nil {
t.Fatalf("open db fail, err: %s", err)
}
if db.Ping() != nil {
t.Fatalf("ping db fail, err: %s", err)
}
rows := db.QueryRow("select * from testtable;")

var bar = 0
var count = 1
err = rows.Scan(&bar)
if err != nil {
t.Fatalf("scan fail, err: %s", err)
}
if bar != count {
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
}
err2 := rows.Scan(&bar)
if err2 == nil {
t.Fatalf("QueryRow method should return only one item")
}

}

func TestExec(t *testing.T) {
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
if err != nil {
t.Fatalf("create temp directory fail, err: %s", err)
}
defer os.RemoveAll(sessionDir)
session, err := chdb.NewSession(sessionDir)
if err != nil {
t.Fatalf("new session fail, err: %s", err)
}
defer session.Cleanup()
session.Query("CREATE DATABASE IF NOT EXISTS testdb; " +
"CREATE TABLE IF NOT EXISTS testdb.testtable (id UInt32) ENGINE = MergeTree() ORDER BY id;")

db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir))
if err != nil {
t.Fatalf("open db fail, err: %s", err)
}
if db.Ping() != nil {
t.Fatalf("ping db fail, err: %s", err)
}

_, err = db.Exec("INSERT INTO testdb.testtable VALUES (1), (2), (3);")
if err != nil {
t.Fatalf("exec failed, err: %s", err)
}
rows := db.QueryRow("select * from testdb.testtable;")

var bar = 0
var count = 1
err = rows.Scan(&bar)
if err != nil {
t.Fatalf("scan fail, err: %s", err)
}
if bar != count {
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
}
err2 := rows.Scan(&bar)
if err2 == nil {
t.Fatalf("QueryRow method should return only one item")
}

}