Skip to content

Commit 6ae0098

Browse files
committed
feat(net): AsyncWrite and AsyncRead for WebSocket
Implements `AsyncWrite` and `AsyncRead` on `WebSocket` behind the `io-util` feature flag.
1 parent afbacda commit 6ae0098

File tree

4 files changed

+185
-1
lines changed

4 files changed

+185
-1
lines changed

crates/net/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ gloo-utils = { version = "0.2", path = "../utils", default-features = false }
2727
wasm-bindgen-futures = "0.4"
2828
futures-core = { version = "0.3", optional = true }
2929
futures-sink = { version = "0.3", optional = true }
30+
futures-io = { version = "0.3", optional = true }
3031

3132
thiserror = "1.0"
3233

@@ -45,7 +46,7 @@ serde = { version = "1.0", features = ["derive"] }
4546
once_cell = "1"
4647

4748
[features]
48-
default = ["json", "websocket", "http", "eventsource"]
49+
default = ["json", "websocket", "http", "eventsource", "io-util"]
4950

5051
# Enables `.json()` on `Response`
5152
json = ["serde", "serde_json", "gloo-utils/serde"]
@@ -99,3 +100,5 @@ eventsource = [
99100
'web-sys/EventSource',
100101
'web-sys/MessageEvent',
101102
]
103+
# As of now, only implements `AsyncRead` and `AsyncWrite` on `WebSocket`
104+
io-util = ["futures-io"]

crates/net/src/websocket/futures.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ pub struct WebSocket {
5757
Closure<dyn FnMut(web_sys::Event)>,
5858
Closure<dyn FnMut(web_sys::CloseEvent)>,
5959
),
60+
/// Leftover bytes when using `AsyncRead`.
61+
///
62+
/// These bytes are drained and returned in subsequent calls to `poll_read`.
63+
#[cfg(feature = "io-util")]
64+
pub(super) read_pending_bytes: Option<Vec<u8>>, // Same size as `Vec<u8>` alone thanks to niche optimization
6065
}
6166

6267
impl WebSocket {
@@ -196,6 +201,8 @@ impl WebSocket {
196201
error_callback,
197202
close_callback,
198203
),
204+
#[cfg(feature = "io-util")]
205+
read_pending_bytes: None,
199206
})
200207
}
201208

crates/net/src/websocket/io_util.rs

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
use core::cmp;
2+
use core::pin::Pin;
3+
use core::task::{Context, Poll};
4+
use std::io;
5+
6+
use futures_core::{ready, Stream as _};
7+
use futures_io::{AsyncRead, AsyncWrite};
8+
use futures_sink::Sink;
9+
10+
use crate::websocket::futures::WebSocket;
11+
use crate::websocket::{Message as WebSocketMessage, WebSocketError};
12+
13+
impl WebSocket {
14+
/// Returns whether there are pending bytes left after calling [`AsyncRead::poll_read`] on this WebSocket.
15+
///
16+
/// When calling [`AsyncRead::poll_read`], [`Stream::poll_next`](futures_core::Stream::poll_next) is called
17+
/// under the hood, and when the received item is too big to fit into the provided buffer, leftover bytes are
18+
/// stored. These leftover bytes are returned by subsequent calls to [`AsyncRead::poll_read`].
19+
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
20+
pub fn has_pending_bytes(&self) -> bool {
21+
self.read_pending_bytes.is_some()
22+
}
23+
}
24+
25+
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
26+
impl AsyncRead for WebSocket {
27+
fn poll_read(
28+
mut self: Pin<&mut Self>,
29+
cx: &mut Context<'_>,
30+
buf: &mut [u8],
31+
) -> Poll<io::Result<usize>> {
32+
let mut data = if let Some(data) = self.as_mut().get_mut().read_pending_bytes.take() {
33+
data
34+
} else {
35+
match ready!(self.as_mut().poll_next(cx)) {
36+
Some(Ok(m)) => match m {
37+
WebSocketMessage::Text(s) => s.into_bytes(),
38+
WebSocketMessage::Bytes(data) => data,
39+
},
40+
Some(Err(WebSocketError::ConnectionClose(event))) if event.was_clean == true => {
41+
return Poll::Ready(Ok(0));
42+
}
43+
Some(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
44+
None => return Poll::Ready(Ok(0)),
45+
}
46+
};
47+
48+
let bytes_to_copy = cmp::min(buf.len(), data.len());
49+
buf[..bytes_to_copy].copy_from_slice(&data[..bytes_to_copy]);
50+
51+
if data.len() > bytes_to_copy {
52+
data.drain(..bytes_to_copy);
53+
self.get_mut().read_pending_bytes = Some(data);
54+
}
55+
56+
Poll::Ready(Ok(bytes_to_copy))
57+
}
58+
}
59+
60+
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
61+
impl AsyncWrite for WebSocket {
62+
fn poll_write(
63+
mut self: Pin<&mut Self>,
64+
cx: &mut Context<'_>,
65+
buf: &[u8],
66+
) -> Poll<io::Result<usize>> {
67+
macro_rules! try_in_poll {
68+
($expr:expr) => {{
69+
match $expr {
70+
Ok(o) => o,
71+
// When using `AsyncWriteExt::write_all`, `io::ErrorKind::WriteZero` will be raised.
72+
// In this case it means "attempted to write on a closed socket".
73+
Err(WebSocketError::ConnectionClose(_)) => return Poll::Ready(Ok(0)),
74+
Err(e) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
75+
}
76+
}};
77+
}
78+
79+
// try flushing preemptively
80+
let _ = AsyncWrite::poll_flush(self.as_mut(), cx);
81+
82+
// make sure sink is ready to send
83+
try_in_poll!(ready!(self.as_mut().poll_ready(cx)));
84+
85+
// actually submit new item
86+
try_in_poll!(self.start_send(WebSocketMessage::Bytes(buf.to_vec())));
87+
// ^ if no error occurred, message is accepted and queued when calling `start_send`
88+
// (i.e.: `to_vec` is called only once)
89+
90+
Poll::Ready(Ok(buf.len()))
91+
}
92+
93+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
94+
let res = ready!(Sink::poll_flush(self, cx));
95+
Poll::Ready(ws_result_to_io_result(res))
96+
}
97+
98+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
99+
let res = ready!(Sink::poll_close(self, cx));
100+
Poll::Ready(ws_result_to_io_result(res))
101+
}
102+
}
103+
104+
fn ws_result_to_io_result(res: Result<(), WebSocketError>) -> io::Result<()> {
105+
match res {
106+
Ok(()) => Ok(()),
107+
Err(WebSocketError::ConnectionClose(_)) => Ok(()),
108+
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
109+
}
110+
}
111+
112+
#[cfg(test)]
113+
mod tests {
114+
use super::*;
115+
use futures::{AsyncReadExt, AsyncWriteExt, StreamExt};
116+
use wasm_bindgen_futures::spawn_local;
117+
use wasm_bindgen_test::*;
118+
119+
wasm_bindgen_test_configure!(run_in_browser);
120+
121+
#[wasm_bindgen_test]
122+
async fn check_read_write() {
123+
let ws_echo_server_url =
124+
option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
125+
126+
let mut ws = WebSocket::open(ws_echo_server_url).unwrap();
127+
128+
// ignore first message
129+
// the echo-server uses it to send it's info in the first message
130+
let _ = ws.next().await.unwrap();
131+
132+
let (mut reader, mut writer) = AsyncReadExt::split(ws);
133+
134+
spawn_local(async move {
135+
writer.write_all(b"test 1").await.unwrap();
136+
writer.write_all(b"test 2").await.unwrap();
137+
});
138+
139+
spawn_local(async move {
140+
let mut buf = [0u8; 6];
141+
reader.read_exact(&mut buf).await.unwrap();
142+
assert_eq!(&buf, b"test 1");
143+
reader.read_exact(&mut buf).await.unwrap();
144+
assert_eq!(&buf, b"test 2");
145+
});
146+
}
147+
148+
#[wasm_bindgen_test]
149+
async fn with_pending_bytes() {
150+
let ws_echo_server_url =
151+
option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
152+
153+
let mut ws = WebSocket::open(ws_echo_server_url).unwrap();
154+
155+
// ignore first message
156+
// the echo-server uses it to send it's info in the first message
157+
let _ = ws.next().await.unwrap();
158+
159+
ws.write_all(b"1234567890").await.unwrap();
160+
161+
let mut buf = [0u8; 5];
162+
163+
ws.read_exact(&mut buf).await.unwrap();
164+
assert_eq!(&buf, b"12345");
165+
assert!(ws.has_pending_bytes());
166+
167+
ws.read_exact(&mut buf).await.unwrap();
168+
assert_eq!(&buf, b"67890");
169+
assert!(!ws.has_pending_bytes());
170+
}
171+
}

crates/net/src/websocket/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
pub mod events;
77
pub mod futures;
88

9+
#[cfg(feature = "io-util")]
10+
mod io_util;
11+
912
use events::CloseEvent;
1013
use gloo_utils::errors::JsError;
1114
use std::fmt;

0 commit comments

Comments
 (0)