Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9e0498d

Browse files
committedMar 19, 2024
http2: use synthetic timers for ping timeouts in tests
Change-Id: I642890519b066937ade3c13e8387c31d29e912f4 Reviewed-on: https://go-review.googlesource.com/c/net/+/572377 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Jonathan Amsterdam <[email protected]>
1 parent 31d9683 commit 9e0498d

File tree

4 files changed

+240
-70
lines changed

4 files changed

+240
-70
lines changed
 

‎http2/clientconn_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
123123
tc.fr.SetMaxReadFrameSize(10 << 20)
124124

125125
t.Cleanup(func() {
126+
tc.sync()
126127
if tc.rerr == nil {
127128
tc.rerr = io.EOF
128129
}
@@ -459,6 +460,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he
459460
tc.sync()
460461
}
461462

463+
func (tc *testClientConn) writePing(ack bool, data [8]byte) {
464+
tc.t.Helper()
465+
if err := tc.fr.WritePing(ack, data); err != nil {
466+
tc.t.Fatal(err)
467+
}
468+
tc.sync()
469+
}
470+
462471
func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
463472
tc.t.Helper()
464473
if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {

‎http2/testsync.go

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package http2
55

66
import (
7+
"context"
78
"sync"
89
"time"
910
)
@@ -173,25 +174,64 @@ func (h *testSyncHooks) condWait(cond *sync.Cond) {
173174
h.unlock()
174175
}
175176

176-
// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests.
177+
// newTimer creates a new fake timer.
177178
func (h *testSyncHooks) newTimer(d time.Duration) timer {
178179
h.lock()
179180
defer h.unlock()
180181
t := &fakeTimer{
181-
when: h.now.Add(d),
182-
c: make(chan time.Time),
182+
hooks: h,
183+
when: h.now.Add(d),
184+
c: make(chan time.Time),
183185
}
184186
h.timers = append(h.timers, t)
185187
return t
186188
}
187189

190+
// afterFunc creates a new fake AfterFunc timer.
191+
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
192+
h.lock()
193+
defer h.unlock()
194+
t := &fakeTimer{
195+
hooks: h,
196+
when: h.now.Add(d),
197+
f: f,
198+
}
199+
h.timers = append(h.timers, t)
200+
return t
201+
}
202+
203+
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
204+
ctx, cancel := context.WithCancel(ctx)
205+
t := h.afterFunc(d, cancel)
206+
return ctx, func() {
207+
t.Stop()
208+
cancel()
209+
}
210+
}
211+
212+
func (h *testSyncHooks) timeUntilEvent() time.Duration {
213+
h.lock()
214+
defer h.unlock()
215+
var next time.Time
216+
for _, t := range h.timers {
217+
if next.IsZero() || t.when.Before(next) {
218+
next = t.when
219+
}
220+
}
221+
if d := next.Sub(h.now); d > 0 {
222+
return d
223+
}
224+
return 0
225+
}
226+
188227
// advance advances time and causes synthetic timers to fire.
189228
func (h *testSyncHooks) advance(d time.Duration) {
190229
h.lock()
191230
defer h.unlock()
192231
h.now = h.now.Add(d)
193232
timers := h.timers[:0]
194233
for _, t := range h.timers {
234+
t := t // remove after go.mod depends on go1.22
195235
t.mu.Lock()
196236
switch {
197237
case t.when.After(h.now):
@@ -200,7 +240,20 @@ func (h *testSyncHooks) advance(d time.Duration) {
200240
// stopped timer
201241
default:
202242
t.when = time.Time{}
203-
close(t.c)
243+
if t.c != nil {
244+
close(t.c)
245+
}
246+
if t.f != nil {
247+
h.total++
248+
go func() {
249+
defer func() {
250+
h.lock()
251+
h.total--
252+
h.unlock()
253+
}()
254+
t.f()
255+
}()
256+
}
204257
}
205258
t.mu.Unlock()
206259
}
@@ -212,13 +265,16 @@ func (h *testSyncHooks) advance(d time.Duration) {
212265
type timer interface {
213266
C() <-chan time.Time
214267
Stop() bool
268+
Reset(d time.Duration) bool
215269
}
216270

271+
// timeTimer implements timer using real time.
217272
type timeTimer struct {
218273
t *time.Timer
219274
c chan time.Time
220275
}
221276

277+
// newTimeTimer creates a new timer using real time.
222278
func newTimeTimer(d time.Duration) timer {
223279
ch := make(chan time.Time)
224280
t := time.AfterFunc(d, func() {
@@ -227,20 +283,49 @@ func newTimeTimer(d time.Duration) timer {
227283
return &timeTimer{t, ch}
228284
}
229285

230-
func (t timeTimer) C() <-chan time.Time { return t.c }
231-
func (t timeTimer) Stop() bool { return t.t.Stop() }
286+
// newTimeAfterFunc creates an AfterFunc timer using real time.
287+
func newTimeAfterFunc(d time.Duration, f func()) timer {
288+
return &timeTimer{
289+
t: time.AfterFunc(d, f),
290+
}
291+
}
232292

293+
func (t timeTimer) C() <-chan time.Time { return t.c }
294+
func (t timeTimer) Stop() bool { return t.t.Stop() }
295+
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
296+
297+
// fakeTimer implements timer using fake time.
233298
type fakeTimer struct {
299+
hooks *testSyncHooks
300+
234301
mu sync.Mutex
235-
when time.Time
236-
c chan time.Time
302+
when time.Time // when the timer will fire
303+
c chan time.Time // closed when the timer fires; mutually exclusive with f
304+
f func() // called when the timer fires; mutually exclusive with c
237305
}
238306

239307
func (t *fakeTimer) C() <-chan time.Time { return t.c }
308+
240309
func (t *fakeTimer) Stop() bool {
241310
t.mu.Lock()
242311
defer t.mu.Unlock()
243312
stopped := t.when.IsZero()
244313
t.when = time.Time{}
245314
return stopped
246315
}
316+
317+
func (t *fakeTimer) Reset(d time.Duration) bool {
318+
if t.c != nil || t.f == nil {
319+
panic("fakeTimer only supports Reset on AfterFunc timers")
320+
}
321+
t.mu.Lock()
322+
defer t.mu.Unlock()
323+
t.hooks.lock()
324+
defer t.hooks.unlock()
325+
active := !t.when.IsZero()
326+
t.when = t.hooks.now.Add(d)
327+
if !active {
328+
t.hooks.timers = append(t.hooks.timers, t)
329+
}
330+
return active
331+
}

‎http2/transport.go

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,21 @@ func (cc *ClientConn) newTimer(d time.Duration) timer {
391391
return newTimeTimer(d)
392392
}
393393

394+
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
395+
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
396+
if cc.syncHooks != nil {
397+
return cc.syncHooks.afterFunc(d, f)
398+
}
399+
return newTimeAfterFunc(d, f)
400+
}
401+
402+
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
403+
if cc.syncHooks != nil {
404+
return cc.syncHooks.contextWithTimeout(ctx, d)
405+
}
406+
return context.WithTimeout(ctx, d)
407+
}
408+
394409
// clientStream is the state for a single HTTP/2 stream. One of these
395410
// is created for each Transport.RoundTrip call.
396411
type clientStream struct {
@@ -875,7 +890,7 @@ func (cc *ClientConn) healthCheck() {
875890
pingTimeout := cc.t.pingTimeout()
876891
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
877892
// trigger the healthCheck again if there is no frame received.
878-
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
893+
ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
879894
defer cancel()
880895
cc.vlogf("http2: Transport sending health check")
881896
err := cc.Ping(ctx)
@@ -1432,6 +1447,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
14321447
if cc.reqHeaderMu == nil {
14331448
panic("RoundTrip on uninitialized ClientConn") // for tests
14341449
}
1450+
var newStreamHook func(*clientStream)
1451+
if cc.syncHooks != nil {
1452+
newStreamHook = cc.syncHooks.newstream
1453+
cc.syncHooks.blockUntil(func() bool {
1454+
select {
1455+
case cc.reqHeaderMu <- struct{}{}:
1456+
<-cc.reqHeaderMu
1457+
case <-cs.reqCancel:
1458+
case <-ctx.Done():
1459+
default:
1460+
return false
1461+
}
1462+
return true
1463+
})
1464+
}
14351465
select {
14361466
case cc.reqHeaderMu <- struct{}{}:
14371467
case <-cs.reqCancel:
@@ -1456,8 +1486,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
14561486
}
14571487
cc.mu.Unlock()
14581488

1459-
if cc.syncHooks != nil {
1460-
cc.syncHooks.newstream(cs)
1489+
if newStreamHook != nil {
1490+
newStreamHook(cs)
14611491
}
14621492

14631493
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
@@ -2369,10 +2399,9 @@ func (rl *clientConnReadLoop) run() error {
23692399
cc := rl.cc
23702400
gotSettings := false
23712401
readIdleTimeout := cc.t.ReadIdleTimeout
2372-
var t *time.Timer
2402+
var t timer
23732403
if readIdleTimeout != 0 {
2374-
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
2375-
defer t.Stop()
2404+
t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
23762405
}
23772406
for {
23782407
f, err := cc.fr.ReadFrame()
@@ -3067,24 +3096,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
30673096
}
30683097
cc.mu.Unlock()
30693098
}
3070-
errc := make(chan error, 1)
3099+
var pingError error
3100+
errc := make(chan struct{})
30713101
cc.goRun(func() {
30723102
cc.wmu.Lock()
30733103
defer cc.wmu.Unlock()
3074-
if err := cc.fr.WritePing(false, p); err != nil {
3075-
errc <- err
3104+
if pingError = cc.fr.WritePing(false, p); pingError != nil {
3105+
close(errc)
30763106
return
30773107
}
3078-
if err := cc.bw.Flush(); err != nil {
3079-
errc <- err
3108+
if pingError = cc.bw.Flush(); pingError != nil {
3109+
close(errc)
30803110
return
30813111
}
30823112
})
3113+
if cc.syncHooks != nil {
3114+
cc.syncHooks.blockUntil(func() bool {
3115+
select {
3116+
case <-c:
3117+
case <-errc:
3118+
case <-ctx.Done():
3119+
case <-cc.readerDone:
3120+
default:
3121+
return false
3122+
}
3123+
return true
3124+
})
3125+
}
30833126
select {
30843127
case <-c:
30853128
return nil
3086-
case err := <-errc:
3087-
return err
3129+
case <-errc:
3130+
return pingError
30883131
case <-ctx.Done():
30893132
return ctx.Err()
30903133
case <-cc.readerDone:

‎http2/transport_test.go

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3310,26 +3310,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
33103310
}
33113311

33123312
func TestTransportCloseAfterLostPing(t *testing.T) {
3313-
clientDone := make(chan struct{})
3314-
ct := newClientTester(t)
3315-
ct.tr.PingTimeout = 1 * time.Second
3316-
ct.tr.ReadIdleTimeout = 1 * time.Second
3317-
ct.client = func() error {
3318-
defer ct.cc.(*net.TCPConn).CloseWrite()
3319-
defer close(clientDone)
3320-
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3321-
_, err := ct.tr.RoundTrip(req)
3322-
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
3323-
return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
3324-
}
3325-
return nil
3326-
}
3327-
ct.server = func() error {
3328-
ct.greet()
3329-
<-clientDone
3330-
return nil
3313+
tc := newTestClientConn(t, func(tr *Transport) {
3314+
tr.PingTimeout = 1 * time.Second
3315+
tr.ReadIdleTimeout = 1 * time.Second
3316+
})
3317+
tc.greet()
3318+
3319+
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3320+
rt := tc.roundTrip(req)
3321+
tc.wantFrameType(FrameHeaders)
3322+
3323+
tc.advance(1 * time.Second)
3324+
tc.wantFrameType(FramePing)
3325+
3326+
tc.advance(1 * time.Second)
3327+
err := rt.err()
3328+
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
3329+
t.Fatalf("expected to get error about \"connection lost\", got %v", err)
33313330
}
3332-
ct.run()
33333331
}
33343332

33353333
func TestTransportPingWriteBlocks(t *testing.T) {
@@ -3362,38 +3360,73 @@ func TestTransportPingWriteBlocks(t *testing.T) {
33623360
}
33633361
}
33643362

3365-
func TestTransportPingWhenReading(t *testing.T) {
3366-
testCases := []struct {
3367-
name string
3368-
readIdleTimeout time.Duration
3369-
deadline time.Duration
3370-
expectedPingCount int
3371-
}{
3372-
{
3373-
name: "two pings",
3374-
readIdleTimeout: 100 * time.Millisecond,
3375-
deadline: time.Second,
3376-
expectedPingCount: 2,
3377-
},
3378-
{
3379-
name: "zero ping",
3380-
readIdleTimeout: time.Second,
3381-
deadline: 200 * time.Millisecond,
3382-
expectedPingCount: 0,
3383-
},
3384-
{
3385-
name: "0 readIdleTimeout means no ping",
3386-
readIdleTimeout: 0 * time.Millisecond,
3387-
deadline: 500 * time.Millisecond,
3388-
expectedPingCount: 0,
3389-
},
3363+
func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
3364+
tc := newTestClientConn(t, func(tr *Transport) {
3365+
tr.ReadIdleTimeout = 1000 * time.Millisecond
3366+
})
3367+
tc.greet()
3368+
3369+
ctx, cancel := context.WithCancel(context.Background())
3370+
req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
3371+
rt := tc.roundTrip(req)
3372+
3373+
tc.wantFrameType(FrameHeaders)
3374+
tc.writeHeaders(HeadersFrameParam{
3375+
StreamID: rt.streamID(),
3376+
EndHeaders: true,
3377+
EndStream: false,
3378+
BlockFragment: tc.makeHeaderBlockFragment(
3379+
":status", "200",
3380+
),
3381+
})
3382+
3383+
for i := 0; i < 5; i++ {
3384+
// No ping yet...
3385+
tc.advance(999 * time.Millisecond)
3386+
if f := tc.readFrame(); f != nil {
3387+
t.Fatalf("unexpected frame: %v", f)
3388+
}
3389+
3390+
// ...ping now.
3391+
tc.advance(1 * time.Millisecond)
3392+
f := testClientConnReadFrame[*PingFrame](tc)
3393+
tc.writePing(true, f.Data)
33903394
}
33913395

3392-
for _, tc := range testCases {
3393-
tc := tc // capture range variable
3394-
t.Run(tc.name, func(t *testing.T) {
3395-
testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
3396-
})
3396+
// Cancel the request, Transport resets it and returns an error from body reads.
3397+
cancel()
3398+
tc.sync()
3399+
3400+
tc.wantFrameType(FrameRSTStream)
3401+
_, err := rt.readBody()
3402+
if err == nil {
3403+
t.Fatalf("Response.Body.Read() = %v, want error", err)
3404+
}
3405+
}
3406+
3407+
func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
3408+
tc := newTestClientConn(t, func(tr *Transport) {
3409+
tr.ReadIdleTimeout = 0 // PINGs disabled
3410+
})
3411+
tc.greet()
3412+
3413+
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3414+
rt := tc.roundTrip(req)
3415+
3416+
tc.wantFrameType(FrameHeaders)
3417+
tc.writeHeaders(HeadersFrameParam{
3418+
StreamID: rt.streamID(),
3419+
EndHeaders: true,
3420+
EndStream: false,
3421+
BlockFragment: tc.makeHeaderBlockFragment(
3422+
":status", "200",
3423+
),
3424+
})
3425+
3426+
// No PING is sent, even after a long delay.
3427+
tc.advance(1 * time.Minute)
3428+
if f := tc.readFrame(); f != nil {
3429+
t.Fatalf("unexpected frame: %v", f)
33973430
}
33983431
}
33993432

0 commit comments

Comments
 (0)
Please sign in to comment.