Skip to content

Commit bfd818f

Browse files
committed
Add override for blinded path creation
Allow tests to provide a override that receives the caller's , enabling custom blinded-path generation while preserving valid bindings. Co-Authored-By: HAL 9000
1 parent 3ca1f1c commit bfd818f

1 file changed

Lines changed: 34 additions & 1 deletion

File tree

lightning/src/util/test_utils.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,23 @@ impl chaininterface::FeeEstimator for TestFeeEstimator {
165165
}
166166
}
167167

168+
/// Override closure type for [`TestRouter::override_create_blinded_payment_paths`].
169+
///
170+
/// This closure is called instead of the default [`Router::create_blinded_payment_paths`]
171+
/// implementation when set, receiving the actual [`ReceiveTlvs`] so tests can construct custom
172+
/// blinded payment paths using the same TLVs the caller generated.
173+
pub type BlindedPaymentPathOverrideFn = Box<
174+
dyn Fn(
175+
PublicKey,
176+
ReceiveAuthKey,
177+
Vec<ChannelDetails>,
178+
ReceiveTlvs,
179+
Option<u64>,
180+
) -> Result<Vec<BlindedPaymentPath>, ()>
181+
+ Send
182+
+ Sync,
183+
>;
184+
168185
pub struct TestRouter<'a> {
169186
pub router: DefaultRouter<
170187
Arc<NetworkGraph<&'a TestLogger>>,
@@ -177,6 +194,7 @@ pub struct TestRouter<'a> {
177194
pub network_graph: Arc<NetworkGraph<&'a TestLogger>>,
178195
pub next_routes: Mutex<VecDeque<(RouteParameters, Option<Result<Route, &'static str>>)>>,
179196
pub next_blinded_payment_paths: Mutex<Vec<BlindedPaymentPath>>,
197+
pub override_create_blinded_payment_paths: Mutex<Option<BlindedPaymentPathOverrideFn>>,
180198
pub scorer: &'a RwLock<TestScorer>,
181199
}
182200

@@ -188,6 +206,7 @@ impl<'a> TestRouter<'a> {
188206
let entropy_source = Arc::new(RandomBytes::new([42; 32]));
189207
let next_routes = Mutex::new(VecDeque::new());
190208
let next_blinded_payment_paths = Mutex::new(Vec::new());
209+
let override_create_blinded_payment_paths = Mutex::new(None);
191210
Self {
192211
router: DefaultRouter::new(
193212
Arc::clone(&network_graph),
@@ -199,6 +218,7 @@ impl<'a> TestRouter<'a> {
199218
network_graph,
200219
next_routes,
201220
next_blinded_payment_paths,
221+
override_create_blinded_payment_paths,
202222
scorer,
203223
}
204224
}
@@ -321,6 +341,12 @@ impl<'a> Router for TestRouter<'a> {
321341
first_hops: Vec<ChannelDetails>, tlvs: ReceiveTlvs, amount_msats: Option<u64>,
322342
secp_ctx: &Secp256k1<T>,
323343
) -> Result<Vec<BlindedPaymentPath>, ()> {
344+
if let Some(override_fn) =
345+
self.override_create_blinded_payment_paths.lock().unwrap().as_ref()
346+
{
347+
return override_fn(recipient, local_node_receive_key, first_hops, tlvs, amount_msats);
348+
}
349+
324350
let mut expected_paths = self.next_blinded_payment_paths.lock().unwrap();
325351
if expected_paths.is_empty() {
326352
self.router.create_blinded_payment_paths(
@@ -366,6 +392,7 @@ pub enum TestMessageRouterInternal<'a> {
366392
pub struct TestMessageRouter<'a> {
367393
pub inner: TestMessageRouterInternal<'a>,
368394
pub peers_override: Mutex<Vec<PublicKey>>,
395+
pub forward_node_scid_override: Mutex<HashMap<PublicKey, u64>>,
369396
}
370397

371398
impl<'a> TestMessageRouter<'a> {
@@ -378,6 +405,7 @@ impl<'a> TestMessageRouter<'a> {
378405
entropy_source,
379406
)),
380407
peers_override: Mutex::new(Vec::new()),
408+
forward_node_scid_override: Mutex::new(new_hash_map()),
381409
}
382410
}
383411

@@ -390,6 +418,7 @@ impl<'a> TestMessageRouter<'a> {
390418
entropy_source,
391419
)),
392420
peers_override: Mutex::new(Vec::new()),
421+
forward_node_scid_override: Mutex::new(new_hash_map()),
393422
}
394423
}
395424
}
@@ -421,9 +450,13 @@ impl<'a> MessageRouter for TestMessageRouter<'a> {
421450
{
422451
let peers_override = self.peers_override.lock().unwrap();
423452
if !peers_override.is_empty() {
453+
let scid_override = self.forward_node_scid_override.lock().unwrap();
424454
let peer_override_nodes: Vec<_> = peers_override
425455
.iter()
426-
.map(|pk| MessageForwardNode { node_id: *pk, short_channel_id: None })
456+
.map(|pk| MessageForwardNode {
457+
node_id: *pk,
458+
short_channel_id: scid_override.get(pk).copied(),
459+
})
427460
.collect();
428461
peers = peer_override_nodes;
429462
}

0 commit comments

Comments
 (0)