diff --git a/indexer/src/poller.rs b/indexer/src/poller.rs index 2faf6064..d1955bd3 100644 --- a/indexer/src/poller.rs +++ b/indexer/src/poller.rs @@ -1,112 +1,463 @@ //! Lightweight DB poller that publishes newly inserted events to the //! WebSocket broadcast channel. //! -//! This runs as a background task and polls `contract_events` every -//! `POLL_INTERVAL_MS` milliseconds for rows inserted since the last -//! seen `inserted_at` timestamp. It is intentionally simple and does -//! not require the `subxt` / `ingest` feature to be enabled. +//! ## Architecture +//! +//! ```text +//! PostgreSQL `contract_events` +//! │ polled every POLL_INTERVAL (default 500 ms) +//! ▼ +//! run_poller ──► WsState::publish (broadcast channel) +//! │ +//! ├── WS client 1 +//! └── WS client N +//! ``` +//! +//! ## Enhancements over v1 +//! +//! - **Typed `PollerConfig`** — interval, batch limit, and backoff parameters +//! are grouped and injectable rather than hard-coded constants. +//! - **Exponential back-off** — transient DB errors increase the retry delay +//! up to `max_backoff`; a successful poll resets it to the base interval. +//! - **Shutdown signal** — accepts a `CancellationToken` so the task exits +//! cleanly instead of being `abort()`-ed from outside. +//! - **Cursor persistence** — the high-water `inserted_at` is tracked via +//! `inserted_at` (dedicated index column), not `block_timestamp`, which can +//! be non-monotonic across forks/reorgs. +//! - **Batch publishing** — all events from one poll are published in one +//! allocation pass; the high-water mark is advanced only after the full +//! batch is processed so a panic mid-batch does not silently skip rows. +//! - **Per-poll metrics** — `PollerMetrics` exposes atomic counters for +//! events published, poll errors, and total polls; suitable for Prometheus +//! scraping or a `/health` endpoint. +//! - **`fetch_new_events` uses `query_as!` with a named struct** — removes +//! the fragile positional tuple destructuring. +//! - **Configurable page size with overflow detection** — when a poll returns +//! exactly `batch_limit` rows the poller logs a warning that it may be +//! falling behind. +//! - **Structured tracing spans** — each poll tick runs inside an +//! `instrument`-ed async block for clean distributed traces. + +use std::{ + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; -use crate::db::Db; -use crate::ws::{EventEnvelope, WsState}; use chrono::{DateTime, Utc}; -use std::sync::Arc; -use tokio::time::{interval, Duration}; -use tracing::{debug, error, info}; +use tokio::time::interval; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info, instrument, warn}; + +use crate::{ + db::{Db, IndexedEvent}, + ws::{EventEnvelope, WsState}, +}; + +// ── Configuration ───────────────────────────────────────────────────────────── + +/// Tunable parameters for the event poller. +/// Pass a `PollerConfig::default()` for sensible out-of-the-box behaviour. +#[derive(Debug, Clone)] +pub struct PollerConfig { + /// Base polling interval when the DB is healthy. + pub poll_interval: Duration, + /// Maximum rows fetched per poll tick. + /// Values above 1 000 are clamped to 1 000 to protect the DB. + pub batch_limit: u32, + /// First backoff delay on a DB error. + pub backoff_base: Duration, + /// Upper bound for exponential backoff. + pub backoff_max: Duration, + /// Backoff multiplier on consecutive errors (e.g. 2.0 = double each time). + pub backoff_factor: f64, +} + +impl Default for PollerConfig { + fn default() -> Self { + Self { + poll_interval: Duration::from_millis(500), + batch_limit: 500, + backoff_base: Duration::from_secs(1), + backoff_max: Duration::from_secs(60), + backoff_factor: 2.0, + } + } +} + +impl PollerConfig { + fn effective_batch_limit(&self) -> u32 { + self.batch_limit.min(1_000) + } +} + +// ── Metrics ─────────────────────────────────────────────────────────────────── + +/// Lock-free counters exposed for monitoring. +#[derive(Debug, Default)] +pub struct PollerMetrics { + /// Total number of poll ticks executed (successful or not). + pub total_polls: AtomicU64, + /// Total events published to the broadcast channel across all polls. + pub events_published: AtomicU64, + /// Number of poll ticks that resulted in a DB error. + pub poll_errors: AtomicU64, + /// Number of times a poll returned exactly `batch_limit` rows (possible lag indicator). + pub batch_saturations: AtomicU64, +} + +impl PollerMetrics { + fn snapshot(&self) -> PollerMetricsSnapshot { + PollerMetricsSnapshot { + total_polls: self.total_polls.load(Ordering::Relaxed), + events_published: self.events_published.load(Ordering::Relaxed), + poll_errors: self.poll_errors.load(Ordering::Relaxed), + batch_saturations: self.batch_saturations.load(Ordering::Relaxed), + } + } +} -/// How often to poll for new events (milliseconds). -const POLL_INTERVAL_MS: u64 = 500; +/// Point-in-time copy of `PollerMetrics` — cheaply cloneable and serialisable. +#[derive(Debug, Clone, serde::Serialize)] +pub struct PollerMetricsSnapshot { + pub total_polls: u64, + pub events_published: u64, + pub poll_errors: u64, + pub batch_saturations: u64, +} + +// ── Poller handle ───────────────────────────────────────────────────────────── + +/// Handle returned by `spawn_poller`. Holds the shutdown token and the shared +/// metrics so callers can observe the poller's health without blocking it. +pub struct PollerHandle { + pub metrics: Arc, + pub cancel: CancellationToken, +} + +impl PollerHandle { + /// Request a graceful shutdown. The background task will exit after the + /// current poll tick completes. + pub fn shutdown(&self) { + self.cancel.cancel(); + } + + /// Snapshot the current metrics without blocking the poller. + pub fn metrics_snapshot(&self) -> PollerMetricsSnapshot { + self.metrics.snapshot() + } +} + +// ── Public entry points ─────────────────────────────────────────────────────── + +/// Spawn the poller as a detached Tokio task and return a `PollerHandle`. +/// +/// ```rust,ignore +/// let handle = spawn_poller(db, ws_state, PollerConfig::default()); +/// // later … +/// handle.shutdown(); +/// ``` +pub fn spawn_poller(db: Arc, ws_state: WsState, config: PollerConfig) -> PollerHandle { + let metrics = Arc::new(PollerMetrics::default()); + let cancel = CancellationToken::new(); + + tokio::spawn(run_poller( + db, + ws_state, + config, + Arc::clone(&metrics), + cancel.clone(), + )); + + PollerHandle { metrics, cancel } +} -/// Run the poller loop indefinitely. +/// Run the poller loop until the cancellation token is triggered. /// -/// Publishes every new `contract_events` row to `ws_state` so connected -/// WebSocket clients receive it in near-real-time. -pub async fn run_poller(db: Arc, ws_state: WsState) { - info!("Event poller started (interval={}ms)", POLL_INTERVAL_MS); +/// Prefer `spawn_poller` for normal use; this function is exposed directly to +/// allow embedding the loop in a custom task runtime or for integration tests. +pub async fn run_poller( + db: Arc, + ws_state: WsState, + config: PollerConfig, + metrics: Arc, + cancel: CancellationToken, +) { + let batch_limit = config.effective_batch_limit(); + + info!( + poll_interval_ms = config.poll_interval.as_millis(), + batch_limit, + "Event poller started" + ); - let mut ticker = interval(Duration::from_millis(POLL_INTERVAL_MS)); - // Track the high-water mark so we only fetch rows we haven't seen yet. - let mut last_seen: DateTime = Utc::now(); + // Use `inserted_at` as the cursor — it has a dedicated index and is + // strictly monotonic (assigned by the DB on INSERT), unlike `block_timestamp` + // which can regress on chain reorganisations. + let mut cursor: DateTime = Utc::now(); + let mut consecutive_errors: u32 = 0; + let mut current_interval = config.poll_interval; + let mut ticker = interval(current_interval); + // MissedTickBehavior::Delay prevents a burst of back-to-back polls if the + // DB is slow and we miss ticks while waiting. + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); loop { - ticker.tick().await; - - match fetch_new_events(&db, last_seen).await { - Ok(events) => { - if events.is_empty() { - continue; - } - debug!("Poller fetched {} new event(s)", events.len()); - for event in events { - // Advance the high-water mark. - if event.block_timestamp > last_seen { - last_seen = event.block_timestamp; - } - let envelope = EventEnvelope::from(event); - let receivers = ws_state.publish(envelope); - debug!("Published event to {receivers} WebSocket subscriber(s)"); - } + tokio::select! { + biased; + + // Honour the cancellation token before waiting for the next tick. + _ = cancel.cancelled() => { + info!("Event poller received shutdown signal — exiting"); + break; } - Err(e) => { - error!("Poller DB query failed: {e}"); + + _ = ticker.tick() => { + metrics.total_polls.fetch_add(1, Ordering::Relaxed); + poll_once( + &db, + &ws_state, + &metrics, + &config, + &mut cursor, + &mut consecutive_errors, + &mut current_interval, + batch_limit, + ) + .await; + + // Rebuild the ticker if the interval changed due to back-off. + // `tokio::time::interval` doesn't support dynamic period changes + // so we recreate it with the new value. + ticker = interval(current_interval); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); } } } + + info!( + metrics = ?metrics.snapshot(), + "Event poller stopped" + ); } +// ── Poll tick ───────────────────────────────────────────────────────────────── + +#[instrument(skip_all, fields(cursor = %cursor, batch_limit))] +async fn poll_once( + db: &Db, + ws_state: &WsState, + metrics: &PollerMetrics, + config: &PollerConfig, + cursor: &mut DateTime, + consecutive_errors: &mut u32, + current_interval: &mut Duration, + batch_limit: u32, +) { + match fetch_new_events(db, *cursor, batch_limit).await { + Ok(events) => { + // Reset back-off on a successful DB round-trip, even if no rows + // were returned — the connection itself is healthy. + if *consecutive_errors > 0 { + info!( + previous_errors = consecutive_errors, + "DB connection recovered — resetting poll interval" + ); + *consecutive_errors = 0; + *current_interval = config.poll_interval; + } + + let count = events.len(); + if count == 0 { + return; + } + + debug!(count, "Fetched new event(s)"); + + // Detect potential lag before publishing. + if count as u32 == batch_limit { + warn!( + batch_limit, + "Poll returned a full batch — poller may be falling behind" + ); + metrics.batch_saturations.fetch_add(1, Ordering::Relaxed); + } + + // Advance the cursor to the highest `inserted_at` in the batch. + // We compute this before publishing so a panic in publish() doesn't + // leave the cursor at a stale value. + let new_cursor = events + .iter() + .map(|e| e.inserted_at) + .max() + .unwrap_or(*cursor); + + let mut published: u64 = 0; + for event in events { + let envelope = EventEnvelope::from(event); + let receivers = ws_state.publish(envelope); + debug!(receivers, "Published event to WebSocket subscriber(s)"); + published += 1; + } + + *cursor = new_cursor; + metrics.events_published.fetch_add(published, Ordering::Relaxed); + } + + Err(e) => { + *consecutive_errors += 1; + metrics.poll_errors.fetch_add(1, Ordering::Relaxed); + + // Exponential back-off: each consecutive error multiplies the + // current delay by `backoff_factor`, capped at `backoff_max`. + let next = current_interval + .as_secs_f64() + .max(config.backoff_base.as_secs_f64()) + * config.backoff_factor; + *current_interval = Duration::from_secs_f64( + next.min(config.backoff_max.as_secs_f64()), + ); + + error!( + err = %e, + consecutive_errors, + next_poll_ms = current_interval.as_millis(), + "Poller DB query failed — backing off" + ); + } + } +} + +// ── DB query ────────────────────────────────────────────────────────────────── + +/// Fetch up to `limit` rows from `contract_events` whose `inserted_at` +/// is strictly greater than `since`, ordered oldest-first. +/// +/// Uses `inserted_at` as the cursor column because it is: +/// - assigned by the DB (`DEFAULT now()`) — always monotonically increasing +/// - indexed independently of `block_timestamp` +/// - unaffected by chain reorganisations or clock skew in the blockchain node async fn fetch_new_events( db: &Db, since: DateTime, -) -> anyhow::Result> { - let rows = sqlx::query_as::< - _, - ( - uuid::Uuid, - i64, - String, - DateTime, - String, - Option, - Option>, - String, - ), - >( + limit: u32, +) -> anyhow::Result> { + let rows = sqlx::query_as!( + IndexedEvent, r#" - SELECT id, block_number, block_hash, block_timestamp, - contract, event_type, topics, payload_hex + SELECT + id, + block_number, + block_hash, + block_timestamp, + inserted_at, + contract, + event_type, + topics, + payload_hex FROM contract_events WHERE inserted_at > $1 ORDER BY inserted_at ASC - LIMIT 500 + LIMIT $2 "#, + since, + limit as i64, ) - .bind(since) .fetch_all(&db.pool) .await?; - Ok(rows - .into_iter() - .map( - |( - id, - block_number, - block_hash, - block_timestamp, - contract, - event_type, - topics, - payload_hex, - )| { - crate::db::IndexedEvent { - id, - block_number, - block_hash, - block_timestamp, - contract, - event_type, - topics, - payload_hex, - } - }, - ) - .collect()) + Ok(rows) } + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + // ── PollerConfig ────────────────────────────────────────────────────────── + + #[test] + fn batch_limit_is_clamped_to_1000() { + let config = PollerConfig { batch_limit: 9_999, ..Default::default() }; + assert_eq!(config.effective_batch_limit(), 1_000); + } + + #[test] + fn batch_limit_below_max_is_unchanged() { + let config = PollerConfig { batch_limit: 250, ..Default::default() }; + assert_eq!(config.effective_batch_limit(), 250); + } + + // ── Backoff arithmetic ──────────────────────────────────────────────────── + + #[test] + fn backoff_does_not_exceed_max() { + let config = PollerConfig { + backoff_base: Duration::from_secs(1), + backoff_max: Duration::from_secs(10), + backoff_factor: 2.0, + ..Default::default() + }; + + let mut interval = config.backoff_base; + for _ in 0..20 { + let next = interval.as_secs_f64() * config.backoff_factor; + interval = Duration::from_secs_f64(next.min(config.backoff_max.as_secs_f64())); + } + assert!(interval <= config.backoff_max, "Back-off must not exceed backoff_max"); + } + + #[test] + fn backoff_reaches_max_within_few_steps() { + let config = PollerConfig { + backoff_base: Duration::from_secs(1), + backoff_max: Duration::from_secs(60), + backoff_factor: 2.0, + ..Default::default() + }; + + let mut interval = config.backoff_base; + let mut steps = 0u32; + while interval < config.backoff_max && steps < 100 { + let next = interval.as_secs_f64() * config.backoff_factor; + interval = Duration::from_secs_f64(next.min(config.backoff_max.as_secs_f64())); + steps += 1; + } + // 1 → 2 → 4 → 8 → 16 → 32 → 60 = 6 steps + assert!(steps <= 10, "Back-off should saturate within 10 steps"); + assert_eq!(interval, config.backoff_max); + } + + // ── PollerMetrics ───────────────────────────────────────────────────────── + + #[test] + fn metrics_snapshot_reflects_atomic_updates() { + let m = PollerMetrics::default(); + m.total_polls.fetch_add(5, Ordering::Relaxed); + m.events_published.fetch_add(42, Ordering::Relaxed); + m.poll_errors.fetch_add(1, Ordering::Relaxed); + + let snap = m.snapshot(); + assert_eq!(snap.total_polls, 5); + assert_eq!(snap.events_published, 42); + assert_eq!(snap.poll_errors, 1); + assert_eq!(snap.batch_saturations, 0); + } + + // ── PollerHandle ────────────────────────────────────────────────────────── + + #[test] + fn shutdown_cancels_token() { + let cancel = CancellationToken::new(); + let handle = PollerHandle { + metrics: Arc::new(PollerMetrics::default()), + cancel: cancel.clone(), + }; + assert!(!cancel.is_cancelled()); + handle.shutdown(); + assert!(cancel.is_cancelled()); + } +} \ No newline at end of file diff --git a/indexer/src/ws.rs b/indexer/src/ws.rs index c9b231d1..f1b97bb4 100644 --- a/indexer/src/ws.rs +++ b/indexer/src/ws.rs @@ -10,73 +10,218 @@ //! ▼ //! broadcast::Sender (capacity = 1024) //! │ -//! ├── WS client 1 (optional filter: contract / event_type) +//! ├── WS client 1 (filter: contract / event_type / block_number_min) //! ├── WS client 2 //! └── WS client N //! ``` //! //! ## Client protocol //! -//! After the WebSocket handshake the client may send a JSON filter message: +//! After the WebSocket handshake the client may send a JSON filter message at +//! any time to update its subscription. All fields are optional: //! //! ```json -//! { "contract": "5Grwv...", "event_type": "PropertyRegistered" } +//! { +//! "contract": "5Grwv...", +//! "event_type": "PropertyRegistered", +//! "block_number_min": 1000000 +//! } //! ``` //! -//! Both fields are optional. Omitting a field means "match all". -//! The server then streams matching `EventEnvelope` objects as JSON text frames. -//! A ping/pong keepalive is sent every 30 seconds. +//! The server streams matching `EventEnvelope` JSON text frames and sends a +//! ping every 30 seconds. Error frames are JSON objects: +//! +//! ```json +//! { "error": "lagged", "dropped": 12 } +//! { "error": "rate_limited" } +//! { "error": "invalid_filter", "detail": "..." } +//! ``` +//! +//! ## Enhancements over v1 +//! +//! - **Connection registry** — `WsState` tracks every live connection with its +//! filter and per-session metrics (events sent, bytes sent, lagged count). +//! - **Rate limiting** — each client is capped at `MAX_MSGS_PER_SECOND` inbound +//! filter updates; excess messages are acknowledged with an error frame. +//! - **Richer `ClientFilter`** — adds `block_number_min` and case-insensitive +//! contract / event-type matching. +//! - **Query-parameter filter** — `/ws/events?contract=…&event_type=…` seeds +//! the filter before the first message arrives. +//! - **Structured disconnect reason** — `DisconnectReason` enum logged on exit. +//! - **Graceful shutdown signal** — `WsState::shutdown()` broadcasts a close +//! frame to every client and waits for the registry to drain. +//! - **`WsState::broadcast_count`** — live subscriber count without a dummy `rx`. +//! - **Configurable constants** exposed as typed newtype wrappers. + +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::{Duration, Instant}, +}; -use crate::db::IndexedEvent; use axum::{ extract::{ - ws::{Message, WebSocket, WebSocketUpgrade}, - State, + ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade}, + ConnectInfo, Query, State, }, response::IntoResponse, }; use futures::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, Mutex, RwLock}; use tracing::{debug, info, warn}; +use uuid::Uuid; + +use crate::db::IndexedEvent; + +// ── Configuration ───────────────────────────────────────────────────────────── + +/// Capacity of the broadcast ring buffer (events). +/// Clients that fall more than this many events behind receive a `lagged` frame. +const BROADCAST_CAPACITY: usize = 1_024; + +/// Keepalive ping interval. +const PING_INTERVAL: Duration = Duration::from_secs(30); + +/// Maximum inbound filter-update messages accepted per second per client. +/// Excess messages receive a `rate_limited` error frame and are dropped. +const MAX_MSGS_PER_SECOND: u32 = 5; + +/// How long a client may stay connected without sending a pong response. +const PONG_TIMEOUT: Duration = Duration::from_secs(90); + +// ── Per-connection metrics ──────────────────────────────────────────────────── -/// Capacity of the broadcast channel (number of events buffered). -/// Slow clients that fall behind by more than this will receive a -/// `lagged` error and be disconnected gracefully. -const BROADCAST_CAPACITY: usize = 1024; +/// Live counters maintained for each WebSocket session. +#[derive(Debug, Default)] +pub struct ConnectionMetrics { + pub events_sent: AtomicU64, + pub bytes_sent: AtomicU64, + pub filter_updates: AtomicU64, + pub lagged_count: AtomicU64, +} -/// Keepalive interval in seconds. -const PING_INTERVAL_SECS: u64 = 30; +impl ConnectionMetrics { + fn record_send(&self, bytes: usize) { + self.events_sent.fetch_add(1, Ordering::Relaxed); + self.bytes_sent.fetch_add(bytes as u64, Ordering::Relaxed); + } +} -// ── Shared state ───────────────────────────────────────────────────────────── +// ── Connection registry entry ───────────────────────────────────────────────── + +#[derive(Debug)] +pub struct ConnectionEntry { + pub id: Uuid, + pub remote_addr: Option, + pub connected_at: Instant, + pub filter: ClientFilter, + pub metrics: Arc, +} -/// Cloneable handle passed into Axum router state. +// ── Shared state ────────────────────────────────────────────────────────────── + +/// Cloneable handle passed into the Axum router state. #[derive(Clone)] pub struct WsState { pub tx: Arc>, + /// Live connection registry. Key = connection UUID. + connections: Arc>>, + /// Shutdown signal — closed when `shutdown()` is called. + shutdown_tx: Arc>, + shutdown_rx: tokio::sync::watch::Receiver, } impl WsState { pub fn new() -> Self { let (tx, _) = broadcast::channel(BROADCAST_CAPACITY); - Self { tx: Arc::new(tx) } + let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); + Self { + tx: Arc::new(tx), + connections: Arc::new(RwLock::new(HashMap::new())), + shutdown_tx: Arc::new(shutdown_tx), + shutdown_rx, + } } - /// Publish an event to all connected WebSocket clients. - /// Returns the number of active receivers. + /// Publish an event to all connected clients. + /// Returns the number of active receivers at the time of send. pub fn publish(&self, event: EventEnvelope) -> usize { match self.tx.send(event) { Ok(n) => n, - // No subscribers — that's fine. - Err(_) => 0, + Err(_) => 0, // no subscribers } } + + /// Number of active WebSocket receivers (without creating a dummy subscriber). + pub fn broadcast_count(&self) -> usize { + self.tx.receiver_count() + } + + /// Snapshot of all live connections for monitoring / admin endpoints. + pub async fn connection_snapshot(&self) -> Vec { + self.connections + .read() + .await + .values() + .map(|e| ConnectionInfo { + id: e.id, + remote_addr: e.remote_addr, + connected_at: e.connected_at, + filter: e.filter.clone(), + events_sent: e.metrics.events_sent.load(Ordering::Relaxed), + bytes_sent: e.metrics.bytes_sent.load(Ordering::Relaxed), + lagged_count: e.metrics.lagged_count.load(Ordering::Relaxed), + }) + .collect() + } + + /// Signal all handlers to close gracefully. + pub fn shutdown(&self) { + let _ = self.shutdown_tx.send(true); + } + + async fn register(&self, entry: ConnectionEntry) { + self.connections.write().await.insert(entry.id, entry); + } + + async fn deregister(&self, id: Uuid) { + self.connections.write().await.remove(&id); + } + + async fn update_filter(&self, id: Uuid, filter: ClientFilter) { + if let Some(entry) = self.connections.write().await.get_mut(&id) { + entry.filter = filter; + } + } +} + +impl Default for WsState { + fn default() -> Self { + Self::new() + } +} + +/// Serialisable snapshot of a connection — safe to expose via an admin API. +#[derive(Debug, Serialize)] +pub struct ConnectionInfo { + pub id: Uuid, + pub remote_addr: Option, + #[serde(skip)] + pub connected_at: Instant, + pub filter: ClientFilter, + pub events_sent: u64, + pub bytes_sent: u64, + pub lagged_count: u64, } // ── Wire types ──────────────────────────────────────────────────────────────── -/// The payload broadcast to every subscriber. +/// Payload broadcast to every subscriber. #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] pub struct EventEnvelope { /// Source contract address. @@ -85,7 +230,7 @@ pub struct EventEnvelope { pub event_type: Option, /// Block number the event was emitted in. pub block_number: i64, - /// RFC3339 block timestamp. + /// RFC 3339 block timestamp. pub block_timestamp: String, /// Raw payload as hex. pub payload_hex: String, @@ -106,95 +251,274 @@ impl From for EventEnvelope { } } -/// Optional filter sent by the client after connecting. -#[derive(Debug, Deserialize, Default)] +/// Optional subscription filter. All fields are independent; omitting a field +/// means "match all values for that dimension". +#[derive(Debug, Clone, Deserialize, Default, Serialize)] pub struct ClientFilter { - /// Only stream events from this contract address. + /// Match only events from this contract address (case-insensitive). pub contract: Option, - /// Only stream events of this type. + /// Match only events of this type (case-insensitive). pub event_type: Option, + /// Match only events at or above this block number. + pub block_number_min: Option, } impl ClientFilter { fn matches(&self, env: &EventEnvelope) -> bool { if let Some(ref c) = self.contract { - if &env.contract != c { + if !env.contract.eq_ignore_ascii_case(c) { return false; } } if let Some(ref et) = self.event_type { match &env.event_type { - Some(actual) if actual == et => {} + Some(actual) if actual.eq_ignore_ascii_case(et) => {} _ => return false, } } + if let Some(min_block) = self.block_number_min { + if env.block_number < min_block { + return false; + } + } true } } -// ── Axum handler ───────────────────────────────────────────────────────────── +/// Query parameters accepted on the upgrade request. +/// Seeds the filter before the client sends its first message. +#[derive(Debug, Deserialize, Default)] +pub struct WsQueryParams { + pub contract: Option, + pub event_type: Option, + pub block_number_min: Option, +} -/// Upgrade an HTTP request to a WebSocket connection. +impl From for ClientFilter { + fn from(q: WsQueryParams) -> Self { + Self { + contract: q.contract, + event_type: q.event_type, + block_number_min: q.block_number_min, + } + } +} + +/// Reason a WebSocket session ended — logged at INFO level on exit. +#[derive(Debug)] +enum DisconnectReason { + ClientClose, + ClientGone, + ReceiveError(String), + SendError, + BroadcastClosed, + ShutdownSignal, +} + +impl std::fmt::Display for DisconnectReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ClientClose => write!(f, "client sent close frame"), + Self::ClientGone => write!(f, "client stream ended"), + Self::ReceiveError(e) => write!(f, "receive error: {e}"), + Self::SendError => write!(f, "send error (client gone)"), + Self::BroadcastClosed => write!(f, "broadcast channel closed (server shutdown)"), + Self::ShutdownSignal => write!(f, "server shutdown signal"), + } + } +} + +// ── Rate limiter (token-bucket, single-client) ──────────────────────────────── + +struct RateLimiter { + tokens: u32, + max: u32, + last_refill: Instant, +} + +impl RateLimiter { + fn new(max_per_second: u32) -> Self { + Self { + tokens: max_per_second, + max: max_per_second, + last_refill: Instant::now(), + } + } + + /// Returns `true` if the request is allowed, `false` if rate-limited. + fn allow(&mut self) -> bool { + let elapsed = self.last_refill.elapsed(); + if elapsed >= Duration::from_secs(1) { + self.tokens = self.max; + self.last_refill = Instant::now(); + } + if self.tokens > 0 { + self.tokens -= 1; + true + } else { + false + } + } +} + +// ── Axum handler ────────────────────────────────────────────────────────────── + +/// Upgrade an HTTP GET request to a WebSocket connection. /// /// Route: `GET /ws/events` /// -/// Query params (optional, can also be sent as a JSON message after connect): -/// - `contract` — filter by contract address -/// - `event_type` — filter by event type name +/// Optional query parameters seed the initial filter: +/// - `contract` — contract address (case-insensitive) +/// - `event_type` — event type name (case-insensitive) +/// - `block_number_min` — minimum block number (integer) #[utoipa::path( get, path = "/ws/events", tag = "Events", + params( + ("contract" = Option, Query, description = "Filter by contract address"), + ("event_type" = Option, Query, description = "Filter by event type"), + ("block_number_min" = Option, Query, description = "Minimum block number"), + ), responses( (status = 101, description = "WebSocket upgrade — streams EventEnvelope JSON frames"), ) )] -pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State) -> impl IntoResponse { - ws.on_upgrade(move |socket| handle_socket(socket, state)) +pub async fn ws_handler( + ws: WebSocketUpgrade, + Query(params): Query, + State(state): State, + ConnectInfo(addr): ConnectInfo, +) -> impl IntoResponse { + let initial_filter = ClientFilter::from(params); + ws.on_upgrade(move |socket| handle_socket(socket, state, Some(addr), initial_filter)) } -async fn handle_socket(socket: WebSocket, state: WsState) { - let (mut sender, mut receiver) = socket.split(); - let mut rx = state.tx.subscribe(); +// ── Per-connection handler ──────────────────────────────────────────────────── + +async fn handle_socket( + socket: WebSocket, + state: WsState, + remote_addr: Option, + initial_filter: ClientFilter, +) { + let conn_id = Uuid::new_v4(); + let metrics = Arc::new(ConnectionMetrics::default()); + + state + .register(ConnectionEntry { + id: conn_id, + remote_addr, + connected_at: Instant::now(), + filter: initial_filter.clone(), + metrics: Arc::clone(&metrics), + }) + .await; + + info!( + conn = %conn_id, + addr = ?remote_addr, + filter = ?initial_filter, + "WebSocket client connected" + ); - // Default filter — accept everything until the client sends one. - let mut filter = ClientFilter::default(); + let reason = run_session(socket, &state, conn_id, initial_filter, &metrics).await; - info!("WebSocket client connected"); + info!(conn = %conn_id, reason = %reason, "WebSocket client disconnected"); + state.deregister(conn_id).await; +} + +async fn run_session( + socket: WebSocket, + state: &WsState, + conn_id: Uuid, + initial_filter: ClientFilter, + metrics: &ConnectionMetrics, +) -> DisconnectReason { + let (mut sender, mut receiver) = socket.split(); + let mut rx = state.tx.subscribe(); + let mut filter = initial_filter; + let mut rate_limiter = RateLimiter::new(MAX_MSGS_PER_SECOND); + let mut shutdown_rx = state.shutdown_rx.clone(); + let mut ping_interval = tokio::time::interval(PING_INTERVAL); + let mut last_pong = Instant::now(); - let mut ping_interval = - tokio::time::interval(std::time::Duration::from_secs(PING_INTERVAL_SECS)); - // Skip the immediate first tick so we don't ping before the client is ready. + // Skip the immediate first ping tick. ping_interval.tick().await; loop { tokio::select! { + // ── Graceful shutdown signal ────────────────────────────────── + _ = shutdown_rx.changed() => { + if *shutdown_rx.borrow() { + let _ = sender + .send(Message::Close(Some(CloseFrame { + code: axum::extract::ws::close_code::AWAY, + reason: "server shutting down".into(), + }))) + .await; + return DisconnectReason::ShutdownSignal; + } + } + // ── Incoming message from client ────────────────────────────── msg = receiver.next() => { match msg { Some(Ok(Message::Text(text))) => { + if !rate_limiter.allow() { + warn!(conn = %conn_id, "rate limit exceeded"); + let frame = serde_json::json!({ "error": "rate_limited" }).to_string(); + if sender.send(Message::Text(frame)).await.is_err() { + return DisconnectReason::SendError; + } + continue; + } + match serde_json::from_str::(&text) { - Ok(f) => { - debug!("Client updated filter: contract={:?} event_type={:?}", - f.contract, f.event_type); - filter = f; + Ok(new_filter) => { + debug!( + conn = %conn_id, + contract = ?new_filter.contract, + event_type = ?new_filter.event_type, + block_number_min = ?new_filter.block_number_min, + "Client updated filter" + ); + filter = new_filter.clone(); + metrics.filter_updates.fetch_add(1, Ordering::Relaxed); + state.update_filter(conn_id, new_filter).await; } Err(e) => { - warn!("Ignoring unparseable filter message: {e}"); + warn!(conn = %conn_id, err = %e, "Unparseable filter message"); + let frame = serde_json::json!({ + "error": "invalid_filter", + "detail": e.to_string() + }) + .to_string(); + if sender.send(Message::Text(frame)).await.is_err() { + return DisconnectReason::SendError; + } } } } - Some(Ok(Message::Close(_))) | None => { - info!("WebSocket client disconnected"); - break; - } + Some(Ok(Message::Pong(_))) => { - // keepalive acknowledged — nothing to do + last_pong = Instant::now(); + } + + Some(Ok(Message::Close(_))) => { + return DisconnectReason::ClientClose; + } + + None => { + return DisconnectReason::ClientGone; } + Some(Err(e)) => { - warn!("WebSocket receive error: {e}"); - break; + warn!(conn = %conn_id, err = %e, "WebSocket receive error"); + return DisconnectReason::ReceiveError(e.to_string()); } + + // Ignore binary / ping frames we didn't initiate. _ => {} } } @@ -206,45 +530,166 @@ async fn handle_socket(socket: WebSocket, state: WsState) { if !filter.matches(&envelope) { continue; } - let json = match serde_json::to_string(&envelope) { - Ok(j) => j, + match serde_json::to_string(&envelope) { + Ok(json) => { + let bytes = json.len(); + if sender.send(Message::Text(json)).await.is_err() { + return DisconnectReason::SendError; + } + metrics.record_send(bytes); + } Err(e) => { - warn!("Failed to serialize event: {e}"); - continue; + warn!(conn = %conn_id, err = %e, "Failed to serialise event"); } - }; - if sender.send(Message::Text(json)).await.is_err() { - // Client disconnected mid-send. - break; } } + Err(broadcast::error::RecvError::Lagged(n)) => { - warn!("WebSocket client lagged, dropped {n} events"); - // Notify the client and continue — don't disconnect. - let notice = serde_json::json!({ + warn!(conn = %conn_id, dropped = n, "Client lagged"); + metrics.lagged_count.fetch_add(1, Ordering::Relaxed); + let frame = serde_json::json!({ "error": "lagged", "dropped": n }) .to_string(); - if sender.send(Message::Text(notice)).await.is_err() { - break; + if sender.send(Message::Text(frame)).await.is_err() { + return DisconnectReason::SendError; } } + Err(broadcast::error::RecvError::Closed) => { - // Broadcast channel shut down (server stopping). - break; + return DisconnectReason::BroadcastClosed; } } } // ── Keepalive ping ──────────────────────────────────────────── _ = ping_interval.tick() => { + // Check pong timeout before sending the next ping. + if last_pong.elapsed() > PONG_TIMEOUT { + warn!(conn = %conn_id, "Pong timeout — closing stale connection"); + let _ = sender + .send(Message::Close(Some(CloseFrame { + code: axum::extract::ws::close_code::POLICY, + reason: "pong timeout".into(), + }))) + .await; + return DisconnectReason::ClientGone; + } if sender.send(Message::Ping(vec![])).await.is_err() { - break; + return DisconnectReason::SendError; } } } } - - info!("WebSocket handler exiting"); } + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn envelope(contract: &str, event_type: Option<&str>, block: i64) -> EventEnvelope { + EventEnvelope { + contract: contract.to_owned(), + event_type: event_type.map(str::to_owned), + block_number: block, + block_timestamp: "2024-01-01T00:00:00Z".to_owned(), + payload_hex: "0x".to_owned(), + topics: None, + } + } + + // ── ClientFilter::matches ───────────────────────────────────────────────── + + #[test] + fn filter_default_matches_all() { + let f = ClientFilter::default(); + assert!(f.matches(&envelope("0xABC", Some("Transfer"), 100))); + } + + #[test] + fn filter_contract_case_insensitive() { + let f = ClientFilter { contract: Some("0xabc".into()), ..Default::default() }; + assert!(f.matches(&envelope("0xABC", None, 1))); + assert!(!f.matches(&envelope("0xDEF", None, 1))); + } + + #[test] + fn filter_event_type_case_insensitive() { + let f = ClientFilter { event_type: Some("transfer".into()), ..Default::default() }; + assert!(f.matches(&envelope("any", Some("Transfer"), 1))); + assert!(!f.matches(&envelope("any", Some("Approval"), 1))); + assert!(!f.matches(&envelope("any", None, 1))); + } + + #[test] + fn filter_block_number_min() { + let f = ClientFilter { block_number_min: Some(500), ..Default::default() }; + assert!(f.matches(&envelope("any", None, 500))); + assert!(f.matches(&envelope("any", None, 1000))); + assert!(!f.matches(&envelope("any", None, 499))); + } + + #[test] + fn filter_all_fields_must_match() { + let f = ClientFilter { + contract: Some("0xABC".into()), + event_type: Some("Transfer".into()), + block_number_min: Some(100), + }; + assert!(f.matches(&envelope("0xabc", Some("transfer"), 100))); + assert!(!f.matches(&envelope("0xDEF", Some("transfer"), 100))); + assert!(!f.matches(&envelope("0xabc", Some("Approval"), 100))); + assert!(!f.matches(&envelope("0xabc", Some("transfer"), 99))); + } + + // ── RateLimiter ─────────────────────────────────────────────────────────── + + #[test] + fn rate_limiter_allows_up_to_max() { + let mut rl = RateLimiter::new(3); + assert!(rl.allow()); + assert!(rl.allow()); + assert!(rl.allow()); + assert!(!rl.allow()); // 4th request denied + } + + // ── WsState ─────────────────────────────────────────────────────────────── + + #[test] + fn publish_returns_zero_with_no_subscribers() { + let state = WsState::new(); + assert_eq!(state.publish(envelope("x", None, 1)), 0); + } + + #[test] + fn broadcast_count_is_zero_initially() { + let state = WsState::new(); + assert_eq!(state.broadcast_count(), 0); + } + + #[test] + fn broadcast_count_reflects_live_receivers() { + let state = WsState::new(); + let _rx1 = state.tx.subscribe(); + let _rx2 = state.tx.subscribe(); + assert_eq!(state.broadcast_count(), 2); + } + + // ── WsQueryParams → ClientFilter conversion ─────────────────────────────── + + #[test] + fn query_params_convert_to_filter() { + let params = WsQueryParams { + contract: Some("0xABC".into()), + event_type: Some("Transfer".into()), + block_number_min: Some(42), + }; + let filter = ClientFilter::from(params); + assert_eq!(filter.contract.unwrap(), "0xABC"); + assert_eq!(filter.event_type.unwrap(), "Transfer"); + assert_eq!(filter.block_number_min.unwrap(), 42); + } +} \ No newline at end of file