diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f27a2ccf863..340b7f898d9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -242,19 +242,19 @@ jobs: id: cache-graph uses: actions/cache@v3 with: - path: lightning/net_graph-2021-05-31.bin - key: ldk-net_graph-v0.0.15-2021-05-31.bin + path: lightning/net_graph-2023-01-18.bin + key: ldk-net_graph-v0.0.113-2023-01-18.bin - name: Fetch routing graph snapshot if: steps.cache-graph.outputs.cache-hit != 'true' run: | - curl --verbose -L -o lightning/net_graph-2021-05-31.bin https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin - echo "Sha sum: $(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')" - if [ "$(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')" != "${EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM}" ]; then + curl --verbose -L -o lightning/net_graph-2023-01-18.bin https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin + echo "Sha sum: $(sha256sum lightning/net_graph-2023-01-18.bin | awk '{ print $1 }')" + if [ "$(sha256sum lightning/net_graph-2023-01-18.bin | awk '{ print $1 }')" != "${EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM}" ]; then echo "Bad hash" exit 1 fi env: - EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM: 05a5361278f68ee2afd086cc04a1f927a63924be451f3221d380533acfacc303 + EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM: da6066f2bddcddbe7d8a6debbd53545697137b310bbb8c4911bc8c81fc5ff48c - name: Fetch rapid graph sync reference input run: | curl --verbose -L -o lightning-rapid-gossip-sync/res/full_graph.lngossip https://bitcoin.ninja/ldk-compressed_graph-285cb27df79-2022-07-21.bin diff --git a/fuzz/src/bin/gen_target.sh b/fuzz/src/bin/gen_target.sh index 95e65695eb8..fa29540f96b 100755 --- a/fuzz/src/bin/gen_target.sh +++ b/fuzz/src/bin/gen_target.sh @@ -14,6 +14,7 @@ GEN_TEST peer_crypt GEN_TEST process_network_graph GEN_TEST router GEN_TEST zbase32 +GEN_TEST indexedmap GEN_TEST msg_accept_channel msg_targets:: GEN_TEST msg_announcement_signatures msg_targets:: diff --git a/fuzz/src/bin/indexedmap_target.rs b/fuzz/src/bin/indexedmap_target.rs new file mode 100644 index 00000000000..238566d5465 --- /dev/null +++ b/fuzz/src/bin/indexedmap_target.rs @@ -0,0 +1,113 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +// This file is auto-generated by gen_target.sh based on target_template.txt +// To modify it, modify target_template.txt and run gen_target.sh instead. + +#![cfg_attr(feature = "libfuzzer_fuzz", no_main)] + +#[cfg(not(fuzzing))] +compile_error!("Fuzz targets need cfg=fuzzing"); + +extern crate lightning_fuzz; +use lightning_fuzz::indexedmap::*; + +#[cfg(feature = "afl")] +#[macro_use] extern crate afl; +#[cfg(feature = "afl")] +fn main() { + fuzz!(|data| { + indexedmap_run(data.as_ptr(), data.len()); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + indexedmap_run(data.as_ptr(), data.len()); + }); + } +} + +#[cfg(feature = "libfuzzer_fuzz")] +#[macro_use] extern crate libfuzzer_sys; +#[cfg(feature = "libfuzzer_fuzz")] +fuzz_target!(|data: &[u8]| { + indexedmap_run(data.as_ptr(), data.len()); +}); + +#[cfg(feature = "stdin_fuzz")] +fn main() { + use std::io::Read; + + let mut data = Vec::with_capacity(8192); + std::io::stdin().read_to_end(&mut data).unwrap(); + indexedmap_run(data.as_ptr(), data.len()); +} + +#[test] +fn run_test_cases() { + use std::fs; + use std::io::Read; + use lightning_fuzz::utils::test_logger::StringBuffer; + + use std::sync::{atomic, Arc}; + { + let data: Vec = vec![0]; + indexedmap_run(data.as_ptr(), data.len()); + } + let mut threads = Vec::new(); + let threads_running = Arc::new(atomic::AtomicUsize::new(0)); + if let Ok(tests) = fs::read_dir("test_cases/indexedmap") { + for test in tests { + let mut data: Vec = Vec::new(); + let path = test.unwrap().path(); + fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap(); + threads_running.fetch_add(1, atomic::Ordering::AcqRel); + + let thread_count_ref = Arc::clone(&threads_running); + let main_thread_ref = std::thread::current(); + threads.push((path.file_name().unwrap().to_str().unwrap().to_string(), + std::thread::spawn(move || { + let string_logger = StringBuffer::new(); + + let panic_logger = string_logger.clone(); + let res = if ::std::panic::catch_unwind(move || { + indexedmap_test(&data, panic_logger); + }).is_err() { + Some(string_logger.into_string()) + } else { None }; + thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel); + main_thread_ref.unpark(); + res + }) + )); + while threads_running.load(atomic::Ordering::Acquire) > 32 { + std::thread::park(); + } + } + } + let mut failed_outputs = Vec::new(); + for (test, thread) in threads.drain(..) { + if let Some(output) = thread.join().unwrap() { + println!("\nOutput of {}:\n{}\n", test, output); + failed_outputs.push(test); + } + } + if !failed_outputs.is_empty() { + println!("Test cases which failed: "); + for case in failed_outputs { + println!("{}", case); + } + panic!(); + } +} diff --git a/fuzz/src/bin/msg_channel_details_target.rs b/fuzz/src/bin/msg_channel_details_target.rs new file mode 100644 index 00000000000..cb5021aedfa --- /dev/null +++ b/fuzz/src/bin/msg_channel_details_target.rs @@ -0,0 +1,113 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +// This file is auto-generated by gen_target.sh based on target_template.txt +// To modify it, modify target_template.txt and run gen_target.sh instead. + +#![cfg_attr(feature = "libfuzzer_fuzz", no_main)] + +#[cfg(not(fuzzing))] +compile_error!("Fuzz targets need cfg=fuzzing"); + +extern crate lightning_fuzz; +use lightning_fuzz::msg_targets::msg_channel_details::*; + +#[cfg(feature = "afl")] +#[macro_use] extern crate afl; +#[cfg(feature = "afl")] +fn main() { + fuzz!(|data| { + msg_channel_details_run(data.as_ptr(), data.len()); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + msg_channel_details_run(data.as_ptr(), data.len()); + }); + } +} + +#[cfg(feature = "libfuzzer_fuzz")] +#[macro_use] extern crate libfuzzer_sys; +#[cfg(feature = "libfuzzer_fuzz")] +fuzz_target!(|data: &[u8]| { + msg_channel_details_run(data.as_ptr(), data.len()); +}); + +#[cfg(feature = "stdin_fuzz")] +fn main() { + use std::io::Read; + + let mut data = Vec::with_capacity(8192); + std::io::stdin().read_to_end(&mut data).unwrap(); + msg_channel_details_run(data.as_ptr(), data.len()); +} + +#[test] +fn run_test_cases() { + use std::fs; + use std::io::Read; + use lightning_fuzz::utils::test_logger::StringBuffer; + + use std::sync::{atomic, Arc}; + { + let data: Vec = vec![0]; + msg_channel_details_run(data.as_ptr(), data.len()); + } + let mut threads = Vec::new(); + let threads_running = Arc::new(atomic::AtomicUsize::new(0)); + if let Ok(tests) = fs::read_dir("test_cases/msg_channel_details") { + for test in tests { + let mut data: Vec = Vec::new(); + let path = test.unwrap().path(); + fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap(); + threads_running.fetch_add(1, atomic::Ordering::AcqRel); + + let thread_count_ref = Arc::clone(&threads_running); + let main_thread_ref = std::thread::current(); + threads.push((path.file_name().unwrap().to_str().unwrap().to_string(), + std::thread::spawn(move || { + let string_logger = StringBuffer::new(); + + let panic_logger = string_logger.clone(); + let res = if ::std::panic::catch_unwind(move || { + msg_channel_details_test(&data, panic_logger); + }).is_err() { + Some(string_logger.into_string()) + } else { None }; + thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel); + main_thread_ref.unpark(); + res + }) + )); + while threads_running.load(atomic::Ordering::Acquire) > 32 { + std::thread::park(); + } + } + } + let mut failed_outputs = Vec::new(); + for (test, thread) in threads.drain(..) { + if let Some(output) = thread.join().unwrap() { + println!("\nOutput of {}:\n{}\n", test, output); + failed_outputs.push(test); + } + } + if !failed_outputs.is_empty() { + println!("Test cases which failed: "); + for case in failed_outputs { + println!("{}", case); + } + panic!(); + } +} diff --git a/fuzz/src/indexedmap.rs b/fuzz/src/indexedmap.rs new file mode 100644 index 00000000000..795d6175bb5 --- /dev/null +++ b/fuzz/src/indexedmap.rs @@ -0,0 +1,137 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +use lightning::util::indexed_map::{IndexedMap, self}; +use std::collections::{BTreeMap, btree_map}; +use hashbrown::HashSet; + +use crate::utils::test_logger; + +fn check_eq(btree: &BTreeMap, indexed: &IndexedMap) { + assert_eq!(btree.len(), indexed.len()); + assert_eq!(btree.is_empty(), indexed.is_empty()); + + let mut btree_clone = btree.clone(); + assert!(btree_clone == *btree); + let mut indexed_clone = indexed.clone(); + assert!(indexed_clone == *indexed); + + for k in 0..=255 { + assert_eq!(btree.contains_key(&k), indexed.contains_key(&k)); + assert_eq!(btree.get(&k), indexed.get(&k)); + + let btree_entry = btree_clone.entry(k); + let indexed_entry = indexed_clone.entry(k); + match btree_entry { + btree_map::Entry::Occupied(mut bo) => { + if let indexed_map::Entry::Occupied(mut io) = indexed_entry { + assert_eq!(bo.get(), io.get()); + assert_eq!(bo.get_mut(), io.get_mut()); + } else { panic!(); } + }, + btree_map::Entry::Vacant(_) => { + if let indexed_map::Entry::Vacant(_) = indexed_entry { + } else { panic!(); } + } + } + } + + const STRIDE: u8 = 16; + for k in 0..=255/STRIDE { + let lower_bound = k * STRIDE; + let upper_bound = lower_bound + (STRIDE - 1); + let mut btree_iter = btree.range(lower_bound..=upper_bound); + let mut indexed_iter = indexed.range(lower_bound..=upper_bound); + loop { + let b_v = btree_iter.next(); + let i_v = indexed_iter.next(); + assert_eq!(b_v, i_v); + if b_v.is_none() { break; } + } + } + + let mut key_set = HashSet::with_capacity(256); + for k in indexed.unordered_keys() { + assert!(key_set.insert(*k)); + assert!(btree.contains_key(k)); + } + assert_eq!(key_set.len(), btree.len()); + + key_set.clear(); + for (k, v) in indexed.unordered_iter() { + assert!(key_set.insert(*k)); + assert_eq!(btree.get(k).unwrap(), v); + } + assert_eq!(key_set.len(), btree.len()); + + key_set.clear(); + for (k, v) in indexed_clone.unordered_iter_mut() { + assert!(key_set.insert(*k)); + assert_eq!(btree.get(k).unwrap(), v); + } + assert_eq!(key_set.len(), btree.len()); +} + +#[inline] +pub fn do_test(data: &[u8]) { + if data.len() % 2 != 0 { return; } + let mut btree = BTreeMap::new(); + let mut indexed = IndexedMap::new(); + + // Read in k-v pairs from the input and insert them into the maps then check that the maps are + // equivalent in every way we can read them. + for tuple in data.windows(2) { + let prev_value_b = btree.insert(tuple[0], tuple[1]); + let prev_value_i = indexed.insert(tuple[0], tuple[1]); + assert_eq!(prev_value_b, prev_value_i); + } + check_eq(&btree, &indexed); + + // Now, modify the maps in all the ways we have to do so, checking that the maps remain + // equivalent as we go. + for (k, v) in indexed.unordered_iter_mut() { + *v = *k; + *btree.get_mut(k).unwrap() = *k; + } + check_eq(&btree, &indexed); + + for k in 0..=255 { + match btree.entry(k) { + btree_map::Entry::Occupied(mut bo) => { + if let indexed_map::Entry::Occupied(mut io) = indexed.entry(k) { + if k < 64 { + *io.get_mut() ^= 0xff; + *bo.get_mut() ^= 0xff; + } else if k < 128 { + *io.into_mut() ^= 0xff; + *bo.get_mut() ^= 0xff; + } else { + assert_eq!(bo.remove_entry(), io.remove_entry()); + } + } else { panic!(); } + }, + btree_map::Entry::Vacant(bv) => { + if let indexed_map::Entry::Vacant(iv) = indexed.entry(k) { + bv.insert(k); + iv.insert(k); + } else { panic!(); } + }, + } + } + check_eq(&btree, &indexed); +} + +pub fn indexedmap_test(data: &[u8], _out: Out) { + do_test(data); +} + +#[no_mangle] +pub extern "C" fn indexedmap_run(data: *const u8, datalen: usize) { + do_test(unsafe { std::slice::from_raw_parts(data, datalen) }); +} diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index 2238a9702a9..462307d55b4 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -17,6 +17,7 @@ pub mod utils; pub mod chanmon_deser; pub mod chanmon_consistency; pub mod full_stack; +pub mod indexedmap; pub mod onion_message; pub mod peer_crypt; pub mod process_network_graph; diff --git a/fuzz/targets.h b/fuzz/targets.h index cff3f9bdbb5..5bfee07dafb 100644 --- a/fuzz/targets.h +++ b/fuzz/targets.h @@ -7,6 +7,7 @@ void peer_crypt_run(const unsigned char* data, size_t data_len); void process_network_graph_run(const unsigned char* data, size_t data_len); void router_run(const unsigned char* data, size_t data_len); void zbase32_run(const unsigned char* data, size_t data_len); +void indexedmap_run(const unsigned char* data, size_t data_len); void msg_accept_channel_run(const unsigned char* data, size_t data_len); void msg_announcement_signatures_run(const unsigned char* data, size_t data_len); void msg_channel_reestablish_run(const unsigned char* data, size_t data_len); diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index 1a5b978502c..24a21d795b2 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -32,11 +32,11 @@ use crate::util::logger::{Logger, Level}; use crate::util::events::{MessageSendEvent, MessageSendEventsProvider}; use crate::util::scid_utils::{block_from_scid, scid_from_parts, MAX_SCID_BLOCK}; use crate::util::string::PrintableString; +use crate::util::indexed_map::{IndexedMap, Entry as IndexedMapEntry}; use crate::io; use crate::io_extras::{copy, sink}; use crate::prelude::*; -use alloc::collections::{BTreeMap, btree_map::Entry as BtreeEntry}; use core::{cmp, fmt}; use crate::sync::{RwLock, RwLockReadGuard}; #[cfg(feature = "std")] @@ -133,8 +133,8 @@ pub struct NetworkGraph where L::Target: Logger { genesis_hash: BlockHash, logger: L, // Lock order: channels -> nodes - channels: RwLock>, - nodes: RwLock>, + channels: RwLock>, + nodes: RwLock>, // Lock order: removed_channels -> removed_nodes // // NOTE: In the following `removed_*` maps, we use seconds since UNIX epoch to track time instead @@ -158,8 +158,8 @@ pub struct NetworkGraph where L::Target: Logger { /// A read-only view of [`NetworkGraph`]. pub struct ReadOnlyNetworkGraph<'a> { - channels: RwLockReadGuard<'a, BTreeMap>, - nodes: RwLockReadGuard<'a, BTreeMap>, + channels: RwLockReadGuard<'a, IndexedMap>, + nodes: RwLockReadGuard<'a, IndexedMap>, } /// Update to the [`NetworkGraph`] based on payment failure information conveyed via the Onion @@ -1054,10 +1054,6 @@ impl Readable for NodeAlias { pub struct NodeInfo { /// All valid channels a node has announced pub channels: Vec, - /// Lowest fees enabling routing via any of the enabled, known channels to a node. - /// The two fields (flat and proportional fee) are independent, - /// meaning they don't have to refer to the same channel. - pub lowest_inbound_channel_fees: Option, /// More information about a node from node_announcement. /// Optional because we store a Node entry after learning about it from /// a channel announcement, but before receiving a node announcement. @@ -1066,8 +1062,8 @@ pub struct NodeInfo { impl fmt::Display for NodeInfo { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(f, "lowest_inbound_channel_fees: {:?}, channels: {:?}, announcement_info: {:?}", - self.lowest_inbound_channel_fees, &self.channels[..], self.announcement_info)?; + write!(f, " channels: {:?}, announcement_info: {:?}", + &self.channels[..], self.announcement_info)?; Ok(()) } } @@ -1075,7 +1071,7 @@ impl fmt::Display for NodeInfo { impl Writeable for NodeInfo { fn write(&self, writer: &mut W) -> Result<(), io::Error> { write_tlv_fields!(writer, { - (0, self.lowest_inbound_channel_fees, option), + // Note that older versions of LDK wrote the lowest inbound fees here at type 0 (2, self.announcement_info, option), (4, self.channels, vec_type), }); @@ -1103,18 +1099,22 @@ impl MaybeReadable for NodeAnnouncementInfoDeserWrapper { impl Readable for NodeInfo { fn read(reader: &mut R) -> Result { - _init_tlv_field_var!(lowest_inbound_channel_fees, option); + // Historically, we tracked the lowest inbound fees for any node in order to use it as an + // A* heuristic when routing. Sadly, these days many, many nodes have at least one channel + // with zero inbound fees, causing that heuristic to provide little gain. Worse, because it + // requires additional complexity and lookups during routing, it ends up being a + // performance loss. Thus, we simply ignore the old field here and no longer track it. + let mut _lowest_inbound_channel_fees: Option = None; let mut announcement_info_wrap: Option = None; _init_tlv_field_var!(channels, vec_type); read_tlv_fields!(reader, { - (0, lowest_inbound_channel_fees, option), + (0, _lowest_inbound_channel_fees, option), (2, announcement_info_wrap, ignorable), (4, channels, vec_type), }); Ok(NodeInfo { - lowest_inbound_channel_fees: _init_tlv_based_struct_field!(lowest_inbound_channel_fees, option), announcement_info: announcement_info_wrap.map(|w| w.0), channels: _init_tlv_based_struct_field!(channels, vec_type), }) @@ -1131,13 +1131,13 @@ impl Writeable for NetworkGraph where L::Target: Logger { self.genesis_hash.write(writer)?; let channels = self.channels.read().unwrap(); (channels.len() as u64).write(writer)?; - for (ref chan_id, ref chan_info) in channels.iter() { + for (ref chan_id, ref chan_info) in channels.unordered_iter() { (*chan_id).write(writer)?; chan_info.write(writer)?; } let nodes = self.nodes.read().unwrap(); (nodes.len() as u64).write(writer)?; - for (ref node_id, ref node_info) in nodes.iter() { + for (ref node_id, ref node_info) in nodes.unordered_iter() { node_id.write(writer)?; node_info.write(writer)?; } @@ -1156,14 +1156,14 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { let genesis_hash: BlockHash = Readable::read(reader)?; let channels_count: u64 = Readable::read(reader)?; - let mut channels = BTreeMap::new(); + let mut channels = IndexedMap::new(); for _ in 0..channels_count { let chan_id: u64 = Readable::read(reader)?; let chan_info = Readable::read(reader)?; channels.insert(chan_id, chan_info); } let nodes_count: u64 = Readable::read(reader)?; - let mut nodes = BTreeMap::new(); + let mut nodes = IndexedMap::new(); for _ in 0..nodes_count { let node_id = Readable::read(reader)?; let node_info = Readable::read(reader)?; @@ -1191,11 +1191,11 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { impl fmt::Display for NetworkGraph where L::Target: Logger { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { writeln!(f, "Network map\n[Channels]")?; - for (key, val) in self.channels.read().unwrap().iter() { + for (key, val) in self.channels.read().unwrap().unordered_iter() { writeln!(f, " {}: {}", key, val)?; } writeln!(f, "[Nodes]")?; - for (&node_id, val) in self.nodes.read().unwrap().iter() { + for (&node_id, val) in self.nodes.read().unwrap().unordered_iter() { writeln!(f, " {}: {}", log_bytes!(node_id.as_slice()), val)?; } Ok(()) @@ -1218,8 +1218,8 @@ impl NetworkGraph where L::Target: Logger { secp_ctx: Secp256k1::verification_only(), genesis_hash, logger, - channels: RwLock::new(BTreeMap::new()), - nodes: RwLock::new(BTreeMap::new()), + channels: RwLock::new(IndexedMap::new()), + nodes: RwLock::new(IndexedMap::new()), last_rapid_gossip_sync_timestamp: Mutex::new(None), removed_channels: Mutex::new(HashMap::new()), removed_nodes: Mutex::new(HashMap::new()), @@ -1252,7 +1252,7 @@ impl NetworkGraph where L::Target: Logger { /// purposes. #[cfg(test)] pub fn clear_nodes_announcement_info(&self) { - for node in self.nodes.write().unwrap().iter_mut() { + for node in self.nodes.write().unwrap().unordered_iter_mut() { node.1.announcement_info = None; } } @@ -1382,7 +1382,7 @@ impl NetworkGraph where L::Target: Logger { let node_id_b = channel_info.node_two.clone(); match channels.entry(short_channel_id) { - BtreeEntry::Occupied(mut entry) => { + IndexedMapEntry::Occupied(mut entry) => { //TODO: because asking the blockchain if short_channel_id is valid is only optional //in the blockchain API, we need to handle it smartly here, though it's unclear //exactly how... @@ -1401,20 +1401,19 @@ impl NetworkGraph where L::Target: Logger { return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip}); } }, - BtreeEntry::Vacant(entry) => { + IndexedMapEntry::Vacant(entry) => { entry.insert(channel_info); } }; for current_node_id in [node_id_a, node_id_b].iter() { match nodes.entry(current_node_id.clone()) { - BtreeEntry::Occupied(node_entry) => { + IndexedMapEntry::Occupied(node_entry) => { node_entry.into_mut().channels.push(short_channel_id); }, - BtreeEntry::Vacant(node_entry) => { + IndexedMapEntry::Vacant(node_entry) => { node_entry.insert(NodeInfo { channels: vec!(short_channel_id), - lowest_inbound_channel_fees: None, announcement_info: None, }); } @@ -1586,7 +1585,7 @@ impl NetworkGraph where L::Target: Logger { for scid in node.channels.iter() { if let Some(chan_info) = channels.remove(scid) { let other_node_id = if node_id == chan_info.node_one { chan_info.node_two } else { chan_info.node_one }; - if let BtreeEntry::Occupied(mut other_node_entry) = nodes.entry(other_node_id) { + if let IndexedMapEntry::Occupied(mut other_node_entry) = nodes.entry(other_node_id) { other_node_entry.get_mut().channels.retain(|chan_id| { *scid != *chan_id }); @@ -1645,7 +1644,7 @@ impl NetworkGraph where L::Target: Logger { // Sadly BTreeMap::retain was only stabilized in 1.53 so we can't switch to it for some // time. let mut scids_to_remove = Vec::new(); - for (scid, info) in channels.iter_mut() { + for (scid, info) in channels.unordered_iter_mut() { if info.one_to_two.is_some() && info.one_to_two.as_ref().unwrap().last_update < min_time_unix { info.one_to_two = None; } @@ -1715,9 +1714,7 @@ impl NetworkGraph where L::Target: Logger { } fn update_channel_intern(&self, msg: &msgs::UnsignedChannelUpdate, full_msg: Option<&msgs::ChannelUpdate>, sig: Option<&secp256k1::ecdsa::Signature>) -> Result<(), LightningError> { - let dest_node_id; let chan_enabled = msg.flags & (1 << 1) != (1 << 1); - let chan_was_enabled; #[cfg(all(feature = "std", not(test), not(feature = "_test_utils")))] { @@ -1765,9 +1762,6 @@ impl NetworkGraph where L::Target: Logger { } else if existing_chan_info.last_update == msg.timestamp { return Err(LightningError{err: "Update had same timestamp as last processed update".to_owned(), action: ErrorAction::IgnoreDuplicateGossip}); } - chan_was_enabled = existing_chan_info.enabled; - } else { - chan_was_enabled = false; } } } @@ -1795,7 +1789,6 @@ impl NetworkGraph where L::Target: Logger { let msg_hash = hash_to_message!(&Sha256dHash::hash(&msg.encode()[..])[..]); if msg.flags & 1 == 1 { - dest_node_id = channel.node_one.clone(); check_update_latest!(channel.two_to_one); if let Some(sig) = sig { secp_verify_sig!(self.secp_ctx, &msg_hash, &sig, &PublicKey::from_slice(channel.node_two.as_slice()).map_err(|_| LightningError{ @@ -1805,7 +1798,6 @@ impl NetworkGraph where L::Target: Logger { } channel.two_to_one = get_new_channel_info!(); } else { - dest_node_id = channel.node_two.clone(); check_update_latest!(channel.one_to_two); if let Some(sig) = sig { secp_verify_sig!(self.secp_ctx, &msg_hash, &sig, &PublicKey::from_slice(channel.node_one.as_slice()).map_err(|_| LightningError{ @@ -1818,51 +1810,13 @@ impl NetworkGraph where L::Target: Logger { } } - let mut nodes = self.nodes.write().unwrap(); - if chan_enabled { - let node = nodes.get_mut(&dest_node_id).unwrap(); - let mut base_msat = msg.fee_base_msat; - let mut proportional_millionths = msg.fee_proportional_millionths; - if let Some(fees) = node.lowest_inbound_channel_fees { - base_msat = cmp::min(base_msat, fees.base_msat); - proportional_millionths = cmp::min(proportional_millionths, fees.proportional_millionths); - } - node.lowest_inbound_channel_fees = Some(RoutingFees { - base_msat, - proportional_millionths - }); - } else if chan_was_enabled { - let node = nodes.get_mut(&dest_node_id).unwrap(); - let mut lowest_inbound_channel_fees = None; - - for chan_id in node.channels.iter() { - let chan = channels.get(chan_id).unwrap(); - let chan_info_opt; - if chan.node_one == dest_node_id { - chan_info_opt = chan.two_to_one.as_ref(); - } else { - chan_info_opt = chan.one_to_two.as_ref(); - } - if let Some(chan_info) = chan_info_opt { - if chan_info.enabled { - let fees = lowest_inbound_channel_fees.get_or_insert(RoutingFees { - base_msat: u32::max_value(), proportional_millionths: u32::max_value() }); - fees.base_msat = cmp::min(fees.base_msat, chan_info.fees.base_msat); - fees.proportional_millionths = cmp::min(fees.proportional_millionths, chan_info.fees.proportional_millionths); - } - } - } - - node.lowest_inbound_channel_fees = lowest_inbound_channel_fees; - } - Ok(()) } - fn remove_channel_in_nodes(nodes: &mut BTreeMap, chan: &ChannelInfo, short_channel_id: u64) { + fn remove_channel_in_nodes(nodes: &mut IndexedMap, chan: &ChannelInfo, short_channel_id: u64) { macro_rules! remove_from_node { ($node_id: expr) => { - if let BtreeEntry::Occupied(mut entry) = nodes.entry($node_id) { + if let IndexedMapEntry::Occupied(mut entry) = nodes.entry($node_id) { entry.get_mut().channels.retain(|chan_id| { short_channel_id != *chan_id }); @@ -1883,8 +1837,8 @@ impl NetworkGraph where L::Target: Logger { impl ReadOnlyNetworkGraph<'_> { /// Returns all known valid channels' short ids along with announced channel info. /// - /// (C-not exported) because we have no mapping for `BTreeMap`s - pub fn channels(&self) -> &BTreeMap { + /// (C-not exported) because we don't want to return lifetime'd references + pub fn channels(&self) -> &IndexedMap { &*self.channels } @@ -1896,13 +1850,13 @@ impl ReadOnlyNetworkGraph<'_> { #[cfg(c_bindings)] // Non-bindings users should use `channels` /// Returns the list of channels in the graph pub fn list_channels(&self) -> Vec { - self.channels.keys().map(|c| *c).collect() + self.channels.unordered_keys().map(|c| *c).collect() } /// Returns all known nodes' public keys along with announced node info. /// - /// (C-not exported) because we have no mapping for `BTreeMap`s - pub fn nodes(&self) -> &BTreeMap { + /// (C-not exported) because we don't want to return lifetime'd references + pub fn nodes(&self) -> &IndexedMap { &*self.nodes } @@ -1914,7 +1868,7 @@ impl ReadOnlyNetworkGraph<'_> { #[cfg(c_bindings)] // Non-bindings users should use `nodes` /// Returns the list of nodes in the graph pub fn list_nodes(&self) -> Vec { - self.nodes.keys().map(|n| *n).collect() + self.nodes.unordered_keys().map(|n| *n).collect() } /// Get network addresses by node id. @@ -3275,7 +3229,6 @@ mod tests { // 2. Check we can read a NodeInfo anyways, but set the NodeAnnouncementInfo to None if invalid let valid_node_info = NodeInfo { channels: Vec::new(), - lowest_inbound_channel_fees: None, announcement_info: Some(valid_node_ann_info), }; diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index c15b612d939..8543956ac65 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -582,7 +582,6 @@ impl_writeable_tlv_based!(RouteHintHop, { #[derive(Eq, PartialEq)] struct RouteGraphNode { node_id: NodeId, - lowest_fee_to_peer_through_node: u64, lowest_fee_to_node: u64, total_cltv_delta: u32, // The maximum value a yet-to-be-constructed payment path might flow through this node. @@ -603,9 +602,9 @@ struct RouteGraphNode { impl cmp::Ord for RouteGraphNode { fn cmp(&self, other: &RouteGraphNode) -> cmp::Ordering { - let other_score = cmp::max(other.lowest_fee_to_peer_through_node, other.path_htlc_minimum_msat) + let other_score = cmp::max(other.lowest_fee_to_node, other.path_htlc_minimum_msat) .saturating_add(other.path_penalty_msat); - let self_score = cmp::max(self.lowest_fee_to_peer_through_node, self.path_htlc_minimum_msat) + let self_score = cmp::max(self.lowest_fee_to_node, self.path_htlc_minimum_msat) .saturating_add(self.path_penalty_msat); other_score.cmp(&self_score).then_with(|| other.node_id.cmp(&self.node_id)) } @@ -729,8 +728,6 @@ struct PathBuildingHop<'a> { candidate: CandidateRouteHop<'a>, fee_msat: u64, - /// Minimal fees required to route to the source node of the current hop via any of its inbound channels. - src_lowest_inbound_fees: RoutingFees, /// All the fees paid *after* this channel on the way to the destination next_hops_fee_msat: u64, /// Fee paid for the use of the current channel (see candidate.fees()). @@ -888,18 +885,20 @@ impl<'a> PaymentPath<'a> { } } +#[inline(always)] +/// Calculate the fees required to route the given amount over a channel with the given fees. fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option { - let proportional_fee_millions = - amount_msat.checked_mul(channel_fees.proportional_millionths as u64); - if let Some(new_fee) = proportional_fee_millions.and_then(|part| { - (channel_fees.base_msat as u64).checked_add(part / 1_000_000) }) { + amount_msat.checked_mul(channel_fees.proportional_millionths as u64) + .and_then(|part| (channel_fees.base_msat as u64).checked_add(part / 1_000_000)) +} - Some(new_fee) - } else { - // This function may be (indirectly) called without any verification, - // with channel_fees provided by a caller. We should handle it gracefully. - None - } +#[inline(always)] +/// Calculate the fees required to route the given amount over a channel with the given fees, +/// saturating to [`u64::max_value`]. +fn compute_fees_saturating(amount_msat: u64, channel_fees: RoutingFees) -> u64 { + amount_msat.checked_mul(channel_fees.proportional_millionths as u64) + .map(|prop| prop / 1_000_000).unwrap_or(u64::max_value()) + .saturating_add(channel_fees.base_msat as u64) } /// The default `features` we assume for a node in a route, when no `features` are known about that @@ -1007,9 +1006,8 @@ where L::Target: Logger { // 8. If our maximum channel saturation limit caused us to pick two identical paths, combine // them so that we're not sending two HTLCs along the same path. - // As for the actual search algorithm, - // we do a payee-to-payer pseudo-Dijkstra's sorting by each node's distance from the payee - // plus the minimum per-HTLC fee to get from it to another node (aka "shitty pseudo-A*"). + // As for the actual search algorithm, we do a payee-to-payer Dijkstra's sorting by each node's + // distance from the payee // // We are not a faithful Dijkstra's implementation because we can change values which impact // earlier nodes while processing later nodes. Specifically, if we reach a channel with a lower @@ -1044,10 +1042,6 @@ where L::Target: Logger { // runtime for little gain. Specifically, the current algorithm rather efficiently explores the // graph for candidate paths, calculating the maximum value which can realistically be sent at // the same time, remaining generic across different payment values. - // - // TODO: There are a few tweaks we could do, including possibly pre-calculating more stuff - // to use as the A* heuristic beyond just the cost to get one node further than the current - // one. let network_channels = network_graph.channels(); let network_nodes = network_graph.nodes(); @@ -1097,7 +1091,7 @@ where L::Target: Logger { } } - // The main heap containing all candidate next-hops sorted by their score (max(A* fee, + // The main heap containing all candidate next-hops sorted by their score (max(fee, // htlc_minimum)). Ideally this would be a heap which allowed cheap score reduction instead of // adding duplicate entries when we find a better path to a given node. let mut targets: BinaryHeap = BinaryHeap::new(); @@ -1262,10 +1256,10 @@ where L::Target: Logger { // might violate htlc_minimum_msat on the hops which are next along the // payment path (upstream to the payee). To avoid that, we recompute // path fees knowing the final path contribution after constructing it. - let path_htlc_minimum_msat = compute_fees($next_hops_path_htlc_minimum_msat, $candidate.fees()) - .and_then(|fee_msat| fee_msat.checked_add($next_hops_path_htlc_minimum_msat)) - .map(|fee_msat| cmp::max(fee_msat, $candidate.htlc_minimum_msat())) - .unwrap_or_else(|| u64::max_value()); + let path_htlc_minimum_msat = cmp::max( + compute_fees_saturating($next_hops_path_htlc_minimum_msat, $candidate.fees()) + .saturating_add($next_hops_path_htlc_minimum_msat), + $candidate.htlc_minimum_msat()); let hm_entry = dist.entry($src_node_id); let old_entry = hm_entry.or_insert_with(|| { // If there was previously no known way to access the source node @@ -1273,20 +1267,10 @@ where L::Target: Logger { // semi-dummy record just to compute the fees to reach the source node. // This will affect our decision on selecting short_channel_id // as a way to reach the $dest_node_id. - let mut fee_base_msat = 0; - let mut fee_proportional_millionths = 0; - if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) { - fee_base_msat = fees.base_msat; - fee_proportional_millionths = fees.proportional_millionths; - } PathBuildingHop { node_id: $dest_node_id.clone(), candidate: $candidate.clone(), fee_msat: 0, - src_lowest_inbound_fees: RoutingFees { - base_msat: fee_base_msat, - proportional_millionths: fee_proportional_millionths, - }, next_hops_fee_msat: u64::max_value(), hop_use_fee_msat: u64::max_value(), total_fee_msat: u64::max_value(), @@ -1309,38 +1293,15 @@ where L::Target: Logger { if should_process { let mut hop_use_fee_msat = 0; - let mut total_fee_msat = $next_hops_fee_msat; + let mut total_fee_msat: u64 = $next_hops_fee_msat; // Ignore hop_use_fee_msat for channel-from-us as we assume all channels-from-us // will have the same effective-fee if $src_node_id != our_node_id { - match compute_fees(amount_to_transfer_over_msat, $candidate.fees()) { - // max_value means we'll always fail - // the old_entry.total_fee_msat > total_fee_msat check - None => total_fee_msat = u64::max_value(), - Some(fee_msat) => { - hop_use_fee_msat = fee_msat; - total_fee_msat += hop_use_fee_msat; - // When calculating the lowest inbound fees to a node, we - // calculate fees here not based on the actual value we think - // will flow over this channel, but on the minimum value that - // we'll accept flowing over it. The minimum accepted value - // is a constant through each path collection run, ensuring - // consistent basis. Otherwise we may later find a - // different path to the source node that is more expensive, - // but which we consider to be cheaper because we are capacity - // constrained and the relative fee becomes lower. - match compute_fees(minimal_value_contribution_msat, old_entry.src_lowest_inbound_fees) - .map(|a| a.checked_add(total_fee_msat)) { - Some(Some(v)) => { - total_fee_msat = v; - }, - _ => { - total_fee_msat = u64::max_value(); - } - }; - } - } + // Note that `u64::max_value` means we'll always fail the + // `old_entry.total_fee_msat > total_fee_msat` check below + hop_use_fee_msat = compute_fees_saturating(amount_to_transfer_over_msat, $candidate.fees()); + total_fee_msat = total_fee_msat.saturating_add(hop_use_fee_msat); } let channel_usage = ChannelUsage { @@ -1355,8 +1316,7 @@ where L::Target: Logger { .saturating_add(channel_penalty_msat); let new_graph_node = RouteGraphNode { node_id: $src_node_id, - lowest_fee_to_peer_through_node: total_fee_msat, - lowest_fee_to_node: $next_hops_fee_msat as u64 + hop_use_fee_msat, + lowest_fee_to_node: total_fee_msat, total_cltv_delta: hop_total_cltv_delta, value_contribution_msat, path_htlc_minimum_msat, @@ -5544,9 +5504,9 @@ mod tests { 'load_endpoints: for _ in 0..10 { loop { seed = seed.overflowing_mul(0xdeadbeef).0; - let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); let payment_params = PaymentParameters::from_node_id(dst); let amt = seed as u64 % 200_000_000; let params = ProbabilisticScoringParameters::default(); @@ -5582,9 +5542,9 @@ mod tests { 'load_endpoints: for _ in 0..10 { loop { seed = seed.overflowing_mul(0xdeadbeef).0; - let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); let payment_params = PaymentParameters::from_node_id(dst).with_features(channelmanager::provided_invoice_features(&config)); let amt = seed as u64 % 200_000_000; let params = ProbabilisticScoringParameters::default(); @@ -5639,8 +5599,8 @@ pub(crate) mod bench_utils { use std::fs::File; /// Tries to open a network graph file, or panics with a URL to fetch it. pub(crate) fn get_route_file() -> Result { - let res = File::open("net_graph-2021-05-31.bin") // By default we're run in RL/lightning - .or_else(|_| File::open("lightning/net_graph-2021-05-31.bin")) // We may be run manually in RL/ + let res = File::open("net_graph-2023-01-18.bin") // By default we're run in RL/lightning + .or_else(|_| File::open("lightning/net_graph-2023-01-18.bin")) // We may be run manually in RL/ .or_else(|_| { // Fall back to guessing based on the binary location // path is likely something like .../rust-lightning/target/debug/deps/lightning-... let mut path = std::env::current_exe().unwrap(); @@ -5649,11 +5609,11 @@ pub(crate) mod bench_utils { path.pop(); // debug path.pop(); // target path.push("lightning"); - path.push("net_graph-2021-05-31.bin"); + path.push("net_graph-2023-01-18.bin"); eprintln!("{}", path.to_str().unwrap()); File::open(path) }) - .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin and place it at lightning/net_graph-2021-05-31.bin"); + .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin and place it at lightning/net_graph-2023-01-18.bin"); #[cfg(require_route_graph_test)] return Ok(res.unwrap()); #[cfg(not(require_route_graph_test))] @@ -5782,9 +5742,9 @@ mod benches { 'load_endpoints: for _ in 0..150 { loop { seed *= 0xdeadbeef; - let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let src = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); seed *= 0xdeadbeef; - let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); let params = PaymentParameters::from_node_id(dst).with_features(features.clone()); let first_hop = first_hop(src); let amt = seed as u64 % 1_000_000; diff --git a/lightning/src/util/indexed_map.rs b/lightning/src/util/indexed_map.rs new file mode 100644 index 00000000000..cccbfe7bc7a --- /dev/null +++ b/lightning/src/util/indexed_map.rs @@ -0,0 +1,203 @@ +//! This module has a map which can be iterated in a deterministic order. See the [`IndexedMap`]. + +use crate::prelude::{HashMap, hash_map}; +use alloc::collections::{BTreeSet, btree_set}; +use core::hash::Hash; +use core::cmp::Ord; +use core::ops::RangeBounds; + +/// A map which can be iterated in a deterministic order. +/// +/// This would traditionally be accomplished by simply using a [`BTreeMap`], however B-Trees +/// generally have very slow lookups. Because we use a nodes+channels map while finding routes +/// across the network graph, our network graph backing map must be as performant as possible. +/// However, because peers expect to sync the network graph from us (and we need to support that +/// without holding a lock on the graph for the duration of the sync or dumping the entire graph +/// into our outbound message queue), we need an iterable map with a consistent iteration order we +/// can jump to a starting point on. +/// +/// Thus, we have a custom data structure here - its API mimics that of Rust's [`BTreeMap`], but is +/// actually backed by a [`HashMap`], with some additional tracking to ensure we can iterate over +/// keys in the order defined by [`Ord`]. +/// +/// [`BTreeMap`]: alloc::collections::BTreeMap +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct IndexedMap { + map: HashMap, + // TODO: Explore swapping this for a sorted vec (that is only sorted on first range() call) + keys: BTreeSet, +} + +impl IndexedMap { + /// Constructs a new, empty map + pub fn new() -> Self { + Self { + map: HashMap::new(), + keys: BTreeSet::new(), + } + } + + #[inline(always)] + /// Fetches the element with the given `key`, if one exists. + pub fn get(&self, key: &K) -> Option<&V> { + self.map.get(key) + } + + /// Fetches a mutable reference to the element with the given `key`, if one exists. + pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { + self.map.get_mut(key) + } + + #[inline] + /// Returns true if an element with the given `key` exists in the map. + pub fn contains_key(&self, key: &K) -> bool { + self.map.contains_key(key) + } + + /// Removes the element with the given `key`, returning it, if one exists. + pub fn remove(&mut self, key: &K) -> Option { + let ret = self.map.remove(key); + if let Some(_) = ret { + assert!(self.keys.remove(key), "map and keys must be consistent"); + } + ret + } + + /// Inserts the given `key`/`value` pair into the map, returning the element that was + /// previously stored at the given `key`, if one exists. + pub fn insert(&mut self, key: K, value: V) -> Option { + let ret = self.map.insert(key.clone(), value); + if ret.is_none() { + assert!(self.keys.insert(key), "map and keys must be consistent"); + } + ret + } + + /// Returns an [`Entry`] for the given `key` in the map, allowing access to the value. + pub fn entry(&mut self, key: K) -> Entry<'_, K, V> { + match self.map.entry(key.clone()) { + hash_map::Entry::Vacant(entry) => { + Entry::Vacant(VacantEntry { + underlying_entry: entry, + key, + keys: &mut self.keys, + }) + }, + hash_map::Entry::Occupied(entry) => { + Entry::Occupied(OccupiedEntry { + underlying_entry: entry, + keys: &mut self.keys, + }) + } + } + } + + /// Returns an iterator which iterates over the keys in the map, in a random order. + pub fn unordered_keys(&self) -> impl Iterator { + self.map.keys() + } + + /// Returns an iterator which iterates over the `key`/`value` pairs in a random order. + pub fn unordered_iter(&self) -> impl Iterator { + self.map.iter() + } + + /// Returns an iterator which iterates over the `key`s and mutable references to `value`s in a + /// random order. + pub fn unordered_iter_mut(&mut self) -> impl Iterator { + self.map.iter_mut() + } + + /// Returns an iterator which iterates over the `key`/`value` pairs in a given range. + pub fn range>(&self, range: R) -> Range { + Range { + inner_range: self.keys.range(range), + map: &self.map, + } + } + + /// Returns the number of `key`/`value` pairs in the map + pub fn len(&self) -> usize { + self.map.len() + } + + /// Returns true if there are no elements in the map + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + +/// An iterator over a range of values in an [`IndexedMap`] +pub struct Range<'a, K: Hash + Ord, V> { + inner_range: btree_set::Range<'a, K>, + map: &'a HashMap, +} +impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> { + type Item = (&'a K, &'a V); + fn next(&mut self) -> Option<(&'a K, &'a V)> { + self.inner_range.next().map(|k| { + (k, self.map.get(k).expect("map and keys must be consistent")) + }) + } +} + +/// An [`Entry`] for a key which currently has no value +pub struct VacantEntry<'a, K: Hash + Ord, V> { + #[cfg(feature = "hashbrown")] + underlying_entry: hash_map::VacantEntry<'a, K, V, hash_map::DefaultHashBuilder>, + #[cfg(not(feature = "hashbrown"))] + underlying_entry: hash_map::VacantEntry<'a, K, V>, + key: K, + keys: &'a mut BTreeSet, +} + +/// An [`Entry`] for an existing key-value pair +pub struct OccupiedEntry<'a, K: Hash + Ord, V> { + #[cfg(feature = "hashbrown")] + underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>, + #[cfg(not(feature = "hashbrown"))] + underlying_entry: hash_map::OccupiedEntry<'a, K, V>, + keys: &'a mut BTreeSet, +} + +/// A mutable reference to a position in the map. This can be used to reference, add, or update the +/// value at a fixed key. +pub enum Entry<'a, K: Hash + Ord, V> { + /// A mutable reference to a position within the map where there is no value. + Vacant(VacantEntry<'a, K, V>), + /// A mutable reference to a position within the map where there is currently a value. + Occupied(OccupiedEntry<'a, K, V>), +} + +impl<'a, K: Hash + Ord, V> VacantEntry<'a, K, V> { + /// Insert a value into the position described by this entry. + pub fn insert(self, value: V) -> &'a mut V { + assert!(self.keys.insert(self.key), "map and keys must be consistent"); + self.underlying_entry.insert(value) + } +} + +impl<'a, K: Hash + Ord, V> OccupiedEntry<'a, K, V> { + /// Remove the value at the position described by this entry. + pub fn remove_entry(self) -> (K, V) { + let res = self.underlying_entry.remove_entry(); + assert!(self.keys.remove(&res.0), "map and keys must be consistent"); + res + } + + /// Get a reference to the value at the position described by this entry. + pub fn get(&self) -> &V { + self.underlying_entry.get() + } + + /// Get a mutable reference to the value at the position described by this entry. + pub fn get_mut(&mut self) -> &mut V { + self.underlying_entry.get_mut() + } + + /// Consume this entry, returning a mutable reference to the value at the position described by + /// this entry. + pub fn into_mut(self) -> &'a mut V { + self.underlying_entry.into_mut() + } +} diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index 1d46865b601..1673bd07f69 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -40,6 +40,8 @@ pub(crate) mod transaction_utils; pub(crate) mod scid_utils; pub(crate) mod time; +pub mod indexed_map; + /// Logging macro utilities. #[macro_use] pub(crate) mod macro_logger;