@@ -11,7 +11,7 @@ use futures_util::TryStreamExt;
11
11
use http:: HeaderMap ;
12
12
use http_body:: { Body as HttpBody , SizeHint } ;
13
13
14
- use crate :: common:: { task, Future , Never , Pin , Poll } ;
14
+ use crate :: common:: { task, watch , Future , Never , Pin , Poll } ;
15
15
use crate :: proto:: DecodedLength ;
16
16
use crate :: upgrade:: OnUpgrade ;
17
17
@@ -33,7 +33,7 @@ enum Kind {
33
33
Once ( Option < Bytes > ) ,
34
34
Chan {
35
35
content_length : DecodedLength ,
36
- abort_rx : oneshot :: Receiver < ( ) > ,
36
+ want_tx : watch :: Sender ,
37
37
rx : mpsc:: Receiver < Result < Bytes , crate :: Error > > ,
38
38
} ,
39
39
H2 {
@@ -79,12 +79,14 @@ enum DelayEof {
79
79
/// Useful when wanting to stream chunks from another thread. See
80
80
/// [`Body::channel`](Body::channel) for more.
81
81
#[ must_use = "Sender does nothing unless sent on" ]
82
- #[ derive( Debug ) ]
83
82
pub struct Sender {
84
- abort_tx : oneshot :: Sender < ( ) > ,
83
+ want_rx : watch :: Receiver ,
85
84
tx : BodySender ,
86
85
}
87
86
87
+ const WANT_PENDING : usize = 1 ;
88
+ const WANT_READY : usize = 2 ;
89
+
88
90
impl Body {
89
91
/// Create an empty `Body` stream.
90
92
///
@@ -106,17 +108,22 @@ impl Body {
106
108
/// Useful when wanting to stream chunks from another thread.
107
109
#[ inline]
108
110
pub fn channel ( ) -> ( Sender , Body ) {
109
- Self :: new_channel ( DecodedLength :: CHUNKED )
111
+ Self :: new_channel ( DecodedLength :: CHUNKED , /*wanter =*/ false )
110
112
}
111
113
112
- pub ( crate ) fn new_channel ( content_length : DecodedLength ) -> ( Sender , Body ) {
114
+ pub ( crate ) fn new_channel ( content_length : DecodedLength , wanter : bool ) -> ( Sender , Body ) {
113
115
let ( tx, rx) = mpsc:: channel ( 0 ) ;
114
- let ( abort_tx, abort_rx) = oneshot:: channel ( ) ;
115
116
116
- let tx = Sender { abort_tx, tx } ;
117
+ // If wanter is true, `Sender::poll_ready()` won't becoming ready
118
+ // until the `Body` has been polled for data once.
119
+ let want = if wanter { WANT_PENDING } else { WANT_READY } ;
120
+
121
+ let ( want_tx, want_rx) = watch:: channel ( want) ;
122
+
123
+ let tx = Sender { want_rx, tx } ;
117
124
let rx = Body :: new ( Kind :: Chan {
118
125
content_length,
119
- abort_rx ,
126
+ want_tx ,
120
127
rx,
121
128
} ) ;
122
129
@@ -236,11 +243,9 @@ impl Body {
236
243
Kind :: Chan {
237
244
content_length : ref mut len,
238
245
ref mut rx,
239
- ref mut abort_rx ,
246
+ ref mut want_tx ,
240
247
} => {
241
- if let Poll :: Ready ( Ok ( ( ) ) ) = Pin :: new ( abort_rx) . poll ( cx) {
242
- return Poll :: Ready ( Some ( Err ( crate :: Error :: new_body_write_aborted ( ) ) ) ) ;
243
- }
248
+ want_tx. send ( WANT_READY ) ;
244
249
245
250
match ready ! ( Pin :: new( rx) . poll_next( cx) ?) {
246
251
Some ( chunk) => {
@@ -460,19 +465,29 @@ impl From<Cow<'static, str>> for Body {
460
465
impl Sender {
461
466
/// Check to see if this `Sender` can send more data.
462
467
pub fn poll_ready ( & mut self , cx : & mut task:: Context < ' _ > ) -> Poll < crate :: Result < ( ) > > {
463
- match self . abort_tx . poll_canceled ( cx) {
464
- Poll :: Ready ( ( ) ) => return Poll :: Ready ( Err ( crate :: Error :: new_closed ( ) ) ) ,
465
- Poll :: Pending => ( ) , // fallthrough
466
- }
467
-
468
+ // Check if the receiver end has tried polling for the body yet
469
+ ready ! ( self . poll_want( cx) ?) ;
468
470
self . tx
469
471
. poll_ready ( cx)
470
472
. map_err ( |_| crate :: Error :: new_closed ( ) )
471
473
}
472
474
475
+ fn poll_want ( & mut self , cx : & mut task:: Context < ' _ > ) -> Poll < crate :: Result < ( ) > > {
476
+ match self . want_rx . load ( cx) {
477
+ WANT_READY => Poll :: Ready ( Ok ( ( ) ) ) ,
478
+ WANT_PENDING => Poll :: Pending ,
479
+ watch:: CLOSED => Poll :: Ready ( Err ( crate :: Error :: new_closed ( ) ) ) ,
480
+ unexpected => unreachable ! ( "want_rx value: {}" , unexpected) ,
481
+ }
482
+ }
483
+
484
+ async fn ready ( & mut self ) -> crate :: Result < ( ) > {
485
+ futures_util:: future:: poll_fn ( |cx| self . poll_ready ( cx) ) . await
486
+ }
487
+
473
488
/// Send data on this channel when it is ready.
474
489
pub async fn send_data ( & mut self , chunk : Bytes ) -> crate :: Result < ( ) > {
475
- futures_util :: future :: poll_fn ( |cx| self . poll_ready ( cx ) ) . await ?;
490
+ self . ready ( ) . await ?;
476
491
self . tx
477
492
. try_send ( Ok ( chunk) )
478
493
. map_err ( |_| crate :: Error :: new_closed ( ) )
@@ -498,20 +513,30 @@ impl Sender {
498
513
499
514
/// Aborts the body in an abnormal fashion.
500
515
pub fn abort ( self ) {
501
- // TODO(sean): this can just be `self.tx.clone().try_send()`
502
- let _ = self . abort_tx . send ( ( ) ) ;
516
+ let _ = self
517
+ . tx
518
+ // clone so the send works even if buffer is full
519
+ . clone ( )
520
+ . try_send ( Err ( crate :: Error :: new_body_write_aborted ( ) ) ) ;
503
521
}
504
522
505
523
pub ( crate ) fn send_error ( & mut self , err : crate :: Error ) {
506
524
let _ = self . tx . try_send ( Err ( err) ) ;
507
525
}
508
526
}
509
527
528
+ impl fmt:: Debug for Sender {
529
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
530
+ f. debug_struct ( "Sender" ) . finish ( )
531
+ }
532
+ }
533
+
510
534
#[ cfg( test) ]
511
535
mod tests {
512
536
use std:: mem;
537
+ use std:: task:: Poll ;
513
538
514
- use super :: { Body , Sender } ;
539
+ use super :: { Body , DecodedLength , HttpBody , Sender } ;
515
540
516
541
#[ test]
517
542
fn test_size_of ( ) {
@@ -541,4 +566,97 @@ mod tests {
541
566
"Option<Sender>"
542
567
) ;
543
568
}
569
+
570
+ #[ tokio:: test]
571
+ async fn channel_abort ( ) {
572
+ let ( tx, mut rx) = Body :: channel ( ) ;
573
+
574
+ tx. abort ( ) ;
575
+
576
+ let err = rx. data ( ) . await . unwrap ( ) . unwrap_err ( ) ;
577
+ assert ! ( err. is_body_write_aborted( ) , "{:?}" , err) ;
578
+ }
579
+
580
+ #[ tokio:: test]
581
+ async fn channel_abort_when_buffer_is_full ( ) {
582
+ let ( mut tx, mut rx) = Body :: channel ( ) ;
583
+
584
+ tx. try_send_data ( "chunk 1" . into ( ) ) . expect ( "send 1" ) ;
585
+ // buffer is full, but can still send abort
586
+ tx. abort ( ) ;
587
+
588
+ let chunk1 = rx. data ( ) . await . expect ( "item 1" ) . expect ( "chunk 1" ) ;
589
+ assert_eq ! ( chunk1, "chunk 1" ) ;
590
+
591
+ let err = rx. data ( ) . await . unwrap ( ) . unwrap_err ( ) ;
592
+ assert ! ( err. is_body_write_aborted( ) , "{:?}" , err) ;
593
+ }
594
+
595
+ #[ test]
596
+ fn channel_buffers_one ( ) {
597
+ let ( mut tx, _rx) = Body :: channel ( ) ;
598
+
599
+ tx. try_send_data ( "chunk 1" . into ( ) ) . expect ( "send 1" ) ;
600
+
601
+ // buffer is now full
602
+ let chunk2 = tx. try_send_data ( "chunk 2" . into ( ) ) . expect_err ( "send 2" ) ;
603
+ assert_eq ! ( chunk2, "chunk 2" ) ;
604
+ }
605
+
606
+ #[ tokio:: test]
607
+ async fn channel_empty ( ) {
608
+ let ( _, mut rx) = Body :: channel ( ) ;
609
+
610
+ assert ! ( rx. data( ) . await . is_none( ) ) ;
611
+ }
612
+
613
+ #[ test]
614
+ fn channel_ready ( ) {
615
+ let ( mut tx, _rx) = Body :: new_channel ( DecodedLength :: CHUNKED , /*wanter = */ false ) ;
616
+
617
+ let mut tx_ready = tokio_test:: task:: spawn ( tx. ready ( ) ) ;
618
+
619
+ assert ! ( tx_ready. poll( ) . is_ready( ) , "tx is ready immediately" ) ;
620
+ }
621
+
622
+ #[ test]
623
+ fn channel_wanter ( ) {
624
+ let ( mut tx, mut rx) = Body :: new_channel ( DecodedLength :: CHUNKED , /*wanter = */ true ) ;
625
+
626
+ let mut tx_ready = tokio_test:: task:: spawn ( tx. ready ( ) ) ;
627
+ let mut rx_data = tokio_test:: task:: spawn ( rx. data ( ) ) ;
628
+
629
+ assert ! (
630
+ tx_ready. poll( ) . is_pending( ) ,
631
+ "tx isn't ready before rx has been polled"
632
+ ) ;
633
+
634
+ assert ! ( rx_data. poll( ) . is_pending( ) , "poll rx.data" ) ;
635
+ assert ! ( tx_ready. is_woken( ) , "rx poll wakes tx" ) ;
636
+
637
+ assert ! (
638
+ tx_ready. poll( ) . is_ready( ) ,
639
+ "tx is ready after rx has been polled"
640
+ ) ;
641
+ }
642
+
643
+ #[ test]
644
+ fn channel_notices_closure ( ) {
645
+ let ( mut tx, rx) = Body :: new_channel ( DecodedLength :: CHUNKED , /*wanter = */ true ) ;
646
+
647
+ let mut tx_ready = tokio_test:: task:: spawn ( tx. ready ( ) ) ;
648
+
649
+ assert ! (
650
+ tx_ready. poll( ) . is_pending( ) ,
651
+ "tx isn't ready before rx has been polled"
652
+ ) ;
653
+
654
+ drop ( rx) ;
655
+ assert ! ( tx_ready. is_woken( ) , "dropping rx wakes tx" ) ;
656
+
657
+ match tx_ready. poll ( ) {
658
+ Poll :: Ready ( Err ( ref e) ) if e. is_closed ( ) => ( ) ,
659
+ unexpected => panic ! ( "tx poll ready unexpected: {:?}" , unexpected) ,
660
+ }
661
+ }
544
662
}
0 commit comments