diff --git a/apps/fortuna/Cargo.toml b/apps/fortuna/Cargo.toml index df4181d330..c90d26376a 100644 --- a/apps/fortuna/Cargo.toml +++ b/apps/fortuna/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fortuna" -version = "9.0.0" +version = "9.1.0" edition = "2021" [lib] diff --git a/apps/fortuna/src/api.rs b/apps/fortuna/src/api.rs index cc39d416de..9c138cba00 100644 --- a/apps/fortuna/src/api.rs +++ b/apps/fortuna/src/api.rs @@ -39,21 +39,12 @@ mod revelation; pub type ChainId = String; pub type NetworkId = u64; -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, utoipa::ToSchema, sqlx::Type)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, utoipa::ToSchema)] pub enum StateTag { Pending, - Completed, Failed, -} - -impl std::fmt::Display for StateTag { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - StateTag::Pending => write!(f, "Pending"), - StateTag::Completed => write!(f, "Completed"), - StateTag::Failed => write!(f, "Failed"), - } - } + Completed, + CallbackErrored, } #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] diff --git a/apps/fortuna/src/history.rs b/apps/fortuna/src/history.rs index 3b4b2e5218..85c36bc97b 100644 --- a/apps/fortuna/src/history.rs +++ b/apps/fortuna/src/history.rs @@ -533,9 +533,15 @@ impl<'a> RequestQueryBuilder<'a> { sql.push_str(&format!(" AND network_id = ${param_count}")); } - if self.state.is_some() { + if let Some(state) = &self.state { param_count += 1; sql.push_str(&format!(" AND state = ${param_count}")); + + if *state == StateTag::Completed { + sql.push_str(" AND NOT callback_failed"); + } else if *state == StateTag::CallbackErrored { + sql.push_str(" AND callback_failed"); + } } sql.push_str(" ORDER BY created_at DESC"); @@ -570,7 +576,11 @@ impl<'a> RequestQueryBuilder<'a> { } if let Some(state) = &self.state { - query = query.bind(state.to_string()); + query = query.bind(match state { + StateTag::Pending => "Pending", + StateTag::Failed => "Failed", + StateTag::Completed | StateTag::CallbackErrored => "Completed", + }) } query = query.bind(self.limit).bind(self.offset); @@ -612,9 +622,15 @@ impl<'a> RequestQueryBuilder<'a> { sql.push_str(&format!(" AND network_id = ${param_count}")); } - if self.state.is_some() { + if let Some(state) = &self.state { param_count += 1; sql.push_str(&format!(" AND state = ${param_count}")); + + if *state == StateTag::Completed { + sql.push_str(" AND NOT callback_failed"); + } else if *state == StateTag::CallbackErrored { + sql.push_str(" AND callback_failed"); + } } // Now bind all parameters in order @@ -642,7 +658,11 @@ impl<'a> RequestQueryBuilder<'a> { } if let Some(state) = &self.state { - query = query.bind(state.to_string()); + query = query.bind(match state { + StateTag::Pending => "Pending", + StateTag::Failed => "Failed", + StateTag::Completed | StateTag::CallbackErrored => "Completed", + }) } query.fetch_one(self.pool).await.map_err(|err| err.into()) @@ -1088,6 +1108,90 @@ mod test { } } + #[tokio::test] + async fn test_history_state_filter() { + let history = History::new_in_memory().await.unwrap(); + let reveal_tx_hash = TxHash::random(); + + let pending_status = get_random_request_status(); + History::update_request_status(&history.pool, pending_status.clone()).await; + + let mut failed_status = get_random_request_status(); + History::update_request_status(&history.pool, failed_status.clone()).await; + failed_status.state = RequestEntryState::Failed { + reason: "Failed".to_string(), + provider_random_number: None, + }; + History::update_request_status(&history.pool, failed_status.clone()).await; + + let mut completed_status = get_random_request_status(); + History::update_request_status(&history.pool, completed_status.clone()).await; + completed_status.state = RequestEntryState::Completed { + reveal_block_number: 1, + reveal_tx_hash, + provider_random_number: [40; 32], + gas_used: U256::from(567890), + combined_random_number: RequestStatus::generate_combined_random_number( + &completed_status.user_random_number, + &[40; 32], + ), + callback_failed: false, + callback_return_value: Default::default(), + callback_gas_used: 100_000, + }; + History::update_request_status(&history.pool, completed_status.clone()).await; + + let reveal_tx_hash = TxHash::random(); + let mut callback_errored_status = get_random_request_status(); + History::update_request_status(&history.pool, callback_errored_status.clone()).await; + callback_errored_status.state = RequestEntryState::Completed { + reveal_block_number: 1, + reveal_tx_hash, + provider_random_number: [40; 32], + gas_used: U256::from(567890), + combined_random_number: RequestStatus::generate_combined_random_number( + &callback_errored_status.user_random_number, + &[40; 32], + ), + callback_failed: true, + callback_return_value: Default::default(), + callback_gas_used: 100_000, + }; + History::update_request_status(&history.pool, callback_errored_status.clone()).await; + + let logs = history + .query() + .state(StateTag::Pending) + .execute() + .await + .unwrap(); + assert_eq!(logs, vec![pending_status.clone()]); + + let logs = history + .query() + .state(StateTag::Failed) + .execute() + .await + .unwrap(); + assert_eq!(logs, vec![failed_status.clone()]); + + let logs = history + .query() + .state(StateTag::Completed) + .execute() + .await + .unwrap(); + assert_eq!(logs, vec![completed_status.clone()]); + + let logs = history + .query() + .state(StateTag::CallbackErrored) + .execute() + .await + .unwrap(); + assert_eq!(logs, vec![callback_errored_status.clone()]); + } + #[tokio::test(flavor = "multi_thread")] async fn test_writer_thread() { let history = History::new_in_memory().await.unwrap();