diff --git a/src/config.rs b/src/config.rs index 893f5b70..e5c366c9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,7 +4,7 @@ use log::{error, info}; use once_cell::sync::Lazy; use serde_derive::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::path::Path; use std::sync::Arc; use tokio::fs::File; @@ -122,7 +122,7 @@ impl Address { } /// PostgreSQL user. -#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Serialize, Deserialize, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)] pub struct User { pub username: String, pub password: String, @@ -232,7 +232,7 @@ impl Default for General { /// Pool mode: /// - transaction: server serves one transaction, /// - session: server is attached to the client. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Copy)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)] pub enum PoolMode { #[serde(alias = "transaction", alias = "Transaction")] Transaction, @@ -250,7 +250,7 @@ impl ToString for PoolMode { } } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct Pool { #[serde(default = "Pool::default_pool_mode")] pub pool_mode: PoolMode, @@ -263,11 +263,35 @@ pub struct Pool { #[serde(default)] // False pub primary_reads_enabled: bool, + #[serde(default = "General::default_connect_timeout")] + pub connect_timeout: u64, + pub sharding_function: String, pub shards: HashMap, pub users: HashMap, } +impl Hash for Pool { + fn hash(&self, state: &mut H) { + self.pool_mode.hash(state); + self.default_role.hash(state); + self.query_parser_enabled.hash(state); + self.primary_reads_enabled.hash(state); + self.sharding_function.hash(state); + self.connect_timeout.hash(state); + + for (key, value) in &self.shards { + key.hash(state); + value.hash(state); + } + + for (key, value) in &self.users { + key.hash(state); + value.hash(state); + } + } +} + impl Pool { fn default_pool_mode() -> PoolMode { PoolMode::Transaction @@ -284,6 +308,7 @@ impl Default for Pool { query_parser_enabled: false, primary_reads_enabled: false, sharding_function: "pg_bigint_hash".to_string(), + connect_timeout: General::default_connect_timeout(), } } } @@ -296,7 +321,7 @@ pub struct ServerConfig { } /// Shard configuration. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash, Eq)] pub struct Shard { pub database: String, pub servers: Vec, @@ -575,7 +600,10 @@ pub async fn parse(path: &str) -> Result<(), Error> { None => (), }; - for (pool_name, pool) in &config.pools { + for (pool_name, mut pool) in &mut config.pools { + // Copy the connect timeout over for hashing. + pool.connect_timeout = config.general.connect_timeout; + match pool.sharding_function.as_ref() { "pg_bigint_hash" => (), "sha1" => (), @@ -666,7 +694,7 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result; /// This is atomic and safe and read-optimized. /// The pool is recreated dynamically when the config is reloaded. pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); +static POOLS_HASH: Lazy>> = + Lazy::new(|| ArcSwap::from_pointee(HashSet::default())); /// Pool settings. #[derive(Clone, Debug)] @@ -101,9 +103,23 @@ impl ConnectionPool { let mut new_pools = HashMap::new(); let mut address_id = 0; + let mut pools_hash = (*(*POOLS_HASH.load())).clone(); + for (pool_name, pool_config) in &config.pools { + let changed = pools_hash.insert(pool_config.clone()); + + if !changed { + info!("[db: {}] has not changed", pool_name); + continue; + } + // There is one pool per database/user pair. for (_, user) in &pool_config.users { + info!( + "[pool: {}][user: {}] creating new pool", + pool_name, user.username + ); + let mut shards = Vec::new(); let mut addresses = Vec::new(); let mut banlist = Vec::new(); @@ -156,7 +172,7 @@ impl ConnectionPool { let pool = Pool::builder() .max_size(user.pool_size) .connection_timeout(std::time::Duration::from_millis( - config.general.connect_timeout, + pool_config.connect_timeout, )) .test_on_check_out(false) .build(manager) @@ -217,6 +233,7 @@ impl ConnectionPool { } POOLS.store(Arc::new(new_pools.clone())); + POOLS_HASH.store(Arc::new(pools_hash.clone())); Ok(()) }