Skip to content

Commit bbd867f

Browse files
johanbrandhorstbradfitz
authored andcommitted
http2: use (*tls.Dialer).DialContext in dialTLS
This lets us propagate the request context into the TLS handshake. Related to CL 295370 Updates golang/go#32406 Change-Id: Ie10c301be19b57b4b3e46ac31bbe87679e1eebc7 Reviewed-on: https://go-review.googlesource.com/c/net/+/295173 Trust: Johan Brandhorst-Satzkorn <[email protected]> Run-TryBot: Johan Brandhorst-Satzkorn <[email protected]> TryBot-Result: Go Bot <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]> Reviewed-by: Filippo Valsorda <[email protected]>
1 parent 7fd8e65 commit bbd867f

File tree

4 files changed

+250
-49
lines changed

4 files changed

+250
-49
lines changed

http2/client_conn_pool.go

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
package http2
88

99
import (
10+
"context"
1011
"crypto/tls"
12+
"errors"
1113
"net/http"
1214
"sync"
1315
)
@@ -78,61 +80,69 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
7880
// It gets its own connection.
7981
traceGetConn(req, addr)
8082
const singleUse = true
81-
cc, err := p.t.dialClientConn(addr, singleUse)
83+
cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
8284
if err != nil {
8385
return nil, err
8486
}
8587
return cc, nil
8688
}
87-
p.mu.Lock()
88-
for _, cc := range p.conns[addr] {
89-
if st := cc.idleState(); st.canTakeNewRequest {
90-
if p.shouldTraceGetConn(st) {
91-
traceGetConn(req, addr)
89+
for {
90+
p.mu.Lock()
91+
for _, cc := range p.conns[addr] {
92+
if st := cc.idleState(); st.canTakeNewRequest {
93+
if p.shouldTraceGetConn(st) {
94+
traceGetConn(req, addr)
95+
}
96+
p.mu.Unlock()
97+
return cc, nil
9298
}
99+
}
100+
if !dialOnMiss {
93101
p.mu.Unlock()
94-
return cc, nil
102+
return nil, ErrNoCachedConn
95103
}
96-
}
97-
if !dialOnMiss {
104+
traceGetConn(req, addr)
105+
call := p.getStartDialLocked(req.Context(), addr)
98106
p.mu.Unlock()
99-
return nil, ErrNoCachedConn
107+
<-call.done
108+
if shouldRetryDial(call, req) {
109+
continue
110+
}
111+
return call.res, call.err
100112
}
101-
traceGetConn(req, addr)
102-
call := p.getStartDialLocked(addr)
103-
p.mu.Unlock()
104-
<-call.done
105-
return call.res, call.err
106113
}
107114

108115
// dialCall is an in-flight Transport dial call to a host.
109116
type dialCall struct {
110-
_ incomparable
111-
p *clientConnPool
117+
_ incomparable
118+
p *clientConnPool
119+
// the context associated with the request
120+
// that created this dialCall
121+
ctx context.Context
112122
done chan struct{} // closed when done
113123
res *ClientConn // valid after done is closed
114124
err error // valid after done is closed
115125
}
116126

117127
// requires p.mu is held.
118-
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
128+
func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
119129
if call, ok := p.dialing[addr]; ok {
120130
// A dial is already in-flight. Don't start another.
121131
return call
122132
}
123-
call := &dialCall{p: p, done: make(chan struct{})}
133+
call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
124134
if p.dialing == nil {
125135
p.dialing = make(map[string]*dialCall)
126136
}
127137
p.dialing[addr] = call
128-
go call.dial(addr)
138+
go call.dial(call.ctx, addr)
129139
return call
130140
}
131141

132142
// run in its own goroutine.
133-
func (c *dialCall) dial(addr string) {
143+
func (c *dialCall) dial(ctx context.Context, addr string) {
134144
const singleUse = false // shared conn
135-
c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
145+
c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
136146
close(c.done)
137147

138148
c.p.mu.Lock()
@@ -276,3 +286,28 @@ type noDialClientConnPool struct{ *clientConnPool }
276286
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
277287
return p.getClientConn(req, addr, noDialOnMiss)
278288
}
289+
290+
// shouldRetryDial reports whether the current request should
291+
// retry dialing after the call finished unsuccessfully, for example
292+
// if the dial was canceled because of a context cancellation or
293+
// deadline expiry.
294+
func shouldRetryDial(call *dialCall, req *http.Request) bool {
295+
if call.err == nil {
296+
// No error, no need to retry
297+
return false
298+
}
299+
if call.ctx == req.Context() {
300+
// If the call has the same context as the request, the dial
301+
// should not be retried, since any cancellation will have come
302+
// from this request.
303+
return false
304+
}
305+
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
306+
// If the call error is not because of a context cancellation or a deadline expiry,
307+
// the dial should not be retried.
308+
return false
309+
}
310+
// Only retry if the error is a context cancellation error or deadline expiry
311+
// and the context associated with the call was canceled or expired.
312+
return call.ctx.Err() != nil
313+
}

http2/transport.go

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -564,12 +564,12 @@ func canRetryError(err error) bool {
564564
return false
565565
}
566566

567-
func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
567+
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
568568
host, _, err := net.SplitHostPort(addr)
569569
if err != nil {
570570
return nil, err
571571
}
572-
tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
572+
tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host))
573573
if err != nil {
574574
return nil, err
575575
}
@@ -590,34 +590,28 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
590590
return cfg
591591
}
592592

593-
func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
593+
func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
594594
if t.DialTLS != nil {
595595
return t.DialTLS
596596
}
597-
return t.dialTLSDefault
598-
}
599-
600-
func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
601-
cn, err := tls.Dial(network, addr, cfg)
602-
if err != nil {
603-
return nil, err
604-
}
605-
if err := cn.Handshake(); err != nil {
606-
return nil, err
607-
}
608-
if !cfg.InsecureSkipVerify {
609-
if err := cn.VerifyHostname(cfg.ServerName); err != nil {
597+
return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
598+
dialer := &tls.Dialer{
599+
Config: cfg,
600+
}
601+
cn, err := dialer.DialContext(ctx, network, addr)
602+
if err != nil {
610603
return nil, err
611604
}
605+
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
606+
state := tlsCn.ConnectionState()
607+
if p := state.NegotiatedProtocol; p != NextProtoTLS {
608+
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
609+
}
610+
if !state.NegotiatedProtocolIsMutual {
611+
return nil, errors.New("http2: could not negotiate protocol mutually")
612+
}
613+
return cn, nil
612614
}
613-
state := cn.ConnectionState()
614-
if p := state.NegotiatedProtocol; p != NextProtoTLS {
615-
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
616-
}
617-
if !state.NegotiatedProtocolIsMutual {
618-
return nil, errors.New("http2: could not negotiate protocol mutually")
619-
}
620-
return cn, nil
621615
}
622616

623617
// disableKeepAlives reports whether connections should be closed as

http2/transport_go117_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// Copyright 2021 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.17
6+
// +build go1.17
7+
8+
package http2
9+
10+
import (
11+
"context"
12+
"crypto/tls"
13+
"errors"
14+
"net/http"
15+
"net/http/httptest"
16+
17+
"testing"
18+
)
19+
20+
func TestTransportDialTLSContext(t *testing.T) {
21+
blockCh := make(chan struct{})
22+
serverTLSConfigFunc := func(ts *httptest.Server) {
23+
ts.Config.TLSConfig = &tls.Config{
24+
// Triggers the server to request the clients certificate
25+
// during TLS handshake.
26+
ClientAuth: tls.RequestClientCert,
27+
}
28+
}
29+
ts := newServerTester(t,
30+
func(w http.ResponseWriter, r *http.Request) {},
31+
optOnlyServer,
32+
serverTLSConfigFunc,
33+
)
34+
defer ts.Close()
35+
tr := &Transport{
36+
TLSClientConfig: &tls.Config{
37+
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
38+
// Tests that the context provided to `req` is
39+
// passed into this function.
40+
close(blockCh)
41+
<-cri.Context().Done()
42+
return nil, cri.Context().Err()
43+
},
44+
InsecureSkipVerify: true,
45+
},
46+
}
47+
defer tr.CloseIdleConnections()
48+
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
49+
if err != nil {
50+
t.Fatal(err)
51+
}
52+
ctx, cancel := context.WithCancel(context.Background())
53+
defer cancel()
54+
req = req.WithContext(ctx)
55+
errCh := make(chan error)
56+
go func() {
57+
defer close(errCh)
58+
res, err := tr.RoundTrip(req)
59+
if err != nil {
60+
errCh <- err
61+
return
62+
}
63+
res.Body.Close()
64+
}()
65+
// Wait for GetClientCertificate handler to be called
66+
<-blockCh
67+
// Cancel the context
68+
cancel()
69+
// Expect the cancellation error here
70+
err = <-errCh
71+
if err == nil {
72+
t.Fatal("cancelling context during client certificate fetch did not error as expected")
73+
return
74+
}
75+
if !errors.Is(err, context.Canceled) {
76+
t.Fatalf("unexpected error returned after cancellation: %v", err)
77+
}
78+
}
79+
80+
// TestDialRaceResumesDial tests that, given two concurrent requests
81+
// to the same address, when the first Dial is interrupted because
82+
// the first request's context is cancelled, the second request
83+
// resumes the dial automatically.
84+
func TestDialRaceResumesDial(t *testing.T) {
85+
blockCh := make(chan struct{})
86+
serverTLSConfigFunc := func(ts *httptest.Server) {
87+
ts.Config.TLSConfig = &tls.Config{
88+
// Triggers the server to request the clients certificate
89+
// during TLS handshake.
90+
ClientAuth: tls.RequestClientCert,
91+
}
92+
}
93+
ts := newServerTester(t,
94+
func(w http.ResponseWriter, r *http.Request) {},
95+
optOnlyServer,
96+
serverTLSConfigFunc,
97+
)
98+
defer ts.Close()
99+
tr := &Transport{
100+
TLSClientConfig: &tls.Config{
101+
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
102+
select {
103+
case <-blockCh:
104+
// If we already errored, return without error.
105+
return &tls.Certificate{}, nil
106+
default:
107+
}
108+
close(blockCh)
109+
<-cri.Context().Done()
110+
return nil, cri.Context().Err()
111+
},
112+
InsecureSkipVerify: true,
113+
},
114+
}
115+
defer tr.CloseIdleConnections()
116+
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
117+
if err != nil {
118+
t.Fatal(err)
119+
}
120+
// Create two requests with independent cancellation.
121+
ctx1, cancel1 := context.WithCancel(context.Background())
122+
defer cancel1()
123+
req1 := req.WithContext(ctx1)
124+
ctx2, cancel2 := context.WithCancel(context.Background())
125+
defer cancel2()
126+
req2 := req.WithContext(ctx2)
127+
errCh := make(chan error)
128+
go func() {
129+
res, err := tr.RoundTrip(req1)
130+
if err != nil {
131+
errCh <- err
132+
return
133+
}
134+
res.Body.Close()
135+
}()
136+
successCh := make(chan struct{})
137+
go func() {
138+
// Don't start request until first request
139+
// has initiated the handshake.
140+
<-blockCh
141+
res, err := tr.RoundTrip(req2)
142+
if err != nil {
143+
errCh <- err
144+
return
145+
}
146+
res.Body.Close()
147+
// Close successCh to indicate that the second request
148+
// made it to the server successfully.
149+
close(successCh)
150+
}()
151+
// Wait for GetClientCertificate handler to be called
152+
<-blockCh
153+
// Cancel the context first
154+
cancel1()
155+
// Expect the cancellation error here
156+
err = <-errCh
157+
if err == nil {
158+
t.Fatal("cancelling context during client certificate fetch did not error as expected")
159+
return
160+
}
161+
if !errors.Is(err, context.Canceled) {
162+
t.Fatalf("unexpected error returned after cancellation: %v", err)
163+
}
164+
select {
165+
case err := <-errCh:
166+
t.Fatalf("unexpected second error: %v", err)
167+
case <-successCh:
168+
}
169+
}

http2/transport_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3276,7 +3276,8 @@ func TestClientConnPing(t *testing.T) {
32763276
defer st.Close()
32773277
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
32783278
defer tr.CloseIdleConnections()
3279-
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
3279+
ctx := context.Background()
3280+
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
32803281
if err != nil {
32813282
t.Fatal(err)
32823283
}
@@ -4278,7 +4279,8 @@ func testClientConnClose(t *testing.T, closeMode closeMode) {
42784279
defer st.Close()
42794280
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
42804281
defer tr.CloseIdleConnections()
4281-
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
4282+
ctx := context.Background()
4283+
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
42824284
req, err := http.NewRequest("GET", st.ts.URL, nil)
42834285
if err != nil {
42844286
t.Fatal(err)
@@ -4788,7 +4790,8 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
47884790

47894791
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
47904792
defer tr.CloseIdleConnections()
4791-
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
4793+
ctx := context.Background()
4794+
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
47924795
if err != nil {
47934796
t.Fatal(err)
47944797
}

0 commit comments

Comments
 (0)