Skip to content

Commit 2aef6c2

Browse files
committed
impl write and read for resource manager
Adds write and read implementations to persist the DefaultResourceManager.
1 parent 30efca0 commit 2aef6c2

1 file changed

Lines changed: 271 additions & 2 deletions

File tree

lightning/src/ln/resource_manager.rs

Lines changed: 271 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99

1010
#![allow(dead_code)]
1111

12+
use bitcoin::io::Read;
1213
use core::{fmt::Display, time::Duration};
1314

1415
use crate::{
1516
crypto::chacha20::ChaCha20,
16-
ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS,
17+
io,
18+
ln::{channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS, msgs::DecodeError},
1719
prelude::{hash_map::Entry, new_hash_map, HashMap},
1820
sign::EntropySource,
1921
sync::Mutex,
22+
util::ser::{CollectionLength, Readable, ReadableArgs, Writeable, Writer},
2023
};
2124

2225
/// Resolution time in seconds that is considered "good". HTLCs resolved within this period are
@@ -85,6 +88,14 @@ impl Default for ResourceManagerConfig {
8588
}
8689
}
8790

91+
impl_writeable_tlv_based!(ResourceManagerConfig, {
92+
(1, general_allocation_pct, required),
93+
(3, congestion_allocation_pct, required),
94+
(5, resolution_period, required),
95+
(7, revenue_window, required),
96+
(9, reputation_multiplier, required),
97+
});
98+
8899
/// The outcome of an HTLC forwarding decision.
89100
#[derive(PartialEq, Eq, Debug)]
90101
pub enum ForwardingOutcome {
@@ -292,6 +303,47 @@ impl GeneralBucket {
292303
}
293304
}
294305

306+
impl Writeable for GeneralBucket {
307+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
308+
let channel_info: HashMap<u64, [u8; 32]> =
309+
self.channels_slots.iter().map(|(scid, (_slots, salt))| (*scid, *salt)).collect();
310+
311+
write_tlv_fields!(writer, {
312+
(1, self.scid, required),
313+
(3, self.total_slots, required),
314+
(5, self.total_liquidity, required),
315+
(7, channel_info, required),
316+
});
317+
Ok(())
318+
}
319+
}
320+
321+
impl<ES: EntropySource> ReadableArgs<&ES> for GeneralBucket {
322+
fn read<R: Read>(reader: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
323+
_init_and_read_len_prefixed_tlv_fields!(reader, {
324+
(1, our_scid, required),
325+
(3, general_total_slots, required),
326+
(5, general_total_liquidity, required),
327+
(7, channel_info, required),
328+
});
329+
330+
let mut general_bucket = GeneralBucket::new(
331+
our_scid.0.unwrap(),
332+
general_total_slots.0.unwrap(),
333+
general_total_liquidity.0.unwrap(),
334+
);
335+
336+
let channel_info: HashMap<u64, [u8; 32]> = channel_info.0.unwrap();
337+
for (outgoing_scid, salt) in channel_info {
338+
general_bucket
339+
.assign_slots_for_channel(outgoing_scid, Some(salt), entropy_source)
340+
.map_err(|_| DecodeError::InvalidValue)?;
341+
}
342+
343+
Ok(general_bucket)
344+
}
345+
}
346+
295347
struct BucketResources {
296348
slots_allocated: u16,
297349
slots_used: u16,
@@ -329,6 +381,13 @@ impl BucketResources {
329381
}
330382
}
331383

384+
impl_writeable_tlv_based!(BucketResources, {
385+
(1, slots_allocated, required),
386+
(_unused, slots_used, (static_value, 0)),
387+
(3, liquidity_allocated, required),
388+
(_unused, liquidity_used, (static_value, 0)),
389+
});
390+
332391
#[derive(Debug, Clone)]
333392
struct PendingHTLC {
334393
incoming_amount_msat: u64,
@@ -522,6 +581,42 @@ impl Channel {
522581
}
523582
}
524583

584+
impl Writeable for Channel {
585+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
586+
write_tlv_fields!(writer, {
587+
(1, self.outgoing_reputation, required),
588+
(3, self.incoming_revenue, required),
589+
(5, self.general_bucket, required),
590+
(7, self.congestion_bucket, required),
591+
(9, self.last_congestion_misuse, required),
592+
(11, self.protected_bucket, required)
593+
});
594+
Ok(())
595+
}
596+
}
597+
598+
impl<ES: EntropySource> ReadableArgs<&ES> for Channel {
599+
fn read<R: Read>(reader: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
600+
_init_and_read_len_prefixed_tlv_fields!(reader, {
601+
(1, outgoing_reputation, required),
602+
(3, incoming_revenue, required),
603+
(5, general_bucket, (required: ReadableArgs, entropy_source)),
604+
(7, congestion_bucket, required),
605+
(9, last_congestion_misuse, required),
606+
(11, protected_bucket, required)
607+
});
608+
Ok(Channel {
609+
outgoing_reputation: outgoing_reputation.0.unwrap(),
610+
incoming_revenue: incoming_revenue.0.unwrap(),
611+
general_bucket: general_bucket.0.unwrap(),
612+
pending_htlcs: new_hash_map(),
613+
congestion_bucket: congestion_bucket.0.unwrap(),
614+
last_congestion_misuse: last_congestion_misuse.0.unwrap(),
615+
protected_bucket: protected_bucket.0.unwrap(),
616+
})
617+
}
618+
}
619+
525620
/// An implementation for managing channel resources and informing HTLC forwarding decisions. It
526621
/// implements the core of the mitigation as proposed in https://github.com/lightning/bolts/pull/1280.
527622
pub struct DefaultResourceManager {
@@ -815,6 +910,85 @@ impl DefaultResourceManager {
815910
}
816911
}
817912

913+
pub struct PendingHTLCReplay {
914+
pub incoming_channel_id: u64,
915+
pub incoming_amount_msat: u64,
916+
pub incoming_htlc_id: u64,
917+
pub incoming_cltv_expiry: u32,
918+
pub incoming_accountable: bool,
919+
pub outgoing_channel_id: u64,
920+
pub outgoing_amount_msat: u64,
921+
pub added_at_unix_seconds: u64,
922+
pub height_added: u32,
923+
}
924+
925+
impl Writeable for DefaultResourceManager {
926+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
927+
let channels = self.channels.lock().unwrap();
928+
write_tlv_fields!(writer, {
929+
(1, self.config, required),
930+
(3, channels, required),
931+
});
932+
Ok(())
933+
}
934+
}
935+
936+
impl<ES: EntropySource> ReadableArgs<&ES> for DefaultResourceManager {
937+
fn read<R: Read>(
938+
reader: &mut R, entropy_source: &ES,
939+
) -> Result<DefaultResourceManager, DecodeError> {
940+
_init_and_read_len_prefixed_tlv_fields!(reader, {
941+
(1, config, required),
942+
(3, channels, (required: ReadableArgs, entropy_source)),
943+
});
944+
let channels: HashMap<u64, Channel> = channels.0.unwrap();
945+
Ok(DefaultResourceManager { config: config.0.unwrap(), channels: Mutex::new(channels) })
946+
}
947+
}
948+
949+
impl<ES: EntropySource> ReadableArgs<&ES> for HashMap<u64, Channel> {
950+
fn read<R: Read>(r: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
951+
let len: CollectionLength = Readable::read(r)?;
952+
let mut ret = new_hash_map();
953+
for _ in 0..len.0 {
954+
let k: u64 = Readable::read(r)?;
955+
let v = Channel::read(r, entropy_source)?;
956+
if ret.insert(k, v).is_some() {
957+
return Err(DecodeError::InvalidValue);
958+
}
959+
}
960+
Ok(ret)
961+
}
962+
}
963+
964+
impl DefaultResourceManager {
965+
// This should only be called once during startup to replay pending HTLCs we had before
966+
// shutdown.
967+
pub fn replay_pending_htlcs<ES: EntropySource>(
968+
&self, pending_htlcs: &[PendingHTLCReplay], entropy_source: &ES,
969+
) -> Result<Vec<ForwardingOutcome>, DecodeError> {
970+
let mut forwarding_outcomes = Vec::with_capacity(pending_htlcs.len());
971+
for htlc in pending_htlcs {
972+
forwarding_outcomes.push(
973+
self.add_htlc(
974+
htlc.incoming_channel_id,
975+
htlc.incoming_amount_msat,
976+
htlc.incoming_cltv_expiry,
977+
htlc.outgoing_channel_id,
978+
htlc.outgoing_amount_msat,
979+
htlc.incoming_accountable,
980+
htlc.incoming_htlc_id,
981+
htlc.height_added,
982+
htlc.added_at_unix_seconds,
983+
entropy_source,
984+
)
985+
.map_err(|_| DecodeError::InvalidValue)?,
986+
);
987+
}
988+
Ok(forwarding_outcomes)
989+
}
990+
}
991+
818992
/// A weighted average that decays over a specified window.
819993
///
820994
/// It enables tracking of historical behavior without storing individual data points.
@@ -858,6 +1032,16 @@ impl DecayingAverage {
8581032
}
8591033
}
8601034

1035+
impl_writeable_tlv_based!(DecayingAverage, {
1036+
(1, value, required),
1037+
(3, last_updated_unix_secs, required),
1038+
(5, window, required),
1039+
(_unused, half_life, (static_value, {
1040+
let w: Duration = window.0.unwrap();
1041+
w.as_secs_f64() * 2_f64.ln()
1042+
})),
1043+
});
1044+
8611045
/// Approximates an [`Self::avg_weeks`]-week average by tracking a decaying average over a larger
8621046
/// [`Self::window_weeks`] window to smooth out volatility.
8631047
struct AggregatedWindowAverage {
@@ -903,6 +1087,14 @@ impl AggregatedWindowAverage {
9031087
}
9041088
}
9051089

1090+
impl_writeable_tlv_based!(AggregatedWindowAverage, {
1091+
(1, start_timestamp_unix_secs, required),
1092+
(3, avg_weeks, required),
1093+
(5, window_weeks, required),
1094+
(7, window_duration, required),
1095+
(9, aggregated_revenue_decaying, required),
1096+
});
1097+
9061098
#[cfg(test)]
9071099
mod tests {
9081100
use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -918,7 +1110,10 @@ mod tests {
9181110
},
9191111
},
9201112
sign::EntropySource,
921-
util::test_utils::TestKeysInterface,
1113+
util::{
1114+
ser::{ReadableArgs, Writeable},
1115+
test_utils::TestKeysInterface,
1116+
},
9221117
};
9231118
use bitcoin::Network;
9241119

@@ -1295,6 +1490,13 @@ mod tests {
12951490
outgoing_channel.outgoing_reputation.add_value(target_reputation, now).unwrap();
12961491
}
12971492

1493+
fn add_revenue(rm: &DefaultResourceManager, incoming_scid: u64, revenue: i64) {
1494+
let mut channels = rm.channels.lock().unwrap();
1495+
let channel = channels.get_mut(&incoming_scid).unwrap();
1496+
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
1497+
channel.incoming_revenue.add_value(revenue, now).unwrap();
1498+
}
1499+
12981500
fn fill_general_bucket(rm: &DefaultResourceManager, incoming_scid: u64) {
12991501
let mut channels = rm.channels.lock().unwrap();
13001502
let incoming_channel = channels.get_mut(&incoming_scid).unwrap();
@@ -2189,6 +2391,73 @@ mod tests {
21892391
assert!(get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID_2).is_none());
21902392
}
21912393

2394+
#[test]
2395+
fn test_simple_manager_serialize_deserialize() {
2396+
// This is not a complete test of the serialization/deserialization of the resource
2397+
// manager because the pending HTLCs will be replayed through `replay_pending_htlcs` by
2398+
// the upstream i.e ChannelManager.
2399+
let rm = create_test_resource_manager_with_channels();
2400+
let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet);
2401+
2402+
add_test_htlc(&rm, false, 0, None, &entropy_source).unwrap();
2403+
2404+
let reputation = 50_000_000;
2405+
add_reputation(&rm, OUTGOING_SCID, reputation);
2406+
2407+
let revenue = 70_000_000;
2408+
add_revenue(&rm, INCOMING_SCID, revenue);
2409+
2410+
let serialized_rm = rm.encode();
2411+
2412+
let channels = rm.channels.lock().unwrap();
2413+
let expected_incoming_channel = channels.get(&INCOMING_SCID).unwrap();
2414+
let (expected_slots, expected_salt) = expected_incoming_channel
2415+
.general_bucket
2416+
.channels_slots
2417+
.get(&OUTGOING_SCID)
2418+
.unwrap()
2419+
.clone();
2420+
2421+
let deserialized_rm =
2422+
DefaultResourceManager::read(&mut serialized_rm.as_slice(), &entropy_source).unwrap();
2423+
let deserialized_channels = deserialized_rm.channels.lock().unwrap();
2424+
assert_eq!(2, deserialized_channels.len());
2425+
2426+
let outgoing_channel = deserialized_channels.get(&OUTGOING_SCID).unwrap();
2427+
assert!(outgoing_channel.general_bucket.channels_slots.is_empty());
2428+
2429+
assert_eq!(outgoing_channel.outgoing_reputation.value, reputation);
2430+
2431+
let incoming_channel = deserialized_channels.get(&INCOMING_SCID).unwrap();
2432+
assert_eq!(incoming_channel.incoming_revenue.aggregated_revenue_decaying.value, revenue);
2433+
2434+
assert_eq!(incoming_channel.general_bucket.channels_slots.len(), 1);
2435+
2436+
let (slots, salt) =
2437+
incoming_channel.general_bucket.channels_slots.get(&OUTGOING_SCID).unwrap().clone();
2438+
assert_eq!(slots, expected_slots);
2439+
assert_eq!(salt, expected_salt);
2440+
2441+
let congestion_bucket = &incoming_channel.congestion_bucket;
2442+
assert_eq!(
2443+
congestion_bucket.slots_allocated,
2444+
expected_incoming_channel.congestion_bucket.slots_allocated
2445+
);
2446+
assert_eq!(
2447+
congestion_bucket.liquidity_allocated,
2448+
expected_incoming_channel.congestion_bucket.liquidity_allocated
2449+
);
2450+
let protected_bucket = &incoming_channel.protected_bucket;
2451+
assert_eq!(
2452+
protected_bucket.slots_allocated,
2453+
expected_incoming_channel.protected_bucket.slots_allocated
2454+
);
2455+
assert_eq!(
2456+
protected_bucket.liquidity_allocated,
2457+
expected_incoming_channel.protected_bucket.liquidity_allocated
2458+
);
2459+
}
2460+
21922461
#[test]
21932462
fn test_decaying_average_error() {
21942463
let timestamp = 1000;

0 commit comments

Comments
 (0)