diff --git a/Cargo.toml b/Cargo.toml index b98b56fa..fa25f55c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,10 @@ hyper = "0.14.19" futures-channel = "0.3" futures-util = { version = "0.3", default-features = false } http = "0.2" + +# Necessary to overcome msrv check of rust 1.49, as 1.15.0 failed +once_cell = "=1.14" + pin-project-lite = "0.2.4" socket2 = "0.4" tracing = { version = "0.1", default-features = false, features = ["std"] } diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index 97a0b340..9d5f3ee8 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -12,6 +12,7 @@ use std::time::Duration; use futures_util::future::Either; use http::uri::{Scheme, Uri}; use pin_project_lite::pin_project; +use socket2::TcpKeepalive; use tokio::net::{TcpSocket, TcpStream}; use tokio::time::Sleep; use tracing::{debug, trace, warn}; @@ -67,7 +68,7 @@ struct Config { connect_timeout: Option, enforce_http: bool, happy_eyeballs_timeout: Option, - keep_alive_timeout: Option, + tcp_keepalive_config: TcpKeepaliveConfig, local_address_ipv4: Option, local_address_ipv6: Option, nodelay: bool, @@ -76,6 +77,68 @@ struct Config { recv_buffer_size: Option, } +#[derive(Default, Debug, Clone, Copy)] +struct TcpKeepaliveConfig { + time: Option, + interval: Option, + retries: Option, +} + +impl TcpKeepaliveConfig { + /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration. + fn into_tcpkeepalive(self) -> Option { + let mut dirty = false; + let mut ka = TcpKeepalive::new(); + if let Some(time) = self.time { + ka = ka.with_time(time); + dirty = true + } + if let Some(interval) = self.interval { + ka = Self::ka_with_interval(ka, interval, &mut dirty) + }; + if let Some(retries) = self.retries { + ka = Self::ka_with_retries(ka, retries, &mut dirty) + }; + if dirty { + Some(ka) + } else { + None + } + } + + #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))] + fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_interval(interval) + } + + #[cfg(any(target_os = "openbsd", target_os = "redox", target_os = "solaris"))] + fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive interval is not supported on this platform + } + + #[cfg(not(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "windows" + )))] + fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_retries(retries) + } + + #[cfg(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "windows" + ))] + fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive retries is not supported on this platform + } +} + // ===== impl HttpConnector ===== impl HttpConnector { @@ -95,7 +158,7 @@ impl HttpConnector { connect_timeout: None, enforce_http: true, happy_eyeballs_timeout: Some(Duration::from_millis(300)), - keep_alive_timeout: None, + tcp_keepalive_config: TcpKeepaliveConfig::default(), local_address_ipv4: None, local_address_ipv6: None, nodelay: false, @@ -115,14 +178,28 @@ impl HttpConnector { self.config_mut().enforce_http = is_enforced; } - /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration. + /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration + /// to remain idle before sending TCP keepalive probes. /// - /// If `None`, the option will not be set. + /// If `None`, keepalive is disabled. /// /// Default is `None`. #[inline] - pub fn set_keepalive(&mut self, dur: Option) { - self.config_mut().keep_alive_timeout = dur; + pub fn set_keepalive(&mut self, time: Option) { + self.config_mut().tcp_keepalive_config.time = time; + } + + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + #[inline] + pub fn set_keepalive_interval(&mut self, interval: Option) { + self.config_mut().tcp_keepalive_config.interval = interval; + } + + /// Set the number of retransmissions to be carried out before declaring that remote end is not available. + #[inline] + pub fn set_keepalive_retries(&mut self, retries: Option) { + self.config_mut().tcp_keepalive_config.retries = retries; } /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`. @@ -577,7 +654,7 @@ fn connect( // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the // keepalive timeout, it would be nice to use that instead of socket2, // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance... - use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type}; + use socket2::{Domain, Protocol, Socket, Type}; use std::convert::TryInto; let domain = Domain::for_address(*addr); @@ -590,9 +667,8 @@ fn connect( .set_nonblocking(true) .map_err(ConnectError::m("tcp set_nonblocking error"))?; - if let Some(dur) = config.keep_alive_timeout { - let conf = TcpKeepalive::new().with_time(dur); - if let Err(e) = socket.set_tcp_keepalive(&conf) { + if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() { + if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) { warn!("tcp set_keepalive error: {}", e); } } @@ -701,6 +777,8 @@ mod tests { use ::http::Uri; + use crate::client::connect::http::TcpKeepaliveConfig; + use super::super::sealed::{Connect, ConnectSvc}; use super::{Config, ConnectError, HttpConnector}; @@ -920,7 +998,7 @@ mod tests { local_address_ipv4: None, local_address_ipv6: None, connect_timeout: None, - keep_alive_timeout: None, + tcp_keepalive_config: TcpKeepaliveConfig::default(), happy_eyeballs_timeout: Some(fallback_timeout), nodelay: false, reuse_address: false, @@ -989,4 +1067,51 @@ mod tests { (reachable, duration) } } + + use std::time::Duration; + + #[test] + fn no_tcp_keepalive_config() { + assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none()); + } + + #[test] + fn tcp_keepalive_time_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.time = Some(Duration::from_secs(60)); + if let Some(tcp_keepalive) = kac.into_tcpkeepalive() { + assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))] + #[test] + fn tcp_keepalive_interval_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.interval = Some(Duration::from_secs(1)); + if let Some(tcp_keepalive) = kac.into_tcpkeepalive() { + assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(not(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "windows" + )))] + #[test] + fn tcp_keepalive_retries_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.retries = Some(3); + if let Some(tcp_keepalive) = kac.into_tcpkeepalive() { + assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)")); + } else { + panic!("test failed"); + } + } }