Skip to content

Commit 0e5043f

Browse files
fraenkeldmitshur
authored andcommitted
[internal-branch.go1.16-vendor] http2: close the request body if needed
As per client.Do and Request.Body, the transport is responsible to close the request Body. If there was an error or non 1xx/2xx status code, the transport will wait for the body writer to complete. If there is no data available to read, the body writer will block indefinitely. To prevent this, the body will be closed if it hasn't already. If there was a 1xx/2xx status code, the body will be closed eventually. Updates golang/go#49076 Change-Id: I9a4a5f13658122c562baf915e2c0c8992a023278 Reviewed-on: https://go-review.googlesource.com/c/net/+/323689 Reviewed-by: Damien Neil <[email protected]> Trust: Damien Neil <[email protected]> Trust: Alexander Rakoczy <[email protected]> Run-TryBot: Damien Neil <[email protected]> TryBot-Result: Go Bot <[email protected]> Reviewed-on: https://go-review.googlesource.com/c/net/+/356976 Reviewed-by: Dmitri Shuralyov <[email protected]>
1 parent bb4ce86 commit 0e5043f

File tree

2 files changed

+74
-30
lines changed

2 files changed

+74
-30
lines changed

http2/transport.go

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,13 @@ func (cs *clientStream) abortRequestBodyWrite(err error) {
385385
}
386386
cc := cs.cc
387387
cc.mu.Lock()
388-
cs.stopReqBody = err
389-
cc.cond.Broadcast()
388+
if cs.stopReqBody == nil {
389+
cs.stopReqBody = err
390+
if cs.req.Body != nil {
391+
cs.req.Body.Close()
392+
}
393+
cc.cond.Broadcast()
394+
}
390395
cc.mu.Unlock()
391396
}
392397

@@ -1114,40 +1119,28 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
11141119
return res, false, nil
11151120
}
11161121

1122+
handleError := func(err error) (*http.Response, bool, error) {
1123+
if !hasBody || bodyWritten {
1124+
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1125+
} else {
1126+
bodyWriter.cancel()
1127+
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1128+
<-bodyWriter.resc
1129+
}
1130+
cc.forgetStreamID(cs.ID)
1131+
return nil, cs.getStartedWrite(), err
1132+
}
1133+
11171134
for {
11181135
select {
11191136
case re := <-readLoopResCh:
11201137
return handleReadLoopResponse(re)
11211138
case <-respHeaderTimer:
1122-
if !hasBody || bodyWritten {
1123-
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1124-
} else {
1125-
bodyWriter.cancel()
1126-
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1127-
<-bodyWriter.resc
1128-
}
1129-
cc.forgetStreamID(cs.ID)
1130-
return nil, cs.getStartedWrite(), errTimeout
1139+
return handleError(errTimeout)
11311140
case <-ctx.Done():
1132-
if !hasBody || bodyWritten {
1133-
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1134-
} else {
1135-
bodyWriter.cancel()
1136-
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1137-
<-bodyWriter.resc
1138-
}
1139-
cc.forgetStreamID(cs.ID)
1140-
return nil, cs.getStartedWrite(), ctx.Err()
1141+
return handleError(ctx.Err())
11411142
case <-req.Cancel:
1142-
if !hasBody || bodyWritten {
1143-
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1144-
} else {
1145-
bodyWriter.cancel()
1146-
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1147-
<-bodyWriter.resc
1148-
}
1149-
cc.forgetStreamID(cs.ID)
1150-
return nil, cs.getStartedWrite(), errRequestCanceled
1143+
return handleError(errRequestCanceled)
11511144
case <-cs.peerReset:
11521145
// processResetStream already removed the
11531146
// stream from the streams map; no need for
@@ -1294,7 +1287,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
12941287
// Request.Body is closed by the Transport,
12951288
// and in multiple cases: server replies <=299 and >299
12961289
// while still writing request body
1297-
cerr := bodyCloser.Close()
1290+
var cerr error
1291+
cc.mu.Lock()
1292+
if cs.stopReqBody == nil {
1293+
cs.stopReqBody = errStopReqBodyWrite
1294+
cerr = bodyCloser.Close()
1295+
}
1296+
cc.mu.Unlock()
12981297
if err == nil {
12991298
err = cerr
13001299
}

http2/transport_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4899,3 +4899,48 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
48994899
}
49004900
res.Body.Close()
49014901
}
4902+
4903+
type closeChecker struct {
4904+
io.ReadCloser
4905+
closed chan struct{}
4906+
}
4907+
4908+
func (rc *closeChecker) Close() error {
4909+
close(rc.closed)
4910+
return rc.ReadCloser.Close()
4911+
}
4912+
4913+
func TestTransportCloseRequestBody(t *testing.T) {
4914+
var statusCode int
4915+
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4916+
w.WriteHeader(statusCode)
4917+
}, optOnlyServer)
4918+
defer st.Close()
4919+
4920+
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4921+
defer tr.CloseIdleConnections()
4922+
ctx := context.Background()
4923+
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
4924+
if err != nil {
4925+
t.Fatal(err)
4926+
}
4927+
4928+
for _, status := range []int{200, 401} {
4929+
t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
4930+
statusCode = status
4931+
pr, pw := io.Pipe()
4932+
pipeClosed := make(chan struct{})
4933+
req, err := http.NewRequest("PUT", "https://dummy.tld/", &closeChecker{pr, pipeClosed})
4934+
if err != nil {
4935+
t.Fatal(err)
4936+
}
4937+
res, err := cc.RoundTrip(req)
4938+
if err != nil {
4939+
t.Fatal(err)
4940+
}
4941+
res.Body.Close()
4942+
pw.Close()
4943+
<-pipeClosed
4944+
})
4945+
}
4946+
}

0 commit comments

Comments
 (0)