diff --git a/src/query/service/src/servers/http/v1/query/execute_state.rs b/src/query/service/src/servers/http/v1/query/execute_state.rs index 6d0e89d9deb50..10b76a0b13f93 100644 --- a/src/query/service/src/servers/http/v1/query/execute_state.rs +++ b/src/query/service/src/servers/http/v1/query/execute_state.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::sync::Arc; use std::time::SystemTime; @@ -22,6 +23,7 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::DataBlock; use databend_common_expression::DataSchemaRef; +use databend_common_expression::Scalar; use databend_common_io::prelude::FormatSettings; use databend_common_settings::Settings; use databend_storages_common_txn::TxnManagerRef; @@ -147,6 +149,7 @@ pub struct ExecutorSessionState { pub secondary_roles: Option>, pub settings: Arc, pub txn_manager: TxnManagerRef, + pub variables: HashMap, } impl ExecutorSessionState { @@ -157,6 +160,7 @@ impl ExecutorSessionState { secondary_roles: session.get_secondary_roles(), settings: session.get_settings(), txn_manager: session.txn_mgr(), + variables: session.get_all_variables(), } } } diff --git a/src/query/service/src/servers/http/v1/query/http_query.rs b/src/query/service/src/servers/http/v1/query/http_query.rs index 270be4039101b..b81b12bdd829a 100644 --- a/src/query/service/src/servers/http/v1/query/http_query.rs +++ b/src/query/service/src/servers/http/v1/query/http_query.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::BTreeMap; +use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; @@ -29,6 +30,7 @@ use databend_common_base::runtime::TrySpawn; use databend_common_catalog::table_context::StageAttachment; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::Scalar; use databend_common_io::prelude::FormatSettings; use databend_common_metrics::http::metrics_incr_http_response_errors_count; use databend_common_settings::ScopeLevel; @@ -39,7 +41,9 @@ use log::warn; use poem::web::Json; use poem::IntoResponse; use serde::Deserialize; +use serde::Deserializer; use serde::Serialize; +use serde::Serializer; use super::HttpQueryContext; use super::RemoveReason; @@ -181,6 +185,75 @@ pub struct ServerInfo { pub start_time: String, } +#[derive(Deserialize, Serialize, Debug, Default, Clone, Eq, PartialEq)] +pub struct HttpSessionStateInternal { + /// value is JSON of Scalar + variables: Vec<(String, String)>, +} + +impl HttpSessionStateInternal { + fn new(variables: &HashMap) -> Self { + let variables = variables + .iter() + .map(|(k, v)| { + ( + k.clone(), + serde_json::to_string(&v).expect("fail to serialize Scalar"), + ) + }) + .collect(); + Self { variables } + } + + pub fn get_variables(&self) -> Result> { + let mut vars = HashMap::with_capacity(self.variables.len()); + for (k, v) in self.variables.iter() { + match serde_json::from_str::(v) { + Ok(s) => { + vars.insert(k.to_string(), s); + } + Err(e) => { + return Err(ErrorCode::BadBytes(format!( + "fail decode scalar from string '{v}', error: {e}" + ))); + } + } + } + Ok(vars) + } +} + +fn serialize_as_json_string( + value: &Option, + serializer: S, +) -> Result +where + S: Serializer, +{ + match value { + Some(complex_value) => { + let json_string = + serde_json::to_string(complex_value).map_err(serde::ser::Error::custom)?; + serializer.serialize_some(&json_string) + } + None => serializer.serialize_none(), + } +} + +fn deserialize_from_json_string<'de, D>( + deserializer: D, +) -> Result, D::Error> +where D: Deserializer<'de> { + let json_string: Option = Option::deserialize(deserializer)?; + match json_string { + Some(s) => { + let complex_value = serde_json::from_str(&s).map_err(serde::de::Error::custom)?; + Ok(Some(complex_value)) + } + None => Ok(None), + } +} + #[derive(Deserialize, Serialize, Debug, Default, Clone, Eq, PartialEq)] pub struct HttpSessionConf { #[serde(skip_serializing_if = "Option::is_none")] @@ -189,6 +262,7 @@ pub struct HttpSessionConf { pub role: Option, #[serde(skip_serializing_if = "Option::is_none")] pub secondary_roles: Option>, + // todo: remove this later #[serde(skip_serializing_if = "Option::is_none")] pub keep_server_session_secs: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -198,9 +272,19 @@ pub struct HttpSessionConf { // used to check if the session is still on the same server #[serde(skip_serializing_if = "Option::is_none")] pub last_server_info: Option, - // last_query_ids[0] is the last query id, last_query_ids[1] is the second last query id, etc. + /// last_query_ids[0] is the last query id, last_query_ids[1] is the second last query id, etc. #[serde(default)] pub last_query_ids: Vec, + /// hide state not useful to clients + /// so client only need to know there is a String field `internal`, + /// which need to carry with session/conn + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde( + serialize_with = "serialize_as_json_string", + deserialize_with = "deserialize_from_json_string" + )] + pub internal: Option, } impl HttpSessionConf {} @@ -360,6 +444,11 @@ impl HttpQuery { })?; } } + if let Some(state) = &session_conf.internal { + if !state.variables.is_empty() { + session.set_all_variables(state.get_variables()?) + } + } try_set_txn(&ctx.query_id, &session, session_conf, &http_query_manager)?; }; @@ -548,6 +637,11 @@ impl HttpQuery { let role = session_state.current_role.clone(); let secondary_roles = session_state.secondary_roles.clone(); let txn_state = session_state.txn_manager.lock().state(); + let internal = if !session_state.variables.is_empty() { + Some(HttpSessionStateInternal::new(&session_state.variables)) + } else { + None + }; if txn_state != TxnState::AutoCommit && !self.is_txn_mgr_saved.load(Ordering::Relaxed) && matches!(executor.state, ExecuteState::Stopped(_)) @@ -573,6 +667,7 @@ impl HttpQuery { txn_state: Some(txn_state), last_server_info: Some(HttpQueryManager::instance().server_info.clone()), last_query_ids: vec![self.id.clone()], + internal, } } diff --git a/src/query/service/src/sessions/session.rs b/src/query/service/src/sessions/session.rs index 348ea0c822fdb..c0fa9ac2e928f 100644 --- a/src/query/service/src/sessions/session.rs +++ b/src/query/service/src/sessions/session.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; @@ -20,6 +21,7 @@ use databend_common_catalog::cluster_info::Cluster; use databend_common_config::GlobalConfig; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::Scalar; use databend_common_io::prelude::FormatSettings; use databend_common_meta_app::principal::GrantObject; use databend_common_meta_app::principal::OwnershipObject; @@ -352,6 +354,14 @@ impl Session { Some(x) => x.get_query_profiles(), } } + + pub fn get_all_variables(&self) -> HashMap { + self.session_ctx.get_all_variables() + } + + pub fn set_all_variables(&self, variables: HashMap) { + self.session_ctx.set_all_variables(variables) + } } impl Drop for Session { diff --git a/src/query/service/src/sessions/session_ctx.rs b/src/query/service/src/sessions/session_ctx.rs index 3b089339c47db..c33694ca8d8b8 100644 --- a/src/query/service/src/sessions/session_ctx.rs +++ b/src/query/service/src/sessions/session_ctx.rs @@ -316,4 +316,10 @@ impl SessionContext { pub fn get_variable(&self, key: &str) -> Option { self.variables.read().get(key).cloned() } + pub fn get_all_variables(&self) -> HashMap { + self.variables.read().clone() + } + pub fn set_all_variables(&self, variables: HashMap) { + *self.variables.write() = variables + } } diff --git a/src/query/service/tests/it/servers/http/http_query_handlers.rs b/src/query/service/tests/it/servers/http/http_query_handlers.rs index 6ef9a024b32ba..c77f469dacc1e 100644 --- a/src/query/service/tests/it/servers/http/http_query_handlers.rs +++ b/src/query/service/tests/it/servers/http/http_query_handlers.rs @@ -1393,6 +1393,7 @@ async fn test_affect() -> Result<()> { txn_state: Some(TxnState::AutoCommit), last_server_info: None, last_query_ids: vec![], + internal: None, }), ), ( @@ -1415,6 +1416,7 @@ async fn test_affect() -> Result<()> { txn_state: Some(TxnState::AutoCommit), last_server_info: None, last_query_ids: vec![], + internal: None, }), ), ( @@ -1432,6 +1434,7 @@ async fn test_affect() -> Result<()> { txn_state: Some(TxnState::AutoCommit), last_server_info: None, last_query_ids: vec![], + internal: None, }), ), ( @@ -1451,6 +1454,7 @@ async fn test_affect() -> Result<()> { txn_state: Some(TxnState::AutoCommit), last_server_info: None, last_query_ids: vec![], + internal: None, }), ), ( @@ -1472,6 +1476,7 @@ async fn test_affect() -> Result<()> { txn_state: Some(TxnState::AutoCommit), last_server_info: None, last_query_ids: vec![], + internal: None, }), ), ]; diff --git a/tests/sqllogictests/src/util.rs b/tests/sqllogictests/src/util.rs index d72a6406f77c9..6c7eceb9de79e 100644 --- a/tests/sqllogictests/src/util.rs +++ b/tests/sqllogictests/src/util.rs @@ -42,6 +42,7 @@ pub struct HttpSessionConf { pub last_server_info: Option, #[serde(default)] pub last_query_ids: Vec, + pub internal: Option, } pub fn parser_rows(rows: &Value) -> Result>> { diff --git a/tests/sqllogictests/suites/query/set.test b/tests/sqllogictests/suites/query/set.test index 610000e7d8a53..5d7e0bdc0cc3d 100644 --- a/tests/sqllogictests/suites/query/set.test +++ b/tests/sqllogictests/suites/query/set.test @@ -7,11 +7,10 @@ select value, default = value from system.settings where name in ('max_threads' 4 0 56 0 -onlyif mysql statement ok set variable (a, b) = (select 3, 55) -onlyif mysql + statement ok SET GLOBAL (max_threads, storage_io_min_bytes_for_seek) = select $a + 1, $b + 1; @@ -30,25 +29,20 @@ select default = value from system.settings where name in ('max_threads', 'stor 1 1 -onlyif mysql statement ok set variable a = 1; -onlyif mysql statement ok set variable (b, c) = ('yy', 'zz'); -onlyif mysql query ITT select $a + getvariable('a') + $a, getvariable('b'), getvariable('c'), getvariable('d') ---- 3 yy zz NULL -onlyif mysql statement ok unset variable (a, b) -onlyif mysql query ITT select getvariable('a'), getvariable('b'), 'xx' || 'yy' || getvariable('c') , getvariable('d') ----