@@ -3244,6 +3244,166 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
3244
3244
req .Header = http.Header {}
3245
3245
}
3246
3246
3247
+ func TestTransportCloseAfterLostPing (t * testing.T ) {
3248
+ clientDone := make (chan struct {})
3249
+ ct := newClientTester (t )
3250
+ ct .tr .PingTimeout = 1 * time .Second
3251
+ ct .tr .ReadIdleTimeout = 1 * time .Second
3252
+ ct .client = func () error {
3253
+ defer ct .cc .(* net.TCPConn ).CloseWrite ()
3254
+ defer close (clientDone )
3255
+ req , _ := http .NewRequest ("GET" , "https://dummy.tld/" , nil )
3256
+ _ , err := ct .tr .RoundTrip (req )
3257
+ if err == nil || ! strings .Contains (err .Error (), "client connection lost" ) {
3258
+ return fmt .Errorf ("expected to get error about \" connection lost\" , got %v" , err )
3259
+ }
3260
+ return nil
3261
+ }
3262
+ ct .server = func () error {
3263
+ ct .greet ()
3264
+ <- clientDone
3265
+ return nil
3266
+ }
3267
+ ct .run ()
3268
+ }
3269
+
3270
+ func TestTransportPingWhenReading (t * testing.T ) {
3271
+ testCases := []struct {
3272
+ name string
3273
+ readIdleTimeout time.Duration
3274
+ serverResponseInterval time.Duration
3275
+ expectedPingCount int
3276
+ }{
3277
+ {
3278
+ name : "two pings in each serverResponseInterval" ,
3279
+ readIdleTimeout : 400 * time .Millisecond ,
3280
+ serverResponseInterval : 1000 * time .Millisecond ,
3281
+ expectedPingCount : 4 ,
3282
+ },
3283
+ {
3284
+ name : "one ping in each serverResponseInterval" ,
3285
+ readIdleTimeout : 700 * time .Millisecond ,
3286
+ serverResponseInterval : 1000 * time .Millisecond ,
3287
+ expectedPingCount : 2 ,
3288
+ },
3289
+ {
3290
+ name : "zero ping in each serverResponseInterval" ,
3291
+ readIdleTimeout : 1000 * time .Millisecond ,
3292
+ serverResponseInterval : 500 * time .Millisecond ,
3293
+ expectedPingCount : 0 ,
3294
+ },
3295
+ {
3296
+ name : "0 readIdleTimeout means no ping" ,
3297
+ readIdleTimeout : 0 * time .Millisecond ,
3298
+ serverResponseInterval : 500 * time .Millisecond ,
3299
+ expectedPingCount : 0 ,
3300
+ },
3301
+ }
3302
+
3303
+ for _ , tc := range testCases {
3304
+ tc := tc // capture range variable
3305
+ t .Run (tc .name , func (t * testing.T ) {
3306
+ t .Parallel ()
3307
+ testTransportPingWhenReading (t , tc .readIdleTimeout , tc .serverResponseInterval , tc .expectedPingCount )
3308
+ })
3309
+ }
3310
+ }
3311
+
3312
+ func testTransportPingWhenReading (t * testing.T , readIdleTimeout , serverResponseInterval time.Duration , expectedPingCount int ) {
3313
+ var pingCount int
3314
+ clientDone := make (chan struct {})
3315
+ ct := newClientTester (t )
3316
+ ct .tr .PingTimeout = 10 * time .Millisecond
3317
+ ct .tr .ReadIdleTimeout = readIdleTimeout
3318
+ // guards the ct.fr.Write
3319
+ var wmu sync.Mutex
3320
+
3321
+ ct .client = func () error {
3322
+ defer ct .cc .(* net.TCPConn ).CloseWrite ()
3323
+ defer close (clientDone )
3324
+ req , _ := http .NewRequest ("GET" , "https://dummy.tld/" , nil )
3325
+ res , err := ct .tr .RoundTrip (req )
3326
+ if err != nil {
3327
+ return fmt .Errorf ("RoundTrip: %v" , err )
3328
+ }
3329
+ defer res .Body .Close ()
3330
+ if res .StatusCode != 200 {
3331
+ return fmt .Errorf ("status code = %v; want %v" , res .StatusCode , 200 )
3332
+ }
3333
+ _ , err = ioutil .ReadAll (res .Body )
3334
+ return err
3335
+ }
3336
+
3337
+ ct .server = func () error {
3338
+ ct .greet ()
3339
+ var buf bytes.Buffer
3340
+ enc := hpack .NewEncoder (& buf )
3341
+ for {
3342
+ f , err := ct .fr .ReadFrame ()
3343
+ if err != nil {
3344
+ select {
3345
+ case <- clientDone :
3346
+ // If the client's done, it
3347
+ // will have reported any
3348
+ // errors on its side.
3349
+ return nil
3350
+ default :
3351
+ return err
3352
+ }
3353
+ }
3354
+ switch f := f .(type ) {
3355
+ case * WindowUpdateFrame , * SettingsFrame :
3356
+ case * HeadersFrame :
3357
+ if ! f .HeadersEnded () {
3358
+ return fmt .Errorf ("headers should have END_HEADERS be ended: %v" , f )
3359
+ }
3360
+ enc .WriteField (hpack.HeaderField {Name : ":status" , Value : strconv .Itoa (200 )})
3361
+ ct .fr .WriteHeaders (HeadersFrameParam {
3362
+ StreamID : f .StreamID ,
3363
+ EndHeaders : true ,
3364
+ EndStream : false ,
3365
+ BlockFragment : buf .Bytes (),
3366
+ })
3367
+
3368
+ go func () {
3369
+ for i := 0 ; i < 2 ; i ++ {
3370
+ wmu .Lock ()
3371
+ if err := ct .fr .WriteData (f .StreamID , false , []byte (fmt .Sprintf ("hello, this is server data frame %d" , i ))); err != nil {
3372
+ wmu .Unlock ()
3373
+ t .Error (err )
3374
+ return
3375
+ }
3376
+ wmu .Unlock ()
3377
+ time .Sleep (serverResponseInterval )
3378
+ }
3379
+ wmu .Lock ()
3380
+ if err := ct .fr .WriteData (f .StreamID , true , []byte ("hello, this is last server data frame" )); err != nil {
3381
+ wmu .Unlock ()
3382
+ t .Error (err )
3383
+ return
3384
+ }
3385
+ wmu .Unlock ()
3386
+ }()
3387
+ case * PingFrame :
3388
+ pingCount ++
3389
+ wmu .Lock ()
3390
+ if err := ct .fr .WritePing (true , f .Data ); err != nil {
3391
+ wmu .Unlock ()
3392
+ return err
3393
+ }
3394
+ wmu .Unlock ()
3395
+ default :
3396
+ return fmt .Errorf ("Unexpected client frame %v" , f )
3397
+ }
3398
+ }
3399
+ }
3400
+ ct .run ()
3401
+ if e , a := expectedPingCount , pingCount ; e != a {
3402
+ t .Errorf ("expected receiving %d pings, got %d pings" , e , a )
3403
+
3404
+ }
3405
+ }
3406
+
3247
3407
func TestTransportRetryAfterGOAWAY (t * testing.T ) {
3248
3408
var dialer struct {
3249
3409
sync.Mutex
0 commit comments