@@ -524,9 +524,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
524
524
// Force chunking if we saw a response trailer.
525
525
// This prevents net/http from calculating the length for short
526
526
// bodies and adding a Content-Length.
527
- if fl , ok := rw .(http.Flusher ); ok {
528
- fl .Flush ()
529
- }
527
+ http .NewResponseController (rw ).Flush ()
530
528
}
531
529
532
530
if len (res .Trailer ) == announcedTrailers {
@@ -601,29 +599,30 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
601
599
return p .FlushInterval
602
600
}
603
601
604
- func (p * ReverseProxy ) copyResponse (dst io.Writer , src io.Reader , flushInterval time.Duration ) error {
602
+ func (p * ReverseProxy ) copyResponse (dst http.ResponseWriter , src io.Reader , flushInterval time.Duration ) error {
603
+ var w io.Writer = dst
604
+
605
605
if flushInterval != 0 {
606
- if wf , ok := dst .( writeFlusher ); ok {
607
- mlw := & maxLatencyWriter {
608
- dst : wf ,
609
- latency : flushInterval ,
610
- }
611
- defer mlw .stop ()
606
+ mlw := & maxLatencyWriter {
607
+ dst : dst ,
608
+ flush : http . NewResponseController ( dst ). Flush ,
609
+ latency : flushInterval ,
610
+ }
611
+ defer mlw .stop ()
612
612
613
- // set up initial timer so headers get flushed even if body writes are delayed
614
- mlw .flushPending = true
615
- mlw .t = time .AfterFunc (flushInterval , mlw .delayedFlush )
613
+ // set up initial timer so headers get flushed even if body writes are delayed
614
+ mlw .flushPending = true
615
+ mlw .t = time .AfterFunc (flushInterval , mlw .delayedFlush )
616
616
617
- dst = mlw
618
- }
617
+ w = mlw
619
618
}
620
619
621
620
var buf []byte
622
621
if p .BufferPool != nil {
623
622
buf = p .BufferPool .Get ()
624
623
defer p .BufferPool .Put (buf )
625
624
}
626
- _ , err := p .copyBuffer (dst , src , buf )
625
+ _ , err := p .copyBuffer (w , src , buf )
627
626
return err
628
627
}
629
628
@@ -668,13 +667,9 @@ func (p *ReverseProxy) logf(format string, args ...any) {
668
667
}
669
668
}
670
669
671
- type writeFlusher interface {
672
- io.Writer
673
- http.Flusher
674
- }
675
-
676
670
type maxLatencyWriter struct {
677
- dst writeFlusher
671
+ dst io.Writer
672
+ flush func () error
678
673
latency time.Duration // non-zero; negative means to flush immediately
679
674
680
675
mu sync.Mutex // protects t, flushPending, and dst.Flush
@@ -687,7 +682,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
687
682
defer m .mu .Unlock ()
688
683
n , err = m .dst .Write (p )
689
684
if m .latency < 0 {
690
- m .dst . Flush ()
685
+ m .flush ()
691
686
return
692
687
}
693
688
if m .flushPending {
@@ -708,7 +703,7 @@ func (m *maxLatencyWriter) delayedFlush() {
708
703
if ! m .flushPending { // if stop was called but AfterFunc already started this goroutine
709
704
return
710
705
}
711
- m .dst . Flush ()
706
+ m .flush ()
712
707
m .flushPending = false
713
708
}
714
709
@@ -739,17 +734,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
739
734
return
740
735
}
741
736
742
- hj , ok := rw .(http.Hijacker )
743
- if ! ok {
744
- p .getErrorHandler ()(rw , req , fmt .Errorf ("can't switch protocols using non-Hijacker ResponseWriter type %T" , rw ))
745
- return
746
- }
747
737
backConn , ok := res .Body .(io.ReadWriteCloser )
748
738
if ! ok {
749
739
p .getErrorHandler ()(rw , req , fmt .Errorf ("internal error: 101 switching protocols response with non-writable body" ))
750
740
return
751
741
}
752
742
743
+ rc := http .NewResponseController (rw )
744
+ conn , brw , hijackErr := rc .Hijack ()
745
+ if errors .Is (hijackErr , http .ErrNotSupported ) {
746
+ p .getErrorHandler ()(rw , req , fmt .Errorf ("can't switch protocols using non-Hijacker ResponseWriter type %T" , rw ))
747
+ return
748
+ }
749
+
753
750
backConnCloseCh := make (chan bool )
754
751
go func () {
755
752
// Ensure that the cancellation of a request closes the backend.
@@ -760,12 +757,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
760
757
}
761
758
backConn .Close ()
762
759
}()
763
-
764
760
defer close (backConnCloseCh )
765
761
766
- conn , brw , err := hj .Hijack ()
767
- if err != nil {
768
- p .getErrorHandler ()(rw , req , fmt .Errorf ("Hijack failed on protocol switch: %v" , err ))
762
+ if hijackErr != nil {
763
+ p .getErrorHandler ()(rw , req , fmt .Errorf ("Hijack failed on protocol switch: %v" , hijackErr ))
769
764
return
770
765
}
771
766
defer conn .Close ()
0 commit comments