@@ -624,6 +624,7 @@ type stream struct {
624
624
wroteHeaders bool // whether we wrote headers (not status 100)
625
625
readDeadline * time.Timer // nil if unused
626
626
writeDeadline * time.Timer // nil if unused
627
+ closeErr error // set before cw is closed
627
628
628
629
trailer http.Header // accumulated trailers
629
630
reqTrailer http.Header // handler's Request.Trailer
@@ -1608,6 +1609,14 @@ func (sc *serverConn) closeStream(st *stream, err error) {
1608
1609
1609
1610
p .CloseWithError (err )
1610
1611
}
1612
+ if e , ok := err .(StreamError ); ok {
1613
+ if e .Cause != nil {
1614
+ err = e .Cause
1615
+ } else {
1616
+ err = errStreamClosed
1617
+ }
1618
+ }
1619
+ st .closeErr = err
1611
1620
st .cw .Close () // signals Handler's CloseNotifier, unblocks writes, etc
1612
1621
sc .writeSched .CloseStream (st .id )
1613
1622
}
@@ -1857,7 +1866,11 @@ func (st *stream) onReadTimeout() {
1857
1866
// onWriteTimeout is run on its own goroutine (from time.AfterFunc)
1858
1867
// when the stream's WriteTimeout has fired.
1859
1868
func (st * stream ) onWriteTimeout () {
1860
- st .sc .writeFrameFromHandler (FrameWriteRequest {write : streamError (st .id , ErrCodeInternal )})
1869
+ st .sc .writeFrameFromHandler (FrameWriteRequest {write : StreamError {
1870
+ StreamID : st .id ,
1871
+ Code : ErrCodeInternal ,
1872
+ Cause : os .ErrDeadlineExceeded ,
1873
+ }})
1861
1874
}
1862
1875
1863
1876
func (sc * serverConn ) processHeaders (f * MetaHeadersFrame ) error {
@@ -2471,7 +2484,15 @@ type responseWriterState struct {
2471
2484
2472
2485
type chunkWriter struct { rws * responseWriterState }
2473
2486
2474
- func (cw chunkWriter ) Write (p []byte ) (n int , err error ) { return cw .rws .writeChunk (p ) }
2487
+ func (cw chunkWriter ) Write (p []byte ) (n int , err error ) {
2488
+ n , err = cw .rws .writeChunk (p )
2489
+ if err == errStreamClosed {
2490
+ // If writing failed because the stream has been closed,
2491
+ // return the reason it was closed.
2492
+ err = cw .rws .stream .closeErr
2493
+ }
2494
+ return n , err
2495
+ }
2475
2496
2476
2497
func (rws * responseWriterState ) hasTrailers () bool { return len (rws .trailers ) > 0 }
2477
2498
0 commit comments