Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Demouth <yuya at demouth.net>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me>
Dmitry Zenovich <dzenovich at gmail.com>
Egor Smolyakov <egorsmkv at gmail.com>
Erwan Martin <hello at erwan.io>
Evan Elias <evan at skeema.net>
Expand Down
20 changes: 19 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,10 @@ func (mc *mysqlConn) finish() {

// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
return mc.sendSimpleCommandOK(ctx, comPing)
}

func (mc *mysqlConn) sendSimpleCommandOK(ctx context.Context, cmd byte) (err error) {
if mc.closed.Load() {
return driver.ErrBadConn
}
Expand All @@ -513,7 +517,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
defer mc.finish()

handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
if err = mc.writeCommandPacket(cmd); err != nil {
return mc.markBadConn(err)
}

Expand Down Expand Up @@ -681,6 +685,20 @@ func (mc *mysqlConn) startWatcher() {
}()
}

// Reset resets the server-side session state using COM_RESET_CONNECTION.
// It clears most per-session state (e.g., user variables, prepared statements)
// without re-authenticating.
// Usage hint: call via database/sql.Conn.Raw using a method assertion:
// conn.Raw(func(c any) error {
// if r, ok := c.(interface{ Reset(context.Context) error }); ok {
// return r.Reset(ctx)
// }
// return nil
// })
func (mc *mysqlConn) Reset(ctx context.Context) (err error) {
return mc.sendSimpleCommandOK(ctx, comResetConnection)
}

func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
Expand Down
137 changes: 85 additions & 52 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,65 +129,98 @@ func TestCheckNamedValue(t *testing.T) {
}
}

// TestCleanCancel tests passed context is cancelled at start.
// TestSimpleCommandOKCleanCancel tests passed context is cancelled at start.
// No packet should be sent. Connection should keep current status.
func TestCleanCancel(t *testing.T) {
mc := &mysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
defer mc.cleanup()

ctx, cancel := context.WithCancel(context.Background())
cancel()

for range 3 { // Repeat same behavior
err := mc.Ping(ctx)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %#v", err)
}

if mc.closed.Load() {
t.Error("expected mc is not closed, closed actually")
}

if mc.watching {
t.Error("expected watching is false, but true")
}
func TestSimpleCommandOKCleanCancel(t *testing.T) {
for _, test := range []struct {
name string
funcToCall func(ctx context.Context, mc *mysqlConn) error
} {
{name: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Ping(ctx) }},
{name: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Reset(ctx) }},
} {
test := test
t.Run(test.name, func(t *testing.T) {
mc := &mysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
defer mc.cleanup()

ctx, cancel := context.WithCancel(context.Background())
cancel()

for range 3 { // Repeat same behavior
err := test.funcToCall(ctx, mc)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %#v", err)
}

if mc.closed.Load() {
t.Error("expected mc is not closed, closed actually")
}

if mc.watching {
t.Error("expected watching is false, but true")
}
}
})
}
}

func TestPingMarkBadConnection(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := mc.Ping(context.Background())

if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
func TestSimpleCommandOKMarkBadConnection(t *testing.T) {
for _, test := range []struct {
name string
funcToCall func(mc *mysqlConn) error
} {
{name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }},
{name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }},
} {
test := test
t.Run(test.name, func(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := test.funcToCall(mc)

if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
}
})
}
}

func TestPingErrInvalidConn(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := mc.Ping(context.Background())

if err != nc.err {
t.Errorf("expected %#v, got %#v", nc.err, err)
func TestSimpleCommandOKErrInvalidConn(t *testing.T) {
for _, test := range []struct {
name string
funcToCall func(mc *mysqlConn) error
} {
{name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }},
{name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }},
} {
test := test
t.Run(test.name, func(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := test.funcToCall(mc)

if err != nc.err {
t.Errorf("expected %#v, got %#v", nc.err, err)
}
})
}
}

Expand Down
3 changes: 3 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ const (
comStmtReset
comSetOption
comStmtFetch
comDaemon
comBinlogDumpGTID
comResetConnection
)

// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
Expand Down
118 changes: 97 additions & 21 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2322,48 +2322,124 @@ func TestRejectReadOnly(t *testing.T) {
}

func TestPing(t *testing.T) {
ctx := context.Background()
runTests(t, dsn, func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
dbt.fail("Ping", "Ping", err)
}
})
}

func TestSimpleCommandOK(t *testing.T) {
ctx := context.Background()
for _, test := range []struct{
method string
query string
funcToCall func(ctx context.Context, mc *mysqlConn) error
} {
{method: "Pinger", query: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Ping(ctx)}},
{method: "Conn", query: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Reset(ctx)}},
} {
test := test
t.Run(test.method+"_"+test.query, func(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
conn, err := dbt.db.Conn(ctx)
if err != nil {
dbt.fail("db", "Conn", err)
}
defer conn.Close()

// Check that affectedRows and insertIds are cleared after each call.
conn.Raw(func(conn any) error {
c := conn.(*mysqlConn)

// Issue a query that sets affectedRows and insertIds.
q, err := c.Query(`SELECT 1`, nil)
if err != nil {
dbt.fail("Conn", "Query", err)
}
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
}
q.Close()

// Verify that Ping()/Reset() clears both fields.
for range 2 {
if err := test.funcToCall(ctx, c); err != nil {
dbt.fail(test.method, test.query, err)
}
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad insertIds: got %v, want=%v", got, want)
}
}
return nil
})
})
})
}
}

func TestReset(t *testing.T) {
ctx := context.Background()
runTests(t, dsn, func(dbt *DBTest) {
conn, err := dbt.db.Conn(ctx)
if err != nil {
dbt.fail("db", "Conn", err)
}
defer conn.Close()

// Check that affectedRows and insertIds are cleared after each call.
// Verify that COM_RESET_CONNECTION clears session state (e.g., user variables).
conn.Raw(func(conn any) error {
c := conn.(*mysqlConn)

// Issue a query that sets affectedRows and insertIds.
q, err := c.Query(`SELECT 1`, nil)
_, err = c.ExecContext(ctx, "SET @a := 1", nil)
if err != nil {
dbt.fail("Conn", "Query", err)
dbt.fail("Conn", "ExecContext", err)
}
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
var rows driver.Rows
rows, err = c.QueryContext(ctx, "SELECT @a", nil)
if err != nil {
dbt.fail("Conn", "QueryContext", err)
}
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
result := []driver.Value{0}
err = rows.Next(result)
if err != nil {
dbt.fail("Rows", "Next", err)
}
err = rows.Close()
if err != nil {
dbt.fail("Rows", "Close", err)
}
if !reflect.DeepEqual([]driver.Value{int64(1)}, result) {
dbt.Fatalf("failed to set @a to 1 with SET: got %v, want=%v", result, []driver.Value{int64(1)})
}
q.Close()

// Verify that Ping() clears both fields.
for range 2 {
if err := c.Ping(ctx); err != nil {
dbt.fail("Pinger", "Ping", err)
}
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
err = c.Reset(ctx)
if err != nil {
dbt.fail("Conn", "Reset", err)
}

rows, err = c.QueryContext(ctx, "SELECT @a", nil)
if err != nil {
dbt.fail("Conn", "QueryContext", err)
}
err = rows.Next(result)
if err != nil {
dbt.fail("Rows", "Next", err)
}
err = rows.Close()
if err != nil {
dbt.fail("Rows", "Close", err)
}
if !reflect.DeepEqual([]driver.Value{nil}, result) {
dbt.Fatalf("Reset did not reset the session (@a is still set): got %v, want=%v", result, []driver.Value{nil})
}

return nil
})
})
Expand Down