diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index a042d2f95ee..1f5a2a44c2c 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -70,7 +70,7 @@ use crate::prelude::*; use core::{cmp, mem}; use core::cell::RefCell; use crate::io::Read; -use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock}; +use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock, LockTestExt, LockHeldState}; use core::sync::atomic::{AtomicUsize, Ordering}; use core::time::Duration; use core::ops::Deref; @@ -1218,13 +1218,10 @@ macro_rules! handle_error { match $internal { Ok(msg) => Ok(msg), Err(MsgHandleErrInternal { err, chan_id, shutdown_finish }) => { - #[cfg(any(feature = "_test_utils", test))] - { - // In testing, ensure there are no deadlocks where the lock is already held upon - // entering the macro. - debug_assert!($self.pending_events.try_lock().is_ok()); - debug_assert!($self.per_peer_state.try_write().is_ok()); - } + // In testing, ensure there are no deadlocks where the lock is already held upon + // entering the macro. + debug_assert_ne!($self.pending_events.held_by_thread(), LockHeldState::HeldByThread); + debug_assert_ne!($self.per_peer_state.held_by_thread(), LockHeldState::HeldByThread); let mut msg_events = Vec::with_capacity(2); @@ -3743,17 +3740,12 @@ where /// Fails an HTLC backwards to the sender of it to us. /// Note that we do not assume that channels corresponding to failed HTLCs are still available. fn fail_htlc_backwards_internal(&self, source: &HTLCSource, payment_hash: &PaymentHash, onion_error: &HTLCFailReason, destination: HTLCDestination) { - #[cfg(any(feature = "_test_utils", test))] - { - // Ensure that the peer state channel storage lock is not held when calling this - // function. - // This ensures that future code doesn't introduce a lock_order requirement for - // `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling - // this function with any `per_peer_state` peer lock aquired would. - let per_peer_state = self.per_peer_state.read().unwrap(); - for (_, peer) in per_peer_state.iter() { - debug_assert!(peer.try_lock().is_ok()); - } + // Ensure that no peer state channel storage lock is held when calling this function. + // This ensures that future code doesn't introduce a lock-order requirement for + // `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling + // this function with any `per_peer_state` peer lock acquired would. + for (_, peer) in self.per_peer_state.read().unwrap().iter() { + debug_assert_ne!(peer.held_by_thread(), LockHeldState::HeldByThread); } //TODO: There is a timing attack here where if a node fails an HTLC back to us they can @@ -7723,7 +7715,7 @@ where inbound_payment_key: expanded_inbound_key, pending_inbound_payments: Mutex::new(pending_inbound_payments), - pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()) }, + pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()), }, pending_intercepted_htlcs: Mutex::new(pending_intercepted_htlcs.unwrap()), forward_htlcs: Mutex::new(forward_htlcs), diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index da8abcc108d..c42bbbe9466 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -350,6 +350,19 @@ impl<'a, 'b, 'c> Node<'a, 'b, 'c> { } } +/// If we need an unsafe pointer to a `Node` (ie to reference it in a thread +/// pre-std::thread::scope), this provides that with `Sync`. Note that accessing some of the fields +/// in the `Node` are not safe to use (i.e. the ones behind an `Rc`), but that's left to the caller +/// to figure out. +pub struct NodePtr(pub *const Node<'static, 'static, 'static>); +impl NodePtr { + pub fn from_node<'a, 'b: 'a, 'c: 'b>(node: &Node<'a, 'b, 'c>) -> Self { + Self((node as *const Node<'a, 'b, 'c>).cast()) + } +} +unsafe impl Send for NodePtr {} +unsafe impl Sync for NodePtr {} + impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { fn drop(&mut self) { if !panicking() { diff --git a/lightning/src/ln/outbound_payment.rs b/lightning/src/ln/outbound_payment.rs index a9ced49f647..c94a4e9f60c 100644 --- a/lightning/src/ln/outbound_payment.rs +++ b/lightning/src/ln/outbound_payment.rs @@ -393,12 +393,14 @@ pub enum PaymentSendFailure { pub(super) struct OutboundPayments { pub(super) pending_outbound_payments: Mutex>, + pub(super) retry_lock: Mutex<()>, } impl OutboundPayments { pub(super) fn new() -> Self { Self { - pending_outbound_payments: Mutex::new(HashMap::new()) + pending_outbound_payments: Mutex::new(HashMap::new()), + retry_lock: Mutex::new(()), } } @@ -501,6 +503,7 @@ impl OutboundPayments { FH: Fn() -> Vec, L::Target: Logger, { + let _single_thread = self.retry_lock.lock().unwrap(); loop { let mut outbounds = self.pending_outbound_payments.lock().unwrap(); let mut retry_id_route_params = None; diff --git a/lightning/src/ln/payment_tests.rs b/lightning/src/ln/payment_tests.rs index 1c06c0b32cf..bc0910384f7 100644 --- a/lightning/src/ln/payment_tests.rs +++ b/lightning/src/ln/payment_tests.rs @@ -39,7 +39,7 @@ use crate::routing::gossip::NodeId; #[cfg(feature = "std")] use { crate::util::time::tests::SinceEpoch, - std::time::{SystemTime, Duration} + std::time::{SystemTime, Instant, Duration} }; #[test] @@ -2616,3 +2616,165 @@ fn test_simple_partial_retry() { expect_pending_htlcs_forwardable!(nodes[2]); expect_payment_claimable!(nodes[2], payment_hash, payment_secret, amt_msat); } + +#[test] +#[cfg(feature = "std")] +fn test_threaded_payment_retries() { + // In the first version of the in-`ChannelManager` payment retries, retries weren't limited to + // a single thread and would happily let multiple threads run retries at the same time. Because + // retries are done by first calculating the amount we need to retry, then dropping the + // relevant lock, then actually sending, we would happily let multiple threads retry the same + // amount at the same time, overpaying our original HTLC! + let chanmon_cfgs = create_chanmon_cfgs(4); + let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); + let nodes = create_network(4, &node_cfgs, &node_chanmgrs); + + // There is one mitigating guardrail when retrying payments - we can never over-pay by more + // than 10% of the original value. Thus, we want all our retries to be below that. In order to + // keep things simple, we route one HTLC for 0.1% of the payment over channel 1 and the rest + // out over channel 3+4. This will let us ignore 99% of the payment value and deal with only + // our channel. + let chan_1_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 10_000_000, 0).0.contents.short_channel_id; + create_announced_chan_between_nodes_with_value(&nodes, 1, 3, 10_000_000, 0); + let chan_3_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 10_000_000, 0).0.contents.short_channel_id; + let chan_4_scid = create_announced_chan_between_nodes_with_value(&nodes, 2, 3, 10_000_000, 0).0.contents.short_channel_id; + + let amt_msat = 100_000_000; + let (_, payment_hash, _, payment_secret) = get_route_and_payment_hash!(&nodes[0], nodes[2], amt_msat); + #[cfg(feature = "std")] + let payment_expiry_secs = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs() + 60 * 60; + #[cfg(not(feature = "std"))] + let payment_expiry_secs = 60 * 60; + let mut invoice_features = InvoiceFeatures::empty(); + invoice_features.set_variable_length_onion_required(); + invoice_features.set_payment_secret_required(); + invoice_features.set_basic_mpp_optional(); + let payment_params = PaymentParameters::from_node_id(nodes[1].node.get_our_node_id(), TEST_FINAL_CLTV) + .with_expiry_time(payment_expiry_secs as u64) + .with_features(invoice_features); + let mut route_params = RouteParameters { + payment_params, + final_value_msat: amt_msat, + final_cltv_expiry_delta: TEST_FINAL_CLTV, + }; + + let mut route = Route { + paths: vec![ + vec![RouteHop { + pubkey: nodes[1].node.get_our_node_id(), + node_features: nodes[1].node.node_features(), + short_channel_id: chan_1_scid, + channel_features: nodes[1].node.channel_features(), + fee_msat: 0, + cltv_expiry_delta: 100, + }, RouteHop { + pubkey: nodes[3].node.get_our_node_id(), + node_features: nodes[2].node.node_features(), + short_channel_id: 42, // Set a random SCID which nodes[1] will fail as unknown + channel_features: nodes[2].node.channel_features(), + fee_msat: amt_msat / 1000, + cltv_expiry_delta: 100, + }], + vec![RouteHop { + pubkey: nodes[2].node.get_our_node_id(), + node_features: nodes[2].node.node_features(), + short_channel_id: chan_3_scid, + channel_features: nodes[2].node.channel_features(), + fee_msat: 100_000, + cltv_expiry_delta: 100, + }, RouteHop { + pubkey: nodes[3].node.get_our_node_id(), + node_features: nodes[3].node.node_features(), + short_channel_id: chan_4_scid, + channel_features: nodes[3].node.channel_features(), + fee_msat: amt_msat - amt_msat / 1000, + cltv_expiry_delta: 100, + }] + ], + payment_params: Some(PaymentParameters::from_node_id(nodes[2].node.get_our_node_id(), TEST_FINAL_CLTV)), + }; + nodes[0].router.expect_find_route(route_params.clone(), Ok(route.clone())); + + nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params.clone(), Retry::Attempts(0xdeadbeef)).unwrap(); + check_added_monitors!(nodes[0], 2); + let mut send_msg_events = nodes[0].node.get_and_clear_pending_msg_events(); + assert_eq!(send_msg_events.len(), 2); + send_msg_events.retain(|msg| + if let MessageSendEvent::UpdateHTLCs { node_id, .. } = msg { + // Drop the commitment update for nodes[2], we can just let that one sit pending + // forever. + *node_id == nodes[1].node.get_our_node_id() + } else { panic!(); } + ); + + // from here on out, the retry `RouteParameters` amount will be amt/1000 + route_params.final_value_msat /= 1000; + route.paths.pop(); + + let end_time = Instant::now() + Duration::from_secs(1); + macro_rules! thread_body { () => { { + // We really want std::thread::scope, but its not stable until 1.63. Until then, we get unsafe. + let node_ref = NodePtr::from_node(&nodes[0]); + move || { + let node_a = unsafe { &*node_ref.0 }; + while Instant::now() < end_time { + node_a.node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable + // Ignore if we have any pending events, just always pretend we just got a + // PendingHTLCsForwardable + node_a.node.process_pending_htlc_forwards(); + } + } + } } } + let mut threads = Vec::new(); + for _ in 0..16 { threads.push(std::thread::spawn(thread_body!())); } + + // Back in the main thread, poll pending messages and make sure that we never have more than + // one HTLC pending at a time. Note that the commitment_signed_dance will fail horribly if + // there are HTLC messages shoved in while its running. This allows us to test that we never + // generate an additional update_add_htlc until we've fully failed the first. + let mut previously_failed_channels = Vec::new(); + loop { + assert_eq!(send_msg_events.len(), 1); + let send_event = SendEvent::from_event(send_msg_events.pop().unwrap()); + assert_eq!(send_event.msgs.len(), 1); + + nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]); + commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true); + + // Note that we only push one route into `expect_find_route` at a time, because that's all + // the retries (should) need. If the bug is reintroduced "real" routes may be selected, but + // we should still ultimately fail for the same reason - because we're trying to send too + // many HTLCs at once. + let mut new_route_params = route_params.clone(); + previously_failed_channels.push(route.paths[0][1].short_channel_id); + new_route_params.payment_params.previously_failed_channels = previously_failed_channels.clone(); + route.paths[0][1].short_channel_id += 1; + nodes[0].router.expect_find_route(new_route_params, Ok(route.clone())); + + let bs_fail_updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id()); + nodes[0].node.handle_update_fail_htlc(&nodes[1].node.get_our_node_id(), &bs_fail_updates.update_fail_htlcs[0]); + // The "normal" commitment_signed_dance delivers the final RAA and then calls + // `check_added_monitors` to ensure only the one RAA-generated monitor update was created. + // This races with our other threads which may generate an add-HTLCs commitment update via + // `process_pending_htlc_forwards`. Instead, we defer the monitor update check until after + // *we've* called `process_pending_htlc_forwards` when its guaranteed to have two updates. + let last_raa = commitment_signed_dance!(nodes[0], nodes[1], bs_fail_updates.commitment_signed, false, true, false, true); + nodes[0].node.handle_revoke_and_ack(&nodes[1].node.get_our_node_id(), &last_raa); + + let cur_time = Instant::now(); + if cur_time > end_time { + for thread in threads.drain(..) { thread.join().unwrap(); } + } + + // Make sure we have some events to handle when we go around... + nodes[0].node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable + nodes[0].node.process_pending_htlc_forwards(); + send_msg_events = nodes[0].node.get_and_clear_pending_msg_events(); + check_added_monitors!(nodes[0], 2); + + if cur_time > end_time { + break; + } + } +} diff --git a/lightning/src/sync/debug_sync.rs b/lightning/src/sync/debug_sync.rs index 9f7caa2c180..56310937237 100644 --- a/lightning/src/sync/debug_sync.rs +++ b/lightning/src/sync/debug_sync.rs @@ -14,6 +14,8 @@ use std::sync::Condvar as StdCondvar; use crate::prelude::HashMap; +use super::{LockTestExt, LockHeldState}; + #[cfg(feature = "backtrace")] use {crate::prelude::hash_map, backtrace::Backtrace, std::sync::Once}; @@ -168,6 +170,18 @@ impl LockMetadata { fn pre_lock(this: &Arc) { Self::_pre_lock(this, false); } fn pre_read_lock(this: &Arc) -> bool { Self::_pre_lock(this, true) } + fn held_by_thread(this: &Arc) -> LockHeldState { + let mut res = LockHeldState::NotHeldByThread; + LOCKS_HELD.with(|held| { + for (locked_idx, _locked) in held.borrow().iter() { + if *locked_idx == this.lock_idx { + res = LockHeldState::HeldByThread; + } + } + }); + res + } + fn try_locked(this: &Arc) { LOCKS_HELD.with(|held| { // Since a try-lock will simply fail if the lock is held already, we do not @@ -248,6 +262,13 @@ impl Mutex { } } +impl LockTestExt for Mutex { + #[inline] + fn held_by_thread(&self) -> LockHeldState { + LockMetadata::held_by_thread(&self.deps) + } +} + pub struct RwLock { inner: StdRwLock, deps: Arc, @@ -332,4 +353,11 @@ impl RwLock { } } +impl LockTestExt for RwLock { + #[inline] + fn held_by_thread(&self) -> LockHeldState { + LockMetadata::held_by_thread(&self.deps) + } +} + pub type FairRwLock = RwLock; diff --git a/lightning/src/util/fairrwlock.rs b/lightning/src/sync/fairrwlock.rs similarity index 89% rename from lightning/src/util/fairrwlock.rs rename to lightning/src/sync/fairrwlock.rs index 5715a8cf646..a9519ac240c 100644 --- a/lightning/src/util/fairrwlock.rs +++ b/lightning/src/sync/fairrwlock.rs @@ -1,5 +1,6 @@ use std::sync::{LockResult, RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockResult}; use std::sync::atomic::{AtomicUsize, Ordering}; +use super::{LockHeldState, LockTestExt}; /// Rust libstd's RwLock does not provide any fairness guarantees (and, in fact, when used on /// Linux with pthreads under the hood, readers trivially and completely starve writers). @@ -48,3 +49,11 @@ impl FairRwLock { self.lock.try_write() } } + +impl LockTestExt for FairRwLock { + #[inline] + fn held_by_thread(&self) -> LockHeldState { + // fairrwlock is only built in non-test modes, so we should never support tests. + LockHeldState::Unsupported + } +} diff --git a/lightning/src/sync/mod.rs b/lightning/src/sync/mod.rs index f7226a5fa34..50ef40e295f 100644 --- a/lightning/src/sync/mod.rs +++ b/lightning/src/sync/mod.rs @@ -1,3 +1,16 @@ +#[allow(dead_code)] // Depending on the compilation flags some variants are never used +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum LockHeldState { + HeldByThread, + NotHeldByThread, + #[cfg(any(feature = "_bench_unstable", not(test)))] + Unsupported, +} + +pub(crate) trait LockTestExt { + fn held_by_thread(&self) -> LockHeldState; +} + #[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))] mod debug_sync; #[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))] @@ -7,9 +20,22 @@ pub use debug_sync::*; mod test_lockorder_checks; #[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))] -pub use ::std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; +pub(crate) mod fairrwlock; +#[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))] +pub use {std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}, fairrwlock::FairRwLock}; + #[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))] -pub use crate::util::fairrwlock::FairRwLock; +mod ext_impl { + use super::*; + impl LockTestExt for Mutex { + #[inline] + fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported } + } + impl LockTestExt for RwLock { + #[inline] + fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported } + } +} #[cfg(not(feature = "std"))] mod nostd_sync; diff --git a/lightning/src/sync/nostd_sync.rs b/lightning/src/sync/nostd_sync.rs index caf88a7cc04..e17aa6ab15f 100644 --- a/lightning/src/sync/nostd_sync.rs +++ b/lightning/src/sync/nostd_sync.rs @@ -2,6 +2,7 @@ pub use ::alloc::sync::Arc; use core::ops::{Deref, DerefMut}; use core::time::Duration; use core::cell::{RefCell, Ref, RefMut}; +use super::{LockTestExt, LockHeldState}; pub type LockResult = Result; @@ -61,6 +62,14 @@ impl Mutex { } } +impl LockTestExt for Mutex { + #[inline] + fn held_by_thread(&self) -> LockHeldState { + if self.lock().is_err() { return LockHeldState::HeldByThread; } + else { return LockHeldState::NotHeldByThread; } + } +} + pub struct RwLock { inner: RefCell } @@ -116,4 +125,12 @@ impl RwLock { } } +impl LockTestExt for RwLock { + #[inline] + fn held_by_thread(&self) -> LockHeldState { + if self.write().is_err() { return LockHeldState::HeldByThread; } + else { return LockHeldState::NotHeldByThread; } + } +} + pub type FairRwLock = RwLock; diff --git a/lightning/src/sync/test_lockorder_checks.rs b/lightning/src/sync/test_lockorder_checks.rs index f9f30e2cfa2..a3f746b11dc 100644 --- a/lightning/src/sync/test_lockorder_checks.rs +++ b/lightning/src/sync/test_lockorder_checks.rs @@ -1,5 +1,10 @@ use crate::sync::debug_sync::{RwLock, Mutex}; +use super::{LockHeldState, LockTestExt}; + +use std::sync::Arc; +use std::thread; + #[test] #[should_panic] #[cfg(not(feature = "backtrace"))] @@ -92,3 +97,22 @@ fn read_write_lockorder_fail() { let _a = a.write().unwrap(); } } + +#[test] +fn test_thread_locked_state() { + let mtx = Arc::new(Mutex::new(())); + let mtx_ref = Arc::clone(&mtx); + assert_eq!(mtx.held_by_thread(), LockHeldState::NotHeldByThread); + + let lck = mtx.lock().unwrap(); + assert_eq!(mtx.held_by_thread(), LockHeldState::HeldByThread); + + let thrd = std::thread::spawn(move || { + assert_eq!(mtx_ref.held_by_thread(), LockHeldState::NotHeldByThread); + }); + thrd.join().unwrap(); + assert_eq!(mtx.held_by_thread(), LockHeldState::HeldByThread); + + std::mem::drop(lck); + assert_eq!(mtx.held_by_thread(), LockHeldState::NotHeldByThread); +} diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index 1673bd07f69..7bcbc5a41fb 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -27,8 +27,6 @@ pub mod wakers; pub(crate) mod atomic_counter; pub(crate) mod byte_utils; pub(crate) mod chacha20; -#[cfg(all(any(feature = "_bench_unstable", not(test)), feature = "std"))] -pub(crate) mod fairrwlock; #[cfg(fuzzing)] pub mod zbase32; #[cfg(not(fuzzing))] diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 83647a73385..2d68e74c2e5 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -114,11 +114,12 @@ impl<'a> Router for TestRouter<'a> { fn notify_payment_probe_failed(&self, _path: &[&RouteHop], _short_channel_id: u64) {} } -#[cfg(feature = "std")] // If we put this on the `if`, we get "attributes are not yet allowed on `if` expressions" on 1.41.1 impl<'a> Drop for TestRouter<'a> { fn drop(&mut self) { - if std::thread::panicking() { - return; + #[cfg(feature = "std")] { + if std::thread::panicking() { + return; + } } assert!(self.next_routes.lock().unwrap().is_empty()); }