@@ -51,37 +51,48 @@ fn do_test_htlc_interception_flags(
5151 let node_chanmgrs = create_node_chanmgrs ( 3 , & node_cfgs, & [ None , Some ( intercept_config) , None ] ) ;
5252 let nodes = create_network ( 3 , & node_cfgs, & node_chanmgrs) ;
5353
54- create_announced_chan_between_nodes ( & nodes, 0 , 1 ) ;
54+ let inbound_private = match flag {
55+ Flag :: FromPrivateChannels => {
56+ create_unannounced_chan_between_nodes_with_value ( & nodes, 0 , 1 , 100000 , 0 ) ;
57+ true
58+ } ,
59+ _ => {
60+ create_announced_chan_between_nodes ( & nodes, 0 , 1 ) ;
61+ false
62+ } ,
63+ } ;
5564
5665 let node_0_id = nodes[ 0 ] . node . get_our_node_id ( ) ;
5766 let node_1_id = nodes[ 1 ] . node . get_our_node_id ( ) ;
5867 let node_2_id = nodes[ 2 ] . node . get_our_node_id ( ) ;
5968
6069 // First open the right type of channel (and get it in the right state) for the bit we're
6170 // testing.
62- let ( target_scid, target_chan_id) = match flag {
63- Flag :: ToOfflinePrivateChannels | Flag :: ToOnlinePrivateChannels => {
71+ let ( target_scid, target_chan_id, outbound_private_for_known_scids) = match flag {
72+ Flag :: ToOfflinePrivateChannels
73+ | Flag :: ToOnlinePrivateChannels
74+ | Flag :: FromPublicToPrivateChannels => {
6475 create_unannounced_chan_between_nodes_with_value ( & nodes, 1 , 2 , 100000 , 0 ) ;
6576 let chan_id = nodes[ 2 ] . node . list_channels ( ) [ 0 ] . channel_id ;
6677 let scid = nodes[ 2 ] . node . list_channels ( ) [ 0 ] . short_channel_id . unwrap ( ) ;
6778 if flag == Flag :: ToOfflinePrivateChannels {
6879 nodes[ 1 ] . node . peer_disconnected ( node_2_id) ;
6980 nodes[ 2 ] . node . peer_disconnected ( node_1_id) ;
70- } else {
71- assert_eq ! ( flag, Flag :: ToOnlinePrivateChannels ) ;
7281 }
73- ( scid, chan_id)
82+ ( scid, chan_id, Some ( true ) )
7483 } ,
75- Flag :: ToInterceptSCIDs | Flag :: ToPublicChannels | Flag :: ToUnknownSCIDs => {
84+ Flag :: ToInterceptSCIDs
85+ | Flag :: ToPublicChannels
86+ | Flag :: FromPrivateChannels
87+ | Flag :: FromPublicToPublicChannels
88+ | Flag :: ToUnknownSCIDs => {
7689 let ( chan_upd, _, chan_id, _) = create_announced_chan_between_nodes ( & nodes, 1 , 2 ) ;
7790 if flag == Flag :: ToInterceptSCIDs {
78- ( nodes[ 1 ] . node . get_intercept_scid ( ) , chan_id)
79- } else if flag == Flag :: ToPublicChannels {
80- ( chan_upd. contents . short_channel_id , chan_id)
91+ ( nodes[ 1 ] . node . get_intercept_scid ( ) , chan_id, None )
8192 } else if flag == Flag :: ToUnknownSCIDs {
82- ( 42424242 , chan_id)
93+ ( 42424242 , chan_id, None )
8394 } else {
84- panic ! ( ) ;
95+ ( chan_upd . contents . short_channel_id , chan_id , Some ( false ) )
8596 }
8697 } ,
8798 _ => panic ! ( "Combined flags aren't allowed" ) ,
@@ -101,21 +112,50 @@ fn do_test_htlc_interception_flags(
101112 get_route_and_payment_hash ! ( nodes[ 0 ] , nodes[ 2 ] , pay_params, amt_msat) ;
102113 route. paths [ 0 ] . hops [ 1 ] . short_channel_id = target_scid;
103114
104- let interception_bit_match = ( flags_bitmask & ( flag as u8 ) ) != 0 ;
115+ let mut should_intercept = false ;
116+ for a_flag in ALL_FLAGS {
117+ if flags_bitmask & ( a_flag as u8 ) != 0 {
118+ match a_flag {
119+ Flag :: ToInterceptSCIDs => {
120+ should_intercept |= flag == Flag :: ToInterceptSCIDs ;
121+ } ,
122+ Flag :: ToOfflinePrivateChannels => {
123+ should_intercept |= flag == Flag :: ToOfflinePrivateChannels ;
124+ } ,
125+ Flag :: ToOnlinePrivateChannels => {
126+ should_intercept |= flag != Flag :: ToOfflinePrivateChannels
127+ && outbound_private_for_known_scids == Some ( true ) ;
128+ } ,
129+ Flag :: ToPublicChannels => {
130+ should_intercept |= outbound_private_for_known_scids == Some ( false ) ;
131+ } ,
132+ Flag :: ToUnknownSCIDs => {
133+ should_intercept |= flag == Flag :: ToUnknownSCIDs ;
134+ } ,
135+ Flag :: FromPrivateChannels => {
136+ should_intercept |= inbound_private;
137+ } ,
138+ Flag :: FromPublicToPrivateChannels => {
139+ should_intercept |=
140+ !inbound_private && outbound_private_for_known_scids == Some ( true ) ;
141+ } ,
142+ Flag :: FromPublicToPublicChannels => {
143+ should_intercept |=
144+ !inbound_private && outbound_private_for_known_scids == Some ( false ) ;
145+ } ,
146+ _ => panic ! ( "Combined flags aren't allowed" ) ,
147+ }
148+ }
149+ }
150+
105151 match modification {
106152 Some ( ForwardingMod :: FeeTooLow ) => {
107- assert ! (
108- interception_bit_match,
109- "No reason to test failing if we aren't trying to intercept" ,
110- ) ;
153+ assert ! ( should_intercept, "No reason to test failing if we aren't trying to intercept" ) ;
111154 route. paths [ 0 ] . hops [ 0 ] . fee_msat = 500 ;
112155 } ,
113156 Some ( ForwardingMod :: CLTVBelowConfig ) => {
114157 route. paths [ 0 ] . hops [ 0 ] . cltv_expiry_delta = 6 * 12 ;
115- assert ! (
116- interception_bit_match,
117- "No reason to test failing if we aren't trying to intercept" ,
118- ) ;
158+ assert ! ( should_intercept, "No reason to test failing if we aren't trying to intercept" ) ;
119159 } ,
120160 Some ( ForwardingMod :: CLTVBelowMin ) => {
121161 route. paths [ 0 ] . hops [ 0 ] . cltv_expiry_delta = 6 ;
@@ -133,7 +173,7 @@ fn do_test_htlc_interception_flags(
133173 do_commitment_signed_dance ( & nodes[ 1 ] , & nodes[ 0 ] , & payment_event. commitment_msg , false , true ) ;
134174 expect_and_process_pending_htlcs ( & nodes[ 1 ] , false ) ;
135175
136- if interception_bit_match && modification. is_none ( ) {
176+ if should_intercept && modification. is_none ( ) {
137177 // If we were set to intercept, check that we got an interception event then
138178 // forward the HTLC on to nodes[2] and claim the payment.
139179 let intercept_id;
@@ -172,7 +212,14 @@ fn do_test_htlc_interception_flags(
172212 // If we were not set to intercept, check that the HTLC either failed or was
173213 // automatically forwarded as appropriate.
174214 match ( modification, flag) {
175- ( None , Flag :: ToOnlinePrivateChannels | Flag :: ToPublicChannels ) => {
215+ (
216+ None ,
217+ Flag :: ToOnlinePrivateChannels
218+ | Flag :: ToPublicChannels
219+ | Flag :: FromPrivateChannels
220+ | Flag :: FromPublicToPrivateChannels
221+ | Flag :: FromPublicToPublicChannels ,
222+ ) => {
176223 check_added_monitors ( & nodes[ 1 ] , 1 ) ;
177224
178225 let forward_ev = SendEvent :: from_node ( & nodes[ 1 ] ) ;
@@ -241,31 +288,55 @@ fn do_test_htlc_interception_flags(
241288}
242289
243290const MAX_BITMASK : u8 = HTLCInterceptionFlags :: AllValidHTLCs as u8 ;
244- const ALL_FLAGS : [ HTLCInterceptionFlags ; 5 ] = [
291+ const ALL_FLAGS : [ HTLCInterceptionFlags ; 8 ] = [
245292 HTLCInterceptionFlags :: ToInterceptSCIDs ,
246293 HTLCInterceptionFlags :: ToOfflinePrivateChannels ,
247294 HTLCInterceptionFlags :: ToOnlinePrivateChannels ,
248295 HTLCInterceptionFlags :: ToPublicChannels ,
249296 HTLCInterceptionFlags :: ToUnknownSCIDs ,
297+ HTLCInterceptionFlags :: FromPrivateChannels ,
298+ HTLCInterceptionFlags :: FromPublicToPrivateChannels ,
299+ HTLCInterceptionFlags :: FromPublicToPublicChannels ,
250300] ;
251-
252301#[ test]
253- fn test_htlc_interception_flags ( ) {
302+ fn check_all_flags ( ) {
254303 let mut all_flag_bits = 0 ;
255304 for flag in ALL_FLAGS {
256305 all_flag_bits |= flag as isize ;
257306 }
258307 assert_eq ! ( all_flag_bits, MAX_BITMASK as isize , "all flags must test all bits" ) ;
308+ }
259309
310+ fn test_htlc_interception_flags_subrange < I : Iterator < Item = u8 > > ( r : I ) {
260311 // Test all 2^5 = 32 combinations of the HTLCInterceptionFlags bitmask
261312 // For each combination, test 5 different HTLC forwards and verify correct interception behavior
262- for flags_bitmask in 0 ..= MAX_BITMASK {
313+ for flags_bitmask in r {
263314 for flag in ALL_FLAGS {
264315 do_test_htlc_interception_flags ( flags_bitmask, flag, None ) ;
265316 }
266317 }
267318}
268319
320+ #[ test]
321+ fn test_htlc_interception_flags_a ( ) {
322+ test_htlc_interception_flags_subrange ( 0 ..MAX_BITMASK / 4 ) ;
323+ }
324+
325+ #[ test]
326+ fn test_htlc_interception_flags_b ( ) {
327+ test_htlc_interception_flags_subrange ( MAX_BITMASK / 4 ..MAX_BITMASK / 2 ) ;
328+ }
329+
330+ #[ test]
331+ fn test_htlc_interception_flags_c ( ) {
332+ test_htlc_interception_flags_subrange ( MAX_BITMASK / 2 ..MAX_BITMASK / 4 * 3 ) ;
333+ }
334+
335+ #[ test]
336+ fn test_htlc_interception_flags_d ( ) {
337+ test_htlc_interception_flags_subrange ( MAX_BITMASK / 4 * 3 ..=MAX_BITMASK ) ;
338+ }
339+
269340#[ test]
270341fn test_htlc_bad_for_chan_config ( ) {
271342 // Test that interception won't be done if an HTLC fails to meet the target channel's channel
@@ -274,6 +345,9 @@ fn test_htlc_bad_for_chan_config() {
274345 HTLCInterceptionFlags :: ToOfflinePrivateChannels ,
275346 HTLCInterceptionFlags :: ToOnlinePrivateChannels ,
276347 HTLCInterceptionFlags :: ToPublicChannels ,
348+ HTLCInterceptionFlags :: FromPrivateChannels ,
349+ HTLCInterceptionFlags :: FromPublicToPrivateChannels ,
350+ HTLCInterceptionFlags :: FromPublicToPublicChannels ,
277351 ] ;
278352 for flag in have_chan_flags {
279353 do_test_htlc_interception_flags ( flag as u8 , flag, Some ( ForwardingMod :: FeeTooLow ) ) ;
0 commit comments