Skip to content

Commit b07e707

Browse files
WIP: try pushing scids on the heap instead
1 parent 493c4c2 commit b07e707

File tree

1 file changed

+67
-32
lines changed

1 file changed

+67
-32
lines changed

lightning/src/routing/onion_message.rs

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,48 +54,43 @@ pub fn find_path<L: Deref, GL: Deref>(
5454
let start = NodeId::from_pubkey(&our_node_pubkey);
5555
let mut valid_first_hops = HashSet::new();
5656
let mut frontier = BinaryHeap::new();
57-
frontier.push(PathBuildingHop { cost: 0, node_id: start, parent_node_id: start });
57+
let mut visited = HashMap::new();
5858
if let Some(first_hops) = first_hops {
5959
for hop in first_hops {
60+
if &hop.counterparty.node_id == destination { return Ok(vec![*destination]) }
6061
if hop.counterparty.node_id == *our_node_pubkey { return Err(Error::InvalidFirstHop) }
6162
#[cfg(not(feature = "_bench_unstable"))]
6263
if !hop.counterparty.features.supports_onion_messages() { continue; }
6364
let node_id = NodeId::from_pubkey(&hop.counterparty.node_id);
64-
frontier.push(PathBuildingHop { cost: 1, node_id, parent_node_id: start });
65+
match visited.entry(node_id) {
66+
hash_map::Entry::Occupied(_) => continue,
67+
hash_map::Entry::Vacant(e) => { e.insert(start); },
68+
};
69+
if let Some(node_info) = network_nodes.get(&node_id) {
70+
for scid in &node_info.channels {
71+
if let Some(chan_info) = network_channels.get(&scid) {
72+
if let Some((directed_channel, successor)) = chan_info.as_directed_from(&node_id) {
73+
if *successor == start { continue } // TODO: test
74+
if directed_channel.direction().enabled {
75+
frontier.push(PathBuildingHop {
76+
cost: 1, scid: *scid, one_to_two: chan_info.node_one == node_id, parent_scid: hop.short_channel_id.unwrap_or(0),
77+
});
78+
}
79+
}
80+
}
81+
}
82+
}
6583
valid_first_hops.insert(node_id);
6684
}
6785
}
68-
69-
let mut visited = HashMap::new();
70-
while let Some(PathBuildingHop { cost, node_id, parent_node_id }) = frontier.pop() {
71-
match visited.entry(node_id) {
72-
hash_map::Entry::Occupied(_) => continue,
73-
hash_map::Entry::Vacant(e) => e.insert(parent_node_id),
74-
};
75-
if node_id == dest_node_id {
76-
let path = reverse_path(visited, our_node_id, dest_node_id)?;
77-
log_info!(logger, "Got route to {:?}: {:?}", destination, path);
78-
return Ok(path)
79-
}
80-
if let Some(node_info) = network_nodes.get(&node_id) {
81-
// Only consider the network graph if first_hops does not override it.
82-
if valid_first_hops.contains(&node_id) || node_id == our_node_id {
83-
} else if let Some(node_ann) = &node_info.announcement_info {
84-
#[cfg(not(feature = "_bench_unstable"))]
85-
if !node_ann.features.supports_onion_messages() || node_ann.features.requires_unknown_bits()
86-
{ continue; }
87-
} else { continue; }
86+
if frontier.is_empty() {
87+
if let Some(node_info) = network_nodes.get(&start) {
8888
for scid in &node_info.channels {
8989
if let Some(chan_info) = network_channels.get(&scid) {
90-
if let Some((directed_channel, successor)) = chan_info.as_directed_from(&node_id) {
90+
if let Some((directed_channel, successor)) = chan_info.as_directed_from(&start) {
9191
if directed_channel.direction().enabled {
92-
// We may push a given successor multiple times, but the heap should sort its best
93-
// entry to the top. We do this because there is no way to adjust the priority of an
94-
// existing entry in `BinaryHeap`.
9592
frontier.push(PathBuildingHop {
96-
cost: cost + 1,
97-
node_id: *successor,
98-
parent_node_id: node_id,
93+
cost: 1, scid: *scid, one_to_two: chan_info.node_one == start, parent_scid: 0,
9994
});
10095
}
10196
}
@@ -104,6 +99,45 @@ pub fn find_path<L: Deref, GL: Deref>(
10499
}
105100
}
106101

102+
while let Some(PathBuildingHop { cost, scid, one_to_two, parent_scid }) = frontier.pop() {
103+
if let Some(chan_info) = network_channels.get(&scid) {
104+
let directed_from_node_id = if one_to_two { chan_info.node_one } else { chan_info.node_two };
105+
let directed_to_node_id = if one_to_two { chan_info.node_two } else { chan_info.node_one };
106+
match visited.entry(directed_to_node_id) {
107+
hash_map::Entry::Occupied(_) => continue,
108+
hash_map::Entry::Vacant(e) => e.insert(directed_from_node_id),
109+
};
110+
if directed_to_node_id == dest_node_id {
111+
let path = reverse_path(visited, our_node_id, dest_node_id)?;
112+
log_info!(logger, "Got route to {:?}: {:?}", destination, path);
113+
return Ok(path)
114+
}
115+
if let Some(node_info) = network_nodes.get(&directed_to_node_id) {
116+
// Only consider the network graph if first_hops does not override it.
117+
if valid_first_hops.contains(&directed_to_node_id) || directed_to_node_id == our_node_id {
118+
} else if let Some(node_ann) = &node_info.announcement_info {
119+
#[cfg(not(feature = "_bench_unstable"))]
120+
if !node_ann.features.supports_onion_messages() || node_ann.features.requires_unknown_bits()
121+
{ continue; }
122+
} else { continue; }
123+
for scid_to_push in &node_info.channels {
124+
if let Some(chan_info) = network_channels.get(&scid_to_push) {
125+
if let Some((directed_channel, successor)) = chan_info.as_directed_from(&directed_to_node_id) {
126+
if directed_channel.direction().enabled {
127+
let one_to_two = if let Some(chan_info) = network_channels.get(&scid_to_push) {
128+
directed_to_node_id == chan_info.node_one
129+
} else { continue };
130+
frontier.push(PathBuildingHop {
131+
cost: cost + 1, scid: *scid_to_push, parent_scid: scid, one_to_two,
132+
});
133+
}
134+
}
135+
}
136+
}
137+
}
138+
}
139+
}
140+
107141
Err(Error::PathNotFound)
108142
}
109143

@@ -138,8 +172,9 @@ impl std::error::Error for Error {}
138172
#[derive(Eq, PartialEq)]
139173
struct PathBuildingHop {
140174
cost: u64,
141-
node_id: NodeId,
142-
parent_node_id: NodeId,
175+
scid: u64,
176+
one_to_two: bool,
177+
parent_scid: u64,
143178
}
144179

145180
impl PartialOrd for PathBuildingHop {
@@ -250,7 +285,7 @@ mod tests {
250285
// Route to 1 via 2 and 3 because our channel to 1 is disabled
251286
let path = super::find_path(&our_id, &node_pks[0], &network_graph, None, Arc::clone(&logger)).unwrap();
252287
assert_eq!(path.len(), 3);
253-
assert_eq!(path[0], node_pks[1]);
288+
assert!((path[0] == node_pks[1]) || (path[0] == node_pks[7]));
254289
assert_eq!(path[1], node_pks[2]);
255290
assert_eq!(path[2], node_pks[0]);
256291

0 commit comments

Comments
 (0)