diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index a9dca6f24..350877cd9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -94,6 +94,7 @@ zstd = { version = "0.13.0", optional = true } # channel hyper-timeout = {version = "0.5", optional = true} +sync_wrapper = "1.0.2" [dev-dependencies] bencher = "0.1.5" diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 60c0c9d35..a221a5c93 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -11,6 +11,7 @@ use std::{ task::ready, task::{Context, Poll}, }; +use sync_wrapper::SyncWrapper; use tokio_stream::Stream; use tracing::{debug, trace}; @@ -19,12 +20,12 @@ use tracing::{debug, trace}; /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface /// to fetch the message stream and trailing metadata pub struct Streaming { - decoder: Box + Send + 'static>, + decoder: SyncWrapper + Send + 'static>>, inner: StreamingInner, } struct StreamingInner { - body: Body, + body: SyncWrapper, state: State, direction: Direction, buf: BytesMut, @@ -123,14 +124,14 @@ impl Streaming { { let buffer_size = decoder.buffer_settings().buffer_size; Self { - decoder: Box::new(decoder), + decoder: SyncWrapper::new(Box::new(decoder)), inner: StreamingInner { - body: Body::new( + body: SyncWrapper::new(Body::new( body.map_frame(|frame| { frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) }) .map_err(|err| Status::map_error(err.into())), - ), + )), state: State::ReadHeader, direction, buf: BytesMut::with_capacity(buffer_size), @@ -243,7 +244,7 @@ impl StreamingInner { // Returns Some(()) if data was found or None if the loop in `poll_next` should break fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { - let frame = match ready!(Pin::new(&mut self.body).poll_frame(cx)) { + let frame = match ready!(Pin::new(self.body.get_mut()).poll_frame(cx)) { Some(Ok(frame)) => frame, Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { @@ -367,8 +368,11 @@ impl Streaming { } fn decode_chunk(&mut self) -> Result, Status> { - match self.inner.decode_chunk(self.decoder.buffer_settings())? { - Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { + match self + .inner + .decode_chunk(self.decoder.get_mut().buffer_settings())? + { + Some(mut decode_buf) => match self.decoder.get_mut().decode(&mut decode_buf)? { Some(msg) => { self.inner.state = State::ReadHeader; Ok(Some(msg)) @@ -413,4 +417,4 @@ impl fmt::Debug for Streaming { } #[cfg(test)] -static_assertions::assert_impl_all!(Streaming<()>: Send); +static_assertions::assert_impl_all!(Streaming<()>: Send, Sync);