@@ -55,6 +55,15 @@ var DefaultTransport RoundTripper = &Transport{
55
55
// MaxIdleConnsPerHost.
56
56
const DefaultMaxIdleConnsPerHost = 2
57
57
58
+ // connsPerHostClosedCh is a closed channel used by MaxConnsPerHost
59
+ // for the property that receives from a closed channel return the
60
+ // zero value.
61
+ var connsPerHostClosedCh = make (chan struct {})
62
+
63
+ func init () {
64
+ close (connsPerHostClosedCh )
65
+ }
66
+
58
67
// Transport is an implementation of RoundTripper that supports HTTP,
59
68
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
60
69
//
@@ -103,6 +112,10 @@ type Transport struct {
103
112
altMu sync.Mutex // guards changing altProto only
104
113
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
105
114
115
+ connCountMu sync.Mutex
116
+ connPerHostCount map [connectMethodKey ]int
117
+ connPerHostAvailable map [connectMethodKey ]chan struct {}
118
+
106
119
// Proxy specifies a function to return a proxy for a given
107
120
// Request. If the function returns a non-nil error, the
108
121
// request is aborted with the provided error.
@@ -183,6 +196,18 @@ type Transport struct {
183
196
// DefaultMaxIdleConnsPerHost is used.
184
197
MaxIdleConnsPerHost int
185
198
199
+ // MaxConnsPerHost optionally limits the total number of
200
+ // connections per host, including connections in the dialing,
201
+ // active, and idle states. On limit violation, dials will block.
202
+ //
203
+ // Zero means no limit.
204
+ //
205
+ // For HTTP/2, this currently only controls the number of new
206
+ // connections being created at a time, instead of the total
207
+ // number. In practice, hosts using HTTP/2 only have about one
208
+ // idle connection, though.
209
+ MaxConnsPerHost int
210
+
186
211
// IdleConnTimeout is the maximum amount of time an idle
187
212
// (keep-alive) connection will remain idle before closing
188
213
// itself.
@@ -231,8 +256,6 @@ type Transport struct {
231
256
// h2transport (via onceSetNextProtoDefaults)
232
257
nextProtoOnce sync.Once
233
258
h2transport * http2Transport // non-nil if http2 wired up
234
-
235
- // TODO: tunable on max per-host TCP dials in flight (Issue 13957)
236
259
}
237
260
238
261
// onceSetNextProtoDefaults initializes TLSNextProto.
@@ -409,7 +432,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
409
432
var resp * Response
410
433
if pconn .alt != nil {
411
434
// HTTP/2 path.
412
- t .setReqCanceler (req , nil ) // not cancelable with CancelRequest
435
+ t .decHostConnCount (cm .key ()) // don't count cached http2 conns toward conns per host
436
+ t .setReqCanceler (req , nil ) // not cancelable with CancelRequest
413
437
resp , err = pconn .alt .RoundTrip (req )
414
438
} else {
415
439
resp , err = pconn .roundTrip (treq )
@@ -908,6 +932,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
908
932
err error
909
933
}
910
934
dialc := make (chan dialRes )
935
+ cmKey := cm .key ()
911
936
912
937
// Copy these hooks so we don't race on the postPendingDial in
913
938
// the goroutine we launch. Issue 11136.
@@ -919,6 +944,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
919
944
go func () {
920
945
if v := <- dialc ; v .err == nil {
921
946
t .putOrCloseIdleConn (v .pc )
947
+ } else {
948
+ t .decHostConnCount (cmKey )
922
949
}
923
950
testHookPostPendingDial ()
924
951
}()
@@ -927,6 +954,27 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
927
954
cancelc := make (chan error , 1 )
928
955
t .setReqCanceler (req , func (err error ) { cancelc <- err })
929
956
957
+ if t .MaxConnsPerHost > 0 {
958
+ select {
959
+ case <- t .incHostConnCount (cmKey ):
960
+ // count below conn per host limit; proceed
961
+ case pc := <- t .getIdleConnCh (cm ):
962
+ if trace != nil && trace .GotConn != nil {
963
+ trace .GotConn (httptrace.GotConnInfo {Conn : pc .conn , Reused : pc .isReused ()})
964
+ }
965
+ return pc , nil
966
+ case <- req .Cancel :
967
+ return nil , errRequestCanceledConn
968
+ case <- req .Context ().Done ():
969
+ return nil , req .Context ().Err ()
970
+ case err := <- cancelc :
971
+ if err == errRequestCanceled {
972
+ err = errRequestCanceledConn
973
+ }
974
+ return nil , err
975
+ }
976
+ }
977
+
930
978
go func () {
931
979
pc , err := t .dialConn (ctx , cm )
932
980
dialc <- dialRes {pc , err }
@@ -944,6 +992,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
944
992
}
945
993
// Our dial failed. See why to return a nicer error
946
994
// value.
995
+ t .decHostConnCount (cmKey )
947
996
select {
948
997
case <- req .Cancel :
949
998
// It was an error due to cancelation, so prioritize that
@@ -987,6 +1036,83 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
987
1036
}
988
1037
}
989
1038
1039
+ // incHostConnCount increments the count of connections for a
1040
+ // given host. It returns an already-closed channel if the count
1041
+ // is not at its limit; otherwise it returns a channel which is
1042
+ // notified when the count is below the limit.
1043
+ func (t * Transport ) incHostConnCount (cmKey connectMethodKey ) <- chan struct {} {
1044
+ if t .MaxConnsPerHost <= 0 {
1045
+ return connsPerHostClosedCh
1046
+ }
1047
+ t .connCountMu .Lock ()
1048
+ defer t .connCountMu .Unlock ()
1049
+ if t .connPerHostCount [cmKey ] == t .MaxConnsPerHost {
1050
+ if t .connPerHostAvailable == nil {
1051
+ t .connPerHostAvailable = make (map [connectMethodKey ]chan struct {})
1052
+ }
1053
+ ch , ok := t .connPerHostAvailable [cmKey ]
1054
+ if ! ok {
1055
+ ch = make (chan struct {})
1056
+ t .connPerHostAvailable [cmKey ] = ch
1057
+ }
1058
+ return ch
1059
+ }
1060
+ if t .connPerHostCount == nil {
1061
+ t .connPerHostCount = make (map [connectMethodKey ]int )
1062
+ }
1063
+ t .connPerHostCount [cmKey ]++
1064
+ // return a closed channel to avoid race: if decHostConnCount is called
1065
+ // after incHostConnCount and during the nil check, decHostConnCount
1066
+ // will delete the channel since it's not being listened on yet.
1067
+ return connsPerHostClosedCh
1068
+ }
1069
+
1070
+ // decHostConnCount decrements the count of connections
1071
+ // for a given host.
1072
+ // See Transport.MaxConnsPerHost.
1073
+ func (t * Transport ) decHostConnCount (cmKey connectMethodKey ) {
1074
+ if t .MaxConnsPerHost <= 0 {
1075
+ return
1076
+ }
1077
+ t .connCountMu .Lock ()
1078
+ defer t .connCountMu .Unlock ()
1079
+ t .connPerHostCount [cmKey ]--
1080
+ select {
1081
+ case t .connPerHostAvailable [cmKey ] <- struct {}{}:
1082
+ default :
1083
+ // close channel before deleting avoids getConn waiting forever in
1084
+ // case getConn has reference to channel but hasn't started waiting.
1085
+ // This could lead to more than MaxConnsPerHost in the unlikely case
1086
+ // that > 1 go routine has fetched the channel but none started waiting.
1087
+ if t .connPerHostAvailable [cmKey ] != nil {
1088
+ close (t .connPerHostAvailable [cmKey ])
1089
+ }
1090
+ delete (t .connPerHostAvailable , cmKey )
1091
+ }
1092
+ if t .connPerHostCount [cmKey ] == 0 {
1093
+ delete (t .connPerHostCount , cmKey )
1094
+ }
1095
+ }
1096
+
1097
+ // connCloseListener wraps a connection, the transport that dialed it
1098
+ // and the connected-to host key so the host connection count can be
1099
+ // transparently decremented by whatever closes the embedded connection.
1100
+ type connCloseListener struct {
1101
+ net.Conn
1102
+ t * Transport
1103
+ cmKey connectMethodKey
1104
+ didClose int32
1105
+ }
1106
+
1107
+ func (c * connCloseListener ) Close () error {
1108
+ if atomic .AddInt32 (& c .didClose , 1 ) != 1 {
1109
+ return nil
1110
+ }
1111
+ err := c .Conn .Close ()
1112
+ c .t .decHostConnCount (c .cmKey )
1113
+ return err
1114
+ }
1115
+
990
1116
// The connect method and the transport can both specify a TLS
991
1117
// Host name. The transport's name takes precedence if present.
992
1118
func chooseTLSHost (cm connectMethod , t * Transport ) string {
@@ -1184,6 +1310,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
1184
1310
}
1185
1311
}
1186
1312
1313
+ if t .MaxConnsPerHost > 0 {
1314
+ pconn .conn = & connCloseListener {Conn : pconn .conn , t : t , cmKey : pconn .cacheKey }
1315
+ }
1187
1316
pconn .br = bufio .NewReader (pconn )
1188
1317
pconn .bw = bufio .NewWriter (persistConnWriter {pconn })
1189
1318
go pconn .readLoop ()
0 commit comments