diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index 926fdb547..220d0e941 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -40,6 +40,7 @@ impl AllowedOptions for Accumulator { const DB: bool = false; const CYCLE_FN: bool = false; const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; const LRU: bool = false; const CONSTRUCTOR_NAME: bool = false; const ID: bool = false; diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 0e5f92c7c..38241a959 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -55,6 +55,8 @@ impl crate::options::AllowedOptions for InputStruct { const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; + const LRU: bool = false; const CONSTRUCTOR_NAME: bool = true; diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 7067d3010..646470fce 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -55,6 +55,8 @@ impl crate::options::AllowedOptions for InternedStruct { const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; + const LRU: bool = false; const CONSTRUCTOR_NAME: bool = true; diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index ded09017e..e69e70a12 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -61,6 +61,11 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub cycle_initial: Option, + /// The `cycle_result = ` option is the result for non-fixpoint cycle. + /// + /// If this is `Some`, the value is the ``. + pub cycle_result: Option, + /// The `data = ` option is used to define the name of the data type for an interned /// struct. /// @@ -100,6 +105,7 @@ impl Default for Options { db_path: Default::default(), cycle_fn: Default::default(), cycle_initial: Default::default(), + cycle_result: Default::default(), data: Default::default(), constructor_name: Default::default(), phantom: Default::default(), @@ -123,6 +129,7 @@ pub(crate) trait AllowedOptions { const DB: bool; const CYCLE_FN: bool; const CYCLE_INITIAL: bool; + const CYCLE_RESULT: bool; const LRU: bool; const CONSTRUCTOR_NAME: bool; const ID: bool; @@ -274,6 +281,22 @@ impl syn::parse::Parse for Options { "`cycle_initial` option not allowed here", )); } + } else if ident == "cycle_result" { + if A::CYCLE_RESULT { + let _eq = Equals::parse(input)?; + let expr = syn::Expr::parse(input)?; + if let Some(old) = options.cycle_result.replace(expr) { + return Err(syn::Error::new( + old.span(), + "option `cycle_result` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`cycle_result` option not allowed here", + )); + } } else if ident == "data" { if A::DATA { let _eq = Equals::parse(input)?; diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 66f1fc82c..5ea037390 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -48,6 +48,8 @@ impl crate::options::AllowedOptions for TrackedFn { const CYCLE_INITIAL: bool = true; + const CYCLE_RESULT: bool = true; + const LRU: bool = true; const CONSTRUCTOR_NAME: bool = false; @@ -201,25 +203,38 @@ impl Macro { fn cycle_recovery(&self) -> syn::Result<(TokenStream, TokenStream, TokenStream)> { // TODO should we ask the user to specify a struct that impls a trait with two methods, // rather than asking for two methods separately? - match (&self.args.cycle_fn, &self.args.cycle_initial) { - (Some(cycle_fn), Some(cycle_initial)) => Ok(( + match ( + &self.args.cycle_fn, + &self.args.cycle_initial, + &self.args.cycle_result, + ) { + (Some(cycle_fn), Some(cycle_initial), None) => Ok(( quote!((#cycle_fn)), quote!((#cycle_initial)), quote!(Fixpoint), )), - (None, None) => Ok(( + (None, None, None) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), quote!((salsa::plumbing::unexpected_cycle_initial!)), quote!(Panic), )), - (Some(_), None) => Err(syn::Error::new_spanned( + (Some(_), None, None) => Err(syn::Error::new_spanned( self.args.cycle_fn.as_ref().unwrap(), "must provide `cycle_initial` along with `cycle_fn`", )), - (None, Some(_)) => Err(syn::Error::new_spanned( + (None, Some(_), None) => Err(syn::Error::new_spanned( self.args.cycle_initial.as_ref().unwrap(), "must provide `cycle_fn` along with `cycle_initial`", )), + (None, None, Some(cycle_result)) => Ok(( + quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((#cycle_result)), + quote!(FallbackImmediate), + )), + (_, _, Some(_)) => Err(syn::Error::new_spanned( + self.args.cycle_initial.as_ref().unwrap(), + "must provide either `cycle_result` or `cycle_fn` & `cycle_initial`, not both", + )), } } diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index d72b2ae78..f2a4ab9ab 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -50,6 +50,8 @@ impl crate::options::AllowedOptions for TrackedStruct { const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; + const LRU: bool = false; const CONSTRUCTOR_NAME: bool = true; diff --git a/src/cycle.rs b/src/cycle.rs index e1c63d653..2d09a7dfb 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -83,6 +83,11 @@ pub enum CycleRecoveryStrategy { /// This choice is computed by the query's `cycle_recovery` /// function and initial value. Fixpoint, + + /// Recovers from cycles by inserting a fallback value for all + /// queries that have a fallback, and ignoring any other query + /// in the cycle (as if they were not computed). + FallbackImmediate, } /// A "cycle head" is the query at which we encounter a cycle; that is, if A -> B -> C -> A, then A @@ -91,8 +96,8 @@ pub enum CycleRecoveryStrategy { /// cycle until it converges. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct CycleHead { - pub database_key_index: DatabaseKeyIndex, - pub iteration_count: u32, + pub(crate) database_key_index: DatabaseKeyIndex, + pub(crate) iteration_count: u32, } /// Any provisional value generated by any query in a cycle will track the cycle head(s) (can be @@ -190,3 +195,10 @@ impl From for CycleHeads { pub(crate) static EMPTY_CYCLE_HEADS: std::sync::LazyLock = std::sync::LazyLock::new(|| CycleHeads(ThinVec::new())); + +#[derive(Debug, PartialEq, Eq)] +pub enum CycleHeadKind { + Provisional, + NotProvisional, + FallbackImmediate, +} diff --git a/src/function.rs b/src/function.rs index d483323c2..d5edb03f8 100644 --- a/src/function.rs +++ b/src/function.rs @@ -5,7 +5,7 @@ use std::ptr::NonNull; pub(crate) use maybe_changed_after::VerifyResult; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; -use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy}; +use crate::cycle::{CycleHeadKind, CycleRecoveryAction, CycleRecoveryStrategy}; use crate::function::delete::DeletedEntries; use crate::ingredient::{fmt_index, Ingredient}; use crate::key::DatabaseKeyIndex; @@ -241,14 +241,22 @@ where /// True if the input `input` contains a memo that cites itself as a cycle head. /// This indicates an intermediate value for a cycle that has not yet reached a fixed point. - fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { + fn cycle_head_kind<'db>(&'db self, db: &'db dyn Database, input: Id) -> CycleHeadKind { let zalsa = db.zalsa(); - self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + let is_provisional = self + .get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) .is_some_and(|memo| { memo.cycle_heads() .into_iter() .any(|head| head.database_key_index == self.database_key_index(input)) - }) + }); + if is_provisional { + CycleHeadKind::Provisional + } else if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + CycleHeadKind::FallbackImmediate + } else { + CycleHeadKind::NotProvisional + } } /// Attempts to claim `key_index`, returning `false` if a cycle occurs. diff --git a/src/function/execute.rs b/src/function/execute.rs index 557463817..7bd025457 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::Ordering; + use crate::cycle::{CycleRecoveryStrategy, MAX_ITERATIONS}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; @@ -60,100 +62,138 @@ where let mut revisions = active_query.pop(); // Did the new result we got depend on our own provisional value, in a cycle? - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint - && revisions.cycle_heads.contains(&database_key_index) - { - let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { - // We have a last provisional value from our previous time around the loop. - last_provisional.value.as_ref() - } else { - // This is our first time around the loop; a provisional value must have been - // inserted into the memo table when the cycle was hit, so let's pull our - // initial provisional value from there. - let memo = - self.get_memo_from_table_for(zalsa, id, memo_ingredient_index) - .unwrap_or_else(|| panic!("{database_key_index:#?} is a cycle head, but no provisional memo found")); - debug_assert!(memo.may_be_provisional()); - memo.value.as_ref() - }; - // SAFETY: The `LRU` does not run mid-execution, so the value remains filled - let last_provisional_value = unsafe { last_provisional_value.unwrap_unchecked() }; - tracing::debug!( - "{database_key_index:?}: execute: \ - I am a cycle head, comparing last provisional value with new value" - ); - // If the new result is equal to the last provisional result, the cycle has - // converged and we are done. - if !C::values_equal(&new_value, last_provisional_value) { - if fell_back { - // We fell back to a value last iteration, but the fallback didn't result - // in convergence. We only have bad options here: continue iterating - // (ignoring the request to fall back), or forcibly use the fallback and - // leave the cycle in an inconsistent state (we'll be using a value for - // this query that it doesn't evaluate to, given its inputs). Maybe we'll - // have to go with the latter, but for now let's panic and see if real use - // cases need non-converging fallbacks. - panic!("{database_key_index:?}: execute: fallback did not converge"); - } - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle( - db, - &new_value, - iteration_count, - C::id_to_input(db, id), - ) { - crate::CycleRecoveryAction::Iterate => { - tracing::debug!("{database_key_index:?}: execute: iterate again"); + if revisions.cycle_heads.contains(&database_key_index) { + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + // Ignore the computed value, leave the fallback value there. + let memo = self + .get_memo_from_table_for(zalsa, id, memo_ingredient_index) + .unwrap_or_else(|| { + unreachable!( + "{database_key_index:#?} is a `FallbackImmediate` cycle head, \ + but no memo found" + ) + }); + // We need to mark the memo as finalized so other cycle participants that have fallbacks + // will be verified (participants that don't have fallbacks will not be verified). + memo.revisions.verified_final.store(true, Ordering::Release); + // SAFETY: This is ours memo. + return unsafe { self.extend_memo_lifetime(memo) }; + } else if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { + let last_provisional_value = + if let Some(last_provisional) = opt_last_provisional { + // We have a last provisional value from our previous time around the loop. + last_provisional.value.as_ref() + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + let memo = self + .get_memo_from_table_for(zalsa, id, memo_ingredient_index) + .unwrap_or_else(|| { + unreachable!( + "{database_key_index:#?} is a cycle head, \ + but no provisional memo found" + ) + }); + debug_assert!(memo.may_be_provisional()); + memo.value.as_ref() + }; + // SAFETY: The `LRU` does not run mid-execution, so the value remains filled + let last_provisional_value = + unsafe { last_provisional_value.unwrap_unchecked() }; + tracing::debug!( + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value with new value" + ); + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. + if !C::values_equal(&new_value, last_provisional_value) { + if fell_back { + // We fell back to a value last iteration, but the fallback didn't result + // in convergence. We only have bad options here: continue iterating + // (ignoring the request to fall back), or forcibly use the fallback and + // leave the cycle in an inconsistent state (we'll be using a value for + // this query that it doesn't evaluate to, given its inputs). Maybe we'll + // have to go with the latter, but for now let's panic and see if real use + // cases need non-converging fallbacks. + panic!("{database_key_index:?}: execute: fallback did not converge"); } - crate::CycleRecoveryAction::Fallback(fallback_value) => { - tracing::debug!( + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle( + db, + &new_value, + iteration_count, + C::id_to_input(db, id), + ) { + crate::CycleRecoveryAction::Iterate => { + tracing::debug!("{database_key_index:?}: execute: iterate again"); + } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( "{database_key_index:?}: execute: user cycle_fn says to fall back" ); - new_value = fallback_value; - // We have to insert the fallback value for this query and then iterate - // one more time to fill in correct values for everything else in the - // cycle based on it; then we'll re-insert it as final value. - fell_back = true; + new_value = fallback_value; + // We have to insert the fallback value for this query and then iterate + // one more time to fill in correct values for everything else in the + // cycle based on it; then we'll re-insert it as final value. + fell_back = true; + } } + // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` + // which is less than `u32::MAX`. + iteration_count += 1; + if iteration_count > MAX_ITERATIONS { + panic!("{database_key_index:?}: execute: too many cycle iterations"); + } + db.salsa_event(&|| { + Event::new(EventKind::WillIterateCycle { + database_key: database_key_index, + iteration_count, + fell_back, + }) + }); + revisions + .cycle_heads + .update_iteration_count(database_key_index, iteration_count); + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + memo_ingredient_index, + )); + + active_query = db + .zalsa_local() + .push_query(database_key_index, iteration_count); + + continue; } - // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` - // which is less than `u32::MAX`. - iteration_count += 1; - if iteration_count > MAX_ITERATIONS { - panic!("{database_key_index:?}: execute: too many cycle iterations"); - } - db.salsa_event(&|| { - Event::new(EventKind::WillIterateCycle { - database_key: database_key_index, - iteration_count, - fell_back, - }) - }); - revisions - .cycle_heads - .update_iteration_count(database_key_index, iteration_count); - opt_last_provisional = Some(self.insert_memo( - zalsa, - id, - Memo::new(Some(new_value), revision_now, revisions), - memo_ingredient_index, - )); - - active_query = db - .zalsa_local() - .push_query(database_key_index, iteration_count); - - continue; + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value" + ); + revisions.cycle_heads.remove(&database_key_index); } - tracing::debug!( - "{database_key_index:?}: execute: fixpoint iteration has a final value" - ); - revisions.cycle_heads.remove(&database_key_index); } tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); + if !revisions.cycle_heads.is_empty() + && C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate + { + // If we're in the middle of a cycle and we have a fallback, use it instead. + // Cycle participants that don't have a fallback will be discarded in + // `validate_provisional()`. + let cycle_heads = revisions.cycle_heads; + let active_query = db.zalsa_local().push_query(database_key_index, 0); + new_value = C::cycle_initial(db, C::id_to_input(db, id)); + revisions = active_query.pop(); + // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. + // When verifying this, we will see we have fallback and mark ourselves verified. + revisions.cycle_heads = cycle_heads; + *revisions.verified_final.get_mut() = false; + } + if let Some(old_memo) = opt_old_memo { // If the new value is equal to the old one, then it didn't // really change, even if some of its inputs have. So we can diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 19a4845b0..20110ca94 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,3 +1,4 @@ +use crate::cycle::{CycleHeads, CycleRecoveryStrategy}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl, VerifyResult}; use crate::table::sync::ClaimResult; @@ -49,7 +50,11 @@ where // any further (it could escape outside the cycle); we need to block on the other // thread completing fixpoint iteration of the cycle, and then we can re-query for // our no-longer-provisional memo. - if !memo.provisional_retry(db, zalsa, self.database_key_index(id)) { + // That is only correct for fixpoint cycles, though: `FallbackImmediate` cycles + // never have provisional entries. + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate + || !memo.provisional_retry(db, zalsa, self.database_key_index(id)) + { return memo; } } @@ -100,7 +105,9 @@ where database_key_index, memo_ingredient_index, ) { - ClaimResult::Retry => return None, + ClaimResult::Retry => { + return None; + } ClaimResult::Cycle => { // check if there's a provisional value for this query // Note we don't `validate_may_be_provisional` the memo here as we want to reuse an @@ -126,37 +133,60 @@ where } } // no provisional value; create/insert/return initial provisional value - return self - .initial_value(db, database_key_index.key_index()) - .map(|initial_value| { + return match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => db.zalsa_local().with_query_stack(|stack| { + panic!( + "dependency graph cycle when querying {database_key_index:#?}, \ + set cycle_fn/cycle_initial to fixpoint iterate.\n\ + Query stack:\n{:#?}", + stack, + ); + }), + CycleRecoveryStrategy::Fixpoint => { tracing::debug!( "hit cycle at {database_key_index:#?}, \ inserting and returning fixpoint initial value" ); - self.insert_memo( + let revisions = QueryRevisions::fixpoint_initial( + database_key_index, + zalsa.current_revision(), + ); + let initial_value = self + .initial_value(db, database_key_index.key_index()) + .expect( + "`CycleRecoveryStrategy::Fixpoint` \ + should have initial_value", + ); + Some(self.insert_memo( zalsa, id, - Memo::new( - Some(initial_value), - zalsa.current_revision(), - QueryRevisions::fixpoint_initial( - database_key_index, - zalsa.current_revision(), - ), - ), + Memo::new(Some(initial_value), zalsa.current_revision(), revisions), memo_ingredient_index, - ) - }) - .or_else(|| { - db.zalsa_local().with_query_stack(|stack| { - panic!( - "dependency graph cycle when querying {database_key_index:#?}, \ - set cycle_fn/cycle_initial to fixpoint iterate.\n\ - Query stack:\n{:#?}", - stack, + )) + } + CycleRecoveryStrategy::FallbackImmediate => { + tracing::debug!( + "hit a `FallbackImmediate` cycle at {database_key_index:#?}" + ); + let active_query = db.zalsa_local().push_query(database_key_index, 0); + let fallback_value = self + .initial_value(db, database_key_index.key_index()) + .expect( + "`CycleRecoveryStrategy::FallbackImmediate` \ + should have initial_value", ); - }) - }); + let mut revisions = active_query.pop(); + revisions.cycle_heads = CycleHeads::initial(database_key_index); + // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. + *revisions.verified_final.get_mut() = false; + Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(fallback_value), zalsa.current_revision(), revisions), + memo_ingredient_index, + )) + } + }; } ClaimResult::Claimed(guard) => guard, }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 3bfecd806..c6f03bcf6 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,7 +1,7 @@ use std::sync::atomic::Ordering; use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::cycle::{CycleHeads, CycleRecoveryStrategy}; +use crate::cycle::{CycleHeadKind, CycleHeads, CycleRecoveryStrategy}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::key::DatabaseKeyIndex; @@ -118,6 +118,9 @@ where stack, ); }), + CycleRecoveryStrategy::FallbackImmediate => { + return Some(VerifyResult::unchanged()); + } CycleRecoveryStrategy::Fixpoint => { return Some(VerifyResult::Unchanged( InputAccumulatedValues::Empty, @@ -263,15 +266,34 @@ where "{database_key_index:?}: validate_provisional(memo = {memo:#?})", memo = memo.tracing_debug() ); - if (&memo.revisions.cycle_heads).into_iter().any(|cycle_head| { - zalsa + for cycle_head in &memo.revisions.cycle_heads { + let kind = zalsa .lookup_ingredient(cycle_head.database_key_index.ingredient_index()) - .is_provisional_cycle_head( + .cycle_head_kind( db.as_dyn_database(), cycle_head.database_key_index.key_index(), - ) - }) { - return false; + ); + match kind { + CycleHeadKind::Provisional => return false, + CycleHeadKind::NotProvisional => { + // FIXME: We can ignore this, I just don't have a use-case for this. + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + panic!("cannot mix `cycle_fn` and `cycle_result` in cycles") + } + } + CycleHeadKind::FallbackImmediate => match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => { + // Queries without fallback are not considered when inside a cycle. + return false; + } + // FIXME: We can do the same as with `CycleRecoveryStrategy::Panic` here, I just don't have + // a use-case for this. + CycleRecoveryStrategy::Fixpoint => { + panic!("cannot mix `cycle_fn` and `cycle_result` in cycles") + } + CycleRecoveryStrategy::FallbackImmediate => {} + }, + } } // Relaxed is sufficient here because there are no other writes we need to ensure have // happened before marking this memo as verified-final. diff --git a/src/function/memo.rs b/src/function/memo.rs index eaa315cb2..bfd50005e 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -5,7 +5,7 @@ use std::fmt::{Debug, Formatter}; use std::ptr::NonNull; use std::sync::atomic::Ordering; -use crate::cycle::{CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}; +use crate::cycle::{CycleHeadKind, CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}; use crate::function::{Configuration, IngredientImpl}; use crate::key::DatabaseKeyIndex; use crate::revision::AtomicRevision; @@ -113,7 +113,9 @@ impl IngredientImpl { key: Id, ) -> Option> { match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))), + CycleRecoveryStrategy::Fixpoint | CycleRecoveryStrategy::FallbackImmediate => { + Some(C::cycle_initial(db, C::id_to_input(db, key))) + } CycleRecoveryStrategy::Panic => None, } } @@ -198,7 +200,11 @@ impl Memo { .any(|head| { let head_index = head.database_key_index; let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index()); - if !ingredient.is_provisional_cycle_head(db, head_index.key_index()) { + let cycle_head_kind = ingredient.cycle_head_kind(db, head_index.key_index()); + if matches!( + cycle_head_kind, + CycleHeadKind::NotProvisional | CycleHeadKind::FallbackImmediate + ) { // This cycle is already finalized, so we don't need to wait on it; // keep looping through cycle heads. retry = true; diff --git a/src/ingredient.rs b/src/ingredient.rs index a62e660de..ee90bce21 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -2,7 +2,7 @@ use std::any::{Any, TypeId}; use std::fmt; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; -use crate::cycle::CycleRecoveryStrategy; +use crate::cycle::{CycleHeadKind, CycleRecoveryStrategy}; use crate::function::VerifyResult; use crate::plumbing::IngredientIndices; use crate::table::Table; @@ -61,9 +61,9 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// /// In the case of nested cycles, we are not asking here whether the value is provisional due /// to the outer cycle being unresolved, only whether its own cycle remains provisional. - fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { + fn cycle_head_kind<'db>(&'db self, db: &'db dyn Database, input: Id) -> CycleHeadKind { _ = (db, input); - false + CycleHeadKind::NotProvisional } /// Invoked when the current thread needs to wait for a result for the given `key_index`. diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 17aef7044..3e54574ef 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -54,10 +54,6 @@ where VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) } - fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { - false - } - fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { true } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index dc835daad..1b03e8cc6 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -750,10 +750,6 @@ where VerifyResult::changed_if(data.created_at > revision) } - fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { - false - } - fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { true } diff --git a/tests/cycle_fallback_immediate.rs b/tests/cycle_fallback_immediate.rs new file mode 100644 index 000000000..b22767202 --- /dev/null +++ b/tests/cycle_fallback_immediate.rs @@ -0,0 +1,49 @@ +//! It is possible to omit the `cycle_fn`, only specifying `cycle_result` in which case +//! an immediate fallback value is used as the cycle handling opposed to doing a fixpoint resolution. + +use std::sync::atomic::{AtomicI32, Ordering}; + +#[salsa::tracked(cycle_result=cycle_result)] +fn one_o_one(db: &dyn salsa::Database) -> u32 { + let val = one_o_one(db); + val + 1 +} + +fn cycle_result(_db: &dyn salsa::Database) -> u32 { + 100 +} + +#[test_log::test] +fn simple() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(one_o_one(&db), 100); +} + +#[salsa::tracked(cycle_result=two_queries_cycle_result)] +fn two_queries1(db: &dyn salsa::Database) -> i32 { + two_queries2(db); + 0 +} + +#[salsa::tracked] +fn two_queries2(db: &dyn salsa::Database) -> i32 { + two_queries1(db); + // This is horribly against Salsa's rules, but we want to test that + // the value from within the cycle is not considered, and this is + // the only way I found. + static CALLS_COUNT: AtomicI32 = AtomicI32::new(0); + CALLS_COUNT.fetch_add(1, Ordering::Relaxed) +} + +fn two_queries_cycle_result(_db: &dyn salsa::Database) -> i32 { + 1 +} + +#[test] +fn two_queries() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(two_queries1(&db), 1); + assert_eq!(two_queries2(&db), 1); +} diff --git a/tests/cycle_result_dependencies.rs b/tests/cycle_result_dependencies.rs new file mode 100644 index 000000000..e7071a029 --- /dev/null +++ b/tests/cycle_result_dependencies.rs @@ -0,0 +1,25 @@ +use salsa::{Database, Setter}; + +#[salsa::input] +struct Input { + value: i32, +} + +#[salsa::tracked(cycle_result=cycle_result)] +fn has_cycle(db: &dyn Database, input: Input) -> i32 { + has_cycle(db, input) +} + +fn cycle_result(db: &dyn Database, input: Input) -> i32 { + input.value(db) +} + +#[test] +fn cycle_result_dependencies_are_recorded() { + let mut db = salsa::DatabaseImpl::default(); + let input = Input::new(&db, 123); + assert_eq!(has_cycle(&db, input), 123); + + input.set_value(&mut db).to(456); + assert_eq!(has_cycle(&db, input), 456); +} diff --git a/tests/parallel/cycle_a_t1_b_t2_fallback.rs b/tests/parallel/cycle_a_t1_b_t2_fallback.rs new file mode 100644 index 000000000..faa4c39f4 --- /dev/null +++ b/tests/parallel/cycle_a_t1_b_t2_fallback.rs @@ -0,0 +1,67 @@ +//! Test a specific cycle scenario: +//! +//! ```text +//! Thread T1 Thread T2 +//! --------- --------- +//! | | +//! v | +//! query_a() | +//! ^ | v +//! | +------------> query_b() +//! | | +//! +--------------------+ +//! ``` +use crate::setup::{Knobs, KnobsDatabase}; + +const FALLBACK_A: u32 = 0b01; +const FALLBACK_B: u32 = 0b10; +const OFFSET_A: u32 = 0b0100; +const OFFSET_B: u32 = 0b1000; + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` + +#[salsa::tracked(cycle_result=cycle_result_a)] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + db.signal(1); + + // Wait for Thread T2 to enter `query_b` before we continue. + db.wait_for(2); + + query_b(db) | OFFSET_A +} + +#[salsa::tracked(cycle_result=cycle_result_b)] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + // Wait for Thread T1 to enter `query_a` before we continue. + db.wait_for(1); + + db.signal(2); + + query_a(db) | OFFSET_B +} + +fn cycle_result_a(_db: &dyn KnobsDatabase) -> u32 { + FALLBACK_A +} + +fn cycle_result_b(_db: &dyn KnobsDatabase) -> u32 { + FALLBACK_B +} + +#[test_log::test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_b(&db_t2)); + + let (r_t1, r_t2) = (t1.join(), t2.join()); + + assert_eq!((r_t1?, r_t2?), (FALLBACK_A, FALLBACK_B)); + Ok(()) + }) + .unwrap_or_else(|e| std::panic::resume_unwind(e)); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 613e43e1d..7507a71cf 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,6 +1,7 @@ mod setup; mod cycle_a_t1_b_t2; +mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; mod cycle_nested_three_threads; mod cycle_panic;