Skip to content

fix(http1): only send 100 Continue if request body is polled #2119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 151 additions & 22 deletions src/body/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use futures_util::TryStreamExt;
use http::HeaderMap;
use http_body::{Body as HttpBody, SizeHint};

use crate::common::{task, Future, Never, Pin, Poll};
use crate::common::{task, watch, Future, Never, Pin, Poll};
use crate::proto::DecodedLength;
use crate::upgrade::OnUpgrade;

Expand All @@ -33,7 +33,7 @@ enum Kind {
Once(Option<Bytes>),
Chan {
content_length: DecodedLength,
abort_rx: oneshot::Receiver<()>,
want_tx: watch::Sender,
rx: mpsc::Receiver<Result<Bytes, crate::Error>>,
},
H2 {
Expand Down Expand Up @@ -79,12 +79,14 @@ enum DelayEof {
/// Useful when wanting to stream chunks from another thread. See
/// [`Body::channel`](Body::channel) for more.
#[must_use = "Sender does nothing unless sent on"]
#[derive(Debug)]
pub struct Sender {
abort_tx: oneshot::Sender<()>,
want_rx: watch::Receiver,
tx: BodySender,
}

const WANT_PENDING: usize = 1;
const WANT_READY: usize = 2;

impl Body {
/// Create an empty `Body` stream.
///
Expand All @@ -106,17 +108,22 @@ impl Body {
/// Useful when wanting to stream chunks from another thread.
#[inline]
pub fn channel() -> (Sender, Body) {
Self::new_channel(DecodedLength::CHUNKED)
Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false)
}

pub(crate) fn new_channel(content_length: DecodedLength) -> (Sender, Body) {
pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) {
let (tx, rx) = mpsc::channel(0);
let (abort_tx, abort_rx) = oneshot::channel();

let tx = Sender { abort_tx, tx };
// If wanter is true, `Sender::poll_ready()` won't becoming ready
// until the `Body` has been polled for data once.
let want = if wanter { WANT_PENDING } else { WANT_READY };

let (want_tx, want_rx) = watch::channel(want);

let tx = Sender { want_rx, tx };
let rx = Body::new(Kind::Chan {
content_length,
abort_rx,
want_tx,
rx,
});

Expand Down Expand Up @@ -236,11 +243,9 @@ impl Body {
Kind::Chan {
content_length: ref mut len,
ref mut rx,
ref mut abort_rx,
ref mut want_tx,
} => {
if let Poll::Ready(Ok(())) = Pin::new(abort_rx).poll(cx) {
return Poll::Ready(Some(Err(crate::Error::new_body_write_aborted())));
}
want_tx.send(WANT_READY);

match ready!(Pin::new(rx).poll_next(cx)?) {
Some(chunk) => {
Expand Down Expand Up @@ -460,19 +465,29 @@ impl From<Cow<'static, str>> for Body {
impl Sender {
/// Check to see if this `Sender` can send more data.
pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.abort_tx.poll_canceled(cx) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this was important for detecting when the receiver is dropped while the connection is waiting on the socket.

Poll::Ready(()) => return Poll::Ready(Err(crate::Error::new_closed())),
Poll::Pending => (), // fallthrough
}

// Check if the receiver end has tried polling for the body yet
ready!(self.poll_want(cx)?);
self.tx
.poll_ready(cx)
.map_err(|_| crate::Error::new_closed())
}

fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.want_rx.load(cx) {
WANT_READY => Poll::Ready(Ok(())),
WANT_PENDING => Poll::Pending,
watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())),
unexpected => unreachable!("want_rx value: {}", unexpected),
}
}

async fn ready(&mut self) -> crate::Result<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await
}

/// Send data on this channel when it is ready.
pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await?;
self.ready().await?;
self.tx
.try_send(Ok(chunk))
.map_err(|_| crate::Error::new_closed())
Expand All @@ -498,20 +513,41 @@ impl Sender {

/// Aborts the body in an abnormal fashion.
pub fn abort(self) {
// TODO(sean): this can just be `self.tx.clone().try_send()`
let _ = self.abort_tx.send(());
let _ = self
.tx
// clone so the send works even if buffer is full
.clone()
.try_send(Err(crate::Error::new_body_write_aborted()));
}

pub(crate) fn send_error(&mut self, err: crate::Error) {
let _ = self.tx.try_send(Err(err));
}
}

impl fmt::Debug for Sender {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[derive(Debug)]
struct Open;
#[derive(Debug)]
struct Closed;

let mut builder = f.debug_tuple("Sender");
match self.want_rx.peek() {
watch::CLOSED => builder.field(&Closed),
_ => builder.field(&Open),
};

builder.finish()
}
}

#[cfg(test)]
mod tests {
use std::mem;
use std::task::Poll;

use super::{Body, Sender};
use super::{Body, DecodedLength, HttpBody, Sender};

#[test]
fn test_size_of() {
Expand Down Expand Up @@ -541,4 +577,97 @@ mod tests {
"Option<Sender>"
);
}

#[tokio::test]
async fn channel_abort() {
let (tx, mut rx) = Body::channel();

tx.abort();

let err = rx.data().await.unwrap().unwrap_err();
assert!(err.is_body_write_aborted(), "{:?}", err);
}

#[tokio::test]
async fn channel_abort_when_buffer_is_full() {
let (mut tx, mut rx) = Body::channel();

tx.try_send_data("chunk 1".into()).expect("send 1");
// buffer is full, but can still send abort
tx.abort();

let chunk1 = rx.data().await.expect("item 1").expect("chunk 1");
assert_eq!(chunk1, "chunk 1");

let err = rx.data().await.unwrap().unwrap_err();
assert!(err.is_body_write_aborted(), "{:?}", err);
}

#[test]
fn channel_buffers_one() {
let (mut tx, _rx) = Body::channel();

tx.try_send_data("chunk 1".into()).expect("send 1");

// buffer is now full
let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2");
assert_eq!(chunk2, "chunk 2");
}

#[tokio::test]
async fn channel_empty() {
let (_, mut rx) = Body::channel();

assert!(rx.data().await.is_none());
}

#[test]
fn channel_ready() {
let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false);

let mut tx_ready = tokio_test::task::spawn(tx.ready());

assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
}

#[test]
fn channel_wanter() {
let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);

let mut tx_ready = tokio_test::task::spawn(tx.ready());
let mut rx_data = tokio_test::task::spawn(rx.data());

assert!(
tx_ready.poll().is_pending(),
"tx isn't ready before rx has been polled"
);

assert!(rx_data.poll().is_pending(), "poll rx.data");
assert!(tx_ready.is_woken(), "rx poll wakes tx");

assert!(
tx_ready.poll().is_ready(),
"tx is ready after rx has been polled"
);
}

#[test]
fn channel_notices_closure() {
let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);

let mut tx_ready = tokio_test::task::spawn(tx.ready());

assert!(
tx_ready.poll().is_pending(),
"tx isn't ready before rx has been polled"
);

drop(rx);
assert!(tx_ready.is_woken(), "dropping rx wakes tx");

match tx_ready.poll() {
Poll::Ready(Err(ref e)) if e.is_closed() => (),
unexpected => panic!("tx poll ready unexpected: {:?}", unexpected),
}
}
}
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub(crate) mod io;
mod lazy;
mod never;
pub(crate) mod task;
pub(crate) mod watch;

pub use self::exec::Executor;
pub(crate) use self::exec::{BoxSendFuture, Exec};
Expand Down
73 changes: 73 additions & 0 deletions src/common/watch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//! An SPSC broadcast channel.
//!
//! - The value can only be a `usize`.
//! - The consumer is only notified if the value is different.
//! - The value `0` is reserved for closed.

use futures_util::task::AtomicWaker;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::task;

type Value = usize;

pub(crate) const CLOSED: usize = 0;

pub(crate) fn channel(initial: Value) -> (Sender, Receiver) {
debug_assert!(
initial != CLOSED,
"watch::channel initial state of 0 is reserved"
);

let shared = Arc::new(Shared {
value: AtomicUsize::new(initial),
waker: AtomicWaker::new(),
});

(
Sender {
shared: shared.clone(),
},
Receiver { shared },
)
}

pub(crate) struct Sender {
shared: Arc<Shared>,
}

pub(crate) struct Receiver {
shared: Arc<Shared>,
}

struct Shared {
value: AtomicUsize,
waker: AtomicWaker,
}

impl Sender {
pub(crate) fn send(&mut self, value: Value) {
if self.shared.value.swap(value, Ordering::SeqCst) != value {
self.shared.waker.wake();
}
}
}

impl Drop for Sender {
fn drop(&mut self) {
self.send(CLOSED);
}
}

impl Receiver {
pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value {
self.shared.waker.register(cx.waker());
self.shared.value.load(Ordering::SeqCst)
}

pub(crate) fn peek(&self) -> Value {
self.shared.value.load(Ordering::Relaxed)
}
}
Loading