diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8044b2f47..f0ae551dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - uses: sfackler/actions/rustfmt@master - + clippy: name: clippy runs-on: ubuntu-latest @@ -55,7 +55,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.64.0 + version: 1.77.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 6f1851fc2..3e33b08f0 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -174,7 +174,7 @@ fn ltree_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -182,7 +182,7 @@ fn ltree_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } #[test] @@ -202,7 +202,7 @@ fn lquery_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) + assert!(lquery_from_sql(query.as_slice()).is_ok()) } #[test] @@ -210,7 +210,7 @@ fn lquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) + assert!(lquery_from_sql(query.as_slice()).is_err()) } #[test] @@ -230,7 +230,7 @@ fn ltxtquery_str() { let mut query = vec![1u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -238,5 +238,5 @@ fn ltxtquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index 0ec92437d..b7f4f9a03 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -40,7 +40,7 @@ impl ToSql for NaiveDateTime { impl<'a> FromSql<'a> for DateTime { fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(DateTime::from_utc(naive, Utc)) + Ok(DateTime::from_naive_utc_and_offset(naive, Utc)) } accepts!(TIMESTAMPTZ); @@ -111,7 +111,7 @@ impl<'a> FromSql<'a> for NaiveDate { let jd = types::date_from_sql(raw)?; base() .date() - .checked_add_signed(Duration::days(i64::from(jd))) + .checked_add_signed(Duration::try_days(i64::from(jd)).unwrap()) .ok_or_else(|| "value too large to decode".into()) } diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 5914363c9..2705e3593 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -6,6 +6,7 @@ use crate::connection::Connection; use crate::Client; use log::info; use std::fmt; +use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -43,6 +44,19 @@ use tokio_postgres::{Error, Socket}; /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -74,6 +88,10 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -250,6 +268,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { self.config.host(host); self @@ -260,6 +279,11 @@ impl Config { self.config.get_hosts() } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.config.get_hostaddrs() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -272,6 +296,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.config.hostaddr(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index ed881d2a5..762caa9b0 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -59,6 +59,7 @@ serde = { version = "1.0", optional = true } socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } +rand = "0.8.5" [dev-dependencies] futures-executor = "0.3" @@ -78,7 +79,6 @@ eui48-04 = { version = "0.4", package = "eui48" } eui48-1 = { version = "1.0", package = "eui48" } geo-types-06 = { version = "0.6", package = "geo-types" } geo-types-07 = { version = "0.7", package = "geo-types" } -serde-1 = { version = "1.0", package = "serde" } serde_json-1 = { version = "1.0", package = "serde_json" } smol_str-01 = { version = "0.1", package = "smol_str" } uuid-08 = { version = "0.8", package = "uuid" } diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index d869b5824..078d4b8b6 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -1,5 +1,5 @@ use crate::client::SocketConfig; -use crate::config::{Host, SslMode}; +use crate::config::SslMode; use crate::tls::MakeTlsConnect; use crate::{cancel_query_raw, connect_socket, Error, Socket}; use std::io; @@ -24,18 +24,13 @@ where } }; - let hostname = match &config.host { - Host::Tcp(host) => &**host, - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", - }; let tls = tls - .make_tls_connect(hostname) + .make_tls_connect(config.hostname.as_deref().unwrap_or("")) .map_err(|e| Error::tls(e.into()))?; + let has_hostname = config.hostname.is_some(); let socket = connect_socket::connect_socket( - &config.host, + &config.addr, config.port, config.connect_timeout, config.tcp_user_timeout, @@ -43,5 +38,6 @@ where ) .await?; - cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key) + .await } diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index c89dc581f..41aafe7d9 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -9,6 +9,7 @@ pub async fn cancel_query_raw( stream: S, mode: SslMode, tls: T, + has_hostname: bool, process_id: i32, secret_key: i32, ) -> Result<(), Error> @@ -16,7 +17,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; + let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?; let mut buf = BytesMut::new(); frontend::cancel_request(process_id, secret_key, &mut buf); diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index d048a3c82..c925ce0ca 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -55,6 +55,7 @@ impl CancelToken { stream, self.ssl_mode, tls, + true, self.process_id, self.secret_key, ) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index a61605412..6b7067ee8 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,4 @@ use crate::codec::{BackendMessages, FrontendMessage}; -#[cfg(feature = "runtime")] -use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; use crate::copy_both::CopyBothDuplex; @@ -29,6 +27,10 @@ use postgres_protocol::message::{backend::Message, frontend}; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; +#[cfg(feature = "runtime")] +use std::net::IpAddr; +#[cfg(feature = "runtime")] +use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -155,13 +157,22 @@ impl InnerClient { #[cfg(feature = "runtime")] #[derive(Clone)] pub(crate) struct SocketConfig { - pub host: Host, + pub addr: Addr, + pub hostname: Option, pub port: u16, pub connect_timeout: Option, pub tcp_user_timeout: Option, pub keepalive: Option, } +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub(crate) enum Addr { + Tcp(IpAddr), + #[cfg(unix)] + Unix(PathBuf), +} + /// An asynchronous PostgreSQL client. /// /// The client is one half of what is returned when a connection is established. Users interact with the database diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 55e7595c3..e40ed3e07 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,6 +13,8 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::net::IpAddr; +use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] @@ -73,6 +75,16 @@ pub enum ReplicationMode { Logical, } +/// Load balancing configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum LoadBalanceHosts { + /// Make connection attempts to hosts in the order provided. + Disable, + /// Make connection attempts to hosts in a random order. + Random, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -114,6 +126,19 @@ pub enum Host { /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for TLS certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -136,6 +161,12 @@ pub enum Host { /// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel /// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. /// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -148,6 +179,10 @@ pub enum Host { /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -187,6 +222,7 @@ pub struct Config { pub(crate) ssl_mode: SslMode, pub(crate) ssl_root_cert: Option>, pub(crate) host: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) tcp_user_timeout: Option, @@ -195,7 +231,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) replication_mode: Option, - pub(crate) tls_verify_host: Option, + pub(crate) load_balance_hosts: LoadBalanceHosts, } impl Default for Config { @@ -223,6 +259,7 @@ impl Config { ssl_mode: SslMode::Prefer, ssl_root_cert: None, host: vec![], + hostaddr: vec![], port: vec![], connect_timeout: None, tcp_user_timeout: None, @@ -231,7 +268,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, replication_mode: None, - tls_verify_host: None, + load_balance_hosts: LoadBalanceHosts::Disable, } } @@ -358,6 +395,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { #[cfg(unix)] { @@ -375,22 +413,21 @@ impl Config { &self.host } - /// Gets a mutable view of the hosts that have been added to the configuration with `host`. + /// Gets a mutable view of the hosts that have been added to the + /// configuration with `host`. pub fn get_hosts_mut(&mut self) -> &mut [Host] { &mut self.host } - /// Sets the hostname used during TLS certificate verification, if enabled. - /// - /// This can be useful if you are connecting through an SSH tunnel. - pub fn tls_verify_host(&mut self, host: &str) -> &mut Config { - self.tls_verify_host = Some(host.to_string()); - self + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.hostaddr.deref() } - /// Gets the host that has been added to the configuration with `tls_verify_host`. - pub fn get_tls_verify_host(&self) -> Option<&str> { - self.tls_verify_host.as_deref() + /// Gets a mutable view of the hostaddrs that have been added to the + /// configuration with `hostaddr`. + pub fn get_hostaddrs_mut(&mut self) -> &mut [IpAddr] { + &mut self.hostaddr } /// Adds a Unix socket host to the configuration. @@ -405,6 +442,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.hostaddr.push(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -546,6 +592,19 @@ impl Config { self.replication_mode } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.load_balance_hosts = load_balance_hosts; + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.load_balance_hosts + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -563,7 +622,7 @@ impl Config { "application_name" => { self.application_name(value); } - "sslcert" => match std::fs::read(&value) { + "sslcert" => match std::fs::read(value) { Ok(contents) => { self.ssl_cert(&contents); } @@ -574,7 +633,7 @@ impl Config { "sslcert_inline" => { self.ssl_cert(value.as_bytes()); } - "sslkey" => match std::fs::read(&value) { + "sslkey" => match std::fs::read(value) { Ok(contents) => { self.ssl_key(&contents); } @@ -596,7 +655,7 @@ impl Config { }; self.ssl_mode(mode); } - "sslrootcert" => match std::fs::read(&value) { + "sslrootcert" => match std::fs::read(value) { Ok(contents) => { self.ssl_root_cert(&contents); } @@ -612,6 +671,14 @@ impl Config { self.host(host); } } + "hostaddr" => { + for hostaddr in value.split(',') { + let addr = hostaddr + .parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; + self.hostaddr(addr); + } + } "port" => { for port in value.split(',') { let port = if port.is_empty() { @@ -703,6 +770,18 @@ impl Config { self.replication_mode(mode); } } + "load_balance_hosts" => { + let load_balance_hosts = match value { + "disable" => LoadBalanceHosts::Disable, + "random" => LoadBalanceHosts::Random, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "load_balance_hosts", + )))) + } + }; + self.load_balance_hosts(load_balance_hosts); + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -736,7 +815,7 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, tls, self).await + connect_raw(stream, tls, true, self).await } } @@ -772,6 +851,7 @@ impl fmt::Debug for Config { .field("ssl_mode", &self.ssl_mode) .field("ssl_root_cert", &self.ssl_root_cert) .field("host", &self.host) + .field("hostaddr", &self.hostaddr) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("tcp_user_timeout", &self.tcp_user_timeout) @@ -1156,3 +1236,41 @@ impl<'a> UrlParser<'a> { .map_err(|e| Error::config_parse(e.into())) } } + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + + use crate::{config::Host, Config}; + + #[test] + fn test_simple_parsing() { + let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; + let config = s.parse::().unwrap(); + assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!(Some("postgres"), config.get_dbname()); + assert_eq!( + [ + Host::Tcp("host1".to_string()), + Host::Tcp("host2".to_string()) + ], + config.get_hosts(), + ); + + assert_eq!( + [ + "127.0.0.1".parse::().unwrap(), + "127.0.0.2".parse::().unwrap() + ], + config.get_hostaddrs(), + ); + + assert_eq!(1, 1); + } + + #[test] + fn test_invalid_hostaddr_parsing() { + let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; + s.parse::().err().unwrap(); + } +} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 1eca1ec6d..ca57b9cdd 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -1,12 +1,14 @@ -use crate::client::SocketConfig; -use crate::config::{Host, TargetSessionAttrs}; +use crate::client::{Addr, SocketConfig}; +use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; -use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::tls::MakeTlsConnect; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; -use std::io; +use rand::seq::SliceRandom; use std::task::Poll; +use std::{cmp, io}; +use tokio::net; pub async fn connect( mut tls: T, @@ -15,16 +17,40 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() { - return Err(Error::config("host missing".into())); + if config.host.is_empty() && config.hostaddr.is_empty() { + return Err(Error::config("both host and hostaddr are missing".into())); } - if config.port.len() > 1 && config.port.len() != config.host.len() { + if !config.host.is_empty() + && !config.hostaddr.is_empty() + && config.host.len() != config.hostaddr.len() + { + let msg = format!( + "number of hosts ({}) is different from number of hostaddrs ({})", + config.host.len(), + config.hostaddr.len(), + ); + return Err(Error::config(msg.into())); + } + + // At this point, either one of the following two scenarios could happen: + // (1) either config.host or config.hostaddr must be empty; + // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. + let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); + + if config.port.len() > 1 && config.port.len() != num_hosts { return Err(Error::config("invalid number of ports".into())); } + let mut indices = (0..num_hosts).collect::>(); + if config.load_balance_hosts == LoadBalanceHosts::Random { + indices.shuffle(&mut rand::thread_rng()); + } + let mut error = None; - for (i, host) in config.host.iter().enumerate() { + for i in indices { + let host = config.host.get(i); + let hostaddr = config.hostaddr.get(i); let port = config .port .get(i) @@ -32,19 +58,23 @@ where .copied() .unwrap_or(5432); - let hostname = match (config.tls_verify_host.as_deref(), host) { - (Some(tls_verify_host), Host::Tcp(_)) => tls_verify_host, - (None, Host::Tcp(host)) => host.as_str(), + // The value of host is used as the hostname for TLS validation, + let hostname = match host { + Some(Host::Tcp(host)) => Some(host.clone()), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] - (_, Host::Unix(_)) => "", + Some(Host::Unix(_)) => None, + None => None, }; - let tls = tls - .make_tls_connect(hostname) - .map_err(|e| Error::tls(e.into()))?; + // Try to use the value of hostaddr to establish the TCP connection, + // fallback to host if hostaddr is not present. + let addr = match hostaddr { + Some(ipaddr) => Host::Tcp(ipaddr.to_string()), + None => host.cloned().unwrap(), + }; - match connect_once(host, port, tls, config).await { + match connect_host(addr, hostname, port, &mut tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } @@ -53,17 +83,66 @@ where Err(error.unwrap()) } +async fn connect_host( + host: Host, + hostname: Option, + port: u16, + tls: &mut T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + match host { + Host::Tcp(host) => { + let mut addrs = net::lookup_host((&*host, port)) + .await + .map_err(Error::connect)? + .collect::>(); + + if config.load_balance_hosts == LoadBalanceHosts::Random { + addrs.shuffle(&mut rand::thread_rng()); + } + + let mut last_err = None; + for addr in addrs { + match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => { + last_err = Some(e); + continue; + } + }; + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + #[cfg(unix)] + Host::Unix(path) => { + connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await + } + } +} + async fn connect_once( - host: &Host, + addr: Addr, + hostname: Option<&str>, port: u16, - tls: T, + tls: &mut T, config: &Config, ) -> Result<(Client, Connection), Error> where - T: TlsConnect, + T: MakeTlsConnect, { let socket = connect_socket( - host, + &addr, port, config.connect_timeout, config.tcp_user_timeout, @@ -74,7 +153,12 @@ where }, ) .await?; - let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + + let tls = tls + .make_tls_connect(hostname.unwrap_or("")) + .map_err(|e| Error::tls(e.into()))?; + let has_hostname = hostname.is_some(); + let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); @@ -117,7 +201,8 @@ where } client.set_socket_config(SocketConfig { - host: host.clone(), + addr, + hostname: hostname.map(|s| s.to_string()), port, connect_timeout: config.connect_timeout, tcp_user_timeout: config.tcp_user_timeout, diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index f01b45607..1348828ba 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -81,13 +81,14 @@ where pub async fn connect_raw( stream: S, tls: T, + has_hostname: bool, config: &Config, ) -> Result<(Client, Connection), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let stream = connect_tls(stream, config.ssl_mode, tls).await?; + let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 9b3d31d72..67add04ea 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -1,69 +1,48 @@ -use crate::config::Host; +use crate::client::Addr; use crate::keepalive::KeepaliveConfig; use crate::{Error, Socket}; use socket2::{SockRef, TcpKeepalive}; use std::future::Future; use std::io; use std::time::Duration; +use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; -use tokio::net::{self, TcpStream}; use tokio::time; pub(crate) async fn connect_socket( - host: &Host, + addr: &Addr, port: u16, connect_timeout: Option, tcp_user_timeout: Option, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { - match host { - Host::Tcp(host) => { - let addrs = net::lookup_host((&**host, port)) - .await - .map_err(Error::connect)?; + match addr { + Addr::Tcp(ip) => { + let stream = + connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?; - let mut last_err = None; + stream.set_nodelay(true).map_err(Error::connect)?; - for addr in addrs { - let stream = - match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { - Ok(stream) => stream, - Err(e) => { - last_err = Some(e); - continue; - } - }; - - stream.set_nodelay(true).map_err(Error::connect)?; - - let sock_ref = SockRef::from(&stream); - #[cfg(target_os = "linux")] - { - sock_ref - .set_tcp_user_timeout(tcp_user_timeout) - .map_err(Error::connect)?; - } - - if let Some(keepalive_config) = keepalive_config { - sock_ref - .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) - .map_err(Error::connect)?; - } + let sock_ref = SockRef::from(&stream); + #[cfg(target_os = "linux")] + { + sock_ref + .set_tcp_user_timeout(tcp_user_timeout) + .map_err(Error::connect)?; + } - return Ok(Socket::new_tcp(stream)); + if let Some(keepalive_config) = keepalive_config { + sock_ref + .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) + .map_err(Error::connect)?; } - Err(last_err.unwrap_or_else(|| { - Error::connect(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })) + Ok(Socket::new_tcp(stream)) } #[cfg(unix)] - Host::Unix(path) => { - let path = path.join(format!(".s.PGSQL.{}", port)); + Addr::Unix(dir) => { + let path = dir.join(format!(".s.PGSQL.{}", port)); let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?; Ok(Socket::new_unix(socket)) } diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index 25e913d1f..41b319c2b 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -11,6 +11,7 @@ pub async fn connect_tls( mut stream: S, mode: SslMode, tls: T, + has_hostname: bool, ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, @@ -40,6 +41,10 @@ where } } + if !has_hostname { + return Err(Error::tls("no hostname provided for TLS handshake".into())); + } + let stream = tls .connect(stream) .await diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 8de2b75a2..cab185ae6 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -330,7 +330,7 @@ async fn simple_query() { } match &messages[2] { SimpleQueryMessage::Row(row) => { - assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); + assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); assert_eq!(row.get(0), Some("1")); assert_eq!(row.get(1), Some("steven")); @@ -339,7 +339,7 @@ async fn simple_query() { } match &messages[3] { SimpleQueryMessage::Row(row) => { - assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); + assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); assert_eq!(row.get(0), Some("2")); assert_eq!(row.get(1), Some("joe")); diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 67b4ead8a..86c1f0701 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -66,6 +66,58 @@ async fn target_session_attrs_err() { .unwrap(); } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn cancel_query() { let client = connect("host=localhost port=5433 user=postgres").await; diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs index a8e9e5afa..c325917aa 100644 --- a/tokio-postgres/tests/test/types/chrono_04.rs +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -1,4 +1,4 @@ -use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; +use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use std::fmt; use tokio_postgres::types::{Date, FromSqlOwned, Timestamp}; use tokio_postgres::Client; @@ -54,8 +54,9 @@ async fn test_date_time_params() { fn make_check(time: &str) -> (Option>, &str) { ( Some( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), + NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") + .unwrap() + .and_utc(), ), time, ) @@ -77,8 +78,9 @@ async fn test_with_special_date_time_params() { fn make_check(time: &str) -> (Timestamp>, &str) { ( Timestamp::Value( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), + NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") + .unwrap() + .and_utc(), ), time, )