Skip to content

Commit 9d51cd0

Browse files
committed
net/http: don't reuse a server connection after any Write errors
Fixes #8534 LGTM=adg R=adg CC=golang-codereviews https://golang.org/cl/149340044
1 parent a681749 commit 9d51cd0

File tree

2 files changed

+126
-3
lines changed

2 files changed

+126
-3
lines changed

src/net/http/serve_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2659,6 +2659,103 @@ func TestCloseWrite(t *testing.T) {
26592659
}
26602660
}
26612661

2662+
// This verifies that a handler can Flush and then Hijack.
2663+
//
2664+
// An similar test crashed once during development, but it was only
2665+
// testing this tangentially and temporarily until another TODO was
2666+
// fixed.
2667+
//
2668+
// So add an explicit test for this.
2669+
func TestServerFlushAndHijack(t *testing.T) {
2670+
defer afterTest(t)
2671+
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2672+
io.WriteString(w, "Hello, ")
2673+
w.(Flusher).Flush()
2674+
conn, buf, _ := w.(Hijacker).Hijack()
2675+
buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
2676+
if err := buf.Flush(); err != nil {
2677+
t.Error(err)
2678+
}
2679+
if err := conn.Close(); err != nil {
2680+
t.Error(err)
2681+
}
2682+
}))
2683+
defer ts.Close()
2684+
res, err := Get(ts.URL)
2685+
if err != nil {
2686+
t.Fatal(err)
2687+
}
2688+
defer res.Body.Close()
2689+
all, err := ioutil.ReadAll(res.Body)
2690+
if err != nil {
2691+
t.Fatal(err)
2692+
}
2693+
if want := "Hello, world!"; string(all) != want {
2694+
t.Errorf("Got %q; want %q", all, want)
2695+
}
2696+
}
2697+
2698+
// golang.org/issue/8534 -- the Server shouldn't reuse a connection
2699+
// for keep-alive after it's seen any Write error (e.g. a timeout) on
2700+
// that net.Conn.
2701+
//
2702+
// To test, verify we don't timeout or see fewer unique client
2703+
// addresses (== unique connections) than requests.
2704+
func TestServerKeepAliveAfterWriteError(t *testing.T) {
2705+
if testing.Short() {
2706+
t.Skip("skipping in -short mode")
2707+
}
2708+
defer afterTest(t)
2709+
const numReq = 3
2710+
addrc := make(chan string, numReq)
2711+
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2712+
addrc <- r.RemoteAddr
2713+
time.Sleep(500 * time.Millisecond)
2714+
w.(Flusher).Flush()
2715+
}))
2716+
ts.Config.WriteTimeout = 250 * time.Millisecond
2717+
ts.Start()
2718+
defer ts.Close()
2719+
2720+
errc := make(chan error, numReq)
2721+
go func() {
2722+
defer close(errc)
2723+
for i := 0; i < numReq; i++ {
2724+
res, err := Get(ts.URL)
2725+
if res != nil {
2726+
res.Body.Close()
2727+
}
2728+
errc <- err
2729+
}
2730+
}()
2731+
2732+
timeout := time.NewTimer(numReq * 2 * time.Second) // 4x overkill
2733+
defer timeout.Stop()
2734+
addrSeen := map[string]bool{}
2735+
numOkay := 0
2736+
for {
2737+
select {
2738+
case v := <-addrc:
2739+
addrSeen[v] = true
2740+
case err, ok := <-errc:
2741+
if !ok {
2742+
if len(addrSeen) != numReq {
2743+
t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
2744+
}
2745+
if numOkay != 0 {
2746+
t.Errorf("got %d successful client requests; want 0", numOkay)
2747+
}
2748+
return
2749+
}
2750+
if err == nil {
2751+
numOkay++
2752+
}
2753+
case <-timeout.C:
2754+
t.Fatal("timeout waiting for requests to complete")
2755+
}
2756+
}
2757+
}
2758+
26622759
func BenchmarkClientServer(b *testing.B) {
26632760
b.ReportAllocs()
26642761
b.StopTimer()

src/net/http/server.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ type conn struct {
114114
remoteAddr string // network address of remote side
115115
server *Server // the Server on which the connection arrived
116116
rwc net.Conn // i/o connection
117+
w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack
118+
werr error // any errors writing to w
117119
sr liveSwitchReader // where the LimitReader reads from; usually the rwc
118120
lr *io.LimitedReader // io.LimitReader(sr)
119121
buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
@@ -432,13 +434,14 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
432434
c.remoteAddr = rwc.RemoteAddr().String()
433435
c.server = srv
434436
c.rwc = rwc
437+
c.w = rwc
435438
if debugServerConnections {
436439
c.rwc = newLoggingConn("server", c.rwc)
437440
}
438441
c.sr = liveSwitchReader{r: c.rwc}
439442
c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
440443
br := newBufioReader(c.lr)
441-
bw := newBufioWriterSize(c.rwc, 4<<10)
444+
bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
442445
c.buf = bufio.NewReadWriter(br, bw)
443446
return c, nil
444447
}
@@ -956,8 +959,10 @@ func (w *response) bodyAllowed() bool {
956959
// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes
957960
// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
958961
// and which writes the chunk headers, if needed.
959-
// 4. conn.buf, a bufio.Writer of default (4kB) bytes
960-
// 5. the rwc, the net.Conn.
962+
// 4. conn.buf, a bufio.Writer of default (4kB) bytes, writing to ->
963+
// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write
964+
// and populates c.werr with it if so. but otherwise writes to:
965+
// 6. the rwc, the net.Conn.
961966
//
962967
// TODO(bradfitz): short-circuit some of the buffering when the
963968
// initial header contains both a Content-Type and Content-Length.
@@ -1027,6 +1032,12 @@ func (w *response) finishRequest() {
10271032
// Did not write enough. Avoid getting out of sync.
10281033
w.closeAfterReply = true
10291034
}
1035+
1036+
// There was some error writing to the underlying connection
1037+
// during the request, so don't re-use this conn.
1038+
if w.conn.werr != nil {
1039+
w.closeAfterReply = true
1040+
}
10301041
}
10311042

10321043
func (w *response) Flush() {
@@ -2068,3 +2079,18 @@ func (c *loggingConn) Close() (err error) {
20682079
log.Printf("%s.Close() = %v", c.name, err)
20692080
return
20702081
}
2082+
2083+
// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
2084+
// It only contains one field (and a pointer field at that), so it
2085+
// fits in an interface value without an extra allocation.
2086+
type checkConnErrorWriter struct {
2087+
c *conn
2088+
}
2089+
2090+
func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
2091+
n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil.
2092+
if err != nil && w.c.werr == nil {
2093+
w.c.werr = err
2094+
}
2095+
return
2096+
}

0 commit comments

Comments
 (0)