diff --git a/src/payment/asynchronous/om_mailbox.rs b/src/payment/asynchronous/om_mailbox.rs index b7b789cdb4..be9d0099e4 100644 --- a/src/payment/asynchronous/om_mailbox.rs +++ b/src/payment/asynchronous/om_mailbox.rs @@ -19,22 +19,15 @@ impl OnionMessageMailbox { pub(crate) fn onion_message_intercepted(&self, peer_node_id: PublicKey, message: OnionMessage) { let mut map = self.map.lock().expect("lock"); + if !map.contains_key(&peer_node_id) && map.len() >= Self::MAX_PEERS { + return; + } + let queue = map.entry(peer_node_id).or_insert_with(VecDeque::new); if queue.len() >= Self::MAX_MESSAGES_PER_PEER { queue.pop_front(); } queue.push_back(message); - - // Enforce a peers limit. If exceeded, evict the peer with the longest queue. - if map.len() > Self::MAX_PEERS { - let peer_to_remove = map - .iter() - .max_by_key(|(_, queue)| queue.len()) - .map(|(peer, _)| *peer) - .expect("map is non-empty"); - - map.remove(&peer_to_remove); - } } pub(crate) fn onion_message_peer_connected( @@ -68,18 +61,66 @@ mod tests { fn onion_message_mailbox() { let mailbox = OnionMessageMailbox::new(); - let secp = Secp256k1::new(); - let sk_bytes = [12; 32]; - let sk = SecretKey::from_slice(&sk_bytes).unwrap(); - let peer_node_id = PublicKey::from_secret_key(&secp, &sk); + let peer_node_id = peer_node_id(12); + let message = onion_message(13); + mailbox.onion_message_intercepted(peer_node_id, message.clone()); + + let messages = mailbox.onion_message_peer_connected(peer_node_id); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0], message); + + assert!(mailbox.is_empty()); + + let messages = mailbox.onion_message_peer_connected(peer_node_id); + assert_eq!(messages.len(), 0); + } + + #[test] + fn onion_message_mailbox_keeps_existing_peer_at_capacity() { + let mailbox = OnionMessageMailbox::new(); + let victim = peer_node_id(1); + + for seed in 0..OnionMessageMailbox::MAX_MESSAGES_PER_PEER { + mailbox.onion_message_intercepted(victim, onion_message(seed as u64 + 1)); + } + + for peer in 2..(OnionMessageMailbox::MAX_PEERS as u64 + 2) { + mailbox.onion_message_intercepted(peer_node_id(peer), onion_message(peer)); + } - let blinding_sk = SecretKey::from_slice(&[13; 32]).unwrap(); - let blinding_point = PublicKey::from_secret_key(&secp, &blinding_sk); + let messages = mailbox.onion_message_peer_connected(victim); + assert_eq!(messages.len(), OnionMessageMailbox::MAX_MESSAGES_PER_PEER); + } + + #[test] + fn onion_message_mailbox_drops_new_peer_when_full() { + let mailbox = OnionMessageMailbox::new(); + + for peer in 1..=OnionMessageMailbox::MAX_PEERS as u64 { + mailbox.onion_message_intercepted(peer_node_id(peer), onion_message(peer)); + } + + let new_peer = peer_node_id(OnionMessageMailbox::MAX_PEERS as u64 + 1); + mailbox.onion_message_intercepted(new_peer, onion_message(1)); + assert!(mailbox.onion_message_peer_connected(new_peer).is_empty()); + + let existing_peer = peer_node_id(1); + mailbox.onion_message_intercepted(existing_peer, onion_message(2)); + assert_eq!(mailbox.onion_message_peer_connected(existing_peer).len(), 2); + } + + fn peer_node_id(seed: u64) -> PublicKey { + let secp = Secp256k1::new(); + let sk = secret_key(seed); + PublicKey::from_secret_key(&secp, &sk) + } - let message_sk = SecretKey::from_slice(&[13; 32]).unwrap(); - let message_point = PublicKey::from_secret_key(&secp, &message_sk); + fn onion_message(seed: u64) -> lightning::ln::msgs::OnionMessage { + let secp = Secp256k1::new(); + let blinding_point = PublicKey::from_secret_key(&secp, &secret_key(seed)); + let message_point = PublicKey::from_secret_key(&secp, &secret_key(seed + 1)); - let message = lightning::ln::msgs::OnionMessage { + lightning::ln::msgs::OnionMessage { blinding_point, onion_routing_packet: onion_message::packet::Packet { version: 0, @@ -87,16 +128,12 @@ mod tests { hop_data: vec![1, 2, 3], hmac: [0; 32], }, - }; - mailbox.onion_message_intercepted(peer_node_id, message.clone()); - - let messages = mailbox.onion_message_peer_connected(peer_node_id); - assert_eq!(messages.len(), 1); - assert_eq!(messages[0], message); - - assert!(mailbox.is_empty()); + } + } - let messages = mailbox.onion_message_peer_connected(peer_node_id); - assert_eq!(messages.len(), 0); + fn secret_key(seed: u64) -> SecretKey { + let mut bytes = [0; 32]; + bytes[24..].copy_from_slice(&seed.to_be_bytes()); + SecretKey::from_slice(&bytes).unwrap() } } diff --git a/src/payment/asynchronous/rate_limiter.rs b/src/payment/asynchronous/rate_limiter.rs index bf12508927..aa257a52bd 100644 --- a/src/payment/asynchronous/rate_limiter.rs +++ b/src/payment/asynchronous/rate_limiter.rs @@ -14,8 +14,8 @@ use std::time::{Duration, Instant}; /// and the max idle duration. /// /// For every passing of the refill interval, one token is added to the bucket, up to the maximum capacity. When the -/// bucket has remained at the maximum capacity for longer than the max idle duration, it is removed to prevent memory -/// leakage. +/// bucket has remained unused for longer than the max idle duration, it is removed to prevent +/// memory leakage. pub(crate) struct RateLimiter { users: HashMap, Bucket>, capacity: u32, @@ -28,6 +28,7 @@ const MAX_USERS: usize = 10_000; struct Bucket { tokens: u32, last_refill: Instant, + last_seen: Instant, } impl RateLimiter { @@ -43,20 +44,22 @@ impl RateLimiter { if is_new_user { self.garbage_collect(self.max_idle); if self.users.len() >= MAX_USERS { - return false; + self.evict_least_recently_seen(); } } - let bucket = self - .users - .entry(user_id.to_vec()) - .or_insert(Bucket { tokens: self.capacity, last_refill: now }); + let bucket = self.users.entry(user_id.to_vec()).or_insert(Bucket { + tokens: self.capacity, + last_refill: now, + last_seen: now, + }); + bucket.last_seen = now; let elapsed = now.duration_since(bucket.last_refill); let tokens_to_add = (elapsed.as_secs_f64() / self.refill_interval.as_secs_f64()) as u32; if tokens_to_add > 0 { - bucket.tokens = (bucket.tokens + tokens_to_add).min(self.capacity); + bucket.tokens = bucket.tokens.saturating_add(tokens_to_add).min(self.capacity); bucket.last_refill = now; } @@ -72,7 +75,18 @@ impl RateLimiter { fn garbage_collect(&mut self, max_idle: Duration) { let now = Instant::now(); - self.users.retain(|_, bucket| now.duration_since(bucket.last_refill) < max_idle); + self.users.retain(|_, bucket| now.duration_since(bucket.last_seen) < max_idle); + } + + fn evict_least_recently_seen(&mut self) { + if let Some(user_to_remove) = self + .users + .iter() + .min_by_key(|(_, bucket)| bucket.last_seen) + .map(|(user, _)| user.clone()) + { + self.users.remove(&user_to_remove); + } } } @@ -99,4 +113,16 @@ mod tests { assert!(rate_limiter.allow(b"user1")); assert!(rate_limiter.allow(b"user2")); } + + #[test] + fn rate_limiter_admits_new_user_at_capacity() { + let mut rate_limiter = + RateLimiter::new(3, Duration::from_millis(100), Duration::from_secs(600)); + + for user in 0..super::MAX_USERS { + assert!(rate_limiter.allow(&user.to_be_bytes())); + } + + assert!(rate_limiter.allow(b"legit")); + } }