Skip to content

Commit 36607fe

Browse files
author
Chao Xu
committed
net/http2: perform connection health check
After the connection has been idle for a while, periodic pings are sent over the connection to check its health. Unhealthy connection is closed and removed from the connection pool. Fixes golang/go#31643
1 parent aa69164 commit 36607fe

File tree

2 files changed

+221
-5
lines changed

2 files changed

+221
-5
lines changed

http2/transport.go

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ type Transport struct {
108108
// waiting for their turn.
109109
StrictMaxConcurrentStreams bool
110110

111+
// ReadIdleTimeout is the timeout after which a health check using ping
112+
// frame will be carried out if no frame is received on the connection.
113+
// Note that a ping response will is considered a received frame, so if
114+
// there is no other traffic on the connection, the health check will
115+
// be performed every ReadIdleTimeout interval.
116+
// If zero, no health check is performed.
117+
ReadIdleTimeout time.Duration
118+
119+
// PingTimeout is the timeout after which the connection will be closed
120+
// if a response to Ping is not received.
121+
// Defaults to 15s.
122+
PingTimeout time.Duration
123+
111124
// t1, if non-nil, is the standard library Transport using
112125
// this transport. Its settings are used (but not its
113126
// RoundTrip method, etc).
@@ -131,6 +144,14 @@ func (t *Transport) disableCompression() bool {
131144
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
132145
}
133146

147+
func (t *Transport) pingTimeout() time.Duration {
148+
if t.PingTimeout == 0 {
149+
return 15 * time.Second
150+
}
151+
return t.PingTimeout
152+
153+
}
154+
134155
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
135156
// It returns an error if t1 has already been HTTP/2-enabled.
136157
func ConfigureTransport(t1 *http.Transport) error {
@@ -674,6 +695,20 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
674695
return cc, nil
675696
}
676697

698+
func (cc *ClientConn) healthCheck() {
699+
pingTimeout := cc.t.pingTimeout()
700+
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
701+
// trigger the healthCheck again if there is no frame received.
702+
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
703+
defer cancel()
704+
err := cc.Ping(ctx)
705+
if err != nil {
706+
cc.closeForLostPing()
707+
cc.t.connPool().MarkDead(cc)
708+
return
709+
}
710+
}
711+
677712
func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
678713
cc.mu.Lock()
679714
defer cc.mu.Unlock()
@@ -834,14 +869,12 @@ func (cc *ClientConn) sendGoAway() error {
834869
return nil
835870
}
836871

837-
// Close closes the client connection immediately.
838-
//
839-
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
840-
func (cc *ClientConn) Close() error {
872+
// closes the client connection immediately. In-flight requests are interrupted.
873+
// err is sent to streams.
874+
func (cc *ClientConn) closeForError(err error) error {
841875
cc.mu.Lock()
842876
defer cc.cond.Broadcast()
843877
defer cc.mu.Unlock()
844-
err := errors.New("http2: client connection force closed via ClientConn.Close")
845878
for id, cs := range cc.streams {
846879
select {
847880
case cs.resc <- resAndError{err: err}:
@@ -854,6 +887,20 @@ func (cc *ClientConn) Close() error {
854887
return cc.tconn.Close()
855888
}
856889

890+
// Close closes the client connection immediately.
891+
//
892+
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
893+
func (cc *ClientConn) Close() error {
894+
err := errors.New("http2: client connection force closed via ClientConn.Close")
895+
return cc.closeForError(err)
896+
}
897+
898+
// closes the client connection immediately. In-flight requests are interrupted.
899+
func (cc *ClientConn) closeForLostPing() error {
900+
err := errors.New("http2: client connection lost")
901+
return cc.closeForError(err)
902+
}
903+
857904
const maxAllocFrameSize = 512 << 10
858905

859906
// frameBuffer returns a scratch buffer suitable for writing DATA frames.
@@ -1706,8 +1753,17 @@ func (rl *clientConnReadLoop) run() error {
17061753
rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
17071754
gotReply := false // ever saw a HEADERS reply
17081755
gotSettings := false
1756+
readIdleTimeout := cc.t.ReadIdleTimeout
1757+
var t *time.Timer
1758+
if readIdleTimeout != 0 {
1759+
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
1760+
defer t.Stop()
1761+
}
17091762
for {
17101763
f, err := cc.fr.ReadFrame()
1764+
if t != nil {
1765+
t.Reset(readIdleTimeout)
1766+
}
17111767
if err != nil {
17121768
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
17131769
}

http2/transport_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3244,6 +3244,166 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
32443244
req.Header = http.Header{}
32453245
}
32463246

3247+
func TestTransportCloseAfterLostPing(t *testing.T) {
3248+
clientDone := make(chan struct{})
3249+
ct := newClientTester(t)
3250+
ct.tr.PingTimeout = 1 * time.Second
3251+
ct.tr.ReadIdleTimeout = 1 * time.Second
3252+
ct.client = func() error {
3253+
defer ct.cc.(*net.TCPConn).CloseWrite()
3254+
defer close(clientDone)
3255+
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3256+
_, err := ct.tr.RoundTrip(req)
3257+
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
3258+
return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
3259+
}
3260+
return nil
3261+
}
3262+
ct.server = func() error {
3263+
ct.greet()
3264+
<-clientDone
3265+
return nil
3266+
}
3267+
ct.run()
3268+
}
3269+
3270+
func TestTransportPingWhenReading(t *testing.T) {
3271+
testCases := []struct {
3272+
name string
3273+
readIdleTimeout time.Duration
3274+
serverResponseInterval time.Duration
3275+
expectedPingCount int
3276+
}{
3277+
{
3278+
name: "two pings in each serverResponseInterval",
3279+
readIdleTimeout: 400 * time.Millisecond,
3280+
serverResponseInterval: 1000 * time.Millisecond,
3281+
expectedPingCount: 4,
3282+
},
3283+
{
3284+
name: "one ping in each serverResponseInterval",
3285+
readIdleTimeout: 700 * time.Millisecond,
3286+
serverResponseInterval: 1000 * time.Millisecond,
3287+
expectedPingCount: 2,
3288+
},
3289+
{
3290+
name: "zero ping in each serverResponseInterval",
3291+
readIdleTimeout: 1000 * time.Millisecond,
3292+
serverResponseInterval: 500 * time.Millisecond,
3293+
expectedPingCount: 0,
3294+
},
3295+
{
3296+
name: "0 readIdleTimeout means no ping",
3297+
readIdleTimeout: 0 * time.Millisecond,
3298+
serverResponseInterval: 500 * time.Millisecond,
3299+
expectedPingCount: 0,
3300+
},
3301+
}
3302+
3303+
for _, tc := range testCases {
3304+
tc := tc // capture range variable
3305+
t.Run(tc.name, func(t *testing.T) {
3306+
t.Parallel()
3307+
testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
3308+
})
3309+
}
3310+
}
3311+
3312+
func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
3313+
var pingCount int
3314+
clientDone := make(chan struct{})
3315+
ct := newClientTester(t)
3316+
ct.tr.PingTimeout = 10 * time.Millisecond
3317+
ct.tr.ReadIdleTimeout = readIdleTimeout
3318+
// guards the ct.fr.Write
3319+
var wmu sync.Mutex
3320+
3321+
ct.client = func() error {
3322+
defer ct.cc.(*net.TCPConn).CloseWrite()
3323+
defer close(clientDone)
3324+
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3325+
res, err := ct.tr.RoundTrip(req)
3326+
if err != nil {
3327+
return fmt.Errorf("RoundTrip: %v", err)
3328+
}
3329+
defer res.Body.Close()
3330+
if res.StatusCode != 200 {
3331+
return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
3332+
}
3333+
_, err = ioutil.ReadAll(res.Body)
3334+
return err
3335+
}
3336+
3337+
ct.server = func() error {
3338+
ct.greet()
3339+
var buf bytes.Buffer
3340+
enc := hpack.NewEncoder(&buf)
3341+
for {
3342+
f, err := ct.fr.ReadFrame()
3343+
if err != nil {
3344+
select {
3345+
case <-clientDone:
3346+
// If the client's done, it
3347+
// will have reported any
3348+
// errors on its side.
3349+
return nil
3350+
default:
3351+
return err
3352+
}
3353+
}
3354+
switch f := f.(type) {
3355+
case *WindowUpdateFrame, *SettingsFrame:
3356+
case *HeadersFrame:
3357+
if !f.HeadersEnded() {
3358+
return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
3359+
}
3360+
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
3361+
ct.fr.WriteHeaders(HeadersFrameParam{
3362+
StreamID: f.StreamID,
3363+
EndHeaders: true,
3364+
EndStream: false,
3365+
BlockFragment: buf.Bytes(),
3366+
})
3367+
3368+
go func() {
3369+
for i := 0; i < 2; i++ {
3370+
wmu.Lock()
3371+
if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
3372+
wmu.Unlock()
3373+
t.Error(err)
3374+
return
3375+
}
3376+
wmu.Unlock()
3377+
time.Sleep(serverResponseInterval)
3378+
}
3379+
wmu.Lock()
3380+
if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
3381+
wmu.Unlock()
3382+
t.Error(err)
3383+
return
3384+
}
3385+
wmu.Unlock()
3386+
}()
3387+
case *PingFrame:
3388+
pingCount++
3389+
wmu.Lock()
3390+
if err := ct.fr.WritePing(true, f.Data); err != nil {
3391+
wmu.Unlock()
3392+
return err
3393+
}
3394+
wmu.Unlock()
3395+
default:
3396+
return fmt.Errorf("Unexpected client frame %v", f)
3397+
}
3398+
}
3399+
}
3400+
ct.run()
3401+
if e, a := expectedPingCount, pingCount; e != a {
3402+
t.Errorf("expected receiving %d pings, got %d pings", e, a)
3403+
3404+
}
3405+
}
3406+
32473407
func TestTransportRetryAfterGOAWAY(t *testing.T) {
32483408
var dialer struct {
32493409
sync.Mutex

0 commit comments

Comments
 (0)