diff --git a/Cargo.lock b/Cargo.lock index b966614a..4068354f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1683,7 +1683,7 @@ dependencies = [ "iroh-quinn-proto", "iroh-quinn-udp", "iroh-relay", - "n0-future", + "n0-future 0.1.3", "n0-snafu", "n0-watcher", "nested_enum_utils", @@ -1758,7 +1758,7 @@ dependencies = [ "iroh-quinn", "iroh-test", "irpc", - "n0-future", + "n0-future 0.2.0", "n0-snafu", "nested_enum_utils", "postcard", @@ -1900,7 +1900,7 @@ dependencies = [ "iroh-quinn", "iroh-quinn-proto", "lru", - "n0-future", + "n0-future 0.1.3", "n0-snafu", "nested_enum_utils", "num_enum", @@ -1951,7 +1951,7 @@ dependencies = [ "futures-util", "iroh-quinn", "irpc-derive", - "n0-future", + "n0-future 0.1.3", "postcard", "rcgen", "rustls", @@ -2173,6 +2173,27 @@ dependencies = [ "web-time", ] +[[package]] +name = "n0-future" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d7dd42bd0114c9daa9c4f2255d692a73bba45767ec32cf62892af6fe5d31f6" +dependencies = [ + "cfg_aliases", + "derive_more 1.0.0", + "futures-buffered", + "futures-lite", + "futures-util", + "js-sys", + "pin-project", + "send_wrapper", + "tokio", + "tokio-util", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-time", +] + [[package]] name = "n0-snafu" version = "0.2.1" @@ -2193,7 +2214,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c31462392a10d5ada4b945e840cbec2d5f3fee752b96c4b33eb41414d8f45c2a" dependencies = [ "derive_more 1.0.0", - "n0-future", + "n0-future 0.1.3", "snafu", ] @@ -2319,7 +2340,7 @@ dependencies = [ "iroh-quinn-udp", "js-sys", "libc", - "n0-future", + "n0-future 0.1.3", "n0-watcher", "nested_enum_utils", "netdev", @@ -3596,9 +3617,9 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" [[package]] name = "smallvec" diff --git a/Cargo.toml b/Cargo.toml index 3f9f47a9..bcd5f42d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ bytes = { version = "1", features = ["serde"] } derive_more = { version = "2.0.1", features = ["from", "try_from", "into", "debug", "display", "deref", "deref_mut"] } futures-lite = "2.6.0" quinn = { package = "iroh-quinn", version = "0.14.0" } -n0-future = "0.1.2" +n0-future = "0.2.0" n0-snafu = "0.2.0" range-collections = { version = "0.4.6", features = ["serde"] } redb = { version = "=2.4" } diff --git a/src/api/blobs/reader.rs b/src/api/blobs/reader.rs index e15e374d..9e337dae 100644 --- a/src/api/blobs/reader.rs +++ b/src/api/blobs/reader.rs @@ -221,6 +221,7 @@ mod tests { use super::*; use crate::{ + protocol::ChunkRangesExt, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -228,7 +229,6 @@ mod tests { }, mem::MemStore, }, - util::ChunkRangesExt, }; async fn reader_smoke(blobs: &Blobs) -> TestResult<()> { diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 678a8c6a..a2abbd7e 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -4,26 +4,26 @@ use std::{ fmt::Debug, future::{Future, IntoFuture}, io, - ops::Deref, sync::Arc, - time::{Duration, SystemTime}, }; use anyhow::bail; use genawaiter::sync::Gen; -use iroh::{endpoint::Connection, Endpoint, NodeId}; +use iroh::{Endpoint, NodeId}; use irpc::{channel::mpsc, rpc_requests}; use n0_future::{future, stream, BufferedStreamExt, Stream, StreamExt}; use rand::seq::SliceRandom; use serde::{de::Error, Deserialize, Serialize}; -use tokio::{sync::Mutex, task::JoinSet}; -use tokio_util::time::FutureExt; -use tracing::{info, instrument::Instrument, warn}; +use tokio::task::JoinSet; +use tracing::instrument::Instrument; -use super::{remote::GetConnection, Store}; +use super::Store; use crate::{ protocol::{GetManyRequest, GetRequest}, - util::sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink}, + util::{ + connection_pool::ConnectionPool, + sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink}, + }, BlobFormat, Hash, HashAndFormat, }; @@ -69,7 +69,7 @@ impl DownloaderActor { fn new(store: Store, endpoint: Endpoint) -> Self { Self { store, - pool: ConnectionPool::new(endpoint, crate::ALPN.to_vec()), + pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()), tasks: JoinSet::new(), running: HashSet::new(), } @@ -414,90 +414,6 @@ async fn split_request<'a>( }) } -#[derive(Debug)] -struct ConnectionPoolInner { - alpn: Vec, - endpoint: Endpoint, - connections: Mutex>>>, - retry_delay: Duration, - connect_timeout: Duration, -} - -#[derive(Debug, Clone)] -struct ConnectionPool(Arc); - -#[derive(Debug, Default)] -enum SlotState { - #[default] - Initial, - Connected(Connection), - AttemptFailed(SystemTime), - #[allow(dead_code)] - Evil(String), -} - -impl ConnectionPool { - fn new(endpoint: Endpoint, alpn: Vec) -> Self { - Self( - ConnectionPoolInner { - endpoint, - alpn, - connections: Default::default(), - retry_delay: Duration::from_secs(5), - connect_timeout: Duration::from_secs(2), - } - .into(), - ) - } - - pub fn alpn(&self) -> &[u8] { - &self.0.alpn - } - - pub fn endpoint(&self) -> &Endpoint { - &self.0.endpoint - } - - pub fn retry_delay(&self) -> Duration { - self.0.retry_delay - } - - fn dial(&self, id: NodeId) -> DialNode { - DialNode { - pool: self.clone(), - id, - } - } - - #[allow(dead_code)] - async fn mark_evil(&self, id: NodeId, reason: String) { - let slot = self - .0 - .connections - .lock() - .await - .entry(id) - .or_default() - .clone(); - let mut t = slot.lock().await; - *t = SlotState::Evil(reason) - } - - #[allow(dead_code)] - async fn mark_closed(&self, id: NodeId) { - let slot = self - .0 - .connections - .lock() - .await - .entry(id) - .or_default() - .clone(); - let mut t = slot.lock().await; - *t = SlotState::Initial - } -} - /// Execute a get request sequentially for multiple providers. /// /// It will try each provider in order @@ -526,13 +442,13 @@ async fn execute_get( request: request.clone(), }) .await?; - let mut conn = pool.dial(provider); + let conn = pool.get_or_connect(provider); let local = remote.local_for_request(request.clone()).await?; if local.is_complete() { return Ok(()); } let local_bytes = local.local_bytes(); - let Ok(conn) = conn.connection().await else { + let Ok(conn) = conn.await else { progress .send(DownloadProgessItem::ProviderFailed { id: provider, @@ -543,7 +459,7 @@ async fn execute_get( }; match remote .execute_get_sink( - conn, + &conn, local.missing(), (&mut progress).with_map(move |x| DownloadProgessItem::Progress(x + local_bytes)), ) @@ -571,77 +487,6 @@ async fn execute_get( bail!("Unable to download {}", request.hash); } -#[derive(Debug, Clone)] -struct DialNode { - pool: ConnectionPool, - id: NodeId, -} - -impl DialNode { - async fn connection_impl(&self) -> anyhow::Result { - info!("Getting connection for node {}", self.id); - let slot = self - .pool - .0 - .connections - .lock() - .await - .entry(self.id) - .or_default() - .clone(); - info!("Dialing node {}", self.id); - let mut guard = slot.lock().await; - match guard.deref() { - SlotState::Connected(conn) => { - return Ok(conn.clone()); - } - SlotState::AttemptFailed(time) => { - let elapsed = time.elapsed().unwrap_or_default(); - if elapsed <= self.pool.retry_delay() { - bail!( - "Connection attempt failed {} seconds ago", - elapsed.as_secs_f64() - ); - } - } - SlotState::Evil(reason) => { - bail!("Node is banned due to evil behavior: {reason}"); - } - SlotState::Initial => {} - } - let res = self - .pool - .endpoint() - .connect(self.id, self.pool.alpn()) - .timeout(self.pool.0.connect_timeout) - .await; - match res { - Ok(Ok(conn)) => { - info!("Connected to node {}", self.id); - *guard = SlotState::Connected(conn.clone()); - Ok(conn) - } - Ok(Err(e)) => { - warn!("Failed to connect to node {}: {}", self.id, e); - *guard = SlotState::AttemptFailed(SystemTime::now()); - Err(e.into()) - } - Err(e) => { - warn!("Failed to connect to node {}: {}", self.id, e); - *guard = SlotState::AttemptFailed(SystemTime::now()); - bail!("Failed to connect to node: {}", e); - } - } - } -} - -impl GetConnection for DialNode { - fn connection(&mut self) -> impl Future> + '_ { - let this = self.clone(); - async move { this.connection_impl().await } - } -} - /// Trait for pluggable content discovery strategies. pub trait ContentDiscovery: Debug + Send + Sync + 'static { fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed; diff --git a/src/api/remote.rs b/src/api/remote.rs index 47c3eea2..62320090 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -518,7 +518,7 @@ impl Remote { .connection() .await .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; - let stats = self.execute_get_sink(conn, request, progress).await?; + let stats = self.execute_get_sink(&conn, request, progress).await?; Ok(stats) } @@ -637,7 +637,7 @@ impl Remote { .with_map_err(io::Error::other); let this = self.clone(); let fut = async move { - let res = this.execute_get_sink(conn, request, sink).await.into(); + let res = this.execute_get_sink(&conn, request, sink).await.into(); tx2.send(res).await.ok(); }; GetProgress { @@ -656,13 +656,15 @@ impl Remote { /// This will return the stats of the download. pub(crate) async fn execute_get_sink( &self, - conn: Connection, + conn: &Connection, request: GetRequest, mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; - let start = crate::get::fsm::start(conn, request, Default::default()); + // I am cloning the connection, but it's fine because the original connection or ConnectionRef stays alive + // for the duration of the operation. + let start = crate::get::fsm::start(conn.clone(), request, Default::default()); let connected = start.next().await?; trace!("Getting header"); // read the header @@ -1065,7 +1067,7 @@ mod tests { use crate::{ api::blobs::Blobs, - protocol::{ChunkRangesSeq, GetRequest}, + protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -1074,7 +1076,6 @@ mod tests { mem::MemStore, }, tests::{add_test_hash_seq, add_test_hash_seq_incomplete}, - util::ChunkRangesExt, }; #[tokio::test] diff --git a/src/get/request.rs b/src/get/request.rs index 86ffcabb..98563057 100644 --- a/src/get/request.rs +++ b/src/get/request.rs @@ -27,8 +27,7 @@ use super::{fsm, GetError, GetResult, Stats}; use crate::{ get::error::{BadRequestSnafu, LocalFailureSnafu}, hashseq::HashSeq, - protocol::{ChunkRangesSeq, GetRequest}, - util::ChunkRangesExt, + protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest}, Hash, HashAndFormat, }; diff --git a/src/lib.rs b/src/lib.rs index ed4f7850..521ba4f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,7 @@ pub mod ticket; #[doc(hidden)] pub mod test; -mod util; +pub mod util; #[cfg(test)] mod tests; diff --git a/src/protocol.rs b/src/protocol.rs index 85043199..74e0f986 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -373,13 +373,18 @@ //! a large existing system that has demonstrated performance issues. //! //! If in doubt, just use multiple requests and multiple connections. -use std::io; +use std::{ + io, + ops::{Bound, RangeBounds}, +}; +use bao_tree::{io::round_up_to_chunks, ChunkNum}; use builder::GetRequestBuilder; use derive_more::From; use iroh::endpoint::VarInt; use irpc::util::AsyncReadVarintExt; use postcard::experimental::max_size::MaxSize; +use range_collections::{range_set::RangeSetEntry, RangeSet2}; use serde::{Deserialize, Serialize}; mod range_spec; pub use bao_tree::ChunkRanges; @@ -387,7 +392,6 @@ pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; use tokio::io::AsyncReadExt; -pub use crate::util::ChunkRangesExt; use crate::{api::blobs::Bitfield, provider::CountingReader, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. @@ -714,6 +718,73 @@ impl TryFrom for Closed { } } +pub trait ChunkRangesExt { + fn last_chunk() -> Self; + fn chunk(offset: u64) -> Self; + fn bytes(ranges: impl RangeBounds) -> Self; + fn chunks(ranges: impl RangeBounds) -> Self; + fn offset(offset: u64) -> Self; +} + +impl ChunkRangesExt for ChunkRanges { + fn last_chunk() -> Self { + ChunkRanges::from(ChunkNum(u64::MAX)..) + } + + /// Create a chunk range that contains a single chunk. + fn chunk(offset: u64) -> Self { + ChunkRanges::from(ChunkNum(offset)..ChunkNum(offset + 1)) + } + + /// Create a range of chunks that contains the given byte ranges. + /// The byte ranges are rounded up to the nearest chunk size. + fn bytes(ranges: impl RangeBounds) -> Self { + round_up_to_chunks(&bounds_from_range(ranges, |v| v)) + } + + /// Create a range of chunks from u64 chunk bounds. + /// + /// This is equivalent but more convenient than using the ChunkNum newtype. + fn chunks(ranges: impl RangeBounds) -> Self { + bounds_from_range(ranges, ChunkNum) + } + + /// Create a chunk range that contains a single byte offset. + fn offset(offset: u64) -> Self { + Self::bytes(offset..offset + 1) + } +} + +// todo: move to range_collections +pub(crate) fn bounds_from_range(range: R, f: F) -> RangeSet2 +where + R: RangeBounds, + T: RangeSetEntry, + F: Fn(u64) -> T, +{ + let from = match range.start_bound() { + Bound::Included(start) => Some(*start), + Bound::Excluded(start) => { + let Some(start) = start.checked_add(1) else { + return RangeSet2::empty(); + }; + Some(start) + } + Bound::Unbounded => None, + }; + let to = match range.end_bound() { + Bound::Included(end) => end.checked_add(1), + Bound::Excluded(end) => Some(*end), + Bound::Unbounded => None, + }; + match (from, to) { + (Some(from), Some(to)) => RangeSet2::from(f(from)..f(to)), + (Some(from), None) => RangeSet2::from(f(from)..), + (None, Some(to)) => RangeSet2::from(..f(to)), + (None, None) => RangeSet2::all(), + } +} + pub mod builder { use std::collections::BTreeMap; @@ -863,7 +934,7 @@ pub mod builder { use bao_tree::ChunkNum; use super::*; - use crate::{protocol::GetManyRequest, util::ChunkRangesExt}; + use crate::protocol::{ChunkRangesExt, GetManyRequest}; #[test] fn chunk_ranges_ext() { diff --git a/src/protocol/range_spec.rs b/src/protocol/range_spec.rs index 92cfe938..546dbe70 100644 --- a/src/protocol/range_spec.rs +++ b/src/protocol/range_spec.rs @@ -12,7 +12,7 @@ use bao_tree::{ChunkNum, ChunkRangesRef}; use serde::{Deserialize, Serialize}; use smallvec::{smallvec, SmallVec}; -pub use crate::util::ChunkRangesExt; +use crate::protocol::ChunkRangesExt; static CHUNK_RANGES_EMPTY: OnceLock = OnceLock::new(); @@ -511,7 +511,7 @@ mod tests { use proptest::prelude::*; use super::*; - use crate::util::ChunkRangesExt; + use crate::protocol::ChunkRangesExt; fn ranges(value_range: Range) -> impl Strategy { prop::collection::vec((value_range.clone(), value_range), 0..16).prop_map(|v| { diff --git a/src/store/fs.rs b/src/store/fs.rs index 9e11e098..e8a87ad6 100644 --- a/src/store/fs.rs +++ b/src/store/fs.rs @@ -111,6 +111,7 @@ use crate::{ }, ApiClient, }, + protocol::ChunkRangesExt, store::{ fs::{ bao_file::{ @@ -125,7 +126,6 @@ use crate::{ util::{ channel::oneshot, temp_tag::{TagDrop, TempTag, TempTagScope, TempTags}, - ChunkRangesExt, }, }; mod bao_file; diff --git a/src/store/mem.rs b/src/store/mem.rs index 6d022e0f..8a2a227b 100644 --- a/src/store/mem.rs +++ b/src/store/mem.rs @@ -56,14 +56,12 @@ use crate::{ tags::TagInfo, ApiClient, }, + protocol::ChunkRangesExt, store::{ util::{SizeInfo, SparseMemFile, Tag}, HashAndFormat, IROH_BLOCK_SIZE, }, - util::{ - temp_tag::{TagDrop, TempTagScope, TempTags}, - ChunkRangesExt, - }, + util::temp_tag::{TagDrop, TempTagScope, TempTags}, BlobFormat, Hash, }; diff --git a/src/store/readonly_mem.rs b/src/store/readonly_mem.rs index 42274b2e..0d9b1936 100644 --- a/src/store/readonly_mem.rs +++ b/src/store/readonly_mem.rs @@ -41,8 +41,8 @@ use crate::{ }, ApiClient, TempTag, }, + protocol::ChunkRangesExt, store::{mem::CompleteStorage, IROH_BLOCK_SIZE}, - util::ChunkRangesExt, Hash, }; diff --git a/src/util.rs b/src/util.rs index 7b9ad4e6..3fdaacbc 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,11 +1,9 @@ -use std::ops::{Bound, RangeBounds}; - -use bao_tree::{io::round_up_to_chunks, ChunkNum, ChunkRanges}; -use range_collections::{range_set::RangeSetEntry, RangeSet2}; - -pub mod channel; +//! Utilities +pub(crate) mod channel; +pub mod connection_pool; pub(crate) mod temp_tag; -pub mod serde { + +pub(crate) mod serde { // Module that handles io::Error serialization/deserialization pub mod io_error_serde { use std::{fmt, io}; @@ -216,74 +214,7 @@ pub mod serde { } } -pub trait ChunkRangesExt { - fn last_chunk() -> Self; - fn chunk(offset: u64) -> Self; - fn bytes(ranges: impl RangeBounds) -> Self; - fn chunks(ranges: impl RangeBounds) -> Self; - fn offset(offset: u64) -> Self; -} - -impl ChunkRangesExt for ChunkRanges { - fn last_chunk() -> Self { - ChunkRanges::from(ChunkNum(u64::MAX)..) - } - - /// Create a chunk range that contains a single chunk. - fn chunk(offset: u64) -> Self { - ChunkRanges::from(ChunkNum(offset)..ChunkNum(offset + 1)) - } - - /// Create a range of chunks that contains the given byte ranges. - /// The byte ranges are rounded up to the nearest chunk size. - fn bytes(ranges: impl RangeBounds) -> Self { - round_up_to_chunks(&bounds_from_range(ranges, |v| v)) - } - - /// Create a range of chunks from u64 chunk bounds. - /// - /// This is equivalent but more convenient than using the ChunkNum newtype. - fn chunks(ranges: impl RangeBounds) -> Self { - bounds_from_range(ranges, ChunkNum) - } - - /// Create a chunk range that contains a single byte offset. - fn offset(offset: u64) -> Self { - Self::bytes(offset..offset + 1) - } -} - -// todo: move to range_collections -pub(crate) fn bounds_from_range(range: R, f: F) -> RangeSet2 -where - R: RangeBounds, - T: RangeSetEntry, - F: Fn(u64) -> T, -{ - let from = match range.start_bound() { - Bound::Included(start) => Some(*start), - Bound::Excluded(start) => { - let Some(start) = start.checked_add(1) else { - return RangeSet2::empty(); - }; - Some(start) - } - Bound::Unbounded => None, - }; - let to = match range.end_bound() { - Bound::Included(end) => end.checked_add(1), - Bound::Excluded(end) => Some(*end), - Bound::Unbounded => None, - }; - match (from, to) { - (Some(from), Some(to)) => RangeSet2::from(f(from)..f(to)), - (Some(from), None) => RangeSet2::from(f(from)..), - (None, Some(to)) => RangeSet2::from(..f(to)), - (None, None) => RangeSet2::all(), - } -} - -pub mod outboard_with_progress { +pub(crate) mod outboard_with_progress { use std::io::{self, BufReader, Read}; use bao_tree::{ @@ -431,7 +362,7 @@ pub mod outboard_with_progress { } } -pub mod sink { +pub(crate) mod sink { use std::{future::Future, io}; use irpc::RpcMessage; diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs new file mode 100644 index 00000000..7b283866 --- /dev/null +++ b/src/util/connection_pool.rs @@ -0,0 +1,460 @@ +//! A simple iroh connection pool +//! +//! Entry point is [`ConnectionPool`]. You create a connection pool for a specific +//! ALPN and [`Options`]. Then the pool will manage connections for you. +//! +//! Access to connections is via the [`ConnectionPool::get_or_connect`] method, which +//! gives you access to a connection via a [`ConnectionRef`] if possible. +//! +//! It is important that you keep the [`ConnectionRef`] alive while you are using +//! the connection. +use std::{ + collections::{HashMap, VecDeque}, + ops::Deref, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +use iroh::{endpoint::ConnectError, Endpoint, NodeId}; +use n0_future::{ + future::{self}, + FuturesUnordered, MaybeFuture, Stream, StreamExt, +}; +use snafu::Snafu; +use tokio::sync::{ + mpsc::{self, error::SendError as TokioSendError}, + oneshot, Notify, +}; +use tokio_util::time::FutureExt as TimeFutureExt; +use tracing::{debug, error, trace}; + +/// Configuration options for the connection pool +#[derive(Debug, Clone, Copy)] +pub struct Options { + pub idle_timeout: Duration, + pub connect_timeout: Duration, + pub max_connections: usize, +} + +impl Default for Options { + fn default() -> Self { + Self { + idle_timeout: Duration::from_secs(5), + connect_timeout: Duration::from_secs(1), + max_connections: 1024, + } + } +} + +/// A reference to a connection that is owned by a connection pool. +#[derive(Debug)] +pub struct ConnectionRef { + connection: iroh::endpoint::Connection, + _permit: OneConnection, +} + +impl Deref for ConnectionRef { + type Target = iroh::endpoint::Connection; + + fn deref(&self) -> &Self::Target { + &self.connection + } +} + +impl ConnectionRef { + fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self { + Self { + connection, + _permit: counter, + } + } +} + +/// Error when a connection can not be acquired +/// +/// This includes the normal iroh connection errors as well as pool specific +/// errors such as timeouts and connection limits. +#[derive(Debug, Clone, Snafu)] +#[snafu(module)] +pub enum PoolConnectError { + /// Connection pool is shut down + Shutdown, + /// Timeout during connect + Timeout, + /// Too many connections + TooManyConnections, + /// Error during connect + ConnectError { source: Arc }, +} + +impl From for PoolConnectError { + fn from(e: ConnectError) -> Self { + PoolConnectError::ConnectError { + source: Arc::new(e), + } + } +} + +/// Error when calling a fn on the [`ConnectionPool`]. +/// +/// The only thing that can go wrong is that the connection pool is shut down. +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum ConnectionPoolError { + /// The connection pool has been shut down + Shutdown, +} + +enum ActorMessage { + RequestRef(RequestRef), + ConnectionIdle { id: NodeId }, + ConnectionShutdown { id: NodeId }, +} + +struct RequestRef { + id: NodeId, + tx: oneshot::Sender>, +} + +struct Context { + options: Options, + endpoint: Endpoint, + owner: ConnectionPool, + alpn: Vec, +} + +impl Context { + async fn run_connection_actor( + self: Arc, + node_id: NodeId, + mut rx: mpsc::Receiver, + ) { + let context = self; + + // Connect to the node + let state = context + .endpoint + .connect(node_id, &context.alpn) + .timeout(context.options.connect_timeout) + .await + .map_err(|_| PoolConnectError::Timeout) + .and_then(|r| r.map_err(PoolConnectError::from)); + if let Err(e) = &state { + debug!(%node_id, "Failed to connect {e:?}, requesting shutdown"); + if context.owner.close(node_id).await.is_err() { + return; + } + } + let counter = ConnectionCounter::new(); + let idle_timer = MaybeFuture::default(); + let idle_stream = counter.clone().idle_stream(); + + tokio::pin!(idle_timer, idle_stream); + + loop { + tokio::select! { + biased; + + // Handle new work + handler = rx.recv() => { + match handler { + Some(RequestRef { id, tx }) => { + assert!(id == node_id, "Not for me!"); + match &state { + Ok(state) => { + let res = ConnectionRef::new(state.clone(), counter.get_one()); + + // clear the idle timer + idle_timer.as_mut().set_none(); + tx.send(Ok(res)).ok(); + } + Err(cause) => { + tx.send(Err(cause.clone())).ok(); + } + } + } + None => { + // Channel closed - finish remaining tasks and exit + break; + } + } + } + + _ = idle_stream.next() => { + if !counter.is_idle() { + continue; + }; + // notify the pool that we are idle. + trace!(%node_id, "Idle"); + if context.owner.idle(node_id).await.is_err() { + // If we can't notify the pool, we are shutting down + break; + } + // set the idle timer + idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout)); + } + + // Idle timeout - request shutdown + _ = &mut idle_timer => { + trace!(%node_id, "Idle timer expired, requesting shutdown"); + context.owner.close(node_id).await.ok(); + // Don't break here - wait for main actor to close our channel + } + } + } + + if let Ok(connection) = state { + let reason = if counter.is_idle() { b"idle" } else { b"drop" }; + connection.close(0u32.into(), reason); + } + + trace!(%node_id, "Connection actor shutting down"); + } +} + +struct Actor { + rx: mpsc::Receiver, + connections: HashMap>, + context: Arc, + // idle set (most recent last) + // todo: use a better data structure if this becomes a performance issue + idle: VecDeque, + // per connection tasks + tasks: FuturesUnordered>, +} + +impl Actor { + pub fn new( + endpoint: Endpoint, + alpn: &[u8], + options: Options, + ) -> (Self, mpsc::Sender) { + let (tx, rx) = mpsc::channel(100); + ( + Self { + rx, + connections: HashMap::new(), + idle: VecDeque::new(), + context: Arc::new(Context { + options, + alpn: alpn.to_vec(), + endpoint, + owner: ConnectionPool { tx: tx.clone() }, + }), + tasks: FuturesUnordered::new(), + }, + tx, + ) + } + + fn add_idle(&mut self, id: NodeId) { + self.remove_idle(id); + self.idle.push_back(id); + } + + fn remove_idle(&mut self, id: NodeId) { + self.idle.retain(|&x| x != id); + } + + fn pop_oldest_idle(&mut self) -> Option { + self.idle.pop_front() + } + + fn remove_connection(&mut self, id: NodeId) { + self.connections.remove(&id); + self.remove_idle(id); + } + + async fn handle_msg(&mut self, msg: ActorMessage) { + match msg { + ActorMessage::RequestRef(mut msg) => { + let id = msg.id; + self.remove_idle(id); + // Try to send to existing connection actor + if let Some(conn_tx) = self.connections.get(&id) { + if let Err(TokioSendError(e)) = conn_tx.send(msg).await { + msg = e; + } else { + return; + } + // Connection actor died, remove it + self.remove_connection(id); + } + + // No connection actor or it died - check limits + if self.connections.len() >= self.context.options.max_connections { + if let Some(idle) = self.pop_oldest_idle() { + // remove the oldest idle connection to make room for one more + trace!("removing oldest idle connection {}", idle); + self.connections.remove(&idle); + } else { + msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok(); + return; + } + } + let (conn_tx, conn_rx) = mpsc::channel(100); + self.connections.insert(id, conn_tx.clone()); + + let context = self.context.clone(); + + self.tasks + .push(Box::pin(context.run_connection_actor(id, conn_rx))); + + // Send the handler to the new actor + if conn_tx.send(msg).await.is_err() { + error!(%id, "Failed to send handler to new connection actor"); + self.connections.remove(&id); + } + } + ActorMessage::ConnectionIdle { id } => { + self.add_idle(id); + trace!(%id, "connection idle"); + } + ActorMessage::ConnectionShutdown { id } => { + // Remove the connection from our map - this closes the channel + self.remove_connection(id); + trace!(%id, "removed connection"); + } + } + } + + pub async fn run(mut self) { + loop { + tokio::select! { + biased; + + msg = self.rx.recv() => { + if let Some(msg) = msg { + self.handle_msg(msg).await; + } else { + break; + } + } + + _ = self.tasks.next(), if !self.tasks.is_empty() => {} + } + } + } +} + +/// A connection pool +#[derive(Debug, Clone)] +pub struct ConnectionPool { + tx: mpsc::Sender, +} + +impl ConnectionPool { + pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self { + let (actor, tx) = Actor::new(endpoint, alpn, options); + + // Spawn the main actor + tokio::spawn(actor.run()); + + Self { tx } + } + + /// Returns either a fresh connection or a reference to an existing one. + /// + /// This is guaranteed to return after approximately [Options::connect_timeout] + /// with either an error or a connection. + pub async fn get_or_connect( + &self, + id: NodeId, + ) -> std::result::Result { + let (tx, rx) = oneshot::channel(); + self.tx + .send(ActorMessage::RequestRef(RequestRef { id, tx })) + .await + .map_err(|_| PoolConnectError::Shutdown)?; + rx.await.map_err(|_| PoolConnectError::Shutdown)? + } + + /// Close an existing connection, if it exists + /// + /// This will finish pending tasks and close the connection. New tasks will + /// get a new connection if they are submitted after this call + pub async fn close(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> { + self.tx + .send(ActorMessage::ConnectionShutdown { id }) + .await + .map_err(|_| ConnectionPoolError::Shutdown)?; + Ok(()) + } + + /// Notify the connection pool that a connection is idle. + /// + /// Should only be called from connection handlers. + pub(crate) async fn idle(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> { + self.tx + .send(ActorMessage::ConnectionIdle { id }) + .await + .map_err(|_| ConnectionPoolError::Shutdown)?; + Ok(()) + } +} + +#[derive(Debug)] +struct ConnectionCounterInner { + count: AtomicUsize, + notify: Notify, +} + +#[derive(Debug, Clone)] +struct ConnectionCounter { + inner: Arc, +} + +impl ConnectionCounter { + fn new() -> Self { + Self { + inner: Arc::new(ConnectionCounterInner { + count: Default::default(), + notify: Notify::new(), + }), + } + } + + /// Increase the connection count and return a guard for the new connection + fn get_one(&self) -> OneConnection { + self.inner.count.fetch_add(1, Ordering::SeqCst); + OneConnection { + inner: self.inner.clone(), + } + } + + fn is_idle(&self) -> bool { + self.inner.count.load(Ordering::SeqCst) == 0 + } + + /// Infinite stream that yields when the connection is briefly idle. + /// + /// Note that you still have to check if the connection is still idle when + /// you get the notification. + /// + /// Also note that this stream is triggered on [OneConnection::drop], so it + /// won't trigger initially even though a [ConnectionCounter] starts up as + /// idle. + fn idle_stream(self) -> impl Stream { + n0_future::stream::unfold(self, |c| async move { + c.inner.notify.notified().await; + Some(((), c)) + }) + } +} + +/// Guard for one connection +#[derive(Debug)] +struct OneConnection { + inner: Arc, +} + +impl Drop for OneConnection { + fn drop(&mut self) { + if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 { + self.inner.notify.notify_waiters(); + } + } +}