Skip to content

Commit 0ed1016

Browse files
net/http: return an error if Write is called after WriteTimeout
response.Write now returns an error if the called happened after the configured server WriteTimeout. Fixes #21389
1 parent 335569b commit 0ed1016

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

src/net/http/serve_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,47 @@ func TestOnlyWriteTimeout(t *testing.T) {
972972
}
973973
}
974974

975+
func TestErrorAfterWriteTimeout(t *testing.T) {
976+
setParallel(t)
977+
defer afterTest(t)
978+
writeTimeout := 200 * time.Millisecond
979+
var afterTimeoutErrc = make(chan error, 1)
980+
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
981+
time.Sleep(2 * writeTimeout)
982+
983+
_, err := w.Write([]byte("test"))
984+
afterTimeoutErrc <- err
985+
}))
986+
ts.Config.WriteTimeout = writeTimeout
987+
ts.Start()
988+
defer ts.Close()
989+
990+
c := ts.Client()
991+
992+
errc := make(chan error, 1)
993+
go func() {
994+
res, err := c.Get(ts.URL)
995+
if err != nil {
996+
errc <- err
997+
return
998+
}
999+
_, err = io.Copy(io.Discard, res.Body)
1000+
res.Body.Close()
1001+
errc <- err
1002+
}()
1003+
select {
1004+
case err := <-errc:
1005+
if err == nil {
1006+
t.Errorf("expected an error from Get request")
1007+
}
1008+
case <-time.After(10 * time.Second):
1009+
t.Fatal("timeout waiting for Get error")
1010+
}
1011+
if err := <-afterTimeoutErrc; err == nil {
1012+
t.Error("expected write error after timeout")
1013+
}
1014+
}
1015+
9751016
// trackLastConnListener tracks the last net.Conn that was accepted.
9761017
type trackLastConnListener struct {
9771018
net.Listener

src/net/http/server.go

+27-4
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,11 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) {
390390
return
391391
}
392392

393-
func (cw *chunkWriter) flush() {
393+
func (cw *chunkWriter) flush() error {
394394
if !cw.wroteHeader {
395395
cw.writeHeader(nil)
396396
}
397-
cw.res.conn.bufw.Flush()
397+
return cw.res.conn.bufw.Flush()
398398
}
399399

400400
func (cw *chunkWriter) close() {
@@ -438,6 +438,9 @@ type response struct {
438438
w *bufio.Writer // buffers output in chunks to chunkWriter
439439
cw chunkWriter
440440

441+
writeTimeoutTimer *time.Timer // triggers on write timeout
442+
writeDeadline bool // request writes timed out
443+
441444
// handlerHeader is the Header that Handlers get access to,
442445
// which may be retained and mutated even after WriteHeader.
443446
// handlerHeader is copied into cw.header at WriteHeader
@@ -1052,6 +1055,9 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
10521055
if isH2Upgrade {
10531056
w.closeAfterReply = true
10541057
}
1058+
if d := c.server.WriteTimeout; d > 0 {
1059+
w.setWriteTimeout(d)
1060+
}
10551061
w.cw.res = w
10561062
w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize)
10571063
return w, nil
@@ -1578,6 +1584,14 @@ func (w *response) WriteString(data string) (n int, err error) {
15781584
return w.write(len(data), nil, data)
15791585
}
15801586

1587+
// setWriteTimeout lets the response know if the write was supposed to be
1588+
// timed out, timed out requests will force be flushed on every write
1589+
func (w *response) setWriteTimeout(d time.Duration) {
1590+
w.writeTimeoutTimer = time.AfterFunc(d, func() {
1591+
w.writeDeadline = true
1592+
})
1593+
}
1594+
15811595
// either dataB or dataS is non-zero.
15821596
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
15831597
if w.conn.hijacked() {
@@ -1613,10 +1627,16 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er
16131627
return 0, ErrContentLength
16141628
}
16151629
if dataB != nil {
1616-
return w.w.Write(dataB)
1630+
n, err = w.w.Write(dataB)
16171631
} else {
1618-
return w.w.WriteString(dataS)
1632+
n, err = w.w.WriteString(dataS)
1633+
}
1634+
if err == nil && w.writeDeadline {
1635+
// r.Flush returns no errors, flush manually
1636+
w.w.Flush()
1637+
err = w.cw.flush()
16191638
}
1639+
return
16201640
}
16211641

16221642
func (w *response) finishRequest() {
@@ -1631,6 +1651,9 @@ func (w *response) finishRequest() {
16311651
w.cw.close()
16321652
w.conn.bufw.Flush()
16331653

1654+
if w.writeTimeoutTimer != nil {
1655+
w.writeTimeoutTimer.Stop()
1656+
}
16341657
w.conn.r.abortPendingRead()
16351658

16361659
// Close the body (regardless of w.closeAfterReply) so we can

0 commit comments

Comments
 (0)