From ee9a18e9c69274eb4beae2f3c5c1e65804ed0f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?James=20=E2=80=98Twey=E2=80=99=20Kay?= Date: Wed, 4 Jun 2025 17:05:12 +0100 Subject: [PATCH] fix(tonic): make `Streaming` `Sync` again The boxed `Decoder` inside `Streaming` need not be `Sync` since https://github.com/hyperium/tonic/pull/804. Unfortunately, that makes `Streaming` non-`Sync`, meaning that all the generated `tonic` futures cannot be awaited in `Sync` futures. In fact, the only times we use the `Decoder`, we have a `&mut` unique reference to it, so we are guaranteed not to require synchronization. The `sync_wrapper` crate encodes this reasoning, allowing us to safely make the `Streaming` type `Sync` regardless of whether the contained `Decoder` is `Sync` or not. --- tonic/Cargo.toml | 1 + tonic/src/codec/decode.rs | 22 +++++++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index a9dca6f24..350877cd9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -94,6 +94,7 @@ zstd = { version = "0.13.0", optional = true } # channel hyper-timeout = {version = "0.5", optional = true} +sync_wrapper = "1.0.2" [dev-dependencies] bencher = "0.1.5" diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 60c0c9d35..a221a5c93 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -11,6 +11,7 @@ use std::{ task::ready, task::{Context, Poll}, }; +use sync_wrapper::SyncWrapper; use tokio_stream::Stream; use tracing::{debug, trace}; @@ -19,12 +20,12 @@ use tracing::{debug, trace}; /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface /// to fetch the message stream and trailing metadata pub struct Streaming { - decoder: Box + Send + 'static>, + decoder: SyncWrapper + Send + 'static>>, inner: StreamingInner, } struct StreamingInner { - body: Body, + body: SyncWrapper, state: State, direction: Direction, buf: BytesMut, @@ -123,14 +124,14 @@ impl Streaming { { let buffer_size = decoder.buffer_settings().buffer_size; Self { - decoder: Box::new(decoder), + decoder: SyncWrapper::new(Box::new(decoder)), inner: StreamingInner { - body: Body::new( + body: SyncWrapper::new(Body::new( body.map_frame(|frame| { frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) }) .map_err(|err| Status::map_error(err.into())), - ), + )), state: State::ReadHeader, direction, buf: BytesMut::with_capacity(buffer_size), @@ -243,7 +244,7 @@ impl StreamingInner { // Returns Some(()) if data was found or None if the loop in `poll_next` should break fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { - let frame = match ready!(Pin::new(&mut self.body).poll_frame(cx)) { + let frame = match ready!(Pin::new(self.body.get_mut()).poll_frame(cx)) { Some(Ok(frame)) => frame, Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { @@ -367,8 +368,11 @@ impl Streaming { } fn decode_chunk(&mut self) -> Result, Status> { - match self.inner.decode_chunk(self.decoder.buffer_settings())? { - Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { + match self + .inner + .decode_chunk(self.decoder.get_mut().buffer_settings())? + { + Some(mut decode_buf) => match self.decoder.get_mut().decode(&mut decode_buf)? { Some(msg) => { self.inner.state = State::ReadHeader; Ok(Some(msg)) @@ -413,4 +417,4 @@ impl fmt::Debug for Streaming { } #[cfg(test)] -static_assertions::assert_impl_all!(Streaming<()>: Send); +static_assertions::assert_impl_all!(Streaming<()>: Send, Sync);