Skip to content

Commit cae5c7f

Browse files
meirfbradfitz
authored andcommitted
net/http: add Transport.MaxConnsPerHost knob
Add field to http.Transport which limits connections per host, including dial-in-progress, in-use and idle (keep-alive) connections. For HTTP/2, this field only controls the number of dials in progress. Fixes #13957 Change-Id: I7a5e045b4d4793c6b5b1a7191e1342cd7df78e6c Reviewed-on: https://go-review.googlesource.com/71272 Reviewed-by: Brad Fitzpatrick <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]>
1 parent 58d287e commit cae5c7f

File tree

3 files changed

+220
-15
lines changed

3 files changed

+220
-15
lines changed

src/net/http/export_test.go

+11-3
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string {
133133
return ret
134134
}
135135

136-
func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
136+
func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
137137
t.idleMu.Lock()
138138
defer t.idleMu.Unlock()
139+
key := connectMethodKey{"", scheme, addr}
140+
cacheKey := key.String()
139141
for k, conns := range t.idleConn {
140142
if k.String() == cacheKey {
141143
return len(conns)
@@ -160,13 +162,19 @@ func (t *Transport) RequestIdleConnChForTesting() {
160162
t.getIdleConnCh(connectMethod{nil, "http", "example.com"})
161163
}
162164

163-
func (t *Transport) PutIdleTestConn() bool {
165+
func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
164166
c, _ := net.Pipe()
167+
key := connectMethodKey{"", scheme, addr}
168+
select {
169+
case <-t.incHostConnCount(key):
170+
default:
171+
return false
172+
}
165173
return t.tryPutIdleConn(&persistConn{
166174
t: t,
167175
conn: c, // dummy
168176
closech: make(chan struct{}), // so it can be closed
169-
cacheKey: connectMethodKey{"", "http", "example.com"},
177+
cacheKey: key,
170178
}) == nil
171179
}
172180

src/net/http/transport.go

+132-3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ var DefaultTransport RoundTripper = &Transport{
5555
// MaxIdleConnsPerHost.
5656
const DefaultMaxIdleConnsPerHost = 2
5757

58+
// connsPerHostClosedCh is a closed channel used by MaxConnsPerHost
59+
// for the property that receives from a closed channel return the
60+
// zero value.
61+
var connsPerHostClosedCh = make(chan struct{})
62+
63+
func init() {
64+
close(connsPerHostClosedCh)
65+
}
66+
5867
// Transport is an implementation of RoundTripper that supports HTTP,
5968
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
6069
//
@@ -103,6 +112,10 @@ type Transport struct {
103112
altMu sync.Mutex // guards changing altProto only
104113
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
105114

115+
connCountMu sync.Mutex
116+
connPerHostCount map[connectMethodKey]int
117+
connPerHostAvailable map[connectMethodKey]chan struct{}
118+
106119
// Proxy specifies a function to return a proxy for a given
107120
// Request. If the function returns a non-nil error, the
108121
// request is aborted with the provided error.
@@ -183,6 +196,18 @@ type Transport struct {
183196
// DefaultMaxIdleConnsPerHost is used.
184197
MaxIdleConnsPerHost int
185198

199+
// MaxConnsPerHost optionally limits the total number of
200+
// connections per host, including connections in the dialing,
201+
// active, and idle states. On limit violation, dials will block.
202+
//
203+
// Zero means no limit.
204+
//
205+
// For HTTP/2, this currently only controls the number of new
206+
// connections being created at a time, instead of the total
207+
// number. In practice, hosts using HTTP/2 only have about one
208+
// idle connection, though.
209+
MaxConnsPerHost int
210+
186211
// IdleConnTimeout is the maximum amount of time an idle
187212
// (keep-alive) connection will remain idle before closing
188213
// itself.
@@ -231,8 +256,6 @@ type Transport struct {
231256
// h2transport (via onceSetNextProtoDefaults)
232257
nextProtoOnce sync.Once
233258
h2transport *http2Transport // non-nil if http2 wired up
234-
235-
// TODO: tunable on max per-host TCP dials in flight (Issue 13957)
236259
}
237260

238261
// onceSetNextProtoDefaults initializes TLSNextProto.
@@ -409,7 +432,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
409432
var resp *Response
410433
if pconn.alt != nil {
411434
// HTTP/2 path.
412-
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
435+
t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
436+
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
413437
resp, err = pconn.alt.RoundTrip(req)
414438
} else {
415439
resp, err = pconn.roundTrip(treq)
@@ -908,6 +932,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
908932
err error
909933
}
910934
dialc := make(chan dialRes)
935+
cmKey := cm.key()
911936

912937
// Copy these hooks so we don't race on the postPendingDial in
913938
// the goroutine we launch. Issue 11136.
@@ -919,6 +944,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
919944
go func() {
920945
if v := <-dialc; v.err == nil {
921946
t.putOrCloseIdleConn(v.pc)
947+
} else {
948+
t.decHostConnCount(cmKey)
922949
}
923950
testHookPostPendingDial()
924951
}()
@@ -927,6 +954,27 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
927954
cancelc := make(chan error, 1)
928955
t.setReqCanceler(req, func(err error) { cancelc <- err })
929956

957+
if t.MaxConnsPerHost > 0 {
958+
select {
959+
case <-t.incHostConnCount(cmKey):
960+
// count below conn per host limit; proceed
961+
case pc := <-t.getIdleConnCh(cm):
962+
if trace != nil && trace.GotConn != nil {
963+
trace.GotConn(httptrace.GotConnInfo{Conn: pc.conn, Reused: pc.isReused()})
964+
}
965+
return pc, nil
966+
case <-req.Cancel:
967+
return nil, errRequestCanceledConn
968+
case <-req.Context().Done():
969+
return nil, req.Context().Err()
970+
case err := <-cancelc:
971+
if err == errRequestCanceled {
972+
err = errRequestCanceledConn
973+
}
974+
return nil, err
975+
}
976+
}
977+
930978
go func() {
931979
pc, err := t.dialConn(ctx, cm)
932980
dialc <- dialRes{pc, err}
@@ -944,6 +992,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
944992
}
945993
// Our dial failed. See why to return a nicer error
946994
// value.
995+
t.decHostConnCount(cmKey)
947996
select {
948997
case <-req.Cancel:
949998
// It was an error due to cancelation, so prioritize that
@@ -987,6 +1036,83 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
9871036
}
9881037
}
9891038

1039+
// incHostConnCount increments the count of connections for a
1040+
// given host. It returns an already-closed channel if the count
1041+
// is not at its limit; otherwise it returns a channel which is
1042+
// notified when the count is below the limit.
1043+
func (t *Transport) incHostConnCount(cmKey connectMethodKey) <-chan struct{} {
1044+
if t.MaxConnsPerHost <= 0 {
1045+
return connsPerHostClosedCh
1046+
}
1047+
t.connCountMu.Lock()
1048+
defer t.connCountMu.Unlock()
1049+
if t.connPerHostCount[cmKey] == t.MaxConnsPerHost {
1050+
if t.connPerHostAvailable == nil {
1051+
t.connPerHostAvailable = make(map[connectMethodKey]chan struct{})
1052+
}
1053+
ch, ok := t.connPerHostAvailable[cmKey]
1054+
if !ok {
1055+
ch = make(chan struct{})
1056+
t.connPerHostAvailable[cmKey] = ch
1057+
}
1058+
return ch
1059+
}
1060+
if t.connPerHostCount == nil {
1061+
t.connPerHostCount = make(map[connectMethodKey]int)
1062+
}
1063+
t.connPerHostCount[cmKey]++
1064+
// return a closed channel to avoid race: if decHostConnCount is called
1065+
// after incHostConnCount and during the nil check, decHostConnCount
1066+
// will delete the channel since it's not being listened on yet.
1067+
return connsPerHostClosedCh
1068+
}
1069+
1070+
// decHostConnCount decrements the count of connections
1071+
// for a given host.
1072+
// See Transport.MaxConnsPerHost.
1073+
func (t *Transport) decHostConnCount(cmKey connectMethodKey) {
1074+
if t.MaxConnsPerHost <= 0 {
1075+
return
1076+
}
1077+
t.connCountMu.Lock()
1078+
defer t.connCountMu.Unlock()
1079+
t.connPerHostCount[cmKey]--
1080+
select {
1081+
case t.connPerHostAvailable[cmKey] <- struct{}{}:
1082+
default:
1083+
// close channel before deleting avoids getConn waiting forever in
1084+
// case getConn has reference to channel but hasn't started waiting.
1085+
// This could lead to more than MaxConnsPerHost in the unlikely case
1086+
// that > 1 go routine has fetched the channel but none started waiting.
1087+
if t.connPerHostAvailable[cmKey] != nil {
1088+
close(t.connPerHostAvailable[cmKey])
1089+
}
1090+
delete(t.connPerHostAvailable, cmKey)
1091+
}
1092+
if t.connPerHostCount[cmKey] == 0 {
1093+
delete(t.connPerHostCount, cmKey)
1094+
}
1095+
}
1096+
1097+
// connCloseListener wraps a connection, the transport that dialed it
1098+
// and the connected-to host key so the host connection count can be
1099+
// transparently decremented by whatever closes the embedded connection.
1100+
type connCloseListener struct {
1101+
net.Conn
1102+
t *Transport
1103+
cmKey connectMethodKey
1104+
didClose int32
1105+
}
1106+
1107+
func (c *connCloseListener) Close() error {
1108+
if atomic.AddInt32(&c.didClose, 1) != 1 {
1109+
return nil
1110+
}
1111+
err := c.Conn.Close()
1112+
c.t.decHostConnCount(c.cmKey)
1113+
return err
1114+
}
1115+
9901116
// The connect method and the transport can both specify a TLS
9911117
// Host name. The transport's name takes precedence if present.
9921118
func chooseTLSHost(cm connectMethod, t *Transport) string {
@@ -1184,6 +1310,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
11841310
}
11851311
}
11861312

1313+
if t.MaxConnsPerHost > 0 {
1314+
pconn.conn = &connCloseListener{Conn: pconn.conn, t: t, cmKey: pconn.cacheKey}
1315+
}
11871316
pconn.br = bufio.NewReader(pconn)
11881317
pconn.bw = bufio.NewWriter(persistConnWriter{pconn})
11891318
go pconn.readLoop()

src/net/http/transport_test.go

+77-9
Original file line numberDiff line numberDiff line change
@@ -446,27 +446,95 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
446446
if e, g := 1, len(keys); e != g {
447447
t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
448448
}
449-
cacheKey := "|http|" + ts.Listener.Addr().String()
449+
addr := ts.Listener.Addr().String()
450+
cacheKey := "|http|" + addr
450451
if keys[0] != cacheKey {
451452
t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
452453
}
453-
if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g {
454+
if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
454455
t.Errorf("after first response, expected %d idle conns; got %d", e, g)
455456
}
456457

457458
resch <- "res2"
458459
<-donech
459-
if g, w := tr.IdleConnCountForTesting(cacheKey), 2; g != w {
460+
if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
460461
t.Errorf("after second response, idle conns = %d; want %d", g, w)
461462
}
462463

463464
resch <- "res3"
464465
<-donech
465-
if g, w := tr.IdleConnCountForTesting(cacheKey), maxIdleConnsPerHost; g != w {
466+
if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
466467
t.Errorf("after third response, idle conns = %d; want %d", g, w)
467468
}
468469
}
469470

471+
func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
472+
defer afterTest(t)
473+
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
474+
_, err := w.Write([]byte("foo"))
475+
if err != nil {
476+
t.Fatalf("Write: %v", err)
477+
}
478+
}))
479+
defer ts.Close()
480+
c := ts.Client()
481+
tr := c.Transport.(*Transport)
482+
dialStarted := make(chan struct{})
483+
stallDial := make(chan struct{})
484+
tr.Dial = func(network, addr string) (net.Conn, error) {
485+
dialStarted <- struct{}{}
486+
<-stallDial
487+
return net.Dial(network, addr)
488+
}
489+
490+
tr.DisableKeepAlives = true
491+
tr.MaxConnsPerHost = 1
492+
493+
preDial := make(chan struct{})
494+
reqComplete := make(chan struct{})
495+
doReq := func(reqId string) {
496+
req, _ := NewRequest("GET", ts.URL, nil)
497+
trace := &httptrace.ClientTrace{
498+
GetConn: func(hostPort string) {
499+
preDial <- struct{}{}
500+
},
501+
}
502+
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
503+
resp, err := tr.RoundTrip(req)
504+
if err != nil {
505+
t.Errorf("unexpected error for request %s: %v", reqId, err)
506+
}
507+
_, err = ioutil.ReadAll(resp.Body)
508+
if err != nil {
509+
t.Errorf("unexpected error for request %s: %v", reqId, err)
510+
}
511+
reqComplete <- struct{}{}
512+
}
513+
// get req1 to dial-in-progress
514+
go doReq("req1")
515+
<-preDial
516+
<-dialStarted
517+
518+
// get req2 to waiting on conns per host to go down below max
519+
go doReq("req2")
520+
<-preDial
521+
select {
522+
case <-dialStarted:
523+
t.Error("req2 dial started while req1 dial in progress")
524+
return
525+
default:
526+
}
527+
528+
// let req1 complete
529+
stallDial <- struct{}{}
530+
<-reqComplete
531+
532+
// let req2 complete
533+
<-dialStarted
534+
stallDial <- struct{}{}
535+
<-reqComplete
536+
}
537+
470538
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
471539
setParallel(t)
472540
defer afterTest(t)
@@ -3118,18 +3186,18 @@ func TestRoundTripReturnsProxyError(t *testing.T) {
31183186
func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
31193187
tr := &Transport{}
31203188
wantIdle := func(when string, n int) bool {
3121-
got := tr.IdleConnCountForTesting("|http|example.com") // key used by PutIdleTestConn
3189+
got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
31223190
if got == n {
31233191
return true
31243192
}
31253193
t.Errorf("%s: idle conns = %d; want %d", when, got, n)
31263194
return false
31273195
}
31283196
wantIdle("start", 0)
3129-
if !tr.PutIdleTestConn() {
3197+
if !tr.PutIdleTestConn("http", "example.com") {
31303198
t.Fatal("put failed")
31313199
}
3132-
if !tr.PutIdleTestConn() {
3200+
if !tr.PutIdleTestConn("http", "example.com") {
31333201
t.Fatal("second put failed")
31343202
}
31353203
wantIdle("after put", 2)
@@ -3138,7 +3206,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
31383206
t.Error("should be idle after CloseIdleConnections")
31393207
}
31403208
wantIdle("after close idle", 0)
3141-
if tr.PutIdleTestConn() {
3209+
if tr.PutIdleTestConn("http", "example.com") {
31423210
t.Fatal("put didn't fail")
31433211
}
31443212
wantIdle("after second put", 0)
@@ -3147,7 +3215,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
31473215
if tr.IsIdleForTesting() {
31483216
t.Error("shouldn't be idle after RequestIdleConnChForTesting")
31493217
}
3150-
if !tr.PutIdleTestConn() {
3218+
if !tr.PutIdleTestConn("http", "example.com") {
31513219
t.Fatal("after re-activation")
31523220
}
31533221
wantIdle("after final put", 1)

0 commit comments

Comments
 (0)