Skip to content

Commit 2449bbb

Browse files
Shang Dinggopherbot
Shang Ding
authored andcommitted
net/http/httputil: use response controller in reverse proxy
Previously, the reverse proxy is unable to detect the support for hijack or flush if those things are residing in the response writer in a wrapped manner. The reverse proxy now makes use of the new http response controller as the means to discover the underlying flusher and hijacker associated with the response writer, allowing wrapped flusher and hijacker become discoverable. Change-Id: I53acbb12315c3897be068e8c00598ef42fc74649 Reviewed-on: https://go-review.googlesource.com/c/go/+/468755 Run-TryBot: Damien Neil <[email protected]> Auto-Submit: Damien Neil <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Damien Neil <[email protected]> Reviewed-by: Cherry Mui <[email protected]>
1 parent 602e6aa commit 2449bbb

File tree

2 files changed

+84
-33
lines changed

2 files changed

+84
-33
lines changed

src/net/http/httputil/reverseproxy.go

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
524524
// Force chunking if we saw a response trailer.
525525
// This prevents net/http from calculating the length for short
526526
// bodies and adding a Content-Length.
527-
if fl, ok := rw.(http.Flusher); ok {
528-
fl.Flush()
529-
}
527+
http.NewResponseController(rw).Flush()
530528
}
531529

532530
if len(res.Trailer) == announcedTrailers {
@@ -601,29 +599,30 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
601599
return p.FlushInterval
602600
}
603601

604-
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
602+
func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
603+
var w io.Writer = dst
604+
605605
if flushInterval != 0 {
606-
if wf, ok := dst.(writeFlusher); ok {
607-
mlw := &maxLatencyWriter{
608-
dst: wf,
609-
latency: flushInterval,
610-
}
611-
defer mlw.stop()
606+
mlw := &maxLatencyWriter{
607+
dst: dst,
608+
flush: http.NewResponseController(dst).Flush,
609+
latency: flushInterval,
610+
}
611+
defer mlw.stop()
612612

613-
// set up initial timer so headers get flushed even if body writes are delayed
614-
mlw.flushPending = true
615-
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
613+
// set up initial timer so headers get flushed even if body writes are delayed
614+
mlw.flushPending = true
615+
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
616616

617-
dst = mlw
618-
}
617+
w = mlw
619618
}
620619

621620
var buf []byte
622621
if p.BufferPool != nil {
623622
buf = p.BufferPool.Get()
624623
defer p.BufferPool.Put(buf)
625624
}
626-
_, err := p.copyBuffer(dst, src, buf)
625+
_, err := p.copyBuffer(w, src, buf)
627626
return err
628627
}
629628

@@ -668,13 +667,9 @@ func (p *ReverseProxy) logf(format string, args ...any) {
668667
}
669668
}
670669

671-
type writeFlusher interface {
672-
io.Writer
673-
http.Flusher
674-
}
675-
676670
type maxLatencyWriter struct {
677-
dst writeFlusher
671+
dst io.Writer
672+
flush func() error
678673
latency time.Duration // non-zero; negative means to flush immediately
679674

680675
mu sync.Mutex // protects t, flushPending, and dst.Flush
@@ -687,7 +682,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
687682
defer m.mu.Unlock()
688683
n, err = m.dst.Write(p)
689684
if m.latency < 0 {
690-
m.dst.Flush()
685+
m.flush()
691686
return
692687
}
693688
if m.flushPending {
@@ -708,7 +703,7 @@ func (m *maxLatencyWriter) delayedFlush() {
708703
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
709704
return
710705
}
711-
m.dst.Flush()
706+
m.flush()
712707
m.flushPending = false
713708
}
714709

@@ -739,17 +734,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
739734
return
740735
}
741736

742-
hj, ok := rw.(http.Hijacker)
743-
if !ok {
744-
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
745-
return
746-
}
747737
backConn, ok := res.Body.(io.ReadWriteCloser)
748738
if !ok {
749739
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
750740
return
751741
}
752742

743+
rc := http.NewResponseController(rw)
744+
conn, brw, hijackErr := rc.Hijack()
745+
if errors.Is(hijackErr, http.ErrNotSupported) {
746+
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
747+
return
748+
}
749+
753750
backConnCloseCh := make(chan bool)
754751
go func() {
755752
// Ensure that the cancellation of a request closes the backend.
@@ -760,12 +757,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
760757
}
761758
backConn.Close()
762759
}()
763-
764760
defer close(backConnCloseCh)
765761

766-
conn, brw, err := hj.Hijack()
767-
if err != nil {
768-
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
762+
if hijackErr != nil {
763+
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
769764
return
770765
}
771766
defer conn.Close()

src/net/http/httputil/reverseproxy_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,62 @@ func TestReverseProxyFlushInterval(t *testing.T) {
478478
}
479479
}
480480

481+
type mockFlusher struct {
482+
http.ResponseWriter
483+
flushed bool
484+
}
485+
486+
func (m *mockFlusher) Flush() {
487+
m.flushed = true
488+
}
489+
490+
type wrappedRW struct {
491+
http.ResponseWriter
492+
}
493+
494+
func (w *wrappedRW) Unwrap() http.ResponseWriter {
495+
return w.ResponseWriter
496+
}
497+
498+
func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
499+
const expected = "hi"
500+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
501+
w.Write([]byte(expected))
502+
}))
503+
defer backend.Close()
504+
505+
backendURL, err := url.Parse(backend.URL)
506+
if err != nil {
507+
t.Fatal(err)
508+
}
509+
510+
mf := &mockFlusher{}
511+
proxyHandler := NewSingleHostReverseProxy(backendURL)
512+
proxyHandler.FlushInterval = -1 // flush immediately
513+
proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
514+
mf.ResponseWriter = w
515+
w = &wrappedRW{mf}
516+
proxyHandler.ServeHTTP(w, r)
517+
})
518+
519+
frontend := httptest.NewServer(proxyWithMiddleware)
520+
defer frontend.Close()
521+
522+
req, _ := http.NewRequest("GET", frontend.URL, nil)
523+
req.Close = true
524+
res, err := frontend.Client().Do(req)
525+
if err != nil {
526+
t.Fatalf("Get: %v", err)
527+
}
528+
defer res.Body.Close()
529+
if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
530+
t.Errorf("got body %q; expected %q", bodyBytes, expected)
531+
}
532+
if !mf.flushed {
533+
t.Errorf("response writer was not flushed")
534+
}
535+
}
536+
481537
func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
482538
const expected = "hi"
483539
stopCh := make(chan struct{})

0 commit comments

Comments
 (0)