Skip to content

GraphQL-WS crate and Warp subscriptions update #721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 29, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ members = [
"juniper_rocket",
"juniper_rocket_async",
"juniper_subscriptions",
"juniper_graphql_ws",
"juniper_warp",
"juniper_actix",
]
6 changes: 3 additions & 3 deletions examples/warp_subscriptions/Cargo.toml
Original file line number Diff line number Diff line change
@@ -13,6 +13,6 @@ serde_json = "1.0"
tokio = { version = "0.2", features = ["rt-core", "macros"] }
warp = "0.2.1"

juniper = { git = "https://github.com/graphql-rust/juniper" }
juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" }
juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] }
juniper = { path = "../../juniper" }
juniper_graphql_ws = { path = "../../juniper_graphql_ws" }
juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] }
38 changes: 16 additions & 22 deletions examples/warp_subscriptions/src/main.rs
Original file line number Diff line number Diff line change
@@ -2,10 +2,10 @@

use std::{env, pin::Pin, sync::Arc, time::Duration};

use futures::{Future, FutureExt as _, Stream};
use futures::{FutureExt as _, Stream};
use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode};
use juniper_subscriptions::Coordinator;
use juniper_warp::{playground_filter, subscriptions::graphql_subscriptions};
use juniper_graphql_ws::ConnectionConfig;
use juniper_warp::{playground_filter, subscriptions::serve_graphql_ws};
use warp::{http::Response, Filter};

#[derive(Clone)]
@@ -151,30 +151,24 @@ async fn main() {
let qm_state = warp::any().map(move || Context {});
let qm_graphql_filter = juniper_warp::make_graphql_filter(qm_schema, qm_state.boxed());

let sub_state = warp::any().map(move || Context {});
let coordinator = Arc::new(juniper_subscriptions::Coordinator::new(schema()));
let root_node = Arc::new(schema());

log::info!("Listening on 127.0.0.1:8080");

let routes = (warp::path("subscriptions")
.and(warp::ws())
.and(sub_state.clone())
.and(warp::any().map(move || Arc::clone(&coordinator)))
.map(
|ws: warp::ws::Ws,
ctx: Context,
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
graphql_subscriptions(websocket, coordinator, ctx)
.map(|r| {
if let Err(e) = r {
println!("Websocket error: {}", e);
}
})
.boxed()
})
},
))
.map(move |ws: warp::ws::Ws| {
let root_node = root_node.clone();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to put this on the route?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed simpler to me, but if someone with more Warp expertise tells me it would be better for whatever reason (demonstration purposes?), I'm happy to change it.

ws.on_upgrade(move |websocket| async move {
serve_graphql_ws(websocket, root_node, ConnectionConfig::new(Context {}))
.map(|r| {
if let Err(e) = r {
println!("Websocket error: {}", e);
}
})
.await
})
}))
.map(|reply| {
// TODO#584: remove this workaround
warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")
2 changes: 2 additions & 0 deletions juniper/release.toml
Original file line number Diff line number Diff line change
@@ -30,6 +30,8 @@ pre-release-replacements = [
{file="../juniper_warp/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
# Subscriptions
{file="../juniper_subscriptions/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
# GraphQL-WS
{file="../juniper_graphql_ws/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
# Actix-Web
{file="../juniper_actix/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
{file="../juniper_actix/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
3 changes: 3 additions & 0 deletions juniper_graphql_ws/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# master

- Initial Release
19 changes: 19 additions & 0 deletions juniper_graphql_ws/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "juniper_graphql_ws"
version = "0.1.0"
authors = ["Christopher Brown <[email protected]>"]
license = "BSD-2-Clause"
description = "Graphql-ws protocol implementation for Juniper"
documentation = "https://docs.rs/juniper_graphql_ws"
repository = "https://github.com/graphql-rust/juniper"
keywords = ["graphql-ws", "juniper", "graphql", "apollo"]
edition = "2018"

[dependencies]
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
juniper_subscriptions = { path = "../juniper_subscriptions" }
serde = { version = "1.0.8", features = ["derive"] }
tokio = { version = "0.2", features = ["macros", "rt-core", "time"] }

[dev-dependencies]
serde_json = { version = "1.0.2" }
20 changes: 20 additions & 0 deletions juniper_graphql_ws/Makefile.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[env]
CARGO_MAKE_CARGO_ALL_FEATURES = ""

[tasks.build-verbose]
condition = { rust_version = { min = "1.29.0" } }

[tasks.build-verbose.windows]
condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }

[tasks.test-verbose]
condition = { rust_version = { min = "1.29.0" } }

[tasks.test-verbose.windows]
condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }

[tasks.ci-coverage-flow]
condition = { rust_version = { min = "1.29.0" } }

[tasks.ci-coverage-flow.windows]
disabled = true
8 changes: 8 additions & 0 deletions juniper_graphql_ws/release.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
no-dev-version = true
pre-release-commit-message = "Release {{crate_name}} {{version}}"
pro-release-commit-message = "Bump {{crate_name}} version to {{next_version}}"
tag-message = "Release {{crate_name}} {{version}}"
upload-doc = false
pre-release-replacements = [
{file="src/lib.rs", search="docs.rs/juniper_graphql_ws/[a-z0-9\\.-]+", replace="docs.rs/juniper_graphql_ws/{{version}}"},
]
131 changes: 131 additions & 0 deletions juniper_graphql_ws/src/client_message.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use juniper::{ScalarValue, Variables};

/// The payload for a client's "start" message. This triggers execution of a query, mutation, or
/// subscription.
#[derive(Debug, Deserialize, PartialEq)]
#[serde(bound(deserialize = "S: ScalarValue"))]
#[serde(rename_all = "camelCase")]
pub struct StartPayload<S: ScalarValue> {
/// The document body.
pub query: String,

/// The optional variables.
#[serde(default)]
pub variables: Variables<S>,

/// The optional operation name (required if the document contains multiple operations).
pub operation_name: Option<String>,
}

/// ClientMessage defines the message types that clients can send.
#[derive(Debug, Deserialize, PartialEq)]
#[serde(bound(deserialize = "S: ScalarValue"))]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum ClientMessage<S: ScalarValue> {
/// ConnectionInit is sent by the client upon connecting.
ConnectionInit {
/// Optional parameters of any type sent from the client. These are often used for
/// authentication.
#[serde(default)]
payload: Variables<S>,
},
/// Start messages are used to execute a GraphQL operation.
Start {
/// The id of the operation. This can be anything, but must be unique. If there are other
/// in-flight operations with the same id, the message will be ignored or cause an error.
id: String,

/// The query, variables, and operation name.
payload: StartPayload<S>,
},
/// Stop messages are used to unsubscribe from a subscription.
Stop {
/// The id of the operation to stop.
id: String,
},
/// ConnectionTerminate is used to terminate the connection.
ConnectionTerminate,
}

#[cfg(test)]
mod test {
use super::*;
use juniper::{DefaultScalarValue, InputValue};

#[test]
fn test_deserialization() {
type ClientMessage = super::ClientMessage<DefaultScalarValue>;

assert_eq!(
ClientMessage::ConnectionInit {
payload: [("foo".to_string(), InputValue::scalar("bar"))]
.iter()
.cloned()
.collect(),
},
serde_json::from_str(r##"{"type": "connection_init", "payload": {"foo": "bar"}}"##)
.unwrap(),
);

assert_eq!(
ClientMessage::ConnectionInit {
payload: Variables::default(),
},
serde_json::from_str(r##"{"type": "connection_init"}"##).unwrap(),
);

assert_eq!(
ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "query MyQuery { __typename }".to_string(),
variables: [("foo".to_string(), InputValue::scalar("bar"))]
.iter()
.cloned()
.collect(),
operation_name: Some("MyQuery".to_string()),
},
},
serde_json::from_str(
r##"{"type": "start", "id": "foo", "payload": {
"query": "query MyQuery { __typename }",
"variables": {
"foo": "bar"
},
"operationName": "MyQuery"
}}"##
)
.unwrap(),
);

assert_eq!(
ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "query MyQuery { __typename }".to_string(),
variables: Variables::default(),
operation_name: None,
},
},
serde_json::from_str(
r##"{"type": "start", "id": "foo", "payload": {
"query": "query MyQuery { __typename }"
}}"##
)
.unwrap(),
);

assert_eq!(
ClientMessage::Stop {
id: "foo".to_string()
},
serde_json::from_str(r##"{"type": "stop", "id": "foo"}"##).unwrap(),
);

assert_eq!(
ClientMessage::ConnectionTerminate,
serde_json::from_str(r##"{"type": "connection_terminate"}"##).unwrap(),
);
}
}
1,073 changes: 1,073 additions & 0 deletions juniper_graphql_ws/src/lib.rs

Large diffs are not rendered by default.

131 changes: 131 additions & 0 deletions juniper_graphql_ws/src/schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
use std::sync::Arc;

/// Schema defines the requirements for schemas that can be used for operations. Typically this is
/// just an `Arc<RootNode<...>>` and you should not have to implement it yourself.
pub trait Schema: Unpin + Clone + Send + Sync + 'static {
/// The context type.
type Context: Unpin + Send + Sync;

/// The scalar value type.
type ScalarValue: ScalarValue + Send + Sync;

/// The query type info.
type QueryTypeInfo: Send + Sync;

/// The query type.
type Query: GraphQLTypeAsync<Self::ScalarValue, Context = Self::Context, TypeInfo = Self::QueryTypeInfo>
+ Send;

/// The mutation type info.
type MutationTypeInfo: Send + Sync;

/// The mutation type.
type Mutation: GraphQLTypeAsync<
Self::ScalarValue,
Context = Self::Context,
TypeInfo = Self::MutationTypeInfo,
> + Send;

/// The subscription type info.
type SubscriptionTypeInfo: Send + Sync;

/// The subscription type.
type Subscription: GraphQLSubscriptionType<
Self::ScalarValue,
Context = Self::Context,
TypeInfo = Self::SubscriptionTypeInfo,
> + Send;

/// Returns the root node for the schema.
fn root_node(
&self,
) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>;
}

/// This exists as a work-around for this issue: https://github.com/rust-lang/rust/issues/64552
///
/// It can be used in generators where using Arc directly would result in an error.
// TODO: Remove this once that issue is resolved.
#[doc(hidden)]
pub struct ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>(
pub Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
)
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync,
S: ScalarValue + Send + Sync + 'static;

impl<QueryT, MutationT, SubscriptionT, CtxT, S> Clone
for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync,
S: ScalarValue + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
{
type Context = CtxT;
type ScalarValue = S;
type QueryTypeInfo = QueryT::TypeInfo;
type Query = QueryT;
type MutationTypeInfo = MutationT::TypeInfo;
type Mutation = MutationT;
type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
type Subscription = SubscriptionT;

fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
&self.0
}
}

impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
for Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync,
S: ScalarValue + Send + Sync + 'static,
{
type Context = CtxT;
type ScalarValue = S;
type QueryTypeInfo = QueryT::TypeInfo;
type Query = QueryT;
type MutationTypeInfo = MutationT::TypeInfo;
type Mutation = MutationT;
type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
type Subscription = SubscriptionT;

fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
self
}
}
191 changes: 191 additions & 0 deletions juniper_graphql_ws/src/server_message.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use juniper::{ExecutionError, GraphQLError, ScalarValue, Value};
use serde::{Serialize, Serializer};
use std::{any::Any, fmt, marker::PhantomPinned};

/// The payload for errors that are not associated with a GraphQL operation.
#[derive(Debug, Serialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionErrorPayload {
/// The error message.
pub message: String,
}

/// Sent after execution of an operation. For queries and mutations, this is sent to the client
/// once. For subscriptions, this is sent for every event in the event stream.
#[derive(Debug, Serialize, PartialEq)]
#[serde(bound(serialize = "S: ScalarValue"))]
#[serde(rename_all = "camelCase")]
pub struct DataPayload<S> {
/// The result data.
pub data: Value<S>,

/// The errors that have occurred during execution. Note that parse and validation errors are
/// not included here. They are sent via Error messages.
pub errors: Vec<ExecutionError<S>>,
}

/// A payload for errors that can happen before execution. Errors that happen during execution are
/// instead sent to the client via `DataPayload`. `ErrorPayload` is a wrapper for an owned
/// `GraphQLError`.
// XXX: Think carefully before deriving traits. This is self-referential (error references
// _execution_params).
pub struct ErrorPayload {
_execution_params: Option<Box<dyn Any + Send>>,
error: GraphQLError<'static>,
_marker: PhantomPinned,
}

impl ErrorPayload {
/// For this to be okay, the caller must guarantee that the error can only reference data from
/// execution_params and that execution_params has not been modified or moved.
pub(crate) unsafe fn new_unchecked<'a>(
execution_params: Box<dyn Any + Send>,
error: GraphQLError<'a>,
) -> Self {
Self {
_execution_params: Some(execution_params),
error: std::mem::transmute(error),
_marker: PhantomPinned,
}
}

/// Returns the contained GraphQLError.
pub fn graphql_error<'a>(&'a self) -> &GraphQLError<'a> {
&self.error
}
}

impl fmt::Debug for ErrorPayload {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.error.fmt(f)
}
}

impl PartialEq for ErrorPayload {
fn eq(&self, other: &Self) -> bool {
self.error.eq(&other.error)
}
}

impl Serialize for ErrorPayload {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.error.serialize(serializer)
}
}

impl From<GraphQLError<'static>> for ErrorPayload {
fn from(error: GraphQLError<'static>) -> Self {
Self {
_execution_params: None,
error,
_marker: PhantomPinned,
}
}
}

/// ServerMessage defines the message types that servers can send.
#[derive(Debug, Serialize, PartialEq)]
#[serde(bound(serialize = "S: ScalarValue"))]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum ServerMessage<S: ScalarValue> {
/// ConnectionError is used for errors that are not associated with a GraphQL operation. For
/// example, this will be used when:
///
/// * The server is unable to parse a client's message.
/// * The client's initialization parameters are rejected.
ConnectionError {
/// The error that occurred.
payload: ConnectionErrorPayload,
},
/// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a
/// connection.
ConnectionAck,
/// Data contains the result of a query, mutation, or subscription event.
Data {
/// The id of the operation that the data is for.
id: String,

/// The data and errors that occurred during execution.
payload: DataPayload<S>,
},
/// Error contains an error that occurs before execution, such as validation errors.
Error {
/// The id of the operation that triggered this error.
id: String,

/// The error(s).
payload: ErrorPayload,
},
/// Complete indicates that no more data will be sent for the given operation.
Complete {
/// The id of the operation that has completed.
id: String,
},
/// ConnectionKeepAlive is sent periodically after accepting a connection.
#[serde(rename = "ka")]
ConnectionKeepAlive,
}

#[cfg(test)]
mod test {
use super::*;
use juniper::DefaultScalarValue;

#[test]
fn test_serialization() {
type ServerMessage = super::ServerMessage<DefaultScalarValue>;

assert_eq!(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: ConnectionErrorPayload {
message: "foo".to_string(),
},
})
.unwrap(),
r##"{"type":"connection_error","payload":{"message":"foo"}}"##,
);

assert_eq!(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
r##"{"type":"connection_ack"}"##,
);

assert_eq!(
serde_json::to_string(&ServerMessage::Data {
id: "foo".to_string(),
payload: DataPayload {
data: Value::null(),
errors: vec![],
},
})
.unwrap(),
r##"{"type":"data","id":"foo","payload":{"data":null,"errors":[]}}"##,
);

assert_eq!(
serde_json::to_string(&ServerMessage::Error {
id: "foo".to_string(),
payload: GraphQLError::UnknownOperationName.into(),
})
.unwrap(),
r##"{"type":"error","id":"foo","payload":[{"message":"Unknown operation"}]}"##,
);

assert_eq!(
serde_json::to_string(&ServerMessage::Complete {
id: "foo".to_string(),
})
.unwrap(),
r##"{"type":"complete","id":"foo"}"##,
);

assert_eq!(
serde_json::to_string(&ServerMessage::ConnectionKeepAlive).unwrap(),
r##"{"type":"ka"}"##,
);
}
}
16 changes: 11 additions & 5 deletions juniper_subscriptions/src/lib.rs
Original file line number Diff line number Diff line change
@@ -222,19 +222,25 @@ where
}

if filled_count == obj_len {
let mut errors = vec![];
filled_count = 0;
let new_vec = (0..obj_len).map(|_| None).collect::<Vec<_>>();
let ready_vec = std::mem::replace(&mut ready_vec, new_vec);
let ready_vec_iterator = ready_vec.into_iter().map(|el| {
let (name, val) = el.unwrap();
if let Ok(value) = val {
(name, value)
} else {
(name, Value::Null)
match val {
Ok(value) => (name, value),
Err(e) => {
errors.push(e);
(name, Value::Null)
}
}
});
let obj = Object::from_iter(ready_vec_iterator);
Poll::Ready(Some(ExecutionOutput::from_data(Value::Object(obj))))
Poll::Ready(Some(ExecutionOutput {
data: Value::Object(obj),
errors,
}))
} else {
Poll::Pending
}
4 changes: 2 additions & 2 deletions juniper_warp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -9,15 +9,15 @@ repository = "https://github.com/graphql-rust/juniper"
edition = "2018"

[features]
subscriptions = ["juniper_subscriptions"]
subscriptions = ["juniper_graphql_ws"]

[dependencies]
bytes = "0.5"
anyhow = "1.0"
thiserror = "1.0"
futures = "0.3.1"
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
juniper_subscriptions = { path = "../juniper_subscriptions", optional = true }
juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true }
serde = { version = "1.0.75", features = ["derive"] }
serde_json = "1.0.24"
tokio = { version = "0.2", features = ["blocking", "rt-core"] }
289 changes: 84 additions & 205 deletions juniper_warp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -393,224 +393,103 @@ fn playground_response(
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
#[cfg(feature = "subscriptions")]
pub mod subscriptions {
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
use juniper::{
futures::{
future::{self, Either},
sink::SinkExt,
stream::StreamExt,
},
GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue,
};
use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init};
use std::{convert::Infallible, fmt, sync::Arc};

use anyhow::anyhow;
use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
use juniper_subscriptions::Coordinator;
use serde::{Deserialize, Serialize};
use warp::ws::Message;

/// Listen to incoming messages and do one of the following:
/// - execute subscription and return values from stream
/// - stop stream and close ws connection
#[allow(dead_code)]
pub fn graphql_subscriptions<Query, Mutation, Subscription, CtxT, S>(
websocket: warp::ws::WebSocket,
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, CtxT, S>>,
context: CtxT,
) -> impl Future<Output = Result<(), anyhow::Error>> + Send
where
Query: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
{
let (sink_tx, sink_rx) = websocket.split();
let (ws_tx, ws_rx) = mpsc::unbounded();
tokio::task::spawn(
ws_rx
.take_while(|v: &Option<_>| futures::future::ready(v.is_some()))
.map(|x| x.unwrap())
.forward(sink_tx),
);
struct Message(warp::ws::Message);

let context = Arc::new(context);
let got_close_signal = Arc::new(AtomicBool::new(false));
let got_close_signal2 = got_close_signal.clone();
impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> {
type Error = serde_json::Error;

struct SubscriptionState {
should_stop: AtomicBool,
fn try_from(msg: Message) -> serde_json::Result<Self> {
serde_json::from_slice(msg.0.as_bytes())
}
let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new();
}

sink_rx
.map_err(move |e| {
got_close_signal2.store(true, Ordering::Relaxed);
anyhow!("Websocket error: {}", e)
})
.try_fold(subscription_states, move |mut subscription_states, msg| {
let coordinator = coordinator.clone();
let context = context.clone();
let got_close_signal = got_close_signal.clone();
let ws_tx = ws_tx.clone();

async move {
if msg.is_close() {
return Ok(subscription_states);
}

let msg = msg
.to_str()
.map_err(|_| anyhow!("Non-text messages are not accepted"))?;
let request: WsPayload<S> = serde_json::from_str(msg)
.map_err(|e| anyhow!("Invalid WsPayload: {}", e))?;

match request.type_name.as_str() {
"connection_init" => {}
"start" => {
if got_close_signal.load(Ordering::Relaxed) {
return Ok(subscription_states);
}

let request_id = request.id.clone().unwrap_or("1".to_owned());

if let Some(existing) = subscription_states.get(&request_id) {
existing.should_stop.store(true, Ordering::Relaxed);
}
let state = Arc::new(SubscriptionState {
should_stop: AtomicBool::new(false),
});
subscription_states.insert(request_id.clone(), state.clone());

let ws_tx = ws_tx.clone();

if let Some(ref payload) = request.payload {
if payload.query.is_none() {
return Err(anyhow!("Query not found"));
}
} else {
return Err(anyhow!("Payload not found"));
}

tokio::task::spawn(async move {
let payload = request.payload.unwrap();

let graphql_request = GraphQLRequest::<S>::new(
payload.query.unwrap(),
None,
payload.variables,
);

let values_stream = match coordinator
.subscribe(&graphql_request, &context)
.await
{
Ok(s) => s,
Err(err) => {
let _ =
ws_tx.unbounded_send(Some(Ok(Message::text(format!(
r#"{{"type":"error","id":"{}","payload":{}}}"#,
request_id,
serde_json::ser::to_string(&err).unwrap_or(
"Error deserializing GraphQLError".to_owned()
)
)))));

let close_message = format!(
r#"{{"type":"complete","id":"{}","payload":null}}"#,
request_id
);
let _ = ws_tx
.unbounded_send(Some(Ok(Message::text(close_message))));
// close channel
let _ = ws_tx.unbounded_send(None);
return;
}
};

values_stream
.take_while(move |response| {
let request_id = request_id.clone();
let should_stop = state.should_stop.load(Ordering::Relaxed)
|| got_close_signal.load(Ordering::Relaxed);
if !should_stop {
let mut response_text = serde_json::to_string(
&response,
)
.unwrap_or("Error deserializing response".to_owned());

response_text = format!(
r#"{{"type":"data","id":"{}","payload":{} }}"#,
request_id, response_text
);

let _ = ws_tx.unbounded_send(Some(Ok(Message::text(
response_text,
))));
}

async move { !should_stop }
})
.for_each(|_| async {})
.await;
});
}
"stop" => {
let request_id = request.id.unwrap_or("1".to_owned());
if let Some(existing) = subscription_states.get(&request_id) {
existing.should_stop.store(true, Ordering::Relaxed);
subscription_states.remove(&request_id);
}

let close_message = format!(
r#"{{"type":"complete","id":"{}","payload":null}}"#,
request_id
);
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));

// close channel
let _ = ws_tx.unbounded_send(None);
}
_ => {}
}

Ok(subscription_states)
}
})
.map_ok(|_| ())
/// Errors that can happen while serving a connection.
#[derive(Debug)]
pub enum Error {
/// Errors that can happen in Warp while serving a connection.
Warp(warp::Error),

/// Errors that can happen while serializing outgoing messages. Note that errors that occur
/// while deserializing internal messages are handled internally by the protocol.
Serde(serde_json::Error),
}

#[derive(Deserialize)]
#[serde(bound = "GraphQLPayload<S>: Deserialize<'de>")]
struct WsPayload<S>
where
S: ScalarValue + Send + Sync,
{
id: Option<String>,
#[serde(rename(deserialize = "type"))]
type_name: String,
payload: Option<GraphQLPayload<S>>,
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Warp(e) => write!(f, "warp error: {}", e),
Self::Serde(e) => write!(f, "serde error: {}", e),
}
}
}

#[derive(Debug, Deserialize)]
#[serde(bound = "InputValue<S>: Deserialize<'de>")]
struct GraphQLPayload<S>
where
S: ScalarValue + Send + Sync,
{
variables: Option<InputValue<S>>,
extensions: Option<HashMap<String, String>>,
#[serde(rename(deserialize = "operationName"))]
operaton_name: Option<String>,
query: Option<String>,
impl std::error::Error for Error {}

impl From<warp::Error> for Error {
fn from(err: warp::Error) -> Self {
Self::Warp(err)
}
}

#[derive(Serialize)]
struct Output {
data: String,
variables: String,
impl From<Infallible> for Error {
fn from(_err: Infallible) -> Self {
unreachable!()
}
}

/// Serves the graphql-ws protocol over a WebSocket connection.
///
/// The `init` argument is used to provide the context and additional configuration for
/// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and
/// configuration are already known, or it can be a closure that gets executed asynchronously
/// when the client sends the ConnectionInit message. Using a closure allows you to perform
/// authentication based on the parameters provided by the client.
pub async fn serve_graphql_ws<Query, Mutation, Subscription, CtxT, S, I>(
websocket: warp::ws::WebSocket,
root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
init: I,
) -> Result<(), Error>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send,
{
let (ws_tx, ws_rx) = websocket.split();
let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split();

let ws_rx = ws_rx.map(|r| r.map(|msg| Message(msg)));
let s_rx = s_rx.map(|msg| {
serde_json::to_string(&msg)
.map(|t| warp::ws::Message::text(t))
.map_err(|e| Error::Serde(e))
});

match future::select(
ws_rx.forward(s_tx.sink_err_into()),
s_rx.forward(ws_tx.sink_err_into()),
)
.await
{
Either::Left((r, _)) => r.map_err(|e| e.into()),
Either::Right((r, _)) => r,
}
}
}