From 6f847bdfba52178564b208858ef247343a235b8d Mon Sep 17 00:00:00 2001 From: Rain Date: Mon, 20 Mar 2023 19:42:45 -0700 Subject: [PATCH 1/3] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?= =?UTF-8?q?itial=20version?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 --- Cargo.lock | 43 +++- dropshot/Cargo.toml | 4 +- dropshot/src/extractor/body.rs | 281 +++++++++++++++++++--- dropshot/src/extractor/mod.rs | 1 + dropshot/src/http_util.rs | 51 ---- dropshot/src/lib.rs | 15 +- dropshot/tests/fail/bad_endpoint1.stderr | 3 +- dropshot/tests/fail/bad_endpoint11.stderr | 3 +- dropshot/tests/fail/bad_endpoint13.stderr | 3 +- dropshot/tests/fail/bad_endpoint2.stderr | 3 +- dropshot/tests/fail/bad_endpoint8.stderr | 3 +- dropshot/tests/test_demo.rs | 74 ++++++ dropshot_endpoint/src/lib.rs | 3 +- 13 files changed, 381 insertions(+), 106 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab7b7f325..c45934d03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" @@ -118,6 +118,15 @@ dependencies = [ "byte-tools", ] +[[package]] +name = "buf-list" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f462e45b27db47403356859af1cb4bfbbd0021cb7b7d10db6ea40958bb4e2c48" +dependencies = [ + "bytes", +] + [[package]] name = "bumpalo" version = "3.12.0" @@ -304,6 +313,7 @@ dependencies = [ "async-stream", "async-trait", "base64 0.21.0", + "buf-list", "bytes", "camino", "chrono", @@ -344,6 +354,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-tungstenite", + "tokio-util 0.7.7", "toml", "trybuild", "usdt", @@ -588,7 +599,7 @@ dependencies = [ "indexmap", "slab", "tokio", - "tokio-util", + "tokio-util 0.6.8", "tracing", ] @@ -1039,9 +1050,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.7" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" [[package]] name = "pin-utils" @@ -1710,22 +1721,22 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.19.2" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" +checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" dependencies = [ + "autocfg", "bytes", "libc", "memchr", "mio", "num_cpus", - "once_cell", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "winapi", + "windows-sys", ] [[package]] @@ -1776,6 +1787,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + [[package]] name = "toml" version = "0.7.3" diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 63ede80ff..10a18a013 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -76,6 +76,7 @@ version = "0.8.12" features = [ "uuid1" ] [dev-dependencies] +buf-list = "1.0.0" expectorate = "1.0.6" hyper-rustls = "0.23.2" hyper-staticfile = "0.9" @@ -84,6 +85,7 @@ libc = "0.2.140" mime_guess = "2.0.4" subprocess = "0.2.9" tempfile = "3.4" +tokio-util = { version = "0.7.7", features = ["codec"] } trybuild = "1.0.79" # Used by the https examples and tests pem = "1.1" @@ -111,4 +113,4 @@ features = [ "max_level_trace", "release_max_level_debug" ] version_check = "0.9.4" [features] -usdt-probes = [ "usdt/asm" ] +usdt-probes = ["usdt/asm"] diff --git a/dropshot/src/extractor/body.rs b/dropshot/src/extractor/body.rs index 1ffa566ba..82b457f38 100644 --- a/dropshot/src/extractor/body.rs +++ b/dropshot/src/extractor/body.rs @@ -6,7 +6,7 @@ use crate::api_description::ApiEndpointParameter; use crate::api_description::ApiSchemaGenerator; use crate::api_description::{ApiEndpointBodyContentType, ExtensionMode}; use crate::error::HttpError; -use crate::http_util::http_read_body; +use crate::http_util::http_dump_body; use crate::http_util::CONTENT_TYPE_JSON; use crate::schema_util::make_subschema_for; use crate::server::ServerContext; @@ -14,11 +14,17 @@ use crate::ExclusiveExtractor; use crate::ExtractorMetadata; use crate::RequestContext; use async_trait::async_trait; +use bytes::BufMut; use bytes::Bytes; +use bytes::BytesMut; +use futures::Stream; +use futures::TryStreamExt; +use hyper::body::HttpBody; use schemars::schema::InstanceType; use schemars::schema::SchemaObject; use schemars::JsonSchema; use serde::de::DeserializeOwned; +use std::convert::Infallible; use std::fmt::Debug; // TypedBody: body extractor for formats that can be deserialized to a specific @@ -46,23 +52,22 @@ impl /// to the content type, and deserialize it to an instance of `BodyType`. async fn http_request_load_body( rqctx: &RequestContext, - mut request: hyper::Request, + request: hyper::Request, ) -> Result, HttpError> where BodyType: JsonSchema + DeserializeOwned + Send + Sync, { let server = &rqctx.server; - let body = http_read_body( - request.body_mut(), - server.config.request_body_max_bytes, - ) - .await?; + let (parts, body) = request.into_parts(); + let body = StreamingBody::new(body, server.config.request_body_max_bytes) + .into_bytes_mut() + .await?; // RFC 7231 ยง3.1.1.1: media types are case insensitive and may // be followed by whitespace and/or a parameter (e.g., charset), // which we currently ignore. - let content_type = request - .headers() + let content_type = parts + .headers .get(http::header::CONTENT_TYPE) .map(|hv| { hv.to_str().map_err(|e| { @@ -184,38 +189,246 @@ impl UntypedBody { impl ExclusiveExtractor for UntypedBody { async fn from_request( rqctx: &RequestContext, - mut request: hyper::Request, + request: hyper::Request, ) -> Result { let server = &rqctx.server; - let body_bytes = http_read_body( - request.body_mut(), - server.config.request_body_max_bytes, - ) - .await?; - Ok(UntypedBody { content: body_bytes }) + let body = request.into_body(); + let body_bytes = + StreamingBody::new(body, server.config.request_body_max_bytes) + .into_bytes_mut() + .await?; + Ok(UntypedBody { content: body_bytes.freeze() }) } fn metadata( _content_type: ApiEndpointBodyContentType, ) -> ExtractorMetadata { - ExtractorMetadata { - parameters: vec![ApiEndpointParameter::new_body( - ApiEndpointBodyContentType::Bytes, - true, - ApiSchemaGenerator::Static { - schema: Box::new( - SchemaObject { - instance_type: Some(InstanceType::String.into()), - format: Some(String::from("binary")), - ..Default::default() - } - .into(), - ), - dependencies: indexmap::IndexMap::default(), - }, - vec![], - )], - extension_mode: ExtensionMode::None, + untyped_metadata() + } +} + +// StreamingBody: body extractor that provides a streaming representation of the body. + +/// An extractor for streaming the contents of the HTTP request body, making the +/// raw bytes available to the consumer. +#[derive(Debug)] +pub struct StreamingBody { + body: hyper::Body, + cap: usize, +} + +impl StreamingBody { + fn new(body: hyper::Body, cap: usize) -> Self { + Self { body, cap } + } + + /// Not part of the public API. Used only for doctests. + #[doc(hidden)] + pub fn __from_bytes(data: Bytes) -> Self { + let cap = data.len(); + let stream = futures::stream::iter([Ok::<_, Infallible>(data)]); + let body = hyper::Body::wrap_stream(stream); + Self { body, cap } + } + + /// Converts `self` into a [`BytesMut`], buffering the entire response in memory. + /// + /// If payloads are expected to be large, consider using [`Self::into_stream`] to + /// avoid buffering in memory if possible. + /// + /// # Errors + /// + /// Returns an [`HttpError`] if any of the following cases occur: + /// + /// * A network error occurred. + /// * `request_body_max_bytes` was exceeded for this request. + pub async fn into_bytes_mut(self) -> Result { + self.into_stream() + .try_fold(BytesMut::new(), |mut out, chunk| { + out.put(chunk); + futures::future::ok(out) + }) + .await + } + + /// Converts `self` into a stream. + /// + /// The `Stream` produces values of type `Result`. + /// + /// # Errors + /// + /// The stream produces an [`HttpError`] if any of the following cases occur: + /// + /// * A network error occurred. + /// * `request_body_max_bytes` was exceeded for this request. + /// + /// # Examples + /// + /// Buffer a `StreamingBody` in-memory, into a + /// [`BufList`](https://docs.rs/buf-list/latest/buf_list/struct.BufList.html) + /// (a segmented list of [`Bytes`] chunks). This is similar to + /// [`Self::into_bytes_mut`], except it avoids copying memory into a single + /// large allocation. + /// + /// ``` + /// use buf_list::BufList; + /// use dropshot::{HttpError, StreamingBody}; + /// use futures::prelude::*; + /// # use std::iter::FromIterator; + /// + /// async fn into_buf_list(body: StreamingBody) -> Result { + /// body.into_stream().try_collect().await + /// } + /// + /// # #[tokio::main] + /// # async fn main() { + /// # let body = StreamingBody::__from_bytes(bytes::Bytes::from("foobar")); + /// # assert_eq!( + /// # into_buf_list(body).await.unwrap().into_iter().next(), + /// # Some(bytes::Bytes::from("foobar")), + /// # ); + /// # } + /// ``` + /// + /// --- + /// + /// Write a `StreamingBody` to an [`AsyncWrite`](tokio::io::AsyncWrite), + /// for example a [`tokio::fs::File`], without buffering it into memory: + /// + /// ``` + /// use dropshot::{HttpError, StreamingBody}; + /// use futures::prelude::*; + /// use tokio::io::{AsyncWrite, AsyncWriteExt}; + /// + /// async fn write_all( + /// body: StreamingBody, + /// writer: &mut W, + /// ) -> Result<(), HttpError> { + /// let stream = body.into_stream(); + /// tokio::pin!(stream); + /// + /// while let Some(res) = stream.next().await { + /// let mut data = res?; + /// writer.write_all_buf(&mut data).await.map_err(|error| { + /// HttpError::for_unavail(None, format!("write failed: {error}")) + /// })?; + /// } + /// + /// Ok(()) + /// } + /// + /// # #[tokio::main] + /// # async fn main() { + /// # let body = StreamingBody::__from_bytes(bytes::Bytes::from("foobar")); + /// # let mut writer = vec![]; + /// # write_all(body, &mut writer).await.unwrap(); + /// # assert_eq!(writer, &b"foobar"[..]); + /// # } + /// ``` + /// + /// --- + /// + /// An alternative way to write data to an `AsyncWrite`, using + /// `tokio-util`'s + /// [codecs](https://docs.rs/tokio-util/latest/tokio_util/codec/index.html): + /// + /// ``` + /// use bytes::Bytes; + /// use dropshot::{HttpError, StreamingBody}; + /// use futures::{prelude::*, SinkExt}; + /// use tokio::io::AsyncWrite; + /// use tokio_util::codec::{BytesCodec, FramedWrite}; + /// + /// async fn write_all_sink( + /// body: StreamingBody, + /// writer: &mut W, + /// ) -> Result<(), HttpError> { + /// let stream = body.into_stream(); + /// // This type annotation is required for Rust to compile this code. + /// let sink = SinkExt::::sink_map_err( + /// FramedWrite::new(writer, BytesCodec::new()), + /// |error| HttpError::for_unavail(None, format!("write failed: {error}")), + /// ); + /// + /// stream.forward(sink).await + /// } + /// + /// # #[tokio::main] + /// # async fn main() { + /// # let body = StreamingBody::__from_bytes(Bytes::from("foobar")); + /// # let mut writer = vec![]; + /// # write_all_sink(body, &mut writer).await.unwrap(); + /// # assert_eq!(writer, &b"foobar"[..]); + /// # } + /// ``` + pub fn into_stream( + mut self, + ) -> impl Stream> + Send { + async_stream::try_stream! { + let mut bytes_read: usize = 0; + while let Some(buf_res) = self.body.data().await { + let buf = buf_res?; + let len = buf.len(); + + if bytes_read + len > self.cap { + http_dump_body(&mut self.body).await?; + // TODO-correctness check status code + Err(HttpError::for_bad_request( + None, + format!("request body exceeded maximum size of {} bytes", self.cap), + ))?; + } + + bytes_read += len; + yield buf; + } + + // Read the trailers as well, even though we're not going to do anything + // with them. + self.body.trailers().await?; } } } + +#[async_trait] +impl ExclusiveExtractor for StreamingBody { + async fn from_request( + rqctx: &RequestContext, + request: hyper::Request, + ) -> Result { + let server = &rqctx.server; + + Ok(Self { + body: request.into_body(), + cap: server.config.request_body_max_bytes, + }) + } + + fn metadata( + _content_type: ApiEndpointBodyContentType, + ) -> ExtractorMetadata { + untyped_metadata() + } +} + +fn untyped_metadata() -> ExtractorMetadata { + ExtractorMetadata { + parameters: vec![ApiEndpointParameter::new_body( + ApiEndpointBodyContentType::Bytes, + true, + ApiSchemaGenerator::Static { + schema: Box::new( + SchemaObject { + instance_type: Some(InstanceType::String.into()), + format: Some(String::from("binary")), + ..Default::default() + } + .into(), + ), + dependencies: indexmap::IndexMap::default(), + }, + vec![], + )], + extension_mode: ExtensionMode::None, + } +} diff --git a/dropshot/src/extractor/mod.rs b/dropshot/src/extractor/mod.rs index 103141c6d..a97401038 100644 --- a/dropshot/src/extractor/mod.rs +++ b/dropshot/src/extractor/mod.rs @@ -11,6 +11,7 @@ pub use common::RequestExtractor; pub use common::SharedExtractor; mod body; +pub use body::StreamingBody; pub use body::TypedBody; pub use body::UntypedBody; diff --git a/dropshot/src/http_util.rs b/dropshot/src/http_util.rs index 91ce1b728..23b477edf 100644 --- a/dropshot/src/http_util.rs +++ b/dropshot/src/http_util.rs @@ -1,7 +1,6 @@ // Copyright 2020 Oxide Computer Company //! General-purpose HTTP-related facilities -use bytes::BufMut; use bytes::Bytes; use hyper::body::HttpBody; use serde::de::DeserializeOwned; @@ -21,56 +20,6 @@ pub const CONTENT_TYPE_NDJSON: &str = "application/x-ndjson"; /// MIME type for form/urlencoded data pub const CONTENT_TYPE_URL_ENCODED: &str = "application/x-www-form-urlencoded"; -/// Reads the rest of the body from the request up to the given number of bytes. -/// If the body fits within the specified cap, a buffer is returned with all the -/// bytes read. If not, an error is returned. -pub async fn http_read_body( - body: &mut T, - cap: usize, -) -> Result -where - T: HttpBody + std::marker::Unpin, -{ - // This looks a lot like the implementation of hyper::body::to_bytes(), but - // applies the requested cap. We've skipped the optimization for the - // 1-buffer case for now, as it seems likely this implementation will change - // anyway. - // TODO should this use some Stream interface instead? - // TODO why does this look so different in type signature (Data=Bytes, - // std::marker::Unpin, &mut T) - // TODO Error type shouldn't have to be hyper Error -- Into should - // work too? - // TODO do we need to use saturating_add() here? - let mut parts = std::vec::Vec::new(); - let mut nbytesread: usize = 0; - while let Some(maybebuf) = body.data().await { - let buf = maybebuf?; - let bufsize = buf.len(); - - if nbytesread + bufsize > cap { - http_dump_body(body).await?; - // TODO-correctness check status code - return Err(HttpError::for_bad_request( - None, - format!("request body exceeded maximum size of {} bytes", cap), - )); - } - - nbytesread += bufsize; - parts.put(buf); - } - - // Read the trailers as well, even though we're not going to do anything - // with them. - body.trailers().await?; - // TODO-correctness why does the is_end_stream() assertion fail and the next - // one panic? - // assert!(body.is_end_stream()); - // assert!(body.data().await.is_none()); - // assert!(body.trailers().await?.is_none()); - Ok(parts.into()) -} - /// Reads the rest of the body from the request, dropping all the bytes. This is /// useful after encountering error conditions. pub async fn http_dump_body(body: &mut T) -> Result diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index d9700f4f4..c19d8744c 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -212,7 +212,8 @@ //! [query_params: Query,] //! [path_params: Path

,] //! [body_param: TypedBody,] -//! [body_param: UntypedBody,] +//! [body_param: UntypedBody,] +//! [body_param: StreamingBody,] //! [raw_request: RawRequest,] //! ) -> Result //! ``` @@ -234,14 +235,17 @@ //! body as JSON (or form/url-encoded) and deserializing it into an instance //! of type `J`. `J` must implement `serde::Deserialize` and `schemars::JsonSchema`. //! * [`UntypedBody`] extracts the raw bytes of the request body. +//! * [`StreamingBody`] provides the raw bytes of the request body as a +//! [`Stream`](futures::Stream) of [`Bytes`](bytes::Bytes) chunks. //! * [`RawRequest`] provides access to the underlying [`hyper::Request`]. The //! hope is that this would generally not be needed. It can be useful to //! implement functionality not provided by Dropshot. //! -//! `Query` and `Path` impl `SharedExtractor`. `TypedBody`, `UntypedBody`, and -//! `RawRequest` impl `ExclusiveExtractor`. Your function may accept 0-3 -//! extractors, but only one can be `ExclusiveExtractor`, and it must be the -//! last one. Otherwise, the order of extractor arguments does not matter. +//! `Query` and `Path` impl `SharedExtractor`. `TypedBody`, `UntypedBody`, +//! `StreamingBody`, and `RawRequest` impl `ExclusiveExtractor`. Your function +//! may accept 0-3 extractors, but only one can be `ExclusiveExtractor`, and it +//! must be the last one. Otherwise, the order of extractor arguments does not +//! matter. //! //! If the handler accepts any extractors and the corresponding extraction //! cannot be completed, the request fails with status code 400 and an error @@ -603,6 +607,7 @@ pub use extractor::Path; pub use extractor::Query; pub use extractor::RawRequest; pub use extractor::SharedExtractor; +pub use extractor::StreamingBody; pub use extractor::TypedBody; pub use extractor::UntypedBody; pub use handler::http_response_found; diff --git a/dropshot/tests/fail/bad_endpoint1.stderr b/dropshot/tests/fail/bad_endpoint1.stderr index 7de7a3992..d365cb7b2 100644 --- a/dropshot/tests/fail/bad_endpoint1.stderr +++ b/dropshot/tests/fail/bad_endpoint1.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint1.rs:20:1 diff --git a/dropshot/tests/fail/bad_endpoint11.stderr b/dropshot/tests/fail/bad_endpoint11.stderr index 79748382c..bc581eeff 100644 --- a/dropshot/tests/fail/bad_endpoint11.stderr +++ b/dropshot/tests/fail/bad_endpoint11.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint11.rs:12:1 diff --git a/dropshot/tests/fail/bad_endpoint13.stderr b/dropshot/tests/fail/bad_endpoint13.stderr index 2e41d0953..5fd2e58f4 100644 --- a/dropshot/tests/fail/bad_endpoint13.stderr +++ b/dropshot/tests/fail/bad_endpoint13.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint13.rs:18:1 diff --git a/dropshot/tests/fail/bad_endpoint2.stderr b/dropshot/tests/fail/bad_endpoint2.stderr index eb1e80c1d..f83dafd8b 100644 --- a/dropshot/tests/fail/bad_endpoint2.stderr +++ b/dropshot/tests/fail/bad_endpoint2.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint2.rs:13:1 diff --git a/dropshot/tests/fail/bad_endpoint8.stderr b/dropshot/tests/fail/bad_endpoint8.stderr index 255180375..7bd680aef 100644 --- a/dropshot/tests/fail/bad_endpoint8.stderr +++ b/dropshot/tests/fail/bad_endpoint8.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint8.rs:19:1 diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index 9aa8e302c..a07bc80c6 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -36,6 +36,7 @@ use dropshot::Path; use dropshot::Query; use dropshot::RawRequest; use dropshot::RequestContext; +use dropshot::StreamingBody; use dropshot::TypedBody; use dropshot::UntypedBody; use dropshot::WebsocketChannelResult; @@ -43,6 +44,7 @@ use dropshot::WebsocketConnection; use dropshot::CONTENT_TYPE_JSON; use futures::stream::StreamExt; use futures::SinkExt; +use futures::TryStreamExt; use http::StatusCode; use hyper::Body; use hyper::Method; @@ -71,6 +73,7 @@ fn demo_api() -> ApiDescription { api.register(demo_handler_path_param_uuid).unwrap(); api.register(demo_handler_path_param_u32).unwrap(); api.register(demo_handler_untyped_body).unwrap(); + api.register(demo_handler_streaming_body).unwrap(); api.register(demo_handler_raw_request).unwrap(); api.register(demo_handler_delete).unwrap(); api.register(demo_handler_headers).unwrap(); @@ -727,6 +730,57 @@ async fn test_untyped_body() { testctx.teardown().await; } +// Test `StreamingBody`. +#[tokio::test] +async fn test_streaming_body() { + let api = demo_api(); + let testctx = common::test_setup("test_untyped_body", api); + let client = &testctx.client_testctx; + + // Success case: empty body + let mut response = client + .make_request_with_body( + Method::PUT, + "/testing/streaming_body", + "".into(), + StatusCode::OK, + ) + .await + .unwrap(); + let json: DemoStreaming = read_json(&mut response).await; + assert_eq!(json.nbytes, 0); + + // Success case: non-empty content + let body = vec![0u8; 1024]; + let mut response = client + .make_request_with_body( + Method::PUT, + "/testing/streaming_body", + body.into(), + StatusCode::OK, + ) + .await + .unwrap(); + let json: DemoStreaming = read_json(&mut response).await; + assert_eq!(json.nbytes, 1024); + + // Error case: body too large. + let big_body = vec![0u8; 1025]; + let error = client + .make_request_with_body( + Method::PUT, + "/testing/untyped_body", + big_body.into(), + StatusCode::BAD_REQUEST, + ) + .await + .unwrap_err(); + assert_eq!( + error.message, + "request body exceeded maximum size of 1024 bytes" + ); +} + // Test `RawRequest`. #[tokio::test] async fn test_raw_request() { @@ -1096,6 +1150,26 @@ async fn demo_handler_untyped_body( Ok(HttpResponseOk(DemoUntyped { nbytes, as_utf8 })) } +#[derive(Deserialize, Serialize, JsonSchema)] +pub struct DemoStreaming { + pub nbytes: usize, +} +#[endpoint { + method = PUT, + path = "/testing/streaming_body" +}] +async fn demo_handler_streaming_body( + _rqctx: RequestContext, + body: StreamingBody, +) -> Result, HttpError> { + let nbytes = body + .into_stream() + .try_fold(0, |acc, v| futures::future::ok(acc + v.len())) + .await?; + + Ok(HttpResponseOk(DemoStreaming { nbytes })) +} + #[derive(Deserialize, Serialize, JsonSchema)] pub struct DemoRaw { pub nbytes: usize, diff --git a/dropshot_endpoint/src/lib.rs b/dropshot_endpoint/src/lib.rs index 0d7fd586a..ad3c9b45a 100644 --- a/dropshot_endpoint/src/lib.rs +++ b/dropshot_endpoint/src/lib.rs @@ -85,7 +85,8 @@ const USAGE: &str = "Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result"; From bb21a2af5a24d1c101e0aa8bd1d70e08431f7200 Mon Sep 17 00:00:00 2001 From: Rain Date: Thu, 23 Mar 2023 17:23:25 -0700 Subject: [PATCH 2/3] Address review feedback Created using spr 1.3.4 --- Cargo.lock | 17 +------- dropshot/Cargo.toml | 1 - dropshot/src/extractor/body.rs | 72 ++++++---------------------------- 3 files changed, 14 insertions(+), 76 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a0ff7e65c..6451d01b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -354,7 +354,6 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-tungstenite", - "tokio-util 0.7.7", "toml", "trybuild", "usdt", @@ -599,7 +598,7 @@ dependencies = [ "indexmap", "slab", "tokio", - "tokio-util 0.6.8", + "tokio-util", "tracing", ] @@ -1787,20 +1786,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-util" -version = "0.7.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", - "tracing", -] - [[package]] name = "toml" version = "0.7.3" diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 89bb71218..0f7946b0c 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -85,7 +85,6 @@ libc = "0.2.140" mime_guess = "2.0.4" subprocess = "0.2.9" tempfile = "3.4" -tokio-util = { version = "0.7.7", features = ["codec"] } trybuild = "1.0.79" # Used by the https examples and tests pem = "1.1" diff --git a/dropshot/src/extractor/body.rs b/dropshot/src/extractor/body.rs index 82b457f38..0a50b79ce 100644 --- a/dropshot/src/extractor/body.rs +++ b/dropshot/src/extractor/body.rs @@ -231,26 +231,6 @@ impl StreamingBody { Self { body, cap } } - /// Converts `self` into a [`BytesMut`], buffering the entire response in memory. - /// - /// If payloads are expected to be large, consider using [`Self::into_stream`] to - /// avoid buffering in memory if possible. - /// - /// # Errors - /// - /// Returns an [`HttpError`] if any of the following cases occur: - /// - /// * A network error occurred. - /// * `request_body_max_bytes` was exceeded for this request. - pub async fn into_bytes_mut(self) -> Result { - self.into_stream() - .try_fold(BytesMut::new(), |mut out, chunk| { - out.put(chunk); - futures::future::ok(out) - }) - .await - } - /// Converts `self` into a stream. /// /// The `Stream` produces values of type `Result`. @@ -266,9 +246,7 @@ impl StreamingBody { /// /// Buffer a `StreamingBody` in-memory, into a /// [`BufList`](https://docs.rs/buf-list/latest/buf_list/struct.BufList.html) - /// (a segmented list of [`Bytes`] chunks). This is similar to - /// [`Self::into_bytes_mut`], except it avoids copying memory into a single - /// large allocation. + /// (a segmented list of [`Bytes`] chunks). /// /// ``` /// use buf_list::BufList; @@ -325,42 +303,6 @@ impl StreamingBody { /// # assert_eq!(writer, &b"foobar"[..]); /// # } /// ``` - /// - /// --- - /// - /// An alternative way to write data to an `AsyncWrite`, using - /// `tokio-util`'s - /// [codecs](https://docs.rs/tokio-util/latest/tokio_util/codec/index.html): - /// - /// ``` - /// use bytes::Bytes; - /// use dropshot::{HttpError, StreamingBody}; - /// use futures::{prelude::*, SinkExt}; - /// use tokio::io::AsyncWrite; - /// use tokio_util::codec::{BytesCodec, FramedWrite}; - /// - /// async fn write_all_sink( - /// body: StreamingBody, - /// writer: &mut W, - /// ) -> Result<(), HttpError> { - /// let stream = body.into_stream(); - /// // This type annotation is required for Rust to compile this code. - /// let sink = SinkExt::::sink_map_err( - /// FramedWrite::new(writer, BytesCodec::new()), - /// |error| HttpError::for_unavail(None, format!("write failed: {error}")), - /// ); - /// - /// stream.forward(sink).await - /// } - /// - /// # #[tokio::main] - /// # async fn main() { - /// # let body = StreamingBody::__from_bytes(Bytes::from("foobar")); - /// # let mut writer = vec![]; - /// # write_all_sink(body, &mut writer).await.unwrap(); - /// # assert_eq!(writer, &b"foobar"[..]); - /// # } - /// ``` pub fn into_stream( mut self, ) -> impl Stream> + Send { @@ -388,6 +330,18 @@ impl StreamingBody { self.body.trailers().await?; } } + + /// Converts `self` into a [`BytesMut`], buffering the entire response in + /// memory. Not public API because most users of this should use + /// `UntypedBody` instead. + async fn into_bytes_mut(self) -> Result { + self.into_stream() + .try_fold(BytesMut::new(), |mut out, chunk| { + out.put(chunk); + futures::future::ok(out) + }) + .await + } } #[async_trait] From 8e7645e287889f1a267aaf1b62854e11c64fe900 Mon Sep 17 00:00:00 2001 From: Rain Date: Thu, 23 Mar 2023 17:24:52 -0700 Subject: [PATCH 3/3] fix test Created using spr 1.3.4 --- dropshot/tests/test_demo.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index 9e802d169..41702b9cb 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -769,7 +769,7 @@ async fn test_streaming_body() { let error = client .make_request_with_body( Method::PUT, - "/testing/untyped_body", + "/testing/streaming_body", big_body.into(), StatusCode::BAD_REQUEST, )