diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs index 17ea50b04..bb41e61b3 100644 --- a/components/salsa-macro-rules/src/setup_accumulator_impl.rs +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -23,13 +23,12 @@ macro_rules! setup_accumulator_impl { // Suppress the lint against `cfg(loom)`. #[allow(unexpected_cfgs)] - fn $ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl<$Struct> { + fn $ingredient(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl<$Struct> { $zalsa::__maybe_lazy_static! { static $CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Struct>> = $zalsa::IngredientCache::new(); } - let zalsa = db.zalsa(); $CACHE.get_or_create(zalsa, || { zalsa.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>() }) @@ -42,8 +41,8 @@ macro_rules! setup_accumulator_impl { where Db: ?Sized + $zalsa::Database, { - let db = db.as_dyn_database(); - $ingredient(db).push(db, self); + let (zalsa, zalsa_local) = db.zalsas(); + $ingredient(zalsa).push(zalsa_local, self); } } }; diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 7ac8005c8..72e18343e 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -91,15 +91,18 @@ macro_rules! setup_input_struct { } impl $Configuration { + pub fn ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl { + Self::ingredient_(db.zalsa()) + } + // Suppress the lint against `cfg(loom)`. #[allow(unexpected_cfgs)] - pub fn ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl { + fn ingredient_(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl { zalsa_::__maybe_lazy_static! { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); } - let zalsa = db.zalsa(); CACHE.get_or_create(zalsa, || { zalsa.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) @@ -184,7 +187,7 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient(db.as_dyn_database()).field( + let fields = $Configuration::ingredient_(db.zalsa()).field( db.as_dyn_database(), self, $field_index, @@ -221,7 +224,8 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, { - $Configuration::ingredient(db.as_dyn_database()).get_singleton_input(db) + let zalsa = db.zalsa(); + $Configuration::ingredient_(zalsa).get_singleton_input(zalsa) } #[track_caller] @@ -271,8 +275,9 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database { - let current_revision = $zalsa::current_revision(db); - let ingredient = $Configuration::ingredient(db.as_dyn_database()); + let zalsa = db.zalsa(); + let current_revision = zalsa.current_revision(); + let ingredient = $Configuration::ingredient_(zalsa); let (fields, stamps) = builder::builder_into_inner(self, current_revision); ingredient.new_input(db.as_dyn_database(), fields, stamps) } diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 55f7f80d2..6b2139963 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -213,7 +213,7 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } - fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { + fn cycle_initial<$db_lt>(db: &$db_lt Self::DbView, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { $($cycle_recovery_initial)*(db, $($input_id),*) } @@ -231,7 +231,7 @@ macro_rules! setup_tracked_fn { if $needs_interner { $Configuration::intern_ingredient(db).data(db.as_dyn_database(), key).clone() } else { - $zalsa::FromIdWithDb::from_id(key, db) + $zalsa::FromIdWithDb::from_id(key, db.zalsa()) } } } diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 24f691761..34c740a02 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -169,15 +169,18 @@ macro_rules! setup_tracked_struct { } impl $Configuration { + pub fn ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl { + Self::ingredient_(db.zalsa()) + } + // Suppress the lint against `cfg(loom)`. #[allow(unexpected_cfgs)] - pub fn ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl<$Configuration> { + fn ingredient_(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl { $zalsa::__maybe_lazy_static! { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); } - let zalsa = db.zalsa(); CACHE.get_or_create(zalsa, || { zalsa.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) @@ -216,8 +219,8 @@ macro_rules! setup_tracked_struct { } impl $zalsa::TrackedStructInDb for $Struct<'_> { - fn database_key_index(db: &dyn $zalsa::Database, id: $zalsa::Id) -> $zalsa::DatabaseKeyIndex { - $Configuration::ingredient(db).database_key_index(id) + fn database_key_index(zalsa: &$zalsa::Zalsa, id: $zalsa::Id) -> $zalsa::DatabaseKeyIndex { + $Configuration::ingredient_(zalsa).database_key_index(id) } } diff --git a/components/salsa-macros/src/supertype.rs b/components/salsa-macros/src/supertype.rs index 4ffcf6c4d..5b433bd86 100644 --- a/components/salsa-macros/src/supertype.rs +++ b/components/salsa-macros/src/supertype.rs @@ -59,9 +59,8 @@ fn enum_impl(enum_item: syn::ItemEnum) -> syn::Result { impl #impl_generics zalsa::FromIdWithDb for #enum_name #type_generics #where_clause { #[inline] - fn from_id(__id: zalsa::Id, __db: &(impl ?Sized + zalsa::Database)) -> Self { - let __zalsa = __db.zalsa(); - let __type_id = __zalsa.lookup_page_type_id(__id); + fn from_id(__id: zalsa::Id, zalsa: &zalsa::Zalsa) -> Self { + let __type_id = zalsa.lookup_page_type_id(__id); ::cast(__id, __type_id).expect("invalid enum variant") } } diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 924205c2f..63cc4fe12 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -1,14 +1,40 @@ +#[cfg(test)] use std::sync::{Arc, Mutex}; // ANCHOR: db_struct #[salsa::db] -#[derive(Default, Clone)] +#[derive(Clone)] +#[cfg_attr(not(test), derive(Default))] pub struct CalcDatabaseImpl { storage: salsa::Storage, // The logs are only used for testing and demonstrating reuse: + #[cfg(test)] logs: Arc>>>, } + +#[cfg(test)] +impl Default for CalcDatabaseImpl { + fn default() -> Self { + let logs = >>>>::default(); + Self { + storage: salsa::Storage::new(Some(Box::new({ + let logs = logs.clone(); + move |event| { + eprintln!("Event: {event:?}"); + // Log interesting events, if logging is enabled + if let Some(logs) = &mut *logs.lock().unwrap() { + // only log interesting events + if let salsa::EventKind::WillExecute { .. } = event.kind { + logs.push(format!("Event: {event:?}")); + } + } + } + }))), + logs, + } + } +} // ANCHOR_END: db_struct impl CalcDatabaseImpl { @@ -34,17 +60,5 @@ impl CalcDatabaseImpl { // ANCHOR: db_impl #[salsa::db] -impl salsa::Database for CalcDatabaseImpl { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - eprintln!("Event: {event:?}"); - // Log interesting events, if logging is enabled - if let Some(logs) = &mut *self.logs.lock().unwrap() { - // only log interesting events - if let salsa::EventKind::WillExecute { .. } = event.kind { - logs.push(format!("Event: {event:?}")); - } - } - } -} +impl salsa::Database for CalcDatabaseImpl {} // ANCHOR_END: db_impl diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index aa18bc775..7183eb2ac 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -87,9 +87,18 @@ struct LazyInputDatabase { impl LazyInputDatabase { fn new(tx: Sender) -> Self { + let logs: Arc>> = Default::default(); Self { - storage: Default::default(), - logs: Default::default(), + storage: Storage::new(Some(Box::new({ + let logs = logs.clone(); + move |event| { + // don't log boring events + if let salsa::EventKind::WillExecute { .. } = event.kind { + logs.lock().unwrap().push(format!("{event:?}")); + } + } + }))), + logs, files: DashMap::new(), file_watcher: Arc::new(Mutex::new( new_debouncer(Duration::from_secs(1), tx).unwrap(), @@ -99,15 +108,7 @@ impl LazyInputDatabase { } #[salsa::db] -impl salsa::Database for LazyInputDatabase { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - // don't log boring events - let event = event(); - if let salsa::EventKind::WillExecute { .. } = event.kind { - self.logs.lock().unwrap().push(format!("{event:?}")); - } - } -} +impl salsa::Database for LazyInputDatabase {} #[salsa::db] impl Db for LazyInputDatabase { diff --git a/src/accumulator.rs b/src/accumulator.rs index 4bf6211f9..1e0c88d79 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,7 +10,7 @@ use accumulated::{Accumulated, AnyAccumulated}; use crate::function::VerifyResult; use crate::ingredient::{Ingredient, Jar}; use crate::loom::sync::Arc; -use crate::plumbing::IngredientIndices; +use crate::plumbing::{IngredientIndices, ZalsaLocal}; use crate::table::memo::MemoTableTypes; use crate::zalsa::{IngredientIndex, Zalsa}; use crate::{Database, Id, Revision}; @@ -62,11 +62,7 @@ pub struct IngredientImpl { impl IngredientImpl { /// Find the accumulator ingredient for `A` in the database, if any. - pub fn from_db(db: &Db) -> Option<&Self> - where - Db: ?Sized + Database, - { - let zalsa = db.zalsa(); + pub fn from_zalsa(zalsa: &Zalsa) -> Option<&Self> { let index = zalsa.add_or_lookup_jar_by_type::>(); let ingredient = zalsa.lookup_ingredient(index).assert_type::(); Some(ingredient) @@ -79,8 +75,7 @@ impl IngredientImpl { } } - pub fn push(&self, db: &dyn Database, value: A) { - let zalsa_local = db.zalsa_local(); + pub fn push(&self, zalsa_local: &ZalsaLocal, value: A) { if let Err(()) = zalsa_local.accumulate(self.index, value) { panic!("cannot accumulate values outside of an active tracked function"); } diff --git a/src/database.rs b/src/database.rs index 3276dc4f7..72204a582 100644 --- a/src/database.rs +++ b/src/database.rs @@ -2,20 +2,11 @@ use std::any::Any; use std::borrow::Cow; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; -use crate::{Durability, Event, Revision}; +use crate::{Durability, Revision}; /// The trait implemented by all Salsa databases. /// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro. pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { - /// This function is invoked by the salsa runtime at various points during execution. - /// You can customize what happens by implementing the [`UserData`][] trait. - /// By default, the event is logged at level debug using tracing facade. - /// - /// # Parameters - /// - /// * `event`, a fn that, if called, will create the event that occurred - fn salsa_event(&self, event: &dyn Fn() -> Event); - /// Enforces current LRU limits, evicting entries if necessary. /// /// **WARNING:** Just like an ordinary write, this method triggers @@ -24,7 +15,6 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { /// is owned by the current thread, this could trigger deadlock. fn trigger_lru_eviction(&mut self) { let zalsa_mut = self.zalsa_mut(); - zalsa_mut.runtime_mut().reset_cancellation_flag(); zalsa_mut.evict_lru(); } @@ -77,7 +67,8 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { /// `salsa_event` is emitted when this method is called, so that should be /// used instead. fn unwind_if_revision_cancelled(&self) { - self.zalsa().unwind_if_revision_cancelled(self); + let (zalsa, zalsa_local) = self.zalsas(); + zalsa.unwind_if_revision_cancelled(zalsa_local); } /// Execute `op` with the database in thread-local storage for debug print-outs. diff --git a/src/database_impl.rs b/src/database_impl.rs index 215a4aadf..c1eda125a 100644 --- a/src/database_impl.rs +++ b/src/database_impl.rs @@ -1,13 +1,30 @@ +use tracing::Level; + use crate::storage::HasStorage; -use crate::{Database, Event, Storage}; +use crate::{Database, Storage}; /// Default database implementation that you can use if you don't /// require any custom user data. -#[derive(Default, Clone)] +#[derive(Clone)] pub struct DatabaseImpl { storage: Storage, } +impl Default for DatabaseImpl { + fn default() -> Self { + Self { + // Default behavior: tracing debug log the event. + storage: Storage::new(if tracing::enabled!(Level::DEBUG) { + Some(Box::new(|event| { + tracing::debug!("salsa_event({:?})", event) + })) + } else { + None + }), + } + } +} + impl DatabaseImpl { /// Create a new database; equivalent to `Self::default`. pub fn new() -> Self { @@ -19,13 +36,7 @@ impl DatabaseImpl { } } -impl Database for DatabaseImpl { - /// Default behavior: tracing debug log the event. - #[inline(always)] - fn salsa_event(&self, event: &dyn Fn() -> Event) { - tracing::debug!("salsa_event({:?})", event()); - } -} +impl Database for DatabaseImpl {} // SAFETY: The `storage` and `storage_mut` fields return a reference to the same storage field owned by `self`. unsafe impl HasStorage for DatabaseImpl { diff --git a/src/function.rs b/src/function.rs index 5761f0848..0fe78d167 100644 --- a/src/function.rs +++ b/src/function.rs @@ -64,6 +64,7 @@ pub trait Configuration: Any { /// This invokes user code in form of the `Eq` impl. fn values_equal<'db>(old_value: &Self::Output<'db>, new_value: &Self::Output<'db>) -> bool; + // FIXME: This should take a `&Zalsa` /// Convert from the id used internally to the value that execute is expecting. /// This is a no-op if the input to the function is a salsa struct. fn id_to_input(db: &Self::DbView, key: Id) -> Self::Input<'_>; @@ -241,8 +242,7 @@ where /// True if the input `input` contains a memo that cites itself as a cycle head. /// This indicates an intermediate value for a cycle that has not yet reached a fixed point. - fn cycle_head_kind<'db>(&'db self, db: &'db dyn Database, input: Id) -> CycleHeadKind { - let zalsa = db.zalsa(); + fn cycle_head_kind(&self, zalsa: &Zalsa, input: Id) -> CycleHeadKind { let is_provisional = self .get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) .is_some_and(|memo| { @@ -260,30 +260,29 @@ where } /// Attempts to claim `key_index`, returning `false` if a cycle occurs. - fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool { - let zalsa = db.zalsa(); - match self.sync_table.try_claim(db, zalsa, key_index) { + fn wait_for(&self, zalsa: &Zalsa, key_index: Id) -> bool { + match self.sync_table.try_claim(zalsa, key_index) { ClaimResult::Retry | ClaimResult::Claimed(_) => true, ClaimResult::Cycle => false, } } - fn origin(&self, db: &dyn Database, key: Id) -> Option { - self.origin(db.zalsa(), key) + fn origin(&self, zalsa: &Zalsa, key: Id) -> Option { + self.origin(zalsa, key) } fn mark_validated_output( &self, - db: &dyn Database, + zalsa: &Zalsa, executor: DatabaseKeyIndex, output_key: crate::Id, ) { - self.validate_specified_value(db, executor, output_key); + self.validate_specified_value(zalsa, executor, output_key); } fn remove_stale_output( &self, - _db: &dyn Database, + _zalsa: &Zalsa, _executor: DatabaseKeyIndex, _stale_output_key: crate::Id, ) { diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 8948b522d..216cdd8b5 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -31,7 +31,7 @@ where // are read from outside of salsa anyway so this is not a big deal. zalsa_local.report_untracked_read(zalsa.current_revision()); - let Some(accumulator) = >::from_db(db) else { + let Some(accumulator) = >::from_zalsa(zalsa) else { return vec![]; }; let mut output = vec![]; @@ -69,7 +69,7 @@ where // output vector, we want to push in execution order, so reverse order to // ensure the first child that was executed will be the first child popped // from the stack. - let Some(origin) = ingredient.origin(db, k.key_index()) else { + let Some(origin) = ingredient.origin(zalsa, k.key_index()) else { continue; }; diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 6bc47b5fd..b00dc4143 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -3,7 +3,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::hash::FxIndexSet; use crate::zalsa::Zalsa; use crate::zalsa_local::QueryRevisions; -use crate::{AsDynDatabase as _, Database, DatabaseKeyIndex, Event, EventKind}; +use crate::{DatabaseKeyIndex, Event, EventKind}; impl IngredientImpl where @@ -20,7 +20,6 @@ where pub(super) fn diff_outputs( &self, zalsa: &Zalsa, - db: &C::DbView, key: DatabaseKeyIndex, old_memo: &Memo>, revisions: &mut QueryRevisions, @@ -49,22 +48,17 @@ where }); for old_output in old_outputs { - Self::report_stale_output(zalsa, db, key, old_output); + Self::report_stale_output(zalsa, key, old_output); } } - fn report_stale_output( - zalsa: &Zalsa, - db: &C::DbView, - key: DatabaseKeyIndex, - output: DatabaseKeyIndex, - ) { - db.salsa_event(&|| { + fn report_stale_output(zalsa: &Zalsa, key: DatabaseKeyIndex, output: DatabaseKeyIndex) { + zalsa.event(&|| { Event::new(EventKind::WillDiscardStaleOutput { execute_key: key, output_key: output, }) }); - output.remove_stale_output(zalsa, db.as_dyn_database(), key); + output.remove_stale_output(zalsa, key); } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 65668e8b6..bcb43a8ce 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -4,7 +4,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::loom::sync::atomic::{AtomicBool, Ordering}; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; -use crate::{Database, Event, EventKind, Id, Revision}; +use crate::{Event, EventKind, Id, Revision}; impl IngredientImpl where @@ -30,14 +30,13 @@ where let id = database_key_index.key_index(); tracing::info!("{:?}: executing query", database_key_index); + let zalsa = db.zalsa(); - db.salsa_event(&|| { + zalsa.event(&|| { Event::new(EventKind::WillExecute { database_key: database_key_index, }) }); - - let zalsa = db.zalsa(); let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); let (new_value, mut revisions) = match C::CYCLE_STRATEGY { @@ -104,7 +103,7 @@ where // Diff the new outputs with the old, to discard any no-longer-emitted // outputs and update the tracked struct IDs for seeding the next revision. - self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); + self.diff_outputs(zalsa, database_key_index, old_memo, &mut revisions); } self.insert_memo( zalsa, @@ -209,7 +208,7 @@ where if iteration_count > MAX_ITERATIONS { panic!("{database_key_index:?}: execute: too many cycle iterations"); } - db.salsa_event(&|| { + zalsa.event(&|| { Event::new(EventKind::WillIterateCycle { database_key: database_key_index, iteration_count, diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 2353dfe21..3519e6abb 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -13,7 +13,7 @@ where { pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> { let (zalsa, zalsa_local) = db.zalsas(); - zalsa.unwind_if_revision_cancelled(db); + zalsa.unwind_if_revision_cancelled(zalsa_local); let memo = self.refresh_memo(db, zalsa, id); // SAFETY: We just refreshed the memo so it is guaranteed to contain a value now. @@ -43,7 +43,7 @@ where let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); loop { if let Some(memo) = self - .fetch_hot(zalsa, db, id, memo_ingredient_index) + .fetch_hot(zalsa, id, memo_ingredient_index) .or_else(|| self.fetch_cold(zalsa, db, id, memo_ingredient_index)) { // If we get back a provisional cycle memo, and it's provisional on any cycle heads @@ -54,7 +54,7 @@ where // That is only correct for fixpoint cycles, though: `FallbackImmediate` cycles // never have provisional entries. if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate - || !memo.provisional_retry(db, zalsa, self.database_key_index(id)) + || !memo.provisional_retry(zalsa, self.database_key_index(id)) { return memo; } @@ -66,7 +66,6 @@ where fn fetch_hot<'db>( &'db self, zalsa: &'db Zalsa, - db: &'db C::DbView, id: Id, memo_ingredient_index: MemoIngredientIndex, ) -> Option<&'db Memo>> { @@ -76,10 +75,10 @@ where let database_key_index = self.database_key_index(id); - let shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo)?; + let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); - if !memo.may_be_provisional() { - self.update_shallow(db, zalsa, database_key_index, memo, shallow_update); + if can_shallow_update.yes() && !memo.may_be_provisional() { + self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); // SAFETY: memo is present in memo_map and we have verified that it is // still valid for the current revision. @@ -98,7 +97,7 @@ where memo_ingredient_index: MemoIngredientIndex, ) -> Option<&'db Memo>> { // Try to claim this query: if someone else has claimed it already, go back and start again. - let _claim_guard = match self.sync_table.try_claim(db, zalsa, id) { + let _claim_guard = match self.sync_table.try_claim(zalsa, id) { ClaimResult::Retry => return None, ClaimResult::Cycle => { let database_key_index = self.database_key_index(id); @@ -110,15 +109,14 @@ where if memo.value.is_some() && memo.revisions.cycle_heads.contains(&database_key_index) { - if let Some(shallow_update) = - self.shallow_verify_memo(zalsa, database_key_index, memo) - { + let can_shallow_update = + self.shallow_verify_memo(zalsa, database_key_index, memo); + if can_shallow_update.yes() { self.update_shallow( - db, zalsa, database_key_index, memo, - shallow_update, + can_shallow_update, ); // SAFETY: memo is present in memo_map. return unsafe { Some(self.extend_memo_lifetime(memo)) }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 7c1a56258..de222adc5 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -5,6 +5,7 @@ use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl}; use crate::key::DatabaseKeyIndex; use crate::loom::sync::atomic::Ordering; +use crate::plumbing::ZalsaLocal; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{QueryEdge, QueryOrigin}; use crate::{AsDynDatabase as _, Id, Revision}; @@ -73,9 +74,9 @@ where revision: Revision, in_cycle: bool, ) -> VerifyResult { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); - zalsa.unwind_if_revision_cancelled(db); + zalsa.unwind_if_revision_cancelled(zalsa_local); loop { let database_key_index = self.database_key_index(id); @@ -89,20 +90,18 @@ where return VerifyResult::changed(); }; - if let Some(shallow_update) = self.shallow_verify_memo(zalsa, database_key_index, memo) - { - if !memo.may_be_provisional() { - self.update_shallow(db, zalsa, database_key_index, memo, shallow_update); - - return if memo.revisions.changed_at > revision { - VerifyResult::changed() - } else { - VerifyResult::Unchanged( - memo.revisions.accumulated_inputs.load(), - CycleHeads::default(), - ) - }; - } + let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); + if can_shallow_update.yes() && !memo.may_be_provisional() { + self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); + + return if memo.revisions.changed_at > revision { + VerifyResult::changed() + } else { + VerifyResult::Unchanged( + memo.revisions.accumulated_inputs.load(), + CycleHeads::default(), + ) + }; } if let Some(mcs) = self.maybe_changed_after_cold( @@ -132,7 +131,7 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let _claim_guard = match self.sync_table.try_claim(db, zalsa, key_index) { + let _claim_guard = match self.sync_table.try_claim(zalsa, key_index) { ClaimResult::Retry => return None, ClaimResult::Cycle => match C::CYCLE_STRATEGY { CycleRecoveryStrategy::Panic => db.zalsa_local().with_query_stack(|stack| { @@ -227,7 +226,7 @@ where zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, - ) -> Option { + ) -> ShallowUpdate { tracing::debug!( "{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})", memo = memo.tracing_debug() @@ -237,7 +236,7 @@ where if verified_at == revision_now { // Already verified. - return Some(ShallowUpdate::Verified); + return ShallowUpdate::Verified; } let last_changed = zalsa.last_changed_revision(memo.revisions.durability); @@ -250,24 +249,23 @@ where ); if last_changed <= verified_at { // No input of the suitable durability has changed since last verified. - Some(ShallowUpdate::HigherDurability(revision_now)) + ShallowUpdate::HigherDurability(revision_now) } else { - None + ShallowUpdate::No } } #[inline] pub(super) fn update_shallow( &self, - db: &C::DbView, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, update: ShallowUpdate, ) { if let ShallowUpdate::HigherDurability(revision_now) = update { - memo.mark_as_verified(db, revision_now, database_key_index); - memo.mark_outputs_as_verified(zalsa, db.as_dyn_database(), database_key_index); + memo.mark_as_verified(zalsa, revision_now, database_key_index); + memo.mark_outputs_as_verified(zalsa, database_key_index); } } @@ -279,14 +277,14 @@ where #[inline] pub(super) fn validate_may_be_provisional( &self, - db: &C::DbView, zalsa: &Zalsa, + zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { !memo.may_be_provisional() - || self.validate_provisional(db, zalsa, database_key_index, memo) - || self.validate_same_iteration(db, database_key_index, memo) + || self.validate_provisional(zalsa, database_key_index, memo) + || self.validate_same_iteration(zalsa_local, database_key_index, memo) } /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and @@ -294,7 +292,6 @@ where #[inline] fn validate_provisional( &self, - db: &C::DbView, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, @@ -306,10 +303,7 @@ where for cycle_head in &memo.revisions.cycle_heads { let kind = zalsa .lookup_ingredient(cycle_head.database_key_index.ingredient_index()) - .cycle_head_kind( - db.as_dyn_database(), - cycle_head.database_key_index.key_index(), - ); + .cycle_head_kind(zalsa, cycle_head.database_key_index.key_index()); match kind { CycleHeadKind::Provisional => return false, CycleHeadKind::NotProvisional => { @@ -343,7 +337,7 @@ where /// runaway re-execution of the same queries within a fixpoint iteration. pub(super) fn validate_same_iteration( &self, - db: &C::DbView, + zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { @@ -358,7 +352,7 @@ where let cycle_heads = &memo.revisions.cycle_heads; - db.zalsa_local().with_query_stack(|stack| { + zalsa_local.with_query_stack(|stack| { cycle_heads.iter().all(|cycle_head| { stack.iter().rev().any(|query| { query.database_key_index == cycle_head.database_key_index @@ -387,14 +381,18 @@ where old_memo = old_memo.tracing_debug() ); - let shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo); - let shallow_update_possible = shallow_update.is_some(); - if let Some(shallow_update) = shallow_update { - if self.validate_may_be_provisional(db, zalsa, database_key_index, old_memo) { - self.update_shallow(db, zalsa, database_key_index, old_memo, shallow_update); - - return VerifyResult::unchanged(); - } + let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo); + if can_shallow_update.yes() + && self.validate_may_be_provisional( + zalsa, + db.zalsa_local(), + database_key_index, + old_memo, + ) + { + self.update_shallow(zalsa, database_key_index, old_memo, can_shallow_update); + + return VerifyResult::unchanged(); } match &old_memo.revisions.origin { @@ -429,7 +427,7 @@ where // If the value is from the same revision but is still provisional, consider it changed // because we're now in a new iteration. - if shallow_update_possible && is_provisional { + if can_shallow_update.yes() && is_provisional { return VerifyResult::changed(); } @@ -449,6 +447,7 @@ where QueryEdge::Input(dependency_index) => { match dependency_index.maybe_changed_after( dyn_db, + zalsa, last_verified_at, !cycle_heads.is_empty(), ) { @@ -481,11 +480,7 @@ where // by this function cannot be read until this function is marked green, // so even if we mark them as valid here, the function will re-execute // and overwrite the contents. - dependency_index.mark_validated_output( - zalsa, - dyn_db, - database_key_index, - ); + dependency_index.mark_validated_output(zalsa, database_key_index); } } } @@ -519,7 +514,11 @@ where let in_heads = cycle_heads.remove(&database_key_index); if cycle_heads.is_empty() { - old_memo.mark_as_verified(db, zalsa.current_revision(), database_key_index); + old_memo.mark_as_verified( + zalsa, + zalsa.current_revision(), + database_key_index, + ); old_memo.revisions.accumulated_inputs.store(inputs); if is_provisional { @@ -548,4 +547,16 @@ pub(super) enum ShallowUpdate { /// The revision for the memo's durability hasn't changed. It can be marked as verified /// in this revision. HigherDurability(Revision), + + /// The memo requires a deep verification. + No, +} + +impl ShallowUpdate { + pub(super) fn yes(&self) -> bool { + matches!( + self, + ShallowUpdate::Verified | ShallowUpdate::HigherDurability(_) + ) + } } diff --git a/src/function/memo.rs b/src/function/memo.rs index c29696aa7..bde07d9d9 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -143,7 +143,6 @@ impl Memo { #[inline(always)] pub(super) fn provisional_retry( &self, - db: &(impl crate::Database + ?Sized), zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, ) -> bool { @@ -153,30 +152,23 @@ impl Memo { if self.revisions.cycle_heads.is_empty() { return false; } - return provisional_retry_cold( - db.as_dyn_database(), - zalsa, - database_key_index, - &self.revisions.cycle_heads, - ); + return provisional_retry_cold(zalsa, database_key_index, &self.revisions.cycle_heads); #[inline(never)] fn provisional_retry_cold( - db: &dyn crate::Database, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, cycle_heads: &CycleHeads, ) -> bool { let mut retry = false; - let db = db.as_dyn_database(); let hit_cycle = cycle_heads .into_iter() .filter(|&head| head.database_key_index != database_key_index) .any(|head| { let head_index = head.database_key_index; let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index()); - let cycle_head_kind = ingredient.cycle_head_kind(db, head_index.key_index()); + let cycle_head_kind = ingredient.cycle_head_kind(zalsa, head_index.key_index()); if matches!( cycle_head_kind, CycleHeadKind::NotProvisional | CycleHeadKind::FallbackImmediate @@ -185,7 +177,7 @@ impl Memo { // keep looping through cycle heads. retry = true; false - } else if ingredient.wait_for(db, head_index.key_index()) { + } else if ingredient.wait_for(zalsa, head_index.key_index()) { // There's a new memo available for the cycle head; fetch our own // updated memo and see if it's still provisional or if the cycle // has resolved. @@ -224,16 +216,16 @@ impl Memo { /// Mark memo as having been verified in the `revision_now`, which should /// be the current revision. - /// The caller is responsible to update the memo's `accumulated` state if heir accumulated + /// The caller is responsible to update the memo's `accumulated` state if their accumulated /// values have changed since. #[inline] - pub(super) fn mark_as_verified( + pub(super) fn mark_as_verified( &self, - db: &Db, + zalsa: &Zalsa, revision_now: Revision, database_key_index: DatabaseKeyIndex, ) { - db.salsa_event(&|| { + zalsa.event(&|| { Event::new(EventKind::DidValidateMemoizedValue { database_key: database_key_index, }) @@ -245,11 +237,10 @@ impl Memo { pub(super) fn mark_outputs_as_verified( &self, zalsa: &Zalsa, - db: &dyn crate::Database, database_key_index: DatabaseKeyIndex, ) { for output in self.revisions.origin.outputs() { - output.mark_validated_output(zalsa, db, database_key_index); + output.mark_validated_output(zalsa, database_key_index); } } diff --git a/src/function/specify.rs b/src/function/specify.rs index dddd383dc..e341efd05 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -4,9 +4,9 @@ use crate::function::{Configuration, IngredientImpl}; use crate::loom::sync::atomic::AtomicBool; use crate::revision::AtomicRevision; use crate::tracked_struct::TrackedStructInDb; -use crate::zalsa::ZalsaDatabase; +use crate::zalsa::{Zalsa, ZalsaDatabase}; use crate::zalsa_local::{QueryOrigin, QueryRevisions}; -use crate::{AsDynDatabase as _, Database, DatabaseKeyIndex, Id}; +use crate::{DatabaseKeyIndex, Id}; impl IngredientImpl where @@ -37,7 +37,7 @@ where // * Q4 invokes Q2 and then Q1 // // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. - let database_key_index = >::database_key_index(db.as_dyn_database(), key); + let database_key_index = >::database_key_index(zalsa, key); if !zalsa_local.is_output_of_active_query(database_key_index) { panic!("can only use `specify` on salsa structs created during the current tracked fn"); } @@ -61,7 +61,7 @@ where // - a result that is verified in the current revision, because it was set, which will use the set value // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) - let revision = db.zalsa().current_revision(); + let revision = zalsa.current_revision(); let mut revisions = QueryRevisions { changed_at: current_deps.changed_at, durability: current_deps.durability, @@ -76,7 +76,7 @@ where let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { self.backdate_if_appropriate(old_memo, database_key_index, &mut revisions, &value); - self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); + self.diff_outputs(zalsa, database_key_index, old_memo, &mut revisions); } let memo = Memo { @@ -103,11 +103,10 @@ where /// it would have specified `key` again. pub(super) fn validate_specified_value( &self, - db: &dyn Database, + zalsa: &Zalsa, executor: DatabaseKeyIndex, key: Id, ) { - let zalsa = db.zalsa(); let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); let memo = match self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { @@ -126,7 +125,7 @@ where } let database_key_index = self.database_key_index(key); - memo.mark_as_verified(db, zalsa.current_revision(), database_key_index); + memo.mark_as_verified(zalsa, zalsa.current_revision(), database_key_index); memo.revisions .accumulated_inputs .store(InputAccumulatedValues::Empty); diff --git a/src/function/sync.rs b/src/function/sync.rs index 942b114b1..13c3285d2 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -5,7 +5,7 @@ use crate::loom::sync::Mutex; use crate::loom::thread::{self, ThreadId}; use crate::runtime::{BlockResult, WaitResult}; use crate::zalsa::Zalsa; -use crate::{Database, Id, IngredientIndex}; +use crate::{Id, IngredientIndex}; /// Tracks the keys that are currently being processed; used to coordinate between /// worker threads. @@ -36,12 +36,7 @@ impl SyncTable { } } - pub(crate) fn try_claim<'me>( - &'me self, - db: &'me (impl ?Sized + Database), - zalsa: &'me Zalsa, - key_index: Id, - ) -> ClaimResult<'me> { + pub(crate) fn try_claim<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> ClaimResult<'me> { let mut write = self.syncs.lock(); match write.entry(key_index) { std::collections::hash_map::Entry::Occupied(occupied_entry) => { @@ -57,7 +52,7 @@ impl SyncTable { // not to gate future atomic reads. *anyone_waiting = true; match zalsa.runtime().block_on( - db, + zalsa, DatabaseKeyIndex::new(self.ingredient, key_index), id, write, diff --git a/src/id.rs b/src/id.rs index 393397556..30c8ff95a 100644 --- a/src/id.rs +++ b/src/id.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use std::hash::Hash; use std::num::NonZeroU32; -use crate::Database; +use crate::zalsa::Zalsa; /// The `Id` of a salsa struct in the database [`Table`](`crate::table::Table`). /// @@ -85,12 +85,12 @@ impl FromId for Id { /// Enums cannot use [`FromId`] because they need access to the DB to tell the `TypeId` of the variant, /// so they use this trait instead, that has a blanket implementation for `FromId`. pub trait FromIdWithDb { - fn from_id(id: Id, db: &(impl ?Sized + Database)) -> Self; + fn from_id(id: Id, zalsa: &Zalsa) -> Self; } impl FromIdWithDb for T { #[inline] - fn from_id(id: Id, _db: &(impl ?Sized + Database)) -> Self { + fn from_id(id: Id, _zalsa: &Zalsa) -> Self { FromId::from_id(id) } } diff --git a/src/ingredient.rs b/src/ingredient.rs index a3f2a009e..33917a6f6 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -70,8 +70,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// /// In the case of nested cycles, we are not asking here whether the value is provisional due /// to the outer cycle being unresolved, only whether its own cycle remains provisional. - fn cycle_head_kind<'db>(&'db self, db: &'db dyn Database, input: Id) -> CycleHeadKind { - _ = (db, input); + fn cycle_head_kind(&self, zalsa: &Zalsa, input: Id) -> CycleHeadKind { + _ = (zalsa, input); CycleHeadKind::NotProvisional } @@ -80,21 +80,21 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// A return value of `true` indicates that a result is now available. A return value of /// `false` means that a cycle was encountered; the waited-on query is either already claimed /// by the current thread, or by a thread waiting on the current thread. - fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool { - _ = (db, key_index); + fn wait_for(&self, zalsa: &Zalsa, key_index: Id) -> bool { + _ = (zalsa, key_index); true } /// Invoked when the value `output_key` should be marked as valid in the current revision. /// This occurs because the value for `executor`, which generated it, was marked as valid /// in the current revision. - fn mark_validated_output<'db>( - &'db self, - db: &'db dyn Database, + fn mark_validated_output( + &self, + zalsa: &Zalsa, executor: DatabaseKeyIndex, output_key: crate::Id, ) { - let _ = (db, executor, output_key); + let _ = (zalsa, executor, output_key); unreachable!("only tracked struct and function ingredients can have validatable outputs") } @@ -102,13 +102,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// revision, but was NOT output in the current revision. /// /// This hook is used to clear out the stale value so others cannot read it. - fn remove_stale_output( - &self, - db: &dyn Database, - executor: DatabaseKeyIndex, - stale_output_key: Id, - ) { - let _ = (db, executor, stale_output_key); + fn remove_stale_output(&self, zalsa: &Zalsa, executor: DatabaseKeyIndex, stale_output_key: Id) { + let _ = (zalsa, executor, stale_output_key); unreachable!("only tracked struct ingredients can have stale outputs") } @@ -155,8 +150,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { } /// What were the inputs (if any) that were used to create the value at `key_index`. - fn origin(&self, db: &dyn Database, key_index: Id) -> Option { - let _ = (db, key_index); + fn origin(&self, zalsa: &Zalsa, key_index: Id) -> Option { + let _ = (zalsa, key_index); unreachable!("only function ingredients have origins") } diff --git a/src/input.rs b/src/input.rs index dc479eb4d..be7447b86 100644 --- a/src/input.rs +++ b/src/input.rs @@ -112,7 +112,7 @@ impl IngredientImpl { }) }); - FromIdWithDb::from_id(id, db) + FromIdWithDb::from_id(id, zalsa) } /// Change the value of the field `field_index` to a new value. @@ -152,13 +152,14 @@ impl IngredientImpl { } /// Get the singleton input previously created (if any). - pub fn get_singleton_input(&self, db: &(impl ?Sized + Database)) -> Option + #[doc(hidden)] + pub fn get_singleton_input(&self, zalsa: &Zalsa) -> Option where C: Configuration, { self.singleton .index() - .map(|id| FromIdWithDb::from_id(id, db)) + .map(|id| FromIdWithDb::from_id(id, zalsa)) } /// Access field of an input. diff --git a/src/input/input_field.rs b/src/input/input_field.rs index d747f591d..cafab7a50 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -61,10 +61,6 @@ where VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) } - fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { - true - } - fn fmt_index(&self, index: crate::Id, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!( fmt, diff --git a/src/interned.rs b/src/interned.rs index 76feef37f..6795c11ba 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -221,7 +221,7 @@ where // Sync the value's revision. if value.last_interned_at.load() < current_revision { value.last_interned_at.store(current_revision); - db.salsa_event(&|| { + zalsa.event(&|| { Event::new(EventKind::DidReinternValue { key: index, revision: current_revision, @@ -264,7 +264,7 @@ where // Sync the value's revision. if value.last_interned_at.load() < current_revision { value.last_interned_at.store(current_revision); - db.salsa_event(&|| { + zalsa.event(&|| { Event::new(EventKind::DidReinternValue { key: index, revision: current_revision, @@ -323,7 +323,7 @@ where let index = self.database_key_index(id); zalsa_local.report_tracked_read_simple(index, durability, value.first_interned_at); - db.salsa_event(&|| { + zalsa.event(&|| { Event::new(EventKind::DidInternValue { key: index, revision: current_revision, @@ -364,7 +364,8 @@ where } pub fn reset(&mut self, db: &mut dyn Database) { - db.zalsa_mut().new_revision(); + _ = db.zalsa_mut(); + // We can clear the key_map now that we have cancelled all other handles. self.key_map.clear(); } @@ -412,7 +413,7 @@ where value.last_interned_at.load(), )); - db.salsa_event(&|| { + zalsa.event(&|| { let index = self.database_key_index(input); Event::new(EventKind::DidReinternValue { diff --git a/src/key.rs b/src/key.rs index c3fa57ada..524c926ce 100644 --- a/src/key.rs +++ b/src/key.rs @@ -36,37 +36,32 @@ impl DatabaseKeyIndex { pub(crate) fn maybe_changed_after( &self, db: &dyn Database, + zalsa: &Zalsa, last_verified_at: crate::Revision, in_cycle: bool, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient unsafe { - db.zalsa() + zalsa .lookup_ingredient(self.ingredient_index) .maybe_changed_after(db, self.key_index, last_verified_at, in_cycle) } } - pub(crate) fn remove_stale_output( - &self, - zalsa: &Zalsa, - db: &dyn Database, - executor: DatabaseKeyIndex, - ) { + pub(crate) fn remove_stale_output(&self, zalsa: &Zalsa, executor: DatabaseKeyIndex) { zalsa .lookup_ingredient(self.ingredient_index) - .remove_stale_output(db, executor, self.key_index) + .remove_stale_output(zalsa, executor, self.key_index) } pub(crate) fn mark_validated_output( &self, zalsa: &Zalsa, - db: &dyn Database, database_key_index: DatabaseKeyIndex, ) { zalsa .lookup_ingredient(self.ingredient_index) - .mark_validated_output(db, database_key_index, self.key_index) + .mark_validated_output(zalsa, database_key_index, self.key_index) } } diff --git a/src/loom.rs b/src/loom.rs index d32d606b9..cc944685d 100644 --- a/src/loom.rs +++ b/src/loom.rs @@ -4,12 +4,18 @@ pub use loom::{cell, thread, thread_local}; /// A helper macro to work around the fact that most loom types are not `const` constructable. #[doc(hidden)] #[macro_export] +#[cfg(loom)] macro_rules! __maybe_lazy_static { (static $name:ident: $t:ty = $init:expr $(;)?) => { - #[cfg(loom)] loom::lazy_static! { static ref $name: $t = $init; } - - #[cfg(not(loom))] + }; +} +/// A helper macro to work around the fact that most loom types are not `const` constructable. +#[doc(hidden)] +#[macro_export] +#[cfg(not(loom))] +macro_rules! __maybe_lazy_static { + (static $name:ident: $t:ty = $init:expr $(;)?) => { static $name: $t = $init; }; } diff --git a/src/runtime.rs b/src/runtime.rs index 47f16daed..84e1bae96 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -5,7 +5,7 @@ use crate::loom::sync::atomic::{AtomicBool, Ordering}; use crate::loom::sync::{AtomicMut, Mutex}; use crate::loom::thread::{self, ThreadId}; use crate::table::Table; -use crate::{Cancelled, Database, Event, EventKind, Revision}; +use crate::{Cancelled, Event, EventKind, Revision}; mod dependency_graph; @@ -167,7 +167,7 @@ impl Runtime { /// cancelled, so this function will panic with a `Cancelled` value. pub(crate) fn block_on( &self, - db: &(impl Database + ?Sized), + zalsa: &crate::zalsa::Zalsa, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, @@ -179,7 +179,7 @@ impl Runtime { return BlockResult::Cycle; } - db.salsa_event(&|| { + zalsa.event(&|| { Event::new(EventKind::WillBlockOn { other_thread_id: other_id, database_key, diff --git a/src/storage.rs b/src/storage.rs index df0d989cd..ab2e4907e 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -36,8 +36,14 @@ impl Clone for StorageHandle { impl Default for StorageHandle { fn default() -> Self { + Self::new(None) + } +} + +impl StorageHandle { + pub fn new(event_callback: Option>) -> Self { Self { - zalsa_impl: Arc::new(Zalsa::new::()), + zalsa_impl: Arc::new(Zalsa::new::(event_callback)), coordinate: CoordinateDrop(Arc::new(Coordinate { clones: Mutex::new(1), cvar: Default::default(), @@ -45,9 +51,7 @@ impl Default for StorageHandle { phantom: PhantomData, } } -} -impl StorageHandle { pub fn into_storage(self) -> Storage { Storage { handle: self, @@ -96,14 +100,21 @@ impl RefUnwindSafe for Coordinate {} impl Default for Storage { fn default() -> Self { + Self::new(None) + } +} + +impl Storage { + /// Create a new database storage. + /// + /// The `event_callback` function is invoked by the salsa runtime at various points during execution. + pub fn new(event_callback: Option>) -> Self { Self { - handle: StorageHandle::default(), + handle: StorageHandle::new(event_callback), zalsa_local: ZalsaLocal::new(), } } -} -impl Storage { /// Convert this instance of [`Storage`] into a [`StorageHandle`]. /// /// This will discard the local state of this [`Storage`], thereby returning a value that @@ -131,15 +142,28 @@ impl Storage { /// same database! /// /// Needs to be paired with a call to `reset_cancellation_flag`. - fn cancel_others(&self, db: &Db) { + fn cancel_others(&mut self) -> &mut Zalsa { + debug_assert!( + self.zalsa_local + .try_with_query_stack(|stack| stack.is_empty()) + == Some(true), + "attempted to cancel within query computation, this is a deadlock" + ); self.handle.zalsa_impl.runtime().set_cancellation_flag(); - db.salsa_event(&|| Event::new(EventKind::DidSetCancellationFlag)); + self.handle + .zalsa_impl + .event(&|| Event::new(EventKind::DidSetCancellationFlag)); let mut clones = self.handle.coordinate.clones.lock(); while *clones != 1 { clones = self.handle.coordinate.cvar.wait(clones); } + // The ref count on the `Arc` should now be 1 + let zalsa = Arc::get_mut(&mut self.handle.zalsa_impl).unwrap(); + // cancellation is done, so reset the flag + zalsa.runtime_mut().reset_cancellation_flag(); + zalsa } // ANCHOR_END: cancel_other_workers } @@ -152,14 +176,7 @@ unsafe impl ZalsaDatabase for T { } fn zalsa_mut(&mut self) -> &mut Zalsa { - self.storage().cancel_others(self); - - let storage = self.storage_mut(); - // The ref count on the `Arc` should now be 1 - let zalsa = Arc::get_mut(&mut storage.handle.zalsa_impl).unwrap(); - // cancellation is done, so reset the flag - zalsa.runtime_mut().reset_cancellation_flag(); - zalsa + self.storage_mut().cancel_others() } #[inline(always)] diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index e383ca6d3..377021339 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -138,7 +138,7 @@ impl Jar for JarImpl { pub trait TrackedStructInDb: SalsaStructInDb { /// Converts the identifier for this tracked struct into a `DatabaseKeyIndex`. - fn database_key_index(db: &dyn Database, id: Id) -> DatabaseKeyIndex; + fn database_key_index(zalsa: &Zalsa, id: Id) -> DatabaseKeyIndex; } /// Created for each tracked struct. @@ -598,14 +598,13 @@ where /// Using this method on an entity id that MAY be used in the current revision will lead to /// unspecified results (but not UB). See [`InternedIngredient::delete_index`] for more /// discussion and important considerations. - pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id) { - db.salsa_event(&|| { + pub(crate) fn delete_entity(&self, zalsa: &Zalsa, id: Id) { + zalsa.event(&|| { Event::new(crate::EventKind::DidDiscard { key: self.database_key_index(id), }) }); - let zalsa = db.zalsa(); let current_revision = zalsa.current_revision(); let data_raw = Self::data_raw(zalsa.table(), id); @@ -657,10 +656,10 @@ where let executor = DatabaseKeyIndex::new(ingredient_index, id); - db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor })); + zalsa.event(&|| Event::new(EventKind::DidDiscard { key: executor })); for stale_output in memo.origin().outputs() { - stale_output.remove_stale_output(zalsa, db, executor); + stale_output.remove_stale_output(zalsa, executor); } }) }; @@ -770,13 +769,9 @@ where VerifyResult::changed_if(data.created_at > revision) } - fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { - true - } - - fn mark_validated_output<'db>( - &'db self, - _db: &'db dyn Database, + fn mark_validated_output( + &self, + _zalsa: &Zalsa, _executor: DatabaseKeyIndex, _output_key: crate::Id, ) { @@ -787,7 +782,7 @@ where fn remove_stale_output( &self, - db: &dyn Database, + zalsa: &Zalsa, _executor: DatabaseKeyIndex, stale_output_key: crate::Id, ) { @@ -795,7 +790,7 @@ where // `executor` creates a tracked struct `salsa_output_key`, // but it did not in the current revision. // In that case, we can delete `stale_output_key` and any data associated with it. - self.delete_entity(db, stale_output_key); + self.delete_entity(zalsa, stale_output_key); } fn debug_name(&self) -> &'static str { diff --git a/src/zalsa.rs b/src/zalsa.rs index d1756a697..78b2fc44b 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -149,6 +149,8 @@ pub struct Zalsa { /// The runtime for this particular salsa database handle. /// Each handle gets its own runtime, but the runtimes have shared state between them. runtime: Runtime, + + event_callback: Option>, } /// All fields on Zalsa are locked behind [`Mutex`]es and [`RwLock`]s and cannot enter @@ -159,7 +161,9 @@ pub struct Zalsa { impl RefUnwindSafe for Zalsa {} impl Zalsa { - pub(crate) fn new() -> Self { + pub(crate) fn new( + event_callback: Option>, + ) -> Self { Self { views_of: Views::new::(), nonce: NONCE.nonce(), @@ -169,6 +173,7 @@ impl Zalsa { ingredients_requiring_reset: boxcar::Vec::new(), runtime: Runtime::default(), memo_ingredient_indices: Default::default(), + event_callback, } } @@ -224,10 +229,10 @@ impl Zalsa { /// Cancellation will automatically be triggered by salsa on any query /// invocation. #[inline] - pub(crate) fn unwind_if_revision_cancelled(&self, db: &(impl Database + ?Sized)) { - db.salsa_event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation)); + pub(crate) fn unwind_if_revision_cancelled(&self, zalsa_local: &ZalsaLocal) { + self.event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation)); if self.runtime().load_cancellation_flag() { - db.zalsa_local().unwind_cancelled(self.current_revision()); + zalsa_local.unwind_cancelled(self.current_revision()); } } @@ -382,6 +387,13 @@ impl Zalsa { pub fn ingredient_index(&self, id: Id) -> IngredientIndex { self.table().ingredient_index(id) } + + #[inline(always)] + pub fn event(&self, event: &dyn Fn() -> crate::Event) { + if let Some(event_callback) = &self.event_callback { + event_callback(event()); + } + } } /// Caches a pointer to an ingredient in a database. diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index aa020512e..03d41851f 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -305,6 +305,7 @@ impl ZalsaLocal { #[cold] pub(crate) fn unwind_cancelled(&self, current_revision: Revision) { + // Why is this reporting an untracked read? We do not store the query revisions on unwind do we? self.report_untracked_read(current_revision); Cancelled::PendingWrite.throw(); } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 10e540c21..2852c04bd 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -15,6 +15,12 @@ pub struct Logger { logs: Arc>>, } +impl Logger { + pub fn push_log(&self, string: String) { + self.logs.lock().unwrap().push(string); + } +} + /// Trait implemented by databases that lets them log events. pub trait HasLogger { /// Return a reference to the logger from the database. @@ -63,25 +69,32 @@ impl HasLogger for LoggerDatabase { } #[salsa::db] -impl Database for LoggerDatabase { - fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} -} +impl Database for LoggerDatabase {} /// Database that provides logging and logs salsa events. #[salsa::db] -#[derive(Clone, Default)] +#[derive(Clone)] pub struct EventLoggerDatabase { storage: Storage, logger: Logger, } -#[salsa::db] -impl Database for EventLoggerDatabase { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - self.push_log(format!("{:?}", event().kind)); +impl Default for EventLoggerDatabase { + fn default() -> Self { + let logger = Logger::default(); + Self { + storage: Storage::new(Some(Box::new({ + let logger = logger.clone(); + move |event| logger.push_log(format!("{:?}", event.kind)) + }))), + logger, + } } } +#[salsa::db] +impl Database for EventLoggerDatabase {} + impl HasLogger for EventLoggerDatabase { fn logger(&self) -> &Logger { &self.logger @@ -89,26 +102,34 @@ impl HasLogger for EventLoggerDatabase { } #[salsa::db] -#[derive(Clone, Default)] +#[derive(Clone)] pub struct DiscardLoggerDatabase { storage: Storage, logger: Logger, } -#[salsa::db] -impl Database for DiscardLoggerDatabase { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - match event.kind { - salsa::EventKind::WillDiscardStaleOutput { .. } - | salsa::EventKind::DidDiscard { .. } => { - self.push_log(format!("salsa_event({:?})", event.kind)); - } - _ => {} +impl Default for DiscardLoggerDatabase { + fn default() -> Self { + let logger = Logger::default(); + Self { + storage: Storage::new(Some(Box::new({ + let logger = logger.clone(); + move |event| match event.kind { + salsa::EventKind::WillDiscardStaleOutput { .. } + | salsa::EventKind::DidDiscard { .. } => { + logger.push_log(format!("salsa_event({:?})", event.kind)); + } + _ => {} + } + }))), + logger, } } } +#[salsa::db] +impl Database for DiscardLoggerDatabase {} + impl HasLogger for DiscardLoggerDatabase { fn logger(&self) -> &Logger { &self.logger @@ -116,26 +137,32 @@ impl HasLogger for DiscardLoggerDatabase { } #[salsa::db] -#[derive(Clone, Default)] +#[derive(Clone)] pub struct ExecuteValidateLoggerDatabase { storage: Storage, logger: Logger, } -#[salsa::db] -impl Database for ExecuteValidateLoggerDatabase { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - match event.kind { - salsa::EventKind::WillExecute { .. } - | salsa::EventKind::WillIterateCycle { .. } - | salsa::EventKind::DidValidateMemoizedValue { .. } => { - self.push_log(format!("salsa_event({:?})", event.kind)); - } - _ => {} +impl Default for ExecuteValidateLoggerDatabase { + fn default() -> Self { + let logger = Logger::default(); + Self { + storage: Storage::new(Some(Box::new({ + let logger = logger.clone(); + move |event| match event.kind { + salsa::EventKind::WillExecute { .. } + | salsa::EventKind::WillIterateCycle { .. } + | salsa::EventKind::DidValidateMemoizedValue { .. } => { + logger.push_log(format!("salsa_event({:?})", event.kind)); + } + _ => {} + } + }))), + logger, } } } +impl Database for ExecuteValidateLoggerDatabase {} impl HasLogger for ExecuteValidateLoggerDatabase { fn logger(&self) -> &Logger { @@ -168,9 +195,7 @@ impl HasValue for DatabaseWithValue { } #[salsa::db] -impl Database for DatabaseWithValue { - fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} -} +impl Database for DatabaseWithValue {} impl DatabaseWithValue { pub fn new(value: u32) -> Self { diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 7c5d77ddc..11a0c0399 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -2,7 +2,7 @@ mod common; use common::{HasLogger, LogDatabase, Logger}; use expect_test::expect; -use salsa::Setter; +use salsa::{Setter, Storage}; #[salsa::tracked] struct Output<'db> { @@ -70,7 +70,7 @@ trait HasOptionInput { trait Db: HasOptionInput + salsa::Database {} #[salsa::db] -#[derive(Clone, Default)] +#[derive(Clone)] struct Database { storage: salsa::Storage, logger: Logger, @@ -83,6 +83,29 @@ impl HasLogger for Database { } } +impl Default for Database { + fn default() -> Self { + let logger = Logger::default(); + Self { + storage: Storage::new(Some(Box::new({ + let logger = logger.clone(); + move |event| match event.kind { + salsa::EventKind::WillExecute { .. } + | salsa::EventKind::DidValidateMemoizedValue { .. } => { + logger.push_log(format!("salsa_event({:?})", event.kind)); + } + salsa::EventKind::WillCheckCancellation => {} + _ => { + logger.push_log(format!("salsa_event({:?})", event.kind)); + } + } + }))), + logger, + input: Default::default(), + } + } +} + impl HasOptionInput for Database { fn get_input(&self) -> Option { self.input @@ -94,21 +117,7 @@ impl HasOptionInput for Database { } #[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - match event.kind { - salsa::EventKind::WillExecute { .. } - | salsa::EventKind::DidValidateMemoizedValue { .. } => { - self.push_log(format!("salsa_event({:?})", event.kind)); - } - salsa::EventKind::WillCheckCancellation => {} - _ => { - self.push_log(format!("salsa_event({:?})", event.kind)); - } - } - } -} +impl salsa::Database for Database {} #[salsa::db] impl Db for Database {} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 870e1025a..442467f86 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -1,7 +1,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use salsa::Database; +use salsa::{Database, Storage}; use crate::signal::Signal; @@ -20,7 +20,6 @@ pub(crate) trait KnobsDatabase: Database { /// behave on one specific thread. Note that this state is /// intentionally thread-local (apart from `signal`). #[salsa::db] -#[derive(Default)] pub(crate) struct Knobs { storage: salsa::Storage, @@ -29,10 +28,10 @@ pub(crate) struct Knobs { pub(crate) signal: Arc, /// When this database is about to block, send this signal. - signal_on_will_block: AtomicUsize, + signal_on_will_block: Arc, /// When this database has set the cancellation flag, send this signal. - signal_on_did_cancel: AtomicUsize, + signal_on_did_cancel: Arc, } impl Knobs { @@ -54,28 +53,43 @@ impl Clone for Knobs { Self { storage: self.storage.clone(), signal: self.signal.clone(), - signal_on_will_block: AtomicUsize::new(0), - signal_on_did_cancel: AtomicUsize::new(0), + signal_on_will_block: self.signal_on_will_block.clone(), + signal_on_did_cancel: self.signal_on_did_cancel.clone(), } } } -#[salsa::db] -impl salsa::Database for Knobs { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - match event.kind { - salsa::EventKind::WillBlockOn { .. } => { - self.signal(self.signal_on_will_block.load(Ordering::Acquire)); - } - salsa::EventKind::DidSetCancellationFlag => { - self.signal(self.signal_on_did_cancel.load(Ordering::Acquire)); - } - _ => {} +impl Default for Knobs { + fn default() -> Self { + let signal = >::default(); + let signal_on_will_block = Arc::new(AtomicUsize::new(0)); + let signal_on_did_cancel = Arc::new(AtomicUsize::new(0)); + + Self { + storage: Storage::new(Some(Box::new({ + let signal = signal.clone(); + let signal_on_will_block = signal_on_will_block.clone(); + let signal_on_did_cancel = signal_on_did_cancel.clone(); + move |event| match event.kind { + salsa::EventKind::WillBlockOn { .. } => { + signal.signal(signal_on_will_block.load(Ordering::Acquire)); + } + salsa::EventKind::DidSetCancellationFlag => { + signal.signal(signal_on_did_cancel.load(Ordering::Acquire)); + } + _ => {} + } + }))), + signal, + signal_on_will_block, + signal_on_did_cancel, } } } +#[salsa::db] +impl salsa::Database for Knobs {} + #[salsa::db] impl KnobsDatabase for Knobs { fn signal(&self, stage: usize) { diff --git a/tests/tracked_struct_durability.rs b/tests/tracked_struct_durability.rs index a8c1706fc..7dfd87284 100644 --- a/tests/tracked_struct_durability.rs +++ b/tests/tracked_struct_durability.rs @@ -94,9 +94,7 @@ fn execute() { } #[salsa::db] - impl salsa::Database for Database { - fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} - } + impl salsa::Database for Database {} #[salsa::db] impl Db for Database {