Skip to content

Commit f5b5ab4

Browse files
authored
refactor: Replace connection pool (#138)
## Description Replaces the somewhat hackish connection pool with the one from n0-computer/iroh-experiments#36 that was battle tested more. ## Breaking Changes None ## Notes & open questions Q: Expose the conn pool here? Note: There is a nice list of possible extensions, but I think it is probably best to first get the basic version in. Extensions would be `async fn ban(node_id: NodeId, duration: Option<Duration>)`, bans the node for a time, or for as long as the conn pool lives if duration is set to None. A way to observe pool stats so users know when to schedule new downloads without having to try. ## Change checklist - [ ] Self-review. - [ ] Documentation updates following the [style guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text), if relevant. - [ ] Tests if relevant. - [ ] All breaking changes documented.
1 parent 7925931 commit f5b5ab4

File tree

14 files changed

+597
-271
lines changed

14 files changed

+597
-271
lines changed

Cargo.lock

Lines changed: 27 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ bytes = { version = "1", features = ["serde"] }
1818
derive_more = { version = "2.0.1", features = ["from", "try_from", "into", "debug", "display", "deref", "deref_mut"] }
1919
futures-lite = "2.6.0"
2020
quinn = { package = "iroh-quinn", version = "0.14.0" }
21-
n0-future = "0.1.2"
21+
n0-future = "0.2.0"
2222
n0-snafu = "0.2.0"
2323
range-collections = { version = "0.4.6", features = ["serde"] }
2424
redb = { version = "=2.4" }

src/api/blobs/reader.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,14 @@ mod tests {
221221

222222
use super::*;
223223
use crate::{
224+
protocol::ChunkRangesExt,
224225
store::{
225226
fs::{
226227
tests::{create_n0_bao, test_data, INTERESTING_SIZES},
227228
FsStore,
228229
},
229230
mem::MemStore,
230231
},
231-
util::ChunkRangesExt,
232232
};
233233

234234
async fn reader_smoke(blobs: &Blobs) -> TestResult<()> {

src/api/downloader.rs

Lines changed: 12 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@ use std::{
44
fmt::Debug,
55
future::{Future, IntoFuture},
66
io,
7-
ops::Deref,
87
sync::Arc,
9-
time::{Duration, SystemTime},
108
};
119

1210
use anyhow::bail;
1311
use genawaiter::sync::Gen;
14-
use iroh::{endpoint::Connection, Endpoint, NodeId};
12+
use iroh::{Endpoint, NodeId};
1513
use irpc::{channel::mpsc, rpc_requests};
1614
use n0_future::{future, stream, BufferedStreamExt, Stream, StreamExt};
1715
use rand::seq::SliceRandom;
1816
use serde::{de::Error, Deserialize, Serialize};
19-
use tokio::{sync::Mutex, task::JoinSet};
20-
use tokio_util::time::FutureExt;
21-
use tracing::{info, instrument::Instrument, warn};
17+
use tokio::task::JoinSet;
18+
use tracing::instrument::Instrument;
2219

23-
use super::{remote::GetConnection, Store};
20+
use super::Store;
2421
use crate::{
2522
protocol::{GetManyRequest, GetRequest},
26-
util::sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
23+
util::{
24+
connection_pool::ConnectionPool,
25+
sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
26+
},
2727
BlobFormat, Hash, HashAndFormat,
2828
};
2929

@@ -69,7 +69,7 @@ impl DownloaderActor {
6969
fn new(store: Store, endpoint: Endpoint) -> Self {
7070
Self {
7171
store,
72-
pool: ConnectionPool::new(endpoint, crate::ALPN.to_vec()),
72+
pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()),
7373
tasks: JoinSet::new(),
7474
running: HashSet::new(),
7575
}
@@ -414,90 +414,6 @@ async fn split_request<'a>(
414414
})
415415
}
416416

417-
#[derive(Debug)]
418-
struct ConnectionPoolInner {
419-
alpn: Vec<u8>,
420-
endpoint: Endpoint,
421-
connections: Mutex<HashMap<NodeId, Arc<Mutex<SlotState>>>>,
422-
retry_delay: Duration,
423-
connect_timeout: Duration,
424-
}
425-
426-
#[derive(Debug, Clone)]
427-
struct ConnectionPool(Arc<ConnectionPoolInner>);
428-
429-
#[derive(Debug, Default)]
430-
enum SlotState {
431-
#[default]
432-
Initial,
433-
Connected(Connection),
434-
AttemptFailed(SystemTime),
435-
#[allow(dead_code)]
436-
Evil(String),
437-
}
438-
439-
impl ConnectionPool {
440-
fn new(endpoint: Endpoint, alpn: Vec<u8>) -> Self {
441-
Self(
442-
ConnectionPoolInner {
443-
endpoint,
444-
alpn,
445-
connections: Default::default(),
446-
retry_delay: Duration::from_secs(5),
447-
connect_timeout: Duration::from_secs(2),
448-
}
449-
.into(),
450-
)
451-
}
452-
453-
pub fn alpn(&self) -> &[u8] {
454-
&self.0.alpn
455-
}
456-
457-
pub fn endpoint(&self) -> &Endpoint {
458-
&self.0.endpoint
459-
}
460-
461-
pub fn retry_delay(&self) -> Duration {
462-
self.0.retry_delay
463-
}
464-
465-
fn dial(&self, id: NodeId) -> DialNode {
466-
DialNode {
467-
pool: self.clone(),
468-
id,
469-
}
470-
}
471-
472-
#[allow(dead_code)]
473-
async fn mark_evil(&self, id: NodeId, reason: String) {
474-
let slot = self
475-
.0
476-
.connections
477-
.lock()
478-
.await
479-
.entry(id)
480-
.or_default()
481-
.clone();
482-
let mut t = slot.lock().await;
483-
*t = SlotState::Evil(reason)
484-
}
485-
486-
#[allow(dead_code)]
487-
async fn mark_closed(&self, id: NodeId) {
488-
let slot = self
489-
.0
490-
.connections
491-
.lock()
492-
.await
493-
.entry(id)
494-
.or_default()
495-
.clone();
496-
let mut t = slot.lock().await;
497-
*t = SlotState::Initial
498-
}
499-
}
500-
501417
/// Execute a get request sequentially for multiple providers.
502418
///
503419
/// It will try each provider in order
@@ -526,13 +442,13 @@ async fn execute_get(
526442
request: request.clone(),
527443
})
528444
.await?;
529-
let mut conn = pool.dial(provider);
445+
let conn = pool.get_or_connect(provider);
530446
let local = remote.local_for_request(request.clone()).await?;
531447
if local.is_complete() {
532448
return Ok(());
533449
}
534450
let local_bytes = local.local_bytes();
535-
let Ok(conn) = conn.connection().await else {
451+
let Ok(conn) = conn.await else {
536452
progress
537453
.send(DownloadProgessItem::ProviderFailed {
538454
id: provider,
@@ -543,7 +459,7 @@ async fn execute_get(
543459
};
544460
match remote
545461
.execute_get_sink(
546-
conn,
462+
&conn,
547463
local.missing(),
548464
(&mut progress).with_map(move |x| DownloadProgessItem::Progress(x + local_bytes)),
549465
)
@@ -571,77 +487,6 @@ async fn execute_get(
571487
bail!("Unable to download {}", request.hash);
572488
}
573489

574-
#[derive(Debug, Clone)]
575-
struct DialNode {
576-
pool: ConnectionPool,
577-
id: NodeId,
578-
}
579-
580-
impl DialNode {
581-
async fn connection_impl(&self) -> anyhow::Result<Connection> {
582-
info!("Getting connection for node {}", self.id);
583-
let slot = self
584-
.pool
585-
.0
586-
.connections
587-
.lock()
588-
.await
589-
.entry(self.id)
590-
.or_default()
591-
.clone();
592-
info!("Dialing node {}", self.id);
593-
let mut guard = slot.lock().await;
594-
match guard.deref() {
595-
SlotState::Connected(conn) => {
596-
return Ok(conn.clone());
597-
}
598-
SlotState::AttemptFailed(time) => {
599-
let elapsed = time.elapsed().unwrap_or_default();
600-
if elapsed <= self.pool.retry_delay() {
601-
bail!(
602-
"Connection attempt failed {} seconds ago",
603-
elapsed.as_secs_f64()
604-
);
605-
}
606-
}
607-
SlotState::Evil(reason) => {
608-
bail!("Node is banned due to evil behavior: {reason}");
609-
}
610-
SlotState::Initial => {}
611-
}
612-
let res = self
613-
.pool
614-
.endpoint()
615-
.connect(self.id, self.pool.alpn())
616-
.timeout(self.pool.0.connect_timeout)
617-
.await;
618-
match res {
619-
Ok(Ok(conn)) => {
620-
info!("Connected to node {}", self.id);
621-
*guard = SlotState::Connected(conn.clone());
622-
Ok(conn)
623-
}
624-
Ok(Err(e)) => {
625-
warn!("Failed to connect to node {}: {}", self.id, e);
626-
*guard = SlotState::AttemptFailed(SystemTime::now());
627-
Err(e.into())
628-
}
629-
Err(e) => {
630-
warn!("Failed to connect to node {}: {}", self.id, e);
631-
*guard = SlotState::AttemptFailed(SystemTime::now());
632-
bail!("Failed to connect to node: {}", e);
633-
}
634-
}
635-
}
636-
}
637-
638-
impl GetConnection for DialNode {
639-
fn connection(&mut self) -> impl Future<Output = Result<Connection, anyhow::Error>> + '_ {
640-
let this = self.clone();
641-
async move { this.connection_impl().await }
642-
}
643-
}
644-
645490
/// Trait for pluggable content discovery strategies.
646491
pub trait ContentDiscovery: Debug + Send + Sync + 'static {
647492
fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<NodeId>;

src/api/remote.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ impl Remote {
518518
.connection()
519519
.await
520520
.map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
521-
let stats = self.execute_get_sink(conn, request, progress).await?;
521+
let stats = self.execute_get_sink(&conn, request, progress).await?;
522522
Ok(stats)
523523
}
524524

@@ -637,7 +637,7 @@ impl Remote {
637637
.with_map_err(io::Error::other);
638638
let this = self.clone();
639639
let fut = async move {
640-
let res = this.execute_get_sink(conn, request, sink).await.into();
640+
let res = this.execute_get_sink(&conn, request, sink).await.into();
641641
tx2.send(res).await.ok();
642642
};
643643
GetProgress {
@@ -656,13 +656,15 @@ impl Remote {
656656
/// This will return the stats of the download.
657657
pub(crate) async fn execute_get_sink(
658658
&self,
659-
conn: Connection,
659+
conn: &Connection,
660660
request: GetRequest,
661661
mut progress: impl Sink<u64, Error = io::Error>,
662662
) -> GetResult<Stats> {
663663
let store = self.store();
664664
let root = request.hash;
665-
let start = crate::get::fsm::start(conn, request, Default::default());
665+
// I am cloning the connection, but it's fine because the original connection or ConnectionRef stays alive
666+
// for the duration of the operation.
667+
let start = crate::get::fsm::start(conn.clone(), request, Default::default());
666668
let connected = start.next().await?;
667669
trace!("Getting header");
668670
// read the header
@@ -1065,7 +1067,7 @@ mod tests {
10651067

10661068
use crate::{
10671069
api::blobs::Blobs,
1068-
protocol::{ChunkRangesSeq, GetRequest},
1070+
protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
10691071
store::{
10701072
fs::{
10711073
tests::{create_n0_bao, test_data, INTERESTING_SIZES},
@@ -1074,7 +1076,6 @@ mod tests {
10741076
mem::MemStore,
10751077
},
10761078
tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1077-
util::ChunkRangesExt,
10781079
};
10791080

10801081
#[tokio::test]

0 commit comments

Comments
 (0)