Skip to content

Commit d6400de

Browse files
committed
introduce a new method mysqlConn.Reset() allowing to reset the mysql connection (implemented as "return mc.sendNoArgsCommandWithResultOK(ctx, comConnReset)" where comConnReset=31); add tests
1 parent a3f2c66 commit d6400de

File tree

4 files changed

+185
-72
lines changed

4 files changed

+185
-72
lines changed

connection.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,11 @@ func (mc *mysqlConn) startWatcher() {
685685
}()
686686
}
687687

688+
// Reset resets the MySQL connection.
689+
func (mc *mysqlConn) Reset(ctx context.Context) (err error) {
690+
return mc.sendNoArgsCommandWithResultOK(ctx, comConnReset)
691+
}
692+
688693
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
689694
nv.Value, err = converter{}.ConvertValue(nv.Value)
690695
return

connection_test.go

Lines changed: 85 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -129,65 +129,98 @@ func TestCheckNamedValue(t *testing.T) {
129129
}
130130
}
131131

132-
// TestCleanCancel tests passed context is cancelled at start.
132+
// TestNoArgsCommandCleanCancel tests passed context is cancelled at start.
133133
// No packet should be sent. Connection should keep current status.
134-
func TestCleanCancel(t *testing.T) {
135-
mc := &mysqlConn{
136-
closech: make(chan struct{}),
137-
}
138-
mc.startWatcher()
139-
defer mc.cleanup()
140-
141-
ctx, cancel := context.WithCancel(context.Background())
142-
cancel()
143-
144-
for range 3 { // Repeat same behavior
145-
err := mc.Ping(ctx)
146-
if err != context.Canceled {
147-
t.Errorf("expected context.Canceled, got %#v", err)
148-
}
149-
150-
if mc.closed.Load() {
151-
t.Error("expected mc is not closed, closed actually")
152-
}
153-
154-
if mc.watching {
155-
t.Error("expected watching is false, but true")
156-
}
134+
func TestNoArgsCommandCleanCancel(t *testing.T) {
135+
for _, test := range []struct {
136+
name string
137+
funcToCall func(ctx context.Context, mc *mysqlConn) error
138+
} {
139+
{name: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Ping(ctx) }},
140+
{name: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Reset(ctx) }},
141+
} {
142+
test := test
143+
t.Run(test.name, func(t *testing.T) {
144+
mc := &mysqlConn{
145+
closech: make(chan struct{}),
146+
}
147+
mc.startWatcher()
148+
defer mc.cleanup()
149+
150+
ctx, cancel := context.WithCancel(context.Background())
151+
cancel()
152+
153+
for range 3 { // Repeat same behavior
154+
err := test.funcToCall(ctx, mc)
155+
if err != context.Canceled {
156+
t.Errorf("expected context.Canceled, got %#v", err)
157+
}
158+
159+
if mc.closed.Load() {
160+
t.Error("expected mc is not closed, closed actually")
161+
}
162+
163+
if mc.watching {
164+
t.Error("expected watching is false, but true")
165+
}
166+
}
167+
})
157168
}
158169
}
159170

160-
func TestPingMarkBadConnection(t *testing.T) {
161-
nc := badConnection{err: errors.New("boom")}
162-
mc := &mysqlConn{
163-
netConn: nc,
164-
buf: newBuffer(),
165-
maxAllowedPacket: defaultMaxAllowedPacket,
166-
closech: make(chan struct{}),
167-
cfg: NewConfig(),
168-
}
169-
170-
err := mc.Ping(context.Background())
171-
172-
if err != driver.ErrBadConn {
173-
t.Errorf("expected driver.ErrBadConn, got %#v", err)
171+
func TestNoArgsCommandMarkBadConnection(t *testing.T) {
172+
for _, test := range []struct {
173+
name string
174+
funcToCall func(mc *mysqlConn) error
175+
} {
176+
{name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }},
177+
{name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }},
178+
} {
179+
test := test
180+
t.Run(test.name, func(t *testing.T) {
181+
nc := badConnection{err: errors.New("boom")}
182+
mc := &mysqlConn{
183+
netConn: nc,
184+
buf: newBuffer(),
185+
maxAllowedPacket: defaultMaxAllowedPacket,
186+
closech: make(chan struct{}),
187+
cfg: NewConfig(),
188+
}
189+
190+
err := test.funcToCall(mc)
191+
192+
if err != driver.ErrBadConn {
193+
t.Errorf("expected driver.ErrBadConn, got %#v", err)
194+
}
195+
})
174196
}
175197
}
176198

177-
func TestPingErrInvalidConn(t *testing.T) {
178-
nc := badConnection{err: errors.New("failed to write"), n: 10}
179-
mc := &mysqlConn{
180-
netConn: nc,
181-
buf: newBuffer(),
182-
maxAllowedPacket: defaultMaxAllowedPacket,
183-
closech: make(chan struct{}),
184-
cfg: NewConfig(),
185-
}
186-
187-
err := mc.Ping(context.Background())
188-
189-
if err != nc.err {
190-
t.Errorf("expected %#v, got %#v", nc.err, err)
199+
func TestNoArgsCommandErrInvalidConn(t *testing.T) {
200+
for _, test := range []struct {
201+
name string
202+
funcToCall func(mc *mysqlConn) error
203+
} {
204+
{name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }},
205+
{name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }},
206+
} {
207+
test := test
208+
t.Run(test.name, func(t *testing.T) {
209+
nc := badConnection{err: errors.New("failed to write"), n: 10}
210+
mc := &mysqlConn{
211+
netConn: nc,
212+
buf: newBuffer(),
213+
maxAllowedPacket: defaultMaxAllowedPacket,
214+
closech: make(chan struct{}),
215+
cfg: NewConfig(),
216+
}
217+
218+
err := test.funcToCall(mc)
219+
220+
if err != nc.err {
221+
t.Errorf("expected %#v, got %#v", nc.err, err)
222+
}
223+
})
191224
}
192225
}
193226

const.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ const (
115115
comStmtReset
116116
comSetOption
117117
comStmtFetch
118+
comConnReset = 31
118119
)
119120

120121
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType

driver_test.go

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,13 +2322,69 @@ func TestRejectReadOnly(t *testing.T) {
23222322
}
23232323

23242324
func TestPing(t *testing.T) {
2325-
ctx := context.Background()
23262325
runTests(t, dsn, func(dbt *DBTest) {
23272326
if err := dbt.db.Ping(); err != nil {
23282327
dbt.fail("Ping", "Ping", err)
23292328
}
23302329
})
2330+
}
2331+
2332+
func TestNoArgsCommand(t *testing.T) {
2333+
ctx := context.Background()
2334+
for _, test := range []struct{
2335+
method string
2336+
query string
2337+
funcToCall func(ctx context.Context, mc *mysqlConn) error
2338+
} {
2339+
{method: "Ping", query: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Ping(ctx)}},
2340+
{method: "Conn", query: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Reset(ctx)}},
2341+
} {
2342+
test := test
2343+
t.Run(test.method+"_"+test.query, func(t *testing.T) {
2344+
runTests(t, dsn, func(dbt *DBTest) {
2345+
conn, err := dbt.db.Conn(ctx)
2346+
if err != nil {
2347+
dbt.fail("db", "Conn", err)
2348+
}
2349+
2350+
// Check that affectedRows and insertIds are cleared after each call.
2351+
conn.Raw(func(conn any) error {
2352+
c := conn.(*mysqlConn)
2353+
2354+
// Issue a query that sets affectedRows and insertIds.
2355+
q, err := c.Query(`SELECT 1`, nil)
2356+
if err != nil {
2357+
dbt.fail("Conn", "Query", err)
2358+
}
2359+
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
2360+
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
2361+
}
2362+
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
2363+
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
2364+
}
2365+
q.Close()
2366+
2367+
// Verify that Ping()/Reset() clears both fields.
2368+
for range 2 {
2369+
if err := test.funcToCall(ctx, c); err != nil {
2370+
dbt.fail(test.method, test.query, err)
2371+
}
2372+
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
2373+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2374+
}
2375+
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
2376+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2377+
}
2378+
}
2379+
return nil
2380+
})
2381+
})
2382+
})
2383+
}
2384+
}
23312385

2386+
func TestReset(t *testing.T) {
2387+
ctx := context.Background()
23322388
runTests(t, dsn, func(dbt *DBTest) {
23332389
conn, err := dbt.db.Conn(ctx)
23342390
if err != nil {
@@ -2339,31 +2395,49 @@ func TestPing(t *testing.T) {
23392395
conn.Raw(func(conn any) error {
23402396
c := conn.(*mysqlConn)
23412397

2342-
// Issue a query that sets affectedRows and insertIds.
2343-
q, err := c.Query(`SELECT 1`, nil)
2398+
_, err = c.ExecContext(ctx, "SET @a := 1", nil)
2399+
if err != nil {
2400+
dbt.fail("Conn", "ExecContext", err)
2401+
}
2402+
var rows driver.Rows
2403+
rows, err = c.QueryContext(ctx, "SELECT @a", nil)
23442404
if err != nil {
2345-
dbt.fail("Conn", "Query", err)
2405+
dbt.fail("Conn", "QueryContext", err)
23462406
}
2347-
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
2348-
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
2407+
result := []driver.Value{0}
2408+
err = rows.Next(result)
2409+
if err != nil {
2410+
dbt.fail("Rows", "Next", err)
23492411
}
2350-
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
2351-
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
2412+
err = rows.Close()
2413+
if err != nil {
2414+
dbt.fail("Rows", "Close", err)
2415+
}
2416+
if !reflect.DeepEqual([]driver.Value{int64(1)}, result) {
2417+
dbt.Fatalf("failed to set @a to 1 with SET: got %v, want=%v", result, []driver.Value{int64(1)})
23522418
}
2353-
q.Close()
23542419

2355-
// Verify that Ping() clears both fields.
2356-
for range 2 {
2357-
if err := c.Ping(ctx); err != nil {
2358-
dbt.fail("Pinger", "Ping", err)
2359-
}
2360-
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
2361-
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2362-
}
2363-
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
2364-
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2365-
}
2420+
err = c.Reset(ctx)
2421+
if err != nil {
2422+
dbt.fail("Conn", "Reset", err)
2423+
}
2424+
2425+
rows, err = c.QueryContext(ctx, "SELECT @a", nil)
2426+
if err != nil {
2427+
dbt.fail("Conn", "QueryContext", err)
2428+
}
2429+
err = rows.Next(result)
2430+
if err != nil {
2431+
dbt.fail("Rows", "Next", err)
2432+
}
2433+
err = rows.Close()
2434+
if err != nil {
2435+
dbt.fail("Rows", "Close", err)
23662436
}
2437+
if !reflect.DeepEqual([]driver.Value{nil}, result) {
2438+
dbt.Fatalf("Reset did not reset the session (@a is still set): got %v, want=%v", result, []driver.Value{nil})
2439+
}
2440+
23672441
return nil
23682442
})
23692443
})

0 commit comments

Comments
 (0)