From f61f973ad5a67fc59ed42f278f71e696b1aa2653 Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Fri, 18 Apr 2025 04:37:12 +0300 Subject: [PATCH] Add a third cycle mode, equivalent to old Salsa cycle behavior That is, directly set a value for all queries that have fallbacks, and ignore all other queries in the cycle. Unlike old Salsa, we still need all cycle heads to be marked, and we still execute the queries to completion, but we throw their result. --- components/salsa-macros/src/accumulator.rs | 1 + components/salsa-macros/src/input.rs | 2 + components/salsa-macros/src/interned.rs | 2 + components/salsa-macros/src/options.rs | 23 ++ components/salsa-macros/src/tracked_fn.rs | 25 ++- components/salsa-macros/src/tracked_struct.rs | 2 + src/cycle.rs | 16 +- src/function.rs | 16 +- src/function/execute.rs | 208 +++++++++++------- src/function/fetch.rs | 80 ++++--- src/function/maybe_changed_after.rs | 36 ++- src/function/memo.rs | 12 +- src/ingredient.rs | 6 +- src/input/input_field.rs | 4 - src/tracked_struct.rs | 4 - tests/cycle_fallback_immediate.rs | 49 +++++ tests/cycle_result_dependencies.rs | 25 +++ tests/parallel/cycle_a_t1_b_t2_fallback.rs | 67 ++++++ tests/parallel/main.rs | 1 + 19 files changed, 438 insertions(+), 141 deletions(-) create mode 100644 tests/cycle_fallback_immediate.rs create mode 100644 tests/cycle_result_dependencies.rs create mode 100644 tests/parallel/cycle_a_t1_b_t2_fallback.rs diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index 926fdb547..220d0e941 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -40,6 +40,7 @@ impl AllowedOptions for Accumulator { const DB: bool = false; const CYCLE_FN: bool = false; const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; const LRU: bool = false; const CONSTRUCTOR_NAME: bool = false; const ID: bool = false; diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 0e5f92c7c..38241a959 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -55,6 +55,8 @@ impl crate::options::AllowedOptions for InputStruct { const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; + const LRU: bool = false; const CONSTRUCTOR_NAME: bool = true; diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 7067d3010..646470fce 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -55,6 +55,8 @@ impl crate::options::AllowedOptions for InternedStruct { const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; + const LRU: bool = false; const CONSTRUCTOR_NAME: bool = true; diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index ded09017e..e69e70a12 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -61,6 +61,11 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub cycle_initial: Option, + /// The `cycle_result = ` option is the result for non-fixpoint cycle. + /// + /// If this is `Some`, the value is the ``. + pub cycle_result: Option, + /// The `data = ` option is used to define the name of the data type for an interned /// struct. /// @@ -100,6 +105,7 @@ impl Default for Options { db_path: Default::default(), cycle_fn: Default::default(), cycle_initial: Default::default(), + cycle_result: Default::default(), data: Default::default(), constructor_name: Default::default(), phantom: Default::default(), @@ -123,6 +129,7 @@ pub(crate) trait AllowedOptions { const DB: bool; const CYCLE_FN: bool; const CYCLE_INITIAL: bool; + const CYCLE_RESULT: bool; const LRU: bool; const CONSTRUCTOR_NAME: bool; const ID: bool; @@ -274,6 +281,22 @@ impl syn::parse::Parse for Options { "`cycle_initial` option not allowed here", )); } + } else if ident == "cycle_result" { + if A::CYCLE_RESULT { + let _eq = Equals::parse(input)?; + let expr = syn::Expr::parse(input)?; + if let Some(old) = options.cycle_result.replace(expr) { + return Err(syn::Error::new( + old.span(), + "option `cycle_result` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`cycle_result` option not allowed here", + )); + } } else if ident == "data" { if A::DATA { let _eq = Equals::parse(input)?; diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 66f1fc82c..5ea037390 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -48,6 +48,8 @@ impl crate::options::AllowedOptions for TrackedFn { const CYCLE_INITIAL: bool = true; + const CYCLE_RESULT: bool = true; + const LRU: bool = true; const CONSTRUCTOR_NAME: bool = false; @@ -201,25 +203,38 @@ impl Macro { fn cycle_recovery(&self) -> syn::Result<(TokenStream, TokenStream, TokenStream)> { // TODO should we ask the user to specify a struct that impls a trait with two methods, // rather than asking for two methods separately? - match (&self.args.cycle_fn, &self.args.cycle_initial) { - (Some(cycle_fn), Some(cycle_initial)) => Ok(( + match ( + &self.args.cycle_fn, + &self.args.cycle_initial, + &self.args.cycle_result, + ) { + (Some(cycle_fn), Some(cycle_initial), None) => Ok(( quote!((#cycle_fn)), quote!((#cycle_initial)), quote!(Fixpoint), )), - (None, None) => Ok(( + (None, None, None) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), quote!((salsa::plumbing::unexpected_cycle_initial!)), quote!(Panic), )), - (Some(_), None) => Err(syn::Error::new_spanned( + (Some(_), None, None) => Err(syn::Error::new_spanned( self.args.cycle_fn.as_ref().unwrap(), "must provide `cycle_initial` along with `cycle_fn`", )), - (None, Some(_)) => Err(syn::Error::new_spanned( + (None, Some(_), None) => Err(syn::Error::new_spanned( self.args.cycle_initial.as_ref().unwrap(), "must provide `cycle_fn` along with `cycle_initial`", )), + (None, None, Some(cycle_result)) => Ok(( + quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((#cycle_result)), + quote!(FallbackImmediate), + )), + (_, _, Some(_)) => Err(syn::Error::new_spanned( + self.args.cycle_initial.as_ref().unwrap(), + "must provide either `cycle_result` or `cycle_fn` & `cycle_initial`, not both", + )), } } diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index d72b2ae78..f2a4ab9ab 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -50,6 +50,8 @@ impl crate::options::AllowedOptions for TrackedStruct { const CYCLE_INITIAL: bool = false; + const CYCLE_RESULT: bool = false; + const LRU: bool = false; const CONSTRUCTOR_NAME: bool = true; diff --git a/src/cycle.rs b/src/cycle.rs index e1c63d653..2d09a7dfb 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -83,6 +83,11 @@ pub enum CycleRecoveryStrategy { /// This choice is computed by the query's `cycle_recovery` /// function and initial value. Fixpoint, + + /// Recovers from cycles by inserting a fallback value for all + /// queries that have a fallback, and ignoring any other query + /// in the cycle (as if they were not computed). + FallbackImmediate, } /// A "cycle head" is the query at which we encounter a cycle; that is, if A -> B -> C -> A, then A @@ -91,8 +96,8 @@ pub enum CycleRecoveryStrategy { /// cycle until it converges. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct CycleHead { - pub database_key_index: DatabaseKeyIndex, - pub iteration_count: u32, + pub(crate) database_key_index: DatabaseKeyIndex, + pub(crate) iteration_count: u32, } /// Any provisional value generated by any query in a cycle will track the cycle head(s) (can be @@ -190,3 +195,10 @@ impl From for CycleHeads { pub(crate) static EMPTY_CYCLE_HEADS: std::sync::LazyLock = std::sync::LazyLock::new(|| CycleHeads(ThinVec::new())); + +#[derive(Debug, PartialEq, Eq)] +pub enum CycleHeadKind { + Provisional, + NotProvisional, + FallbackImmediate, +} diff --git a/src/function.rs b/src/function.rs index d483323c2..d5edb03f8 100644 --- a/src/function.rs +++ b/src/function.rs @@ -5,7 +5,7 @@ use std::ptr::NonNull; pub(crate) use maybe_changed_after::VerifyResult; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; -use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy}; +use crate::cycle::{CycleHeadKind, CycleRecoveryAction, CycleRecoveryStrategy}; use crate::function::delete::DeletedEntries; use crate::ingredient::{fmt_index, Ingredient}; use crate::key::DatabaseKeyIndex; @@ -241,14 +241,22 @@ where /// True if the input `input` contains a memo that cites itself as a cycle head. /// This indicates an intermediate value for a cycle that has not yet reached a fixed point. - fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { + fn cycle_head_kind<'db>(&'db self, db: &'db dyn Database, input: Id) -> CycleHeadKind { let zalsa = db.zalsa(); - self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + let is_provisional = self + .get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) .is_some_and(|memo| { memo.cycle_heads() .into_iter() .any(|head| head.database_key_index == self.database_key_index(input)) - }) + }); + if is_provisional { + CycleHeadKind::Provisional + } else if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + CycleHeadKind::FallbackImmediate + } else { + CycleHeadKind::NotProvisional + } } /// Attempts to claim `key_index`, returning `false` if a cycle occurs. diff --git a/src/function/execute.rs b/src/function/execute.rs index 557463817..7bd025457 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::Ordering; + use crate::cycle::{CycleRecoveryStrategy, MAX_ITERATIONS}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; @@ -60,100 +62,138 @@ where let mut revisions = active_query.pop(); // Did the new result we got depend on our own provisional value, in a cycle? - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint - && revisions.cycle_heads.contains(&database_key_index) - { - let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { - // We have a last provisional value from our previous time around the loop. - last_provisional.value.as_ref() - } else { - // This is our first time around the loop; a provisional value must have been - // inserted into the memo table when the cycle was hit, so let's pull our - // initial provisional value from there. - let memo = - self.get_memo_from_table_for(zalsa, id, memo_ingredient_index) - .unwrap_or_else(|| panic!("{database_key_index:#?} is a cycle head, but no provisional memo found")); - debug_assert!(memo.may_be_provisional()); - memo.value.as_ref() - }; - // SAFETY: The `LRU` does not run mid-execution, so the value remains filled - let last_provisional_value = unsafe { last_provisional_value.unwrap_unchecked() }; - tracing::debug!( - "{database_key_index:?}: execute: \ - I am a cycle head, comparing last provisional value with new value" - ); - // If the new result is equal to the last provisional result, the cycle has - // converged and we are done. - if !C::values_equal(&new_value, last_provisional_value) { - if fell_back { - // We fell back to a value last iteration, but the fallback didn't result - // in convergence. We only have bad options here: continue iterating - // (ignoring the request to fall back), or forcibly use the fallback and - // leave the cycle in an inconsistent state (we'll be using a value for - // this query that it doesn't evaluate to, given its inputs). Maybe we'll - // have to go with the latter, but for now let's panic and see if real use - // cases need non-converging fallbacks. - panic!("{database_key_index:?}: execute: fallback did not converge"); - } - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle( - db, - &new_value, - iteration_count, - C::id_to_input(db, id), - ) { - crate::CycleRecoveryAction::Iterate => { - tracing::debug!("{database_key_index:?}: execute: iterate again"); + if revisions.cycle_heads.contains(&database_key_index) { + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + // Ignore the computed value, leave the fallback value there. + let memo = self + .get_memo_from_table_for(zalsa, id, memo_ingredient_index) + .unwrap_or_else(|| { + unreachable!( + "{database_key_index:#?} is a `FallbackImmediate` cycle head, \ + but no memo found" + ) + }); + // We need to mark the memo as finalized so other cycle participants that have fallbacks + // will be verified (participants that don't have fallbacks will not be verified). + memo.revisions.verified_final.store(true, Ordering::Release); + // SAFETY: This is ours memo. + return unsafe { self.extend_memo_lifetime(memo) }; + } else if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { + let last_provisional_value = + if let Some(last_provisional) = opt_last_provisional { + // We have a last provisional value from our previous time around the loop. + last_provisional.value.as_ref() + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + let memo = self + .get_memo_from_table_for(zalsa, id, memo_ingredient_index) + .unwrap_or_else(|| { + unreachable!( + "{database_key_index:#?} is a cycle head, \ + but no provisional memo found" + ) + }); + debug_assert!(memo.may_be_provisional()); + memo.value.as_ref() + }; + // SAFETY: The `LRU` does not run mid-execution, so the value remains filled + let last_provisional_value = + unsafe { last_provisional_value.unwrap_unchecked() }; + tracing::debug!( + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value with new value" + ); + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. + if !C::values_equal(&new_value, last_provisional_value) { + if fell_back { + // We fell back to a value last iteration, but the fallback didn't result + // in convergence. We only have bad options here: continue iterating + // (ignoring the request to fall back), or forcibly use the fallback and + // leave the cycle in an inconsistent state (we'll be using a value for + // this query that it doesn't evaluate to, given its inputs). Maybe we'll + // have to go with the latter, but for now let's panic and see if real use + // cases need non-converging fallbacks. + panic!("{database_key_index:?}: execute: fallback did not converge"); } - crate::CycleRecoveryAction::Fallback(fallback_value) => { - tracing::debug!( + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle( + db, + &new_value, + iteration_count, + C::id_to_input(db, id), + ) { + crate::CycleRecoveryAction::Iterate => { + tracing::debug!("{database_key_index:?}: execute: iterate again"); + } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( "{database_key_index:?}: execute: user cycle_fn says to fall back" ); - new_value = fallback_value; - // We have to insert the fallback value for this query and then iterate - // one more time to fill in correct values for everything else in the - // cycle based on it; then we'll re-insert it as final value. - fell_back = true; + new_value = fallback_value; + // We have to insert the fallback value for this query and then iterate + // one more time to fill in correct values for everything else in the + // cycle based on it; then we'll re-insert it as final value. + fell_back = true; + } } + // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` + // which is less than `u32::MAX`. + iteration_count += 1; + if iteration_count > MAX_ITERATIONS { + panic!("{database_key_index:?}: execute: too many cycle iterations"); + } + db.salsa_event(&|| { + Event::new(EventKind::WillIterateCycle { + database_key: database_key_index, + iteration_count, + fell_back, + }) + }); + revisions + .cycle_heads + .update_iteration_count(database_key_index, iteration_count); + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + memo_ingredient_index, + )); + + active_query = db + .zalsa_local() + .push_query(database_key_index, iteration_count); + + continue; } - // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` - // which is less than `u32::MAX`. - iteration_count += 1; - if iteration_count > MAX_ITERATIONS { - panic!("{database_key_index:?}: execute: too many cycle iterations"); - } - db.salsa_event(&|| { - Event::new(EventKind::WillIterateCycle { - database_key: database_key_index, - iteration_count, - fell_back, - }) - }); - revisions - .cycle_heads - .update_iteration_count(database_key_index, iteration_count); - opt_last_provisional = Some(self.insert_memo( - zalsa, - id, - Memo::new(Some(new_value), revision_now, revisions), - memo_ingredient_index, - )); - - active_query = db - .zalsa_local() - .push_query(database_key_index, iteration_count); - - continue; + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value" + ); + revisions.cycle_heads.remove(&database_key_index); } - tracing::debug!( - "{database_key_index:?}: execute: fixpoint iteration has a final value" - ); - revisions.cycle_heads.remove(&database_key_index); } tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); + if !revisions.cycle_heads.is_empty() + && C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate + { + // If we're in the middle of a cycle and we have a fallback, use it instead. + // Cycle participants that don't have a fallback will be discarded in + // `validate_provisional()`. + let cycle_heads = revisions.cycle_heads; + let active_query = db.zalsa_local().push_query(database_key_index, 0); + new_value = C::cycle_initial(db, C::id_to_input(db, id)); + revisions = active_query.pop(); + // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. + // When verifying this, we will see we have fallback and mark ourselves verified. + revisions.cycle_heads = cycle_heads; + *revisions.verified_final.get_mut() = false; + } + if let Some(old_memo) = opt_old_memo { // If the new value is equal to the old one, then it didn't // really change, even if some of its inputs have. So we can diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 19a4845b0..20110ca94 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,3 +1,4 @@ +use crate::cycle::{CycleHeads, CycleRecoveryStrategy}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl, VerifyResult}; use crate::table::sync::ClaimResult; @@ -49,7 +50,11 @@ where // any further (it could escape outside the cycle); we need to block on the other // thread completing fixpoint iteration of the cycle, and then we can re-query for // our no-longer-provisional memo. - if !memo.provisional_retry(db, zalsa, self.database_key_index(id)) { + // That is only correct for fixpoint cycles, though: `FallbackImmediate` cycles + // never have provisional entries. + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate + || !memo.provisional_retry(db, zalsa, self.database_key_index(id)) + { return memo; } } @@ -100,7 +105,9 @@ where database_key_index, memo_ingredient_index, ) { - ClaimResult::Retry => return None, + ClaimResult::Retry => { + return None; + } ClaimResult::Cycle => { // check if there's a provisional value for this query // Note we don't `validate_may_be_provisional` the memo here as we want to reuse an @@ -126,37 +133,60 @@ where } } // no provisional value; create/insert/return initial provisional value - return self - .initial_value(db, database_key_index.key_index()) - .map(|initial_value| { + return match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => db.zalsa_local().with_query_stack(|stack| { + panic!( + "dependency graph cycle when querying {database_key_index:#?}, \ + set cycle_fn/cycle_initial to fixpoint iterate.\n\ + Query stack:\n{:#?}", + stack, + ); + }), + CycleRecoveryStrategy::Fixpoint => { tracing::debug!( "hit cycle at {database_key_index:#?}, \ inserting and returning fixpoint initial value" ); - self.insert_memo( + let revisions = QueryRevisions::fixpoint_initial( + database_key_index, + zalsa.current_revision(), + ); + let initial_value = self + .initial_value(db, database_key_index.key_index()) + .expect( + "`CycleRecoveryStrategy::Fixpoint` \ + should have initial_value", + ); + Some(self.insert_memo( zalsa, id, - Memo::new( - Some(initial_value), - zalsa.current_revision(), - QueryRevisions::fixpoint_initial( - database_key_index, - zalsa.current_revision(), - ), - ), + Memo::new(Some(initial_value), zalsa.current_revision(), revisions), memo_ingredient_index, - ) - }) - .or_else(|| { - db.zalsa_local().with_query_stack(|stack| { - panic!( - "dependency graph cycle when querying {database_key_index:#?}, \ - set cycle_fn/cycle_initial to fixpoint iterate.\n\ - Query stack:\n{:#?}", - stack, + )) + } + CycleRecoveryStrategy::FallbackImmediate => { + tracing::debug!( + "hit a `FallbackImmediate` cycle at {database_key_index:#?}" + ); + let active_query = db.zalsa_local().push_query(database_key_index, 0); + let fallback_value = self + .initial_value(db, database_key_index.key_index()) + .expect( + "`CycleRecoveryStrategy::FallbackImmediate` \ + should have initial_value", ); - }) - }); + let mut revisions = active_query.pop(); + revisions.cycle_heads = CycleHeads::initial(database_key_index); + // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. + *revisions.verified_final.get_mut() = false; + Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(fallback_value), zalsa.current_revision(), revisions), + memo_ingredient_index, + )) + } + }; } ClaimResult::Claimed(guard) => guard, }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 3bfecd806..c6f03bcf6 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,7 +1,7 @@ use std::sync::atomic::Ordering; use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::cycle::{CycleHeads, CycleRecoveryStrategy}; +use crate::cycle::{CycleHeadKind, CycleHeads, CycleRecoveryStrategy}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::key::DatabaseKeyIndex; @@ -118,6 +118,9 @@ where stack, ); }), + CycleRecoveryStrategy::FallbackImmediate => { + return Some(VerifyResult::unchanged()); + } CycleRecoveryStrategy::Fixpoint => { return Some(VerifyResult::Unchanged( InputAccumulatedValues::Empty, @@ -263,15 +266,34 @@ where "{database_key_index:?}: validate_provisional(memo = {memo:#?})", memo = memo.tracing_debug() ); - if (&memo.revisions.cycle_heads).into_iter().any(|cycle_head| { - zalsa + for cycle_head in &memo.revisions.cycle_heads { + let kind = zalsa .lookup_ingredient(cycle_head.database_key_index.ingredient_index()) - .is_provisional_cycle_head( + .cycle_head_kind( db.as_dyn_database(), cycle_head.database_key_index.key_index(), - ) - }) { - return false; + ); + match kind { + CycleHeadKind::Provisional => return false, + CycleHeadKind::NotProvisional => { + // FIXME: We can ignore this, I just don't have a use-case for this. + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + panic!("cannot mix `cycle_fn` and `cycle_result` in cycles") + } + } + CycleHeadKind::FallbackImmediate => match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => { + // Queries without fallback are not considered when inside a cycle. + return false; + } + // FIXME: We can do the same as with `CycleRecoveryStrategy::Panic` here, I just don't have + // a use-case for this. + CycleRecoveryStrategy::Fixpoint => { + panic!("cannot mix `cycle_fn` and `cycle_result` in cycles") + } + CycleRecoveryStrategy::FallbackImmediate => {} + }, + } } // Relaxed is sufficient here because there are no other writes we need to ensure have // happened before marking this memo as verified-final. diff --git a/src/function/memo.rs b/src/function/memo.rs index eaa315cb2..bfd50005e 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -5,7 +5,7 @@ use std::fmt::{Debug, Formatter}; use std::ptr::NonNull; use std::sync::atomic::Ordering; -use crate::cycle::{CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}; +use crate::cycle::{CycleHeadKind, CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}; use crate::function::{Configuration, IngredientImpl}; use crate::key::DatabaseKeyIndex; use crate::revision::AtomicRevision; @@ -113,7 +113,9 @@ impl IngredientImpl { key: Id, ) -> Option> { match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))), + CycleRecoveryStrategy::Fixpoint | CycleRecoveryStrategy::FallbackImmediate => { + Some(C::cycle_initial(db, C::id_to_input(db, key))) + } CycleRecoveryStrategy::Panic => None, } } @@ -198,7 +200,11 @@ impl Memo { .any(|head| { let head_index = head.database_key_index; let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index()); - if !ingredient.is_provisional_cycle_head(db, head_index.key_index()) { + let cycle_head_kind = ingredient.cycle_head_kind(db, head_index.key_index()); + if matches!( + cycle_head_kind, + CycleHeadKind::NotProvisional | CycleHeadKind::FallbackImmediate + ) { // This cycle is already finalized, so we don't need to wait on it; // keep looping through cycle heads. retry = true; diff --git a/src/ingredient.rs b/src/ingredient.rs index a62e660de..ee90bce21 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -2,7 +2,7 @@ use std::any::{Any, TypeId}; use std::fmt; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; -use crate::cycle::CycleRecoveryStrategy; +use crate::cycle::{CycleHeadKind, CycleRecoveryStrategy}; use crate::function::VerifyResult; use crate::plumbing::IngredientIndices; use crate::table::Table; @@ -61,9 +61,9 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// /// In the case of nested cycles, we are not asking here whether the value is provisional due /// to the outer cycle being unresolved, only whether its own cycle remains provisional. - fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { + fn cycle_head_kind<'db>(&'db self, db: &'db dyn Database, input: Id) -> CycleHeadKind { _ = (db, input); - false + CycleHeadKind::NotProvisional } /// Invoked when the current thread needs to wait for a result for the given `key_index`. diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 17aef7044..3e54574ef 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -54,10 +54,6 @@ where VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) } - fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { - false - } - fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { true } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index dc835daad..1b03e8cc6 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -750,10 +750,6 @@ where VerifyResult::changed_if(data.created_at > revision) } - fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { - false - } - fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { true } diff --git a/tests/cycle_fallback_immediate.rs b/tests/cycle_fallback_immediate.rs new file mode 100644 index 000000000..b22767202 --- /dev/null +++ b/tests/cycle_fallback_immediate.rs @@ -0,0 +1,49 @@ +//! It is possible to omit the `cycle_fn`, only specifying `cycle_result` in which case +//! an immediate fallback value is used as the cycle handling opposed to doing a fixpoint resolution. + +use std::sync::atomic::{AtomicI32, Ordering}; + +#[salsa::tracked(cycle_result=cycle_result)] +fn one_o_one(db: &dyn salsa::Database) -> u32 { + let val = one_o_one(db); + val + 1 +} + +fn cycle_result(_db: &dyn salsa::Database) -> u32 { + 100 +} + +#[test_log::test] +fn simple() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(one_o_one(&db), 100); +} + +#[salsa::tracked(cycle_result=two_queries_cycle_result)] +fn two_queries1(db: &dyn salsa::Database) -> i32 { + two_queries2(db); + 0 +} + +#[salsa::tracked] +fn two_queries2(db: &dyn salsa::Database) -> i32 { + two_queries1(db); + // This is horribly against Salsa's rules, but we want to test that + // the value from within the cycle is not considered, and this is + // the only way I found. + static CALLS_COUNT: AtomicI32 = AtomicI32::new(0); + CALLS_COUNT.fetch_add(1, Ordering::Relaxed) +} + +fn two_queries_cycle_result(_db: &dyn salsa::Database) -> i32 { + 1 +} + +#[test] +fn two_queries() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(two_queries1(&db), 1); + assert_eq!(two_queries2(&db), 1); +} diff --git a/tests/cycle_result_dependencies.rs b/tests/cycle_result_dependencies.rs new file mode 100644 index 000000000..e7071a029 --- /dev/null +++ b/tests/cycle_result_dependencies.rs @@ -0,0 +1,25 @@ +use salsa::{Database, Setter}; + +#[salsa::input] +struct Input { + value: i32, +} + +#[salsa::tracked(cycle_result=cycle_result)] +fn has_cycle(db: &dyn Database, input: Input) -> i32 { + has_cycle(db, input) +} + +fn cycle_result(db: &dyn Database, input: Input) -> i32 { + input.value(db) +} + +#[test] +fn cycle_result_dependencies_are_recorded() { + let mut db = salsa::DatabaseImpl::default(); + let input = Input::new(&db, 123); + assert_eq!(has_cycle(&db, input), 123); + + input.set_value(&mut db).to(456); + assert_eq!(has_cycle(&db, input), 456); +} diff --git a/tests/parallel/cycle_a_t1_b_t2_fallback.rs b/tests/parallel/cycle_a_t1_b_t2_fallback.rs new file mode 100644 index 000000000..faa4c39f4 --- /dev/null +++ b/tests/parallel/cycle_a_t1_b_t2_fallback.rs @@ -0,0 +1,67 @@ +//! Test a specific cycle scenario: +//! +//! ```text +//! Thread T1 Thread T2 +//! --------- --------- +//! | | +//! v | +//! query_a() | +//! ^ | v +//! | +------------> query_b() +//! | | +//! +--------------------+ +//! ``` +use crate::setup::{Knobs, KnobsDatabase}; + +const FALLBACK_A: u32 = 0b01; +const FALLBACK_B: u32 = 0b10; +const OFFSET_A: u32 = 0b0100; +const OFFSET_B: u32 = 0b1000; + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` + +#[salsa::tracked(cycle_result=cycle_result_a)] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + db.signal(1); + + // Wait for Thread T2 to enter `query_b` before we continue. + db.wait_for(2); + + query_b(db) | OFFSET_A +} + +#[salsa::tracked(cycle_result=cycle_result_b)] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + // Wait for Thread T1 to enter `query_a` before we continue. + db.wait_for(1); + + db.signal(2); + + query_a(db) | OFFSET_B +} + +fn cycle_result_a(_db: &dyn KnobsDatabase) -> u32 { + FALLBACK_A +} + +fn cycle_result_b(_db: &dyn KnobsDatabase) -> u32 { + FALLBACK_B +} + +#[test_log::test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_b(&db_t2)); + + let (r_t1, r_t2) = (t1.join(), t2.join()); + + assert_eq!((r_t1?, r_t2?), (FALLBACK_A, FALLBACK_B)); + Ok(()) + }) + .unwrap_or_else(|e| std::panic::resume_unwind(e)); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 613e43e1d..7507a71cf 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,6 +1,7 @@ mod setup; mod cycle_a_t1_b_t2; +mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; mod cycle_nested_three_threads; mod cycle_panic;