Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/warp_subscriptions/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async fn main() {
ctx: Context,
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
graphql_subscriptions(websocket, coordinator, ctx)
graphql_subscriptions(websocket, coordinator, ctx, None)
.map(|r| {
if let Err(e) = r {
println!("Websocket error: {}", e);
Expand Down
1 change: 1 addition & 0 deletions juniper_warp/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Update `playground_filter` to support subscription endpoint URLs
- Update `warp` to 0.2
- Rename synchronous `execute` to `execute_sync`, add asynchronous `execute`
- Add `graphql_subscriptions` on_connect handler

# [[0.5.2] 2019-12-16](https://github.com/graphql-rust/juniper/releases/tag/juniper_warp-0.5.2)

Expand Down
40 changes: 34 additions & 6 deletions juniper_warp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,18 @@ pub mod subscriptions {
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
use juniper_subscriptions::Coordinator;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use warp::ws::Message;

/// Listen to incoming messages and do one of the following:
/// - execute subscription and return values from stream
/// - stop stream and close ws connection
#[allow(dead_code)]
pub fn graphql_subscriptions<Query, Mutation, Subscription, Context, S>(
pub fn graphql_subscriptions<Query, Mutation, Subscription, Context, S, Connect>(
websocket: warp::ws::WebSocket,
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
context: Context,
on_connect: Option<Connect>,
) -> impl Future<Output = Result<(), failure::Error>> + Send
where
S: ScalarValue + Send + Sync + 'static,
Expand All @@ -480,6 +482,7 @@ pub mod subscriptions {
Subscription:
juniper::GraphQLSubscriptionType<S, Context = Context> + Send + Sync + 'static,
Subscription::TypeInfo: Send + Sync,
Connect: Fn(&Value) -> Result<(), String> + Clone + Send + Sync + 'static,
{
let (sink_tx, sink_rx) = websocket.split();
let (ws_tx, ws_rx) = mpsc::unbounded();
Expand All @@ -500,6 +503,7 @@ pub mod subscriptions {
let running = running.clone();
let got_close_signal = got_close_signal.clone();
let ws_tx = ws_tx.clone();
let on_connect = on_connect.clone();

async move {
let msg = match msg {
Expand All @@ -517,12 +521,30 @@ pub mod subscriptions {
let msg = msg
.to_str()
.map_err(|_| failure::format_err!("Non-text messages are not accepted"))?;
let request: WsPayload<S> = serde_json::from_str(msg)
let raw_request: Value = serde_json::from_str(msg)
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;

match request.type_name.as_str() {
"connection_init" => {}
"start" => {
match raw_request["type"].as_str() {
Some("connection_init") => {
if let Some(callback) = on_connect {
if let Err(err) = callback(&raw_request["payload"]) {
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(format!(
r#"{{"type":"connection_error","payload":"{}"}}"#,
err,
)))));

// close channel
let _ = ws_tx.unbounded_send(None);
} else {
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(
r#"{"type":"connection_ack","payload":null}"#,
))));
}
}
}
Some("start") => {
let request: WsPayload<S> = serde_json::from_value(raw_request)
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
{
let closed = got_close_signal.load(Ordering::Relaxed);
if closed {
Expand Down Expand Up @@ -604,7 +626,9 @@ pub mod subscriptions {
.await;
});
}
"stop" => {
Some("stop") => {
let request: WsPayload<S> = serde_json::from_value(raw_request)
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
got_close_signal.store(true, Ordering::Relaxed);

let request_id = request.id.unwrap_or("1".to_owned());
Expand All @@ -617,6 +641,9 @@ pub mod subscriptions {
// close channel
let _ = ws_tx.unbounded_send(None);
}
None => {
return Err(failure::err_msg("Invalid WsPayload: No type specified"));
}
_ => {}
}

Expand All @@ -625,6 +652,7 @@ pub mod subscriptions {
})
}

#[allow(dead_code)]
#[derive(Deserialize)]
#[serde(bound = "GraphQLPayload<S>: Deserialize<'de>")]
struct WsPayload<S>
Expand Down