Skip to content

Commit af69cc1

Browse files
authored
Do not re-verify already verified memoized value in cycle verification (#851)
1 parent 00dbf01 commit af69cc1

File tree

3 files changed

+37
-38
lines changed

3 files changed

+37
-38
lines changed

src/function/maybe_changed_after.rs

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ pub enum VerifyResult {
1919
///
2020
/// The inner value tracks whether the memo or any of its dependencies have an
2121
/// accumulated value.
22-
///
23-
/// Don't mark memos verified until we've iterated the full cycle to ensure no inputs changed
24-
/// when encountering this variant.
2522
Unchanged(InputAccumulatedValues),
2623
}
2724

@@ -37,10 +34,6 @@ impl VerifyResult {
3734
pub(crate) fn unchanged() -> Self {
3835
Self::Unchanged(InputAccumulatedValues::Empty)
3936
}
40-
41-
pub(crate) const fn is_unchanged(&self) -> bool {
42-
matches!(self, Self::Unchanged(_))
43-
}
4437
}
4538

4639
impl<C> IngredientImpl<C>
@@ -146,11 +139,11 @@ where
146139
// Check if the inputs are still valid. We can just compare `changed_at`.
147140
let deep_verify =
148141
self.deep_verify_memo(db, zalsa, old_memo, database_key_index, cycle_heads);
149-
if deep_verify.is_unchanged() {
142+
if let VerifyResult::Unchanged(accumulated_inputs) = deep_verify {
150143
return Some(if old_memo.revisions.changed_at > revision {
151144
VerifyResult::Changed
152145
} else {
153-
VerifyResult::Unchanged(old_memo.revisions.accumulated_inputs.load())
146+
VerifyResult::Unchanged(accumulated_inputs)
154147
});
155148
}
156149

@@ -316,18 +309,18 @@ where
316309
memo = memo.tracing_debug()
317310
);
318311

319-
if memo.revisions.cycle_heads.is_empty() {
312+
let cycle_heads = &memo.revisions.cycle_heads;
313+
if cycle_heads.is_empty() {
320314
return true;
321315
}
322316

323-
let cycle_heads = &memo.revisions.cycle_heads;
324-
325317
zalsa_local.with_query_stack(|stack| {
326318
cycle_heads.iter().all(|cycle_head| {
327-
stack.iter().rev().any(|query| {
328-
query.database_key_index == cycle_head.database_key_index
329-
&& query.iteration_count() == cycle_head.iteration_count
330-
})
319+
stack
320+
.iter()
321+
.rev()
322+
.find(|query| query.database_key_index == cycle_head.database_key_index)
323+
.is_some_and(|query| query.iteration_count() == cycle_head.iteration_count)
331324
})
332325
})
333326
}
@@ -402,16 +395,18 @@ where
402395
return VerifyResult::Changed;
403396
}
404397

398+
let dyn_db = db.as_dyn_database();
399+
400+
let mut last_verified_at = old_memo.verified_at.load();
401+
let mut first_iteration = true;
405402
'cycle: loop {
403+
let mut inputs = InputAccumulatedValues::Empty;
406404
// Fully tracked inputs? Iterate over the inputs and check them, one by one.
407405
//
408406
// NB: It's important here that we are iterating the inputs in the order that
409407
// they executed. It's possible that if the value of some input I0 is no longer
410408
// valid, then some later input I1 might never have executed at all, so verifying
411409
// it is still up to date is meaningless.
412-
let last_verified_at = old_memo.verified_at.load();
413-
let mut inputs = InputAccumulatedValues::Empty;
414-
let dyn_db = db.as_dyn_database();
415410
for &edge in edges.input_outputs.iter() {
416411
match edge {
417412
QueryEdge::Input(dependency_index) => {
@@ -421,9 +416,7 @@ where
421416
last_verified_at,
422417
cycle_heads,
423418
) {
424-
VerifyResult::Changed => {
425-
break 'cycle VerifyResult::Changed;
426-
}
419+
VerifyResult::Changed => break 'cycle VerifyResult::Changed,
427420
VerifyResult::Unchanged(input_accumulated) => {
428421
inputs |= input_accumulated;
429422
}
@@ -477,9 +470,17 @@ where
477470
// from cycle heads. We will handle our own memo (and the rest of our cycle) on a
478471
// future iteration; first the outer cycle head needs to verify itself.
479472

480-
let in_heads = cycle_heads.remove(&database_key_index);
473+
let was_in_heads = cycle_heads.remove(&database_key_index);
474+
let heads_non_empty = !cycle_heads.is_empty();
475+
if heads_non_empty {
476+
// case 2 / 4
477+
break 'cycle VerifyResult::Unchanged(inputs);
478+
} else if !first_iteration {
479+
// 3 (second loop turn)
480+
break 'cycle VerifyResult::Unchanged(inputs);
481+
} else {
482+
last_verified_at = zalsa.current_revision();
481483

482-
if cycle_heads.is_empty() {
483484
old_memo.mark_as_verified(zalsa, database_key_index);
484485
old_memo.revisions.accumulated_inputs.store(inputs);
485486

@@ -490,11 +491,15 @@ where
490491
.store(true, Ordering::Relaxed);
491492
}
492493

493-
if in_heads {
494+
if was_in_heads {
495+
first_iteration = false;
496+
// case 3
494497
continue 'cycle;
498+
} else {
499+
// case 1
500+
break 'cycle VerifyResult::Unchanged(inputs);
495501
}
496502
}
497-
break 'cycle VerifyResult::Unchanged(inputs);
498503
}
499504
}
500505
}

tests/cycle.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,6 @@ fn cycle_unchanged() {
882882
[
883883
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
884884
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
885-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
886885
]"#]]);
887886

888887
a.assert_value(&db, 45);
@@ -929,9 +928,7 @@ fn cycle_unchanged_nested() {
929928
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
930929
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
931930
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(4)) })",
932-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
933931
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
934-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
935932
]"#]]);
936933

937934
a.assert_value(&db, 45);
@@ -992,14 +989,12 @@ fn cycle_unchanged_nested_intertwined() {
992989
b.assert_value(&db, 60);
993990

994991
db.assert_logs(expect![[r#"
995-
[
996-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
997-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
998-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })",
999-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
1000-
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
1001-
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
1002-
]"#]]);
992+
[
993+
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
994+
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
995+
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })",
996+
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
997+
]"#]]);
1003998

1004999
a.assert_value(&db, 45);
10051000
}

tests/cycle_output.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ fn revalidate_no_changes() {
158158
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(402)) })",
159159
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })",
160160
"salsa_event(DidValidateMemoizedValue { database_key: query_a(Id(0)) })",
161-
"salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })",
162161
]"#]]);
163162
}
164163

0 commit comments

Comments
 (0)