@@ -54,48 +54,43 @@ pub fn find_path<L: Deref, GL: Deref>(
54
54
let start = NodeId :: from_pubkey ( & our_node_pubkey) ;
55
55
let mut valid_first_hops = HashSet :: new ( ) ;
56
56
let mut frontier = BinaryHeap :: new ( ) ;
57
- frontier . push ( PathBuildingHop { cost : 0 , node_id : start , parent_node_id : start } ) ;
57
+ let mut visited = HashMap :: new ( ) ;
58
58
if let Some ( first_hops) = first_hops {
59
59
for hop in first_hops {
60
+ if & hop. counterparty . node_id == destination { return Ok ( vec ! [ * destination] ) }
60
61
if hop. counterparty . node_id == * our_node_pubkey { return Err ( Error :: InvalidFirstHop ) }
61
62
#[ cfg( not( feature = "_bench_unstable" ) ) ]
62
63
if !hop. counterparty . features . supports_onion_messages ( ) { continue ; }
63
64
let node_id = NodeId :: from_pubkey ( & hop. counterparty . node_id ) ;
64
- frontier. push ( PathBuildingHop { cost : 1 , node_id, parent_node_id : start } ) ;
65
+ match visited. entry ( node_id) {
66
+ hash_map:: Entry :: Occupied ( _) => continue ,
67
+ hash_map:: Entry :: Vacant ( e) => { e. insert ( start) ; } ,
68
+ } ;
69
+ if let Some ( node_info) = network_nodes. get ( & node_id) {
70
+ for scid in & node_info. channels {
71
+ if let Some ( chan_info) = network_channels. get ( & scid) {
72
+ if let Some ( ( directed_channel, successor) ) = chan_info. as_directed_from ( & node_id) {
73
+ if * successor == start { continue } // TODO: test
74
+ if directed_channel. direction ( ) . enabled {
75
+ frontier. push ( PathBuildingHop {
76
+ cost : 1 , scid : * scid, one_to_two : chan_info. node_one == node_id, parent_scid : hop. short_channel_id . unwrap_or ( 0 ) ,
77
+ } ) ;
78
+ }
79
+ }
80
+ }
81
+ }
82
+ }
65
83
valid_first_hops. insert ( node_id) ;
66
84
}
67
85
}
68
-
69
- let mut visited = HashMap :: new ( ) ;
70
- while let Some ( PathBuildingHop { cost, node_id, parent_node_id } ) = frontier. pop ( ) {
71
- match visited. entry ( node_id) {
72
- hash_map:: Entry :: Occupied ( _) => continue ,
73
- hash_map:: Entry :: Vacant ( e) => e. insert ( parent_node_id) ,
74
- } ;
75
- if node_id == dest_node_id {
76
- let path = reverse_path ( visited, our_node_id, dest_node_id) ?;
77
- log_info ! ( logger, "Got route to {:?}: {:?}" , destination, path) ;
78
- return Ok ( path)
79
- }
80
- if let Some ( node_info) = network_nodes. get ( & node_id) {
81
- // Only consider the network graph if first_hops does not override it.
82
- if valid_first_hops. contains ( & node_id) || node_id == our_node_id {
83
- } else if let Some ( node_ann) = & node_info. announcement_info {
84
- #[ cfg( not( feature = "_bench_unstable" ) ) ]
85
- if !node_ann. features . supports_onion_messages ( ) || node_ann. features . requires_unknown_bits ( )
86
- { continue ; }
87
- } else { continue ; }
86
+ if frontier. is_empty ( ) {
87
+ if let Some ( node_info) = network_nodes. get ( & start) {
88
88
for scid in & node_info. channels {
89
89
if let Some ( chan_info) = network_channels. get ( & scid) {
90
- if let Some ( ( directed_channel, successor) ) = chan_info. as_directed_from ( & node_id ) {
90
+ if let Some ( ( directed_channel, successor) ) = chan_info. as_directed_from ( & start ) {
91
91
if directed_channel. direction ( ) . enabled {
92
- // We may push a given successor multiple times, but the heap should sort its best
93
- // entry to the top. We do this because there is no way to adjust the priority of an
94
- // existing entry in `BinaryHeap`.
95
92
frontier. push ( PathBuildingHop {
96
- cost : cost + 1 ,
97
- node_id : * successor,
98
- parent_node_id : node_id,
93
+ cost : 1 , scid : * scid, one_to_two : chan_info. node_one == start, parent_scid : 0 ,
99
94
} ) ;
100
95
}
101
96
}
@@ -104,6 +99,45 @@ pub fn find_path<L: Deref, GL: Deref>(
104
99
}
105
100
}
106
101
102
+ while let Some ( PathBuildingHop { cost, scid, one_to_two, parent_scid } ) = frontier. pop ( ) {
103
+ if let Some ( chan_info) = network_channels. get ( & scid) {
104
+ let directed_from_node_id = if one_to_two { chan_info. node_one } else { chan_info. node_two } ;
105
+ let directed_to_node_id = if one_to_two { chan_info. node_two } else { chan_info. node_one } ;
106
+ match visited. entry ( directed_to_node_id) {
107
+ hash_map:: Entry :: Occupied ( _) => continue ,
108
+ hash_map:: Entry :: Vacant ( e) => e. insert ( directed_from_node_id) ,
109
+ } ;
110
+ if directed_to_node_id == dest_node_id {
111
+ let path = reverse_path ( visited, our_node_id, dest_node_id) ?;
112
+ log_info ! ( logger, "Got route to {:?}: {:?}" , destination, path) ;
113
+ return Ok ( path)
114
+ }
115
+ if let Some ( node_info) = network_nodes. get ( & directed_to_node_id) {
116
+ // Only consider the network graph if first_hops does not override it.
117
+ if valid_first_hops. contains ( & directed_to_node_id) || directed_to_node_id == our_node_id {
118
+ } else if let Some ( node_ann) = & node_info. announcement_info {
119
+ #[ cfg( not( feature = "_bench_unstable" ) ) ]
120
+ if !node_ann. features . supports_onion_messages ( ) || node_ann. features . requires_unknown_bits ( )
121
+ { continue ; }
122
+ } else { continue ; }
123
+ for scid_to_push in & node_info. channels {
124
+ if let Some ( chan_info) = network_channels. get ( & scid_to_push) {
125
+ if let Some ( ( directed_channel, successor) ) = chan_info. as_directed_from ( & directed_to_node_id) {
126
+ if directed_channel. direction ( ) . enabled {
127
+ let one_to_two = if let Some ( chan_info) = network_channels. get ( & scid_to_push) {
128
+ directed_to_node_id == chan_info. node_one
129
+ } else { continue } ;
130
+ frontier. push ( PathBuildingHop {
131
+ cost : cost + 1 , scid : * scid_to_push, parent_scid : scid, one_to_two,
132
+ } ) ;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+ }
139
+ }
140
+
107
141
Err ( Error :: PathNotFound )
108
142
}
109
143
@@ -138,8 +172,9 @@ impl std::error::Error for Error {}
138
172
#[ derive( Eq , PartialEq ) ]
139
173
struct PathBuildingHop {
140
174
cost : u64 ,
141
- node_id : NodeId ,
142
- parent_node_id : NodeId ,
175
+ scid : u64 ,
176
+ one_to_two : bool ,
177
+ parent_scid : u64 ,
143
178
}
144
179
145
180
impl PartialOrd for PathBuildingHop {
@@ -250,7 +285,7 @@ mod tests {
250
285
// Route to 1 via 2 and 3 because our channel to 1 is disabled
251
286
let path = super :: find_path ( & our_id, & node_pks[ 0 ] , & network_graph, None , Arc :: clone ( & logger) ) . unwrap ( ) ;
252
287
assert_eq ! ( path. len( ) , 3 ) ;
253
- assert_eq ! ( path[ 0 ] , node_pks[ 1 ] ) ;
288
+ assert ! ( ( path[ 0 ] == node_pks[ 1 ] ) || ( path [ 0 ] == node_pks [ 7 ] ) ) ;
254
289
assert_eq ! ( path[ 1 ] , node_pks[ 2 ] ) ;
255
290
assert_eq ! ( path[ 2 ] , node_pks[ 0 ] ) ;
256
291
0 commit comments