diff --git a/services/api/src/audit.rs b/services/api/src/audit.rs index 62684c7..e839301 100644 --- a/services/api/src/audit.rs +++ b/services/api/src/audit.rs @@ -90,7 +90,10 @@ impl AuditLogger { Ok(id) } - /// Query audit log entries with filters + /// Query audit log entries with filters. + /// + /// Uses `sqlx::QueryBuilder` so every filter value is bound as a typed + /// parameter — the SQL string never contains user-supplied data directly. pub async fn query( &self, actor: Option<&str>, @@ -101,52 +104,7 @@ impl AuditLogger { limit: i64, offset: i64, ) -> anyhow::Result> { - let mut query = String::from( - r#" - SELECT id, timestamp, actor, actor_ip, action, resource_type, resource_id, - details, status, error_message, request_id, user_agent - FROM audit_log - WHERE 1=1 - "#, - ); - - let mut bind_count = 0; - let mut bindings: Vec + Send>> = Vec::new(); - - if let Some(a) = actor { - bind_count += 1; - query.push_str(&format!(" AND actor = ${}", bind_count)); - } - - if let Some(a) = action { - bind_count += 1; - query.push_str(&format!(" AND action = ${}", bind_count)); - } - - if let Some(rt) = resource_type { - bind_count += 1; - query.push_str(&format!(" AND resource_type = ${}", bind_count)); - } - - if let Some(f) = from { - bind_count += 1; - query.push_str(&format!(" AND timestamp >= ${}", bind_count)); - } - - if let Some(t) = to { - bind_count += 1; - query.push_str(&format!(" AND timestamp <= ${}", bind_count)); - } - - query.push_str(" ORDER BY timestamp DESC"); - - bind_count += 1; - query.push_str(&format!(" LIMIT ${}", bind_count)); - - bind_count += 1; - query.push_str(&format!(" OFFSET ${}", bind_count)); - - let mut q = sqlx::query_as::<_, ( + type Row = ( i64, DateTime, String, @@ -159,26 +117,34 @@ impl AuditLogger { Option, Option, Option, - )>(&query); + ); + + let mut qb = sqlx::QueryBuilder::::new( + "SELECT id, timestamp, actor, actor_ip, action, resource_type, resource_id, \ + details, status, error_message, request_id, user_agent \ + FROM audit_log WHERE 1=1", + ); if let Some(a) = actor { - q = q.bind(a); + qb.push(" AND actor = ").push_bind(a); } if let Some(a) = action { - q = q.bind(a); + qb.push(" AND action = ").push_bind(a); } if let Some(rt) = resource_type { - q = q.bind(rt); + qb.push(" AND resource_type = ").push_bind(rt); } if let Some(f) = from { - q = q.bind(f); + qb.push(" AND timestamp >= ").push_bind(f); } if let Some(t) = to { - q = q.bind(t); + qb.push(" AND timestamp <= ").push_bind(t); } - q = q.bind(limit).bind(offset); - let rows = q.fetch_all(&self.pool).await?; + qb.push(" ORDER BY timestamp DESC LIMIT ").push_bind(limit); + qb.push(" OFFSET ").push_bind(offset); + + let rows: Vec = qb.build_query_as().fetch_all(&self.pool).await?; Ok(rows .into_iter() @@ -196,26 +162,24 @@ impl AuditLogger { error_message, request_id, user_agent, - )| { - AuditLogEntry { - id: Some(id), - timestamp, - actor, - actor_ip: actor_ip_str.and_then(|s| s.parse().ok()), - action, - resource_type, - resource_id, - details, - status: match status.as_str() { - "success" => AuditStatus::Success, - "failure" => AuditStatus::Failure, - "partial" => AuditStatus::Partial, - _ => AuditStatus::Success, - }, - error_message, - request_id, - user_agent, - } + )| AuditLogEntry { + id: Some(id), + timestamp, + actor, + actor_ip: actor_ip_str.and_then(|s| s.parse().ok()), + action, + resource_type, + resource_id, + details, + status: match status.as_str() { + "success" => AuditStatus::Success, + "failure" => AuditStatus::Failure, + "partial" => AuditStatus::Partial, + _ => AuditStatus::Success, + }, + error_message, + request_id, + user_agent, }, ) .collect()) @@ -257,6 +221,70 @@ pub struct AuditStatistics { pub failed: i64, } +#[cfg(test)] +mod tests { + use super::*; + + fn make_entry(actor: &str, action: &str, status: AuditStatus) -> AuditLogEntry { + AuditLogEntry { + id: None, + timestamp: Utc::now(), + actor: actor.to_string(), + actor_ip: None, + action: action.to_string(), + resource_type: "market".to_string(), + resource_id: None, + details: None, + status, + error_message: None, + request_id: None, + user_agent: None, + } + } + + #[test] + fn audit_status_display_success() { + assert_eq!(AuditStatus::Success.to_string(), "success"); + } + + #[test] + fn audit_status_display_failure() { + assert_eq!(AuditStatus::Failure.to_string(), "failure"); + } + + #[test] + fn audit_status_display_partial() { + assert_eq!(AuditStatus::Partial.to_string(), "partial"); + } + + #[test] + fn create_audit_entry_sets_expected_fields() { + let entry = create_audit_entry( + "api_key_123".to_string(), + None, + "resolve_market".to_string(), + "market".to_string(), + Some("42".to_string()), + None, + None, + None, + ); + assert_eq!(entry.actor, "api_key_123"); + assert_eq!(entry.action, "resolve_market"); + assert_eq!(entry.resource_type, "market"); + assert_eq!(entry.resource_id, Some("42".to_string())); + assert!(matches!(entry.status, AuditStatus::Success)); + assert!(entry.id.is_none()); + } + + #[test] + fn make_entry_helper_sets_status() { + let e = make_entry("admin", "delete", AuditStatus::Failure); + assert!(matches!(e.status, AuditStatus::Failure)); + assert_eq!(e.actor, "admin"); + } +} + /// Helper to create audit log entry from request context pub fn create_audit_entry( actor: String, diff --git a/services/api/src/db.rs b/services/api/src/db.rs index 37cd084..c97a98f 100644 --- a/services/api/src/db.rs +++ b/services/api/src/db.rs @@ -14,7 +14,14 @@ use crate::{ /// Errors that can be returned by [`Database`] methods. #[derive(Debug)] pub enum DbError { + /// A query exceeded the per-operation timeout. Timeout, + /// The connection pool had no connections available within the acquire timeout. + PoolExhausted, + /// A database constraint was violated (unique, foreign-key, not-null, check). + /// The inner string is the database error message for logging. + ConstraintViolation(String), + /// Any other database error. Other(anyhow::Error), } @@ -22,6 +29,10 @@ impl std::fmt::Display for DbError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { DbError::Timeout => write!(f, "database query timed out"), + DbError::PoolExhausted => write!(f, "database connection pool exhausted"), + DbError::ConstraintViolation(msg) => { + write!(f, "database constraint violation: {msg}") + } DbError::Other(e) => write!(f, "{e}"), } } @@ -31,7 +42,19 @@ impl std::error::Error for DbError {} impl From for DbError { fn from(e: sqlx::Error) -> Self { - DbError::Other(anyhow::Error::from(e)) + match &e { + sqlx::Error::PoolTimedOut => DbError::PoolExhausted, + sqlx::Error::Database(db_err) => { + // PostgreSQL constraint violation SQLSTATE codes start with "23" + // (23000 integrity constraint, 23505 unique violation, etc.). + if db_err.code().map(|c| c.starts_with("23")).unwrap_or(false) { + DbError::ConstraintViolation(db_err.message().to_string()) + } else { + DbError::Other(anyhow::Error::from(e)) + } + } + _ => DbError::Other(anyhow::Error::from(e)), + } } } @@ -725,3 +748,40 @@ impl Database { Ok(count > 0) } } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn db_error_timeout_display() { + let e = DbError::Timeout; + assert_eq!(e.to_string(), "database query timed out"); + } + + #[test] + fn db_error_pool_exhausted_display() { + let e = DbError::PoolExhausted; + assert_eq!(e.to_string(), "database connection pool exhausted"); + } + + #[test] + fn db_error_constraint_violation_display() { + let e = DbError::ConstraintViolation("duplicate key value".to_string()); + assert!(e.to_string().contains("constraint violation")); + assert!(e.to_string().contains("duplicate key value")); + } + + #[test] + fn from_sqlx_pool_timed_out_maps_to_pool_exhausted() { + let e = DbError::from(sqlx::Error::PoolTimedOut); + assert!(matches!(e, DbError::PoolExhausted)); + } + + #[test] + fn from_sqlx_other_maps_to_other() { + let e = DbError::from(sqlx::Error::RowNotFound); + assert!(matches!(e, DbError::Other(_))); + } +} diff --git a/services/api/src/email/queue.rs b/services/api/src/email/queue.rs index e38c3bc..44c78e0 100644 --- a/services/api/src/email/queue.rs +++ b/services/api/src/email/queue.rs @@ -267,7 +267,15 @@ impl EmailQueue { .collect() } + /// Minimum delay (seconds) before a requeued dead-letter job is eligible + /// for processing. Prevents immediate re-failure loops on persistent errors. + const DEAD_LETTER_REQUEUE_DELAY_SECS: i64 = 60; + /// Move a job from the dead-letter set back to the main queue for reprocessing. + /// + /// The job is scheduled `DEAD_LETTER_REQUEUE_DELAY_SECS` seconds in the future + /// so a persistent failure does not cause a tight retry loop. The attempts counter + /// is also reset to 0 so the job gets its full retry budget again. pub async fn requeue_dead_letter(&self, job_id: Uuid) -> Result { let mut conn = self.cache.get_connection().await?; @@ -280,18 +288,28 @@ impl EmailQueue { return Ok(false); } - // Reset DB status so the worker will pick it up again. + // Reset attempts to 0 so the job gets its full retry budget. + self.db + .email_update_job_attempts(job_id, 0, None) + .await?; + + // Reset status to pending. self.db .email_update_job_status(job_id, crate::email::types::EmailJobStatus::Pending.as_str(), None) .await?; - let score = chrono::Utc::now().timestamp() as f64; + // Schedule processing after the cooling-off delay to prevent tight loops. + let eligible_at = chrono::Utc::now().timestamp() + Self::DEAD_LETTER_REQUEUE_DELAY_SECS; let _: () = conn - .zadd(EMAIL_QUEUE_KEY, job_id.to_string(), score) + .zadd(EMAIL_QUEUE_KEY, job_id.to_string(), eligible_at as f64) .await .context("Failed to re-enqueue dead-letter job")?; - tracing::info!("Requeued dead-letter email job: {}", job_id); + tracing::info!( + job_id = %job_id, + delay_secs = Self::DEAD_LETTER_REQUEUE_DELAY_SECS, + "Requeued dead-letter email job with cooling-off delay" + ); Ok(true) } @@ -518,6 +536,12 @@ pub struct QueueStats { mod tests { use super::*; + #[test] + fn dead_letter_requeue_delay_is_positive() { + assert!(EmailQueue::DEAD_LETTER_REQUEUE_DELAY_SECS > 0, + "cooling-off delay must be positive to prevent immediate re-failure loops"); + } + /// Test that recover_orphaned_jobs correctly identifies stale jobs. /// /// Acceptance criteria for #472: diff --git a/services/api/src/handlers.rs b/services/api/src/handlers.rs index 1863a93..bbea700 100644 --- a/services/api/src/handlers.rs +++ b/services/api/src/handlers.rs @@ -84,8 +84,17 @@ impl IntoResponse for ApiError { fn into_api_error(err: anyhow::Error) -> ApiError { if let Some(db_err) = err.downcast_ref::() { - if matches!(db_err, DbError::Timeout) { - return ApiError::service_unavailable("database query timed out"); + match db_err { + DbError::Timeout => { + return ApiError::service_unavailable("database query timed out"); + } + DbError::PoolExhausted => { + return ApiError::service_unavailable("database connection pool exhausted"); + } + DbError::ConstraintViolation(msg) => { + return ApiError::conflict(msg.clone()); + } + DbError::Other(_) => {} } } ApiError::internal(err) diff --git a/services/api/src/metrics.rs b/services/api/src/metrics.rs index 8ed873b..61f800e 100644 --- a/services/api/src/metrics.rs +++ b/services/api/src/metrics.rs @@ -17,6 +17,7 @@ pub struct Metrics { db_pool_connections_active: IntGaugeVec, db_pool_connections_idle: IntGaugeVec, db_pool_acquire_duration: HistogramVec, + rate_limit_rejections: IntCounterVec, } impl Metrics { @@ -110,6 +111,15 @@ impl Metrics { ) .context("db_pool_acquire_duration metric")?; + let rate_limit_rejections = IntCounterVec::new( + prometheus::Opts::new( + "rate_limit_rejections_total", + "Requests rejected by the rate limiter, by route", + ), + &["route"], + ) + .context("rate_limit_rejections metric")?; + registry.register(Box::new(cache_hits.clone()))?; registry.register(Box::new(cache_misses.clone()))?; registry.register(Box::new(invalidations.clone()))?; @@ -121,6 +131,7 @@ impl Metrics { registry.register(Box::new(db_pool_connections_active.clone()))?; registry.register(Box::new(db_pool_connections_idle.clone()))?; registry.register(Box::new(db_pool_acquire_duration.clone()))?; + registry.register(Box::new(rate_limit_rejections.clone()))?; Ok(Self { registry, @@ -135,6 +146,7 @@ impl Metrics { db_pool_connections_active, db_pool_connections_idle, db_pool_acquire_duration, + rate_limit_rejections, }) } @@ -204,6 +216,14 @@ impl Metrics { .observe(duration.as_secs_f64()); } + /// Increment the rate-limit rejection counter for a route. + /// Call this whenever a request is rejected with 429 Too Many Requests. + pub fn observe_rate_limit_rejection(&self, route: &str) { + self.rate_limit_rejections + .with_label_values(&[route]) + .inc(); + } + pub fn render(&self) -> anyhow::Result { let mut buffer = vec![]; let encoder = TextEncoder::new(); diff --git a/services/api/src/rate_limit.rs b/services/api/src/rate_limit.rs index 4faf5ad..6a33277 100644 --- a/services/api/src/rate_limit.rs +++ b/services/api/src/rate_limit.rs @@ -43,8 +43,11 @@ impl Default for RateLimitConfig { #[derive(Clone)] pub struct RateLimitState { - pub redis: Arc, - pub config: RateLimitConfig, + pub redis: Arc, + pub config: RateLimitConfig, + /// Optional metrics sink. When present, rejections are counted under + /// the `rate_limit_rejections_total` Prometheus counter. + pub metrics: Option, } #[derive(Serialize)] @@ -122,6 +125,15 @@ pub async fn rate_limit_middleware( match check_rate_limit(&state.redis, &state.config, &client_key).await { Ok(_count) => next.run(req).await, Err(retry_after) => { + if let Some(m) = &state.metrics { + m.observe_rate_limit_rejection(&state.config.key_prefix); + } + tracing::warn!( + client_key = %client_key, + route = %state.config.key_prefix, + retry_after, + "rate limit exceeded" + ); let body = RateLimitError { error: "rate_limit_exceeded", message: format!( @@ -173,4 +185,15 @@ mod tests { assert_eq!(cfg.window_seconds, 60); assert!(!cfg.key_prefix.is_empty()); } + + #[test] + fn rate_limit_state_metrics_field_is_optional() { + let state = RateLimitState { + redis: std::sync::Arc::new(deadpool_redis::Config::from_url("redis://127.0.0.1") + .create_pool(Some(deadpool_redis::Runtime::Tokio1)).unwrap()), + config: RateLimitConfig::default(), + metrics: None, + }; + assert!(state.metrics.is_none()); + } } diff --git a/services/api/src/security.rs b/services/api/src/security.rs index c98066b..c2f6fc8 100644 --- a/services/api/src/security.rs +++ b/services/api/src/security.rs @@ -480,7 +480,13 @@ pub async fn sendgrid_webhook_middleware( .and_then(|h| h.to_str().ok()) .unwrap_or(""); - // Replay protection: reject stale timestamps (> config.replay_window_secs) + // Replay protection: reject stale AND future-dated timestamps. + // + // The previous check used .abs() which accepted future-dated timestamps + // within the replay window. An attacker could pre-sign a request with + // timestamp = now + window - 1 and replay it for up to 2 * window seconds. + // The fix rejects any timestamp that is in the future at all, and any that + // is more than replay_window_secs old. let ts_str = headers .get("x-twilio-email-event-webhook-timestamp") .and_then(|h| h.to_str().ok()) @@ -490,7 +496,14 @@ pub async fn sendgrid_webhook_middleware( .duration_since(std::time::UNIX_EPOCH) .map(|d| d.as_secs() as i64) .unwrap_or(0); - if (now - ts).abs() > config.replay_window_secs as i64 { + let age_secs = now - ts; + if age_secs < 0 || age_secs > config.replay_window_secs as i64 { + tracing::warn!( + ts, + now, + age_secs, + "sendgrid webhook rejected: timestamp out of bounds" + ); return Err(StatusCode::UNAUTHORIZED); } @@ -1214,6 +1227,45 @@ mod tests { grace_period )); } + + // ── webhook timestamp validation tests ─────────────────────────────────── + + fn check_age(now: i64, ts: i64, window: u64) -> bool { + let age_secs = now - ts; + !(age_secs < 0 || age_secs > window as i64) + } + + #[test] + fn webhook_timestamp_within_window_accepted() { + let now = 1_700_000_000i64; + assert!(check_age(now, now - 100, 300)); + } + + #[test] + fn webhook_timestamp_exactly_at_window_edge_accepted() { + let now = 1_700_000_000i64; + assert!(check_age(now, now - 300, 300)); + } + + #[test] + fn webhook_timestamp_beyond_window_rejected() { + let now = 1_700_000_000i64; + assert!(!check_age(now, now - 301, 300)); + } + + #[test] + fn webhook_future_timestamp_rejected() { + let now = 1_700_000_000i64; + assert!(!check_age(now, now + 1, 300)); + } + + #[test] + fn webhook_future_timestamp_within_old_window_rejected() { + // Under the old .abs() logic, now + 299 would have been accepted + // because abs(now - (now+299)) = 299 < 300. The new logic rejects it. + let now = 1_700_000_000i64; + assert!(!check_age(now, now + 299, 300)); + } } // ── Password hashing (Argon2id) ───────────────────────────────────────────────