Skip to content

Commit 9094319

Browse files
committed
impl write and read for resource manager
Adds write and read implementations to persist the DefaultResourceManager.
1 parent 0886f09 commit 9094319

1 file changed

Lines changed: 267 additions & 3 deletions

File tree

lightning/src/ln/resource_manager.rs

Lines changed: 267 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99

1010
#![allow(dead_code)]
1111

12+
use bitcoin::io::Read;
1213
use core::{fmt::Display, time::Duration};
1314

1415
use crate::{
1516
crypto::chacha20::ChaCha20,
16-
ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS,
17+
io,
18+
ln::{channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS, msgs::DecodeError},
1719
prelude::{hash_map::Entry, new_hash_map, HashMap},
1820
sign::EntropySource,
1921
sync::Mutex,
22+
util::ser::{CollectionLength, Readable, ReadableArgs, Writeable, Writer},
2023
};
2124

2225
/// A trait for managing channel resources and making HTLC forwarding decisions.
@@ -129,6 +132,14 @@ impl Default for ResourceManagerConfig {
129132
}
130133
}
131134

135+
impl_writeable_tlv_based!(ResourceManagerConfig, {
136+
(1, general_allocation_pct, required),
137+
(3, congestion_allocation_pct, required),
138+
(5, resolution_period, required),
139+
(7, revenue_window, required),
140+
(9, reputation_multiplier, required),
141+
});
142+
132143
/// The outcome of an HTLC forwarding decision.
133144
#[derive(PartialEq, Eq, Debug)]
134145
pub enum ForwardingOutcome {
@@ -337,6 +348,47 @@ impl GeneralBucket {
337348
}
338349
}
339350

351+
impl Writeable for GeneralBucket {
352+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
353+
let channel_info: HashMap<u64, [u8; 32]> =
354+
self.channels_slots.iter().map(|(scid, (_slots, salt))| (*scid, *salt)).collect();
355+
356+
write_tlv_fields!(writer, {
357+
(1, self.scid, required),
358+
(3, self.total_slots, required),
359+
(5, self.total_liquidity, required),
360+
(7, channel_info, required),
361+
});
362+
Ok(())
363+
}
364+
}
365+
366+
impl<ES: EntropySource> ReadableArgs<&ES> for GeneralBucket {
367+
fn read<R: Read>(reader: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
368+
_init_and_read_len_prefixed_tlv_fields!(reader, {
369+
(1, our_scid, required),
370+
(3, general_total_slots, required),
371+
(5, general_total_liquidity, required),
372+
(7, channel_info, required),
373+
});
374+
375+
let mut general_bucket = GeneralBucket::new(
376+
our_scid.0.unwrap(),
377+
general_total_slots.0.unwrap(),
378+
general_total_liquidity.0.unwrap(),
379+
);
380+
381+
let channel_info: HashMap<u64, [u8; 32]> = channel_info.0.unwrap();
382+
for (outgoing_scid, salt) in channel_info {
383+
general_bucket
384+
.assign_slots_for_channel(outgoing_scid, Some(salt), entropy_source)
385+
.map_err(|_| DecodeError::InvalidValue)?;
386+
}
387+
388+
Ok(general_bucket)
389+
}
390+
}
391+
340392
struct BucketResources {
341393
slots_allocated: u16,
342394
slots_used: u16,
@@ -374,7 +426,13 @@ impl BucketResources {
374426
}
375427
}
376428

377-
#[derive(Debug, Clone)]
429+
impl_writeable_tlv_based!(BucketResources, {
430+
(1, slots_allocated, required),
431+
(_unused, slots_used, (static_value, 0)),
432+
(3, liquidity_allocated, required),
433+
(_unused, liquidity_used, (static_value, 0)),
434+
});
435+
378436
struct PendingHTLC {
379437
incoming_amount_msat: u64,
380438
fee: u64,
@@ -565,6 +623,42 @@ impl Channel {
565623
}
566624
}
567625

626+
impl Writeable for Channel {
627+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
628+
write_tlv_fields!(writer, {
629+
(1, self.outgoing_reputation, required),
630+
(3, self.incoming_revenue, required),
631+
(5, self.general_bucket, required),
632+
(7, self.congestion_bucket, required),
633+
(9, self.last_congestion_misuse, required),
634+
(11, self.protected_bucket, required)
635+
});
636+
Ok(())
637+
}
638+
}
639+
640+
impl<ES: EntropySource> ReadableArgs<&ES> for Channel {
641+
fn read<R: Read>(reader: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
642+
_init_and_read_len_prefixed_tlv_fields!(reader, {
643+
(1, outgoing_reputation, required),
644+
(3, incoming_revenue, required),
645+
(5, general_bucket, (required: ReadableArgs, entropy_source)),
646+
(7, congestion_bucket, required),
647+
(9, last_congestion_misuse, required),
648+
(11, protected_bucket, required)
649+
});
650+
Ok(Channel {
651+
outgoing_reputation: outgoing_reputation.0.unwrap(),
652+
incoming_revenue: incoming_revenue.0.unwrap(),
653+
general_bucket: general_bucket.0.unwrap(),
654+
pending_htlcs: new_hash_map(),
655+
congestion_bucket: congestion_bucket.0.unwrap(),
656+
last_congestion_misuse: last_congestion_misuse.0.unwrap(),
657+
protected_bucket: protected_bucket.0.unwrap(),
658+
})
659+
}
660+
}
661+
568662
/// An implementation of [`ResourceManager`] for managing channel resources and informing HTLC
569663
/// forwarding decisions. It implements the core of the mitigation as proposed in
570664
/// https://github.com/lightning/bolts/pull/1280.
@@ -818,6 +912,82 @@ impl ResourceManager for DefaultResourceManager {
818912
}
819913
}
820914

915+
pub struct PendingHTLCReplay {
916+
pub incoming_channel_id: u64,
917+
pub incoming_amount_msat: u64,
918+
pub incoming_htlc_id: u64,
919+
pub incoming_cltv_expiry: u32,
920+
pub incoming_accountable: bool,
921+
pub outgoing_channel_id: u64,
922+
pub outgoing_amount_msat: u64,
923+
pub added_at_unix_seconds: u64,
924+
pub height_added: u32,
925+
}
926+
927+
impl Writeable for DefaultResourceManager {
928+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
929+
let channels = self.channels.lock().unwrap();
930+
write_tlv_fields!(writer, {
931+
(1, self.config, required),
932+
(3, channels, required),
933+
});
934+
Ok(())
935+
}
936+
}
937+
938+
impl<ES: EntropySource> ReadableArgs<&ES> for DefaultResourceManager {
939+
fn read<R: Read>(
940+
reader: &mut R, entropy_source: &ES,
941+
) -> Result<DefaultResourceManager, DecodeError> {
942+
_init_and_read_len_prefixed_tlv_fields!(reader, {
943+
(1, config, required),
944+
(3, channels, (required: ReadableArgs, entropy_source)),
945+
});
946+
let channels: HashMap<u64, Channel> = channels.0.unwrap();
947+
Ok(DefaultResourceManager { config: config.0.unwrap(), channels: Mutex::new(channels) })
948+
}
949+
}
950+
951+
impl<ES: EntropySource> ReadableArgs<&ES> for HashMap<u64, Channel> {
952+
fn read<R: Read>(r: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
953+
let len: CollectionLength = Readable::read(r)?;
954+
let mut ret = new_hash_map();
955+
for _ in 0..len.0 {
956+
let k: u64 = Readable::read(r)?;
957+
let v = Channel::read(r, entropy_source)?;
958+
if ret.insert(k, v).is_some() {
959+
return Err(DecodeError::InvalidValue);
960+
}
961+
}
962+
Ok(ret)
963+
}
964+
}
965+
966+
impl DefaultResourceManager {
967+
// This should only be called once during startup to replay pending HTLCs we had before
968+
// shutdown.
969+
pub fn replay_pending_htlcs<ES: EntropySource>(
970+
&self, pending_htlcs: Vec<PendingHTLCReplay>, entropy_source: &ES,
971+
) -> Result<(), DecodeError> {
972+
for htlc in pending_htlcs {
973+
self.add_htlc(
974+
htlc.incoming_channel_id,
975+
htlc.incoming_amount_msat,
976+
htlc.incoming_cltv_expiry,
977+
htlc.outgoing_channel_id,
978+
htlc.outgoing_amount_msat,
979+
htlc.incoming_accountable,
980+
htlc.incoming_htlc_id,
981+
htlc.height_added,
982+
htlc.added_at_unix_seconds,
983+
entropy_source,
984+
)
985+
.map_err(|_| DecodeError::InvalidValue)?;
986+
}
987+
Ok(())
988+
}
989+
}
990+
821991
/// A weighted average that decays over a specified window.
822992
///
823993
/// It enables tracking of historical behavior without storing individual data points.
@@ -865,6 +1035,16 @@ impl DecayingAverage {
8651035
}
8661036
}
8671037

1038+
impl_writeable_tlv_based!(DecayingAverage, {
1039+
(1, value, required),
1040+
(3, last_updated_unix_secs, required),
1041+
(5, window, required),
1042+
(_unused, decay_rate, (static_value, {
1043+
let w: Duration = window.0.unwrap();
1044+
0.5_f64.powf(2.0 / w.as_secs_f64())
1045+
})),
1046+
});
1047+
8681048
/// Tracks an average value over multiple rolling windows to smooth out volatility.
8691049
///
8701050
/// It tracks the average value using a single window duration but extends observation over
@@ -929,6 +1109,13 @@ impl AggregatedWindowAverage {
9291109
}
9301110
}
9311111

1112+
impl_writeable_tlv_based!(AggregatedWindowAverage, {
1113+
(1, start_timestamp_unix_secs, required),
1114+
(3, window_count, required),
1115+
(5, window_duration, required),
1116+
(7, aggregated_revenue_decaying, required),
1117+
});
1118+
9321119
#[cfg(test)]
9331120
mod tests {
9341121
use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -945,7 +1132,10 @@ mod tests {
9451132
},
9461133
},
9471134
sign::EntropySource,
948-
util::test_utils::TestKeysInterface,
1135+
util::{
1136+
ser::{ReadableArgs, Writeable},
1137+
test_utils::TestKeysInterface,
1138+
},
9491139
};
9501140

9511141
const WINDOW: Duration = Duration::from_secs(2016 * 10 * 60);
@@ -1315,6 +1505,13 @@ mod tests {
13151505
outgoing_channel.outgoing_reputation.add_value(target_reputation, now).unwrap();
13161506
}
13171507

1508+
fn add_revenue(rm: &DefaultResourceManager, incoming_scid: u64, revenue: i64) {
1509+
let mut channels = rm.channels.lock().unwrap();
1510+
let channel = channels.get_mut(&incoming_scid).unwrap();
1511+
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
1512+
channel.incoming_revenue.add_value(revenue, now).unwrap();
1513+
}
1514+
13181515
fn fill_general_bucket(rm: &DefaultResourceManager, incoming_scid: u64) {
13191516
let mut channels = rm.channels.lock().unwrap();
13201517
let incoming_channel = channels.get_mut(&incoming_scid).unwrap();
@@ -2210,6 +2407,73 @@ mod tests {
22102407
assert!(get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID_2).is_none());
22112408
}
22122409

2410+
#[test]
2411+
fn test_simple_manager_serialize_deserialize() {
2412+
// This is not a complete test of the serialization/deserialization of the resource
2413+
// manager because the pending HTLCs will be replayed through `replay_pending_htlcs` by
2414+
// the upstream i.e ChannelManager.
2415+
let rm = create_test_resource_manager_with_channels();
2416+
let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet);
2417+
2418+
add_test_htlc(&rm, false, 0, None, &entropy_source).unwrap();
2419+
2420+
let reputation = 50_000_000;
2421+
add_reputation(&rm, OUTGOING_SCID, reputation);
2422+
2423+
let revenue = 70_000_000;
2424+
add_revenue(&rm, INCOMING_SCID, revenue);
2425+
2426+
let serialized_rm = rm.encode();
2427+
2428+
let channels = rm.channels.lock().unwrap();
2429+
let expected_incoming_channel = channels.get(&INCOMING_SCID).unwrap();
2430+
let (expected_slots, expected_salt) = expected_incoming_channel
2431+
.general_bucket
2432+
.channels_slots
2433+
.get(&OUTGOING_SCID)
2434+
.unwrap()
2435+
.clone();
2436+
2437+
let deserialized_rm =
2438+
DefaultResourceManager::read(&mut serialized_rm.as_slice(), &entropy_source).unwrap();
2439+
let deserialized_channels = deserialized_rm.channels.lock().unwrap();
2440+
assert_eq!(2, deserialized_channels.len());
2441+
2442+
let outgoing_channel = deserialized_channels.get(&OUTGOING_SCID).unwrap();
2443+
assert!(outgoing_channel.general_bucket.channels_slots.is_empty());
2444+
2445+
assert_eq!(outgoing_channel.outgoing_reputation.value, reputation);
2446+
2447+
let incoming_channel = deserialized_channels.get(&INCOMING_SCID).unwrap();
2448+
assert_eq!(incoming_channel.incoming_revenue.aggregated_revenue_decaying.value, revenue);
2449+
2450+
assert_eq!(incoming_channel.general_bucket.channels_slots.len(), 1);
2451+
2452+
let (slots, salt) =
2453+
incoming_channel.general_bucket.channels_slots.get(&OUTGOING_SCID).unwrap().clone();
2454+
assert_eq!(slots, expected_slots);
2455+
assert_eq!(salt, expected_salt);
2456+
2457+
let congestion_bucket = &incoming_channel.congestion_bucket;
2458+
assert_eq!(
2459+
congestion_bucket.slots_allocated,
2460+
expected_incoming_channel.congestion_bucket.slots_allocated
2461+
);
2462+
assert_eq!(
2463+
congestion_bucket.liquidity_allocated,
2464+
expected_incoming_channel.congestion_bucket.liquidity_allocated
2465+
);
2466+
let protected_bucket = &incoming_channel.protected_bucket;
2467+
assert_eq!(
2468+
protected_bucket.slots_allocated,
2469+
expected_incoming_channel.protected_bucket.slots_allocated
2470+
);
2471+
assert_eq!(
2472+
protected_bucket.liquidity_allocated,
2473+
expected_incoming_channel.protected_bucket.liquidity_allocated
2474+
);
2475+
}
2476+
22132477
#[test]
22142478
fn test_decaying_average_values() {
22152479
// Test average decay at different timestamps. The values we are asserting have been

0 commit comments

Comments
 (0)