Skip to content

Batch request and spec conformance #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ tokio-util = { version = "0.7.13", optional = true, features = ["io"] }
tokio-tungstenite = { version = "0.26.1", features = ["rustls-tls-webpki-roots"], optional = true }
futures-util = { version = "0.3.31", optional = true }

[dev-dependencies]
tempfile = "3.15.0"
tracing-subscriber = "0.3.19"

[features]
default = ["axum", "ws", "ipc"]
axum = ["dep:axum"]
Expand Down Expand Up @@ -66,7 +70,3 @@ inherits = "dev"
strip = true
debug = false
incremental = false

[dev-dependencies]
tempfile = "3.15.0"
tracing-subscriber = "0.3.19"
30 changes: 13 additions & 17 deletions src/axum.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::{
types::{Request, Response},
HandlerArgs,
};
use crate::types::{InboundData, Response};
use axum::{extract::FromRequest, response::IntoResponse};
use bytes::Bytes;
use std::{future::Future, pin::Pin};
Expand All @@ -18,20 +15,19 @@ where
return Box::<str>::from(Response::parse_error()).into_response();
};

let Ok(req) = Request::try_from(bytes) else {
return Box::<str>::from(Response::parse_error()).into_response();
};

let args = HandlerArgs {
ctx: Default::default(),
req,
};

// Default handler ctx does not allow for notifications, which is
// what we want over HTTP.
let response = unwrap_infallible!(self.call_with_state(args, state).await);
// If the inbound data is not currently parsable, we
// send an empty one it to the router, as the router enforces
// the specification.
let req = InboundData::try_from(bytes).unwrap_or_default();

Box::<str>::from(response).into_response()
if let Some(response) = self
.call_batch_with_state(Default::default(), req, state)
.await
{
Box::<str>::from(response).into_response()
} else {
().into_response()
}
})
}
}
8 changes: 5 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ pub mod pubsub;
pub use pubsub::ReadJsonStream;

mod routes;
pub use routes::{BatchFuture, Handler, HandlerArgs, HandlerCtx, NotifyError, RouteFuture};
pub(crate) use routes::{BoxedIntoRoute, ErasedIntoRoute, Method, Route};
pub use routes::{Handler, HandlerArgs, HandlerCtx, NotifyError, RouteFuture};

mod router;
pub use router::Router;
Expand Down Expand Up @@ -208,7 +208,8 @@ mod test {
(),
)
.await
.expect("infallible");
.expect("infallible")
.expect("request had ID, is not a notification");

assert_rv_eq(
&res,
Expand All @@ -226,7 +227,8 @@ mod test {
(),
)
.await
.expect("infallible");
.expect("infallible")
.expect("request had ID, is not a notification");

assert_rv_eq(&res2, r#"{"jsonrpc":"2.0","id":1,"result":"{}"}"#);
}
Expand Down
39 changes: 15 additions & 24 deletions src/pubsub/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use core::fmt;

use crate::{
pubsub::{In, JsonSink, Listener, Out},
types::Request,
HandlerArgs,
types::InboundData,
};
use serde_json::value::RawValue;
use tokio::{
Expand Down Expand Up @@ -193,31 +192,25 @@ where
select! {
biased;
_ = write_task.closed() => {
debug!("IpcWriteTask has gone away");
debug!("WriteTask has gone away");
break;
}
item = requests.next() => {
let Some(item) = item else {
trace!("IPC read stream has closed");
trace!("inbound read stream has closed");
break;
};

let req = match Request::try_from(item) {
Ok(req) => req,
Err(err) => {
tracing::warn!(%err, "inbound request is malformatted");
continue
}
};
// If the inbound data is not currently parsable, we
// send an empty one it to the router, as the router
// enforces the specification.
let reqs = InboundData::try_from(item).unwrap_or_default();

let span = debug_span!("ipc request handling", id = req.id(), method = req.method());
let span = debug_span!("pubsub request handling", reqs = reqs.len());

let args = HandlerArgs {
ctx: write_task.clone().into(),
req,
};
let ctx = write_task.clone().into();

let fut = router.handle_request(args);
let fut = router.handle_request_batch(ctx, reqs);
let write_task = write_task.clone();

// Acquiring the permit before spawning the task means that
Expand All @@ -232,16 +225,14 @@ where
// Run the future in a new task.
tokio::spawn(
async move {
// Run the request handler and serialize the
// response.
let rv = fut.await.expect("infallible");

// Send the response to the write task.
// we don't care if the receiver has gone away,
// as the task is done regardless.
let _ = permit.send(
rv
);
if let Some(rv) = fut.await {
let _ = permit.send(
rv
);
}
}
.instrument(span)
);
Expand Down
70 changes: 48 additions & 22 deletions src/router.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! JSON-RPC router.

use crate::{
routes::{MakeErasedHandler, RouteFuture},
BoxedIntoRoute, ErasedIntoRoute, Handler, HandlerArgs, Method, MethodId, RegistrationError,
Route,
routes::{BatchFuture, MakeErasedHandler, RouteFuture},
types::InboundData,
BoxedIntoRoute, ErasedIntoRoute, Handler, HandlerArgs, HandlerCtx, Method, MethodId,
RegistrationError, Route,
};
use core::fmt;
use serde_json::value::RawValue;
Expand Down Expand Up @@ -193,7 +194,7 @@ where
where
T: Service<
HandlerArgs,
Response = Box<RawValue>,
Response = Option<Box<RawValue>>,
Error = Infallible,
Future: Send + 'static,
> + Clone
Expand Down Expand Up @@ -299,15 +300,35 @@ where
/// This is a convenience method, primarily for testing. Use in production
/// code is discouraged. Routers should not be left in incomplete states.
pub fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
let id = args.req.id_owned();
let method = args.req.method();
let id = args.req().id_owned();
let method = args.req().method();

let span = debug_span!("Router::call_with_state", %method, %id);
trace!(params = args.req.params());
let span = debug_span!("Router::call_with_state", %method, ?id);
trace!(params = args.req().params());

self.inner.call_with_state(args, state).with_span(span)
}

/// Call a method on the router, without providing state.
pub fn call_batch_with_state(
&self,
ctx: HandlerCtx,
inbound: InboundData,
state: S,
) -> BatchFuture {
let mut fut = BatchFuture::new_with_capacity(inbound.single(), inbound.len());
// According to spec, non-parsable requests should still receive a
// response.
for req in inbound.iter() {
let req = req.map(|req| {
let args = HandlerArgs::new(ctx.clone(), req);
self.call_with_state(args, state.clone())
});
fut.push_parse_result(req);
}
fut
}

/// Nest this router into a new Axum router, with the specified path.
#[cfg(feature = "axum")]
pub fn into_axum(self, path: &str) -> axum::Router<S> {
Expand All @@ -316,22 +337,27 @@ where
}

impl Router<()> {
// /// Serve the router over a connection. This method returns a
// /// [`ServerShutdown`], which will shut down the server when dropped.
// ///
// /// [`ServerShutdown`]: crate::pubsub::ServerShutdown
// #[cfg(feature = "pubsub")]
// pub async fn serve_pubsub<C: crate::pubsub::Connect>(
// self,
// connect: C,
// ) -> Result<crate::pubsub::ServerShutdown, C::Error> {
// connect.run(self).await
// }
/// Serve the router over a connection. This method returns a
/// [`ServerShutdown`], which will shut down the server when dropped.
///
/// [`ServerShutdown`]: crate::pubsub::ServerShutdown
#[cfg(feature = "pubsub")]
pub async fn serve_pubsub<C: crate::pubsub::Connect>(
self,
connect: C,
) -> Result<crate::pubsub::ServerShutdown, C::Error> {
connect.serve(self).await
}

/// Call a method on the router.
pub fn handle_request(&self, args: HandlerArgs) -> RouteFuture {
self.call_with_state(args, ())
}

/// Call a batch of methods on the router.
pub fn handle_request_batch(&self, ctx: HandlerCtx, batch: InboundData) -> BatchFuture {
self.call_batch_with_state(ctx, batch, ())
}
}

impl<S> fmt::Debug for Router<S> {
Expand All @@ -341,7 +367,7 @@ impl<S> fmt::Debug for Router<S> {
}

impl tower::Service<HandlerArgs> for Router<()> {
type Response = Box<RawValue>;
type Response = Option<Box<RawValue>>;
type Error = Infallible;
type Future = RouteFuture;

Expand All @@ -355,7 +381,7 @@ impl tower::Service<HandlerArgs> for Router<()> {
}

impl tower::Service<HandlerArgs> for &Router<()> {
type Response = Box<RawValue>;
type Response = Option<Box<RawValue>>;
type Error = Infallible;
type Future = RouteFuture;

Expand Down Expand Up @@ -517,7 +543,7 @@ impl<S> RouterInner<S> {
/// Call a method on the router, with the provided state.
#[track_caller]
pub(crate) fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
let method = args.req.method();
let method = args.req().method();
self.method_by_name(method)
.unwrap_or(&self.fallback)
.call_with_state(args, state)
Expand Down
21 changes: 19 additions & 2 deletions src/routes/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,24 @@ impl HandlerCtx {
#[derive(Debug, Clone)]
pub struct HandlerArgs {
/// The handler context.
pub ctx: HandlerCtx,
pub(crate) ctx: HandlerCtx,
/// The JSON-RPC request.
pub req: Request,
pub(crate) req: Request,
}

impl HandlerArgs {
/// Create new handler arguments.
pub const fn new(ctx: HandlerCtx, req: Request) -> Self {
Self { ctx, req }
}

/// Get a reference to the handler context.
pub const fn ctx(&self) -> &HandlerCtx {
&self.ctx
}

/// Get a reference to the JSON-RPC request.
pub const fn req(&self) -> &Request {
&self.req
}
}
Loading
Loading