Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion dropshot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ version = "0.8.12"
features = [ "uuid1" ]

[dev-dependencies]
buf-list = "1.0.0"
expectorate = "1.0.6"
hyper-rustls = "0.23.2"
hyper-staticfile = "0.9"
Expand Down Expand Up @@ -111,4 +112,4 @@ features = [ "max_level_trace", "release_max_level_debug" ]
version_check = "0.9.4"

[features]
usdt-probes = [ "usdt/asm" ]
usdt-probes = ["usdt/asm"]
235 changes: 201 additions & 34 deletions dropshot/src/extractor/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@ use crate::api_description::ApiEndpointParameter;
use crate::api_description::ApiSchemaGenerator;
use crate::api_description::{ApiEndpointBodyContentType, ExtensionMode};
use crate::error::HttpError;
use crate::http_util::http_read_body;
use crate::http_util::http_dump_body;
use crate::http_util::CONTENT_TYPE_JSON;
use crate::schema_util::make_subschema_for;
use crate::server::ServerContext;
use crate::ExclusiveExtractor;
use crate::ExtractorMetadata;
use crate::RequestContext;
use async_trait::async_trait;
use bytes::BufMut;
use bytes::Bytes;
use bytes::BytesMut;
use futures::Stream;
use futures::TryStreamExt;
use hyper::body::HttpBody;
use schemars::schema::InstanceType;
use schemars::schema::SchemaObject;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use std::convert::Infallible;
use std::fmt::Debug;

// TypedBody: body extractor for formats that can be deserialized to a specific
Expand Down Expand Up @@ -46,23 +52,22 @@ impl<BodyType: JsonSchema + DeserializeOwned + Send + Sync>
/// to the content type, and deserialize it to an instance of `BodyType`.
async fn http_request_load_body<Context: ServerContext, BodyType>(
rqctx: &RequestContext<Context>,
mut request: hyper::Request<hyper::Body>,
request: hyper::Request<hyper::Body>,
) -> Result<TypedBody<BodyType>, HttpError>
where
BodyType: JsonSchema + DeserializeOwned + Send + Sync,
{
let server = &rqctx.server;
let body = http_read_body(
request.body_mut(),
server.config.request_body_max_bytes,
)
.await?;
let (parts, body) = request.into_parts();
let body = StreamingBody::new(body, server.config.request_body_max_bytes)
.into_bytes_mut()
.await?;

// RFC 7231 §3.1.1.1: media types are case insensitive and may
// be followed by whitespace and/or a parameter (e.g., charset),
// which we currently ignore.
let content_type = request
.headers()
let content_type = parts
.headers
.get(http::header::CONTENT_TYPE)
.map(|hv| {
hv.to_str().map_err(|e| {
Expand Down Expand Up @@ -184,38 +189,200 @@ impl UntypedBody {
impl ExclusiveExtractor for UntypedBody {
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
mut request: hyper::Request<hyper::Body>,
request: hyper::Request<hyper::Body>,
) -> Result<UntypedBody, HttpError> {
let server = &rqctx.server;
let body_bytes = http_read_body(
request.body_mut(),
server.config.request_body_max_bytes,
)
.await?;
Ok(UntypedBody { content: body_bytes })
let body = request.into_body();
let body_bytes =
StreamingBody::new(body, server.config.request_body_max_bytes)
.into_bytes_mut()
.await?;
Ok(UntypedBody { content: body_bytes.freeze() })
}

fn metadata(
_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata {
ExtractorMetadata {
parameters: vec![ApiEndpointParameter::new_body(
ApiEndpointBodyContentType::Bytes,
true,
ApiSchemaGenerator::Static {
schema: Box::new(
SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some(String::from("binary")),
..Default::default()
}
.into(),
),
dependencies: indexmap::IndexMap::default(),
},
vec![],
)],
extension_mode: ExtensionMode::None,
untyped_metadata()
}
}

// StreamingBody: body extractor that provides a streaming representation of the body.

/// An extractor for streaming the contents of the HTTP request body, making the
/// raw bytes available to the consumer.
#[derive(Debug)]
pub struct StreamingBody {
body: hyper::Body,
cap: usize,
}

impl StreamingBody {
fn new(body: hyper::Body, cap: usize) -> Self {
Self { body, cap }
}

/// Not part of the public API. Used only for doctests.
#[doc(hidden)]
pub fn __from_bytes(data: Bytes) -> Self {
let cap = data.len();
let stream = futures::stream::iter([Ok::<_, Infallible>(data)]);
let body = hyper::Body::wrap_stream(stream);
Self { body, cap }
}

/// Converts `self` into a stream.
///
/// The `Stream` produces values of type `Result<Bytes, HttpError>`.
///
/// # Errors
///
/// The stream produces an [`HttpError`] if any of the following cases occur:
///
/// * A network error occurred.
/// * `request_body_max_bytes` was exceeded for this request.
///
/// # Examples
///
/// Buffer a `StreamingBody` in-memory, into a
/// [`BufList`](https://docs.rs/buf-list/latest/buf_list/struct.BufList.html)
/// (a segmented list of [`Bytes`] chunks).
///
/// ```
/// use buf_list::BufList;
/// use dropshot::{HttpError, StreamingBody};
/// use futures::prelude::*;
/// # use std::iter::FromIterator;
///
/// async fn into_buf_list(body: StreamingBody) -> Result<BufList, HttpError> {
/// body.into_stream().try_collect().await
/// }
///
/// # #[tokio::main]
/// # async fn main() {
/// # let body = StreamingBody::__from_bytes(bytes::Bytes::from("foobar"));
/// # assert_eq!(
/// # into_buf_list(body).await.unwrap().into_iter().next(),
/// # Some(bytes::Bytes::from("foobar")),
/// # );
/// # }
/// ```
///
/// ---
///
/// Write a `StreamingBody` to an [`AsyncWrite`](tokio::io::AsyncWrite),
/// for example a [`tokio::fs::File`], without buffering it into memory:
///
/// ```
/// use dropshot::{HttpError, StreamingBody};
/// use futures::prelude::*;
/// use tokio::io::{AsyncWrite, AsyncWriteExt};
///
/// async fn write_all<W: AsyncWrite + Unpin>(
/// body: StreamingBody,
/// writer: &mut W,
/// ) -> Result<(), HttpError> {
/// let stream = body.into_stream();
/// tokio::pin!(stream);
///
/// while let Some(res) = stream.next().await {
/// let mut data = res?;
/// writer.write_all_buf(&mut data).await.map_err(|error| {
/// HttpError::for_unavail(None, format!("write failed: {error}"))
/// })?;
/// }
///
/// Ok(())
/// }
///
/// # #[tokio::main]
/// # async fn main() {
/// # let body = StreamingBody::__from_bytes(bytes::Bytes::from("foobar"));
/// # let mut writer = vec![];
/// # write_all(body, &mut writer).await.unwrap();
/// # assert_eq!(writer, &b"foobar"[..]);
/// # }
/// ```
pub fn into_stream(
mut self,
) -> impl Stream<Item = Result<Bytes, HttpError>> + Send {
async_stream::try_stream! {
let mut bytes_read: usize = 0;
while let Some(buf_res) = self.body.data().await {
let buf = buf_res?;
let len = buf.len();

if bytes_read + len > self.cap {
http_dump_body(&mut self.body).await?;
// TODO-correctness check status code
Err(HttpError::for_bad_request(
None,
format!("request body exceeded maximum size of {} bytes", self.cap),
))?;
}

bytes_read += len;
yield buf;
}

// Read the trailers as well, even though we're not going to do anything
// with them.
self.body.trailers().await?;
}
}

/// Converts `self` into a [`BytesMut`], buffering the entire response in
/// memory. Not public API because most users of this should use
/// `UntypedBody` instead.
async fn into_bytes_mut(self) -> Result<BytesMut, HttpError> {
self.into_stream()
.try_fold(BytesMut::new(), |mut out, chunk| {
out.put(chunk);
futures::future::ok(out)
})
.await
}
}

#[async_trait]
impl ExclusiveExtractor for StreamingBody {
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> Result<Self, HttpError> {
let server = &rqctx.server;

Ok(Self {
body: request.into_body(),
cap: server.config.request_body_max_bytes,
})
}

fn metadata(
_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata {
untyped_metadata()
}
}

fn untyped_metadata() -> ExtractorMetadata {
ExtractorMetadata {
parameters: vec![ApiEndpointParameter::new_body(
ApiEndpointBodyContentType::Bytes,
true,
ApiSchemaGenerator::Static {
schema: Box::new(
SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some(String::from("binary")),
..Default::default()
}
.into(),
),
dependencies: indexmap::IndexMap::default(),
},
vec![],
)],
extension_mode: ExtensionMode::None,
}
}
1 change: 1 addition & 0 deletions dropshot/src/extractor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub use common::RequestExtractor;
pub use common::SharedExtractor;

mod body;
pub use body::StreamingBody;
pub use body::TypedBody;
pub use body::UntypedBody;

Expand Down
Loading