diff --git a/backend/src/lib/websocket-relay-queries.js b/backend/src/lib/websocket-relay-queries.js new file mode 100644 index 00000000..87fdb5a3 --- /dev/null +++ b/backend/src/lib/websocket-relay-queries.js @@ -0,0 +1,138 @@ +/** + * websocket-relay-queries.js + * + * Optimized SQL query helpers for the WebSocket relay event store. + * + * Index recommendations (run once against your database): + * + * -- Covering index for paginated relay event reads, ordered by creation time + * CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_relay_events_created_at + * ON relay_events (created_at DESC); + * + * -- Composite index for queue dequeue queries (status + created_at) + * CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_relay_events_status_created + * ON relay_events (status, created_at ASC) + * WHERE status = 'pending'; + * + * -- Index for payment-scoped event lookups + * CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_relay_events_payment_id + * ON relay_events (payment_id, created_at DESC); + * + * These partial/covering indexes keep relay reads fast even when the table + * grows into the millions of rows. + */ + +/** + * Fetch the most recent relay events in descending creation order. + * + * Uses a paginated query so callers never pull unbounded result sets. + * The covering index on (created_at DESC) ensures an index-only scan + * on databases that support it. + * + * @param {object} db - pg Pool or pg Client with a .query() method + * @param {number} limit - Maximum rows to return (default 100, max 1000) + * @param {number} offset - Row offset for pagination (default 0) + * @returns {Promise} + */ +async function getRecentEvents(db, limit = 100, offset = 0) { + const parsedLimit = Number(limit); + const safeLimit = Math.min(Math.max(1, Number.isFinite(parsedLimit) ? parsedLimit : 100), 1000); + const parsedOffset = Number(offset); + const safeOffset = Math.max(0, Number.isFinite(parsedOffset) ? parsedOffset : 0); + + // SET statement_timeout guards against runaway read queries + const { rows } = await db.query( + ` + SET LOCAL statement_timeout = '5s'; + SELECT + id, + payment_id, + event_type, + payload, + status, + created_at + FROM relay_events + ORDER BY created_at DESC + LIMIT $1 + OFFSET $2 + `, + [safeLimit, safeOffset], + ); + + return rows; +} + +/** + * Batch-insert multiple relay events in a single round-trip. + * + * Building one multi-row VALUES clause is significantly faster than + * issuing N sequential INSERT statements, especially when N > 10, + * because it eliminates per-statement network and parse overhead. + * + * @param {object} db - pg Pool or pg Client + * @param {object[]} events - Array of event objects with: + * { payment_id, event_type, payload, status } + * @returns {Promise} Inserted rows + */ +async function batchInsertEvents(db, events) { + if (!Array.isArray(events) || events.length === 0) { + return []; + } + + const values = []; + const placeholders = events.map((event, i) => { + const base = i * 4; + values.push( + event.payment_id, + event.event_type, + event.payload != null ? JSON.stringify(event.payload) : null, + event.status || "pending", + ); + return `($${base + 1}, $${base + 2}, $${base + 3}, $${base + 4}, NOW())`; + }); + + const sql = ` + SET LOCAL statement_timeout = '5s'; + INSERT INTO relay_events (payment_id, event_type, payload, status, created_at) + VALUES ${placeholders.join(", ")} + RETURNING id, payment_id, event_type, status, created_at + `; + + const { rows } = await db.query(sql, values); + return rows; +} + +/** + * Dequeue the next pending relay event using SELECT … FOR UPDATE SKIP LOCKED. + * + * SKIP LOCKED lets multiple consumers pull from the queue concurrently + * without waiting on row locks held by other workers, which eliminates + * the "thundering herd" lock contention common in naive queue designs. + * The SET LOCAL statement_timeout prevents a worker from holding a lock + * indefinitely when the database is slow. + * + * @param {object} db - pg Pool or pg Client + * @returns {Promise} The dequeued event row, or null if the queue is empty + */ +async function dequeueNextEvent(db) { + const { rows } = await db.query(` + SET LOCAL statement_timeout = '5s'; + WITH next_event AS ( + SELECT id + FROM relay_events + WHERE status = 'pending' + ORDER BY created_at ASC + LIMIT 1 + FOR UPDATE SKIP LOCKED + ) + UPDATE relay_events + SET status = 'processing' + FROM next_event + WHERE relay_events.id = next_event.id + RETURNING relay_events.id, relay_events.payment_id, relay_events.event_type, relay_events.payload, relay_events.created_at + `); + + return rows[0] || null; +} + +export { getRecentEvents, batchInsertEvents, dequeueNextEvent }; diff --git a/backend/src/lib/websocket-relay-queries.test.js b/backend/src/lib/websocket-relay-queries.test.js new file mode 100644 index 00000000..379ece57 --- /dev/null +++ b/backend/src/lib/websocket-relay-queries.test.js @@ -0,0 +1,170 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { + getRecentEvents, + batchInsertEvents, + dequeueNextEvent, +} from "./websocket-relay-queries.js"; + +function makeDb(rows = []) { + return { + query: vi.fn().mockResolvedValue({ rows }), + }; +} + +describe("getRecentEvents", () => { + let db; + + beforeEach(() => { + db = makeDb([ + { id: "1", payment_id: "p1", event_type: "payment.confirmed", status: "pending", created_at: new Date() }, + ]); + }); + + it("calls db.query with correct limit and offset", async () => { + await getRecentEvents(db, 10, 20); + expect(db.query).toHaveBeenCalledTimes(1); + const [sql, params] = db.query.mock.calls[0]; + expect(params).toEqual([10, 20]); + expect(sql).toContain("ORDER BY created_at DESC"); + expect(sql).toContain("LIMIT $1"); + expect(sql).toContain("OFFSET $2"); + }); + + it("uses default limit 100 and offset 0 when not provided", async () => { + await getRecentEvents(db); + const [, params] = db.query.mock.calls[0]; + expect(params).toEqual([100, 0]); + }); + + it("clamps limit to 1000 max", async () => { + await getRecentEvents(db, 99999, 0); + const [, params] = db.query.mock.calls[0]; + expect(params[0]).toBe(1000); + }); + + it("clamps limit to minimum of 1", async () => { + await getRecentEvents(db, 0, 0); + const [, params] = db.query.mock.calls[0]; + expect(params[0]).toBe(1); + }); + + it("returns rows from the query result", async () => { + const result = await getRecentEvents(db, 5, 0); + expect(Array.isArray(result)).toBe(true); + expect(result).toHaveLength(1); + expect(result[0].payment_id).toBe("p1"); + }); + + it("includes SET LOCAL statement_timeout in the query", async () => { + await getRecentEvents(db, 10, 0); + const [sql] = db.query.mock.calls[0]; + expect(sql).toContain("statement_timeout"); + }); +}); + +describe("batchInsertEvents", () => { + let db; + + beforeEach(() => { + db = makeDb([ + { id: "a", payment_id: "p1", event_type: "relay.send", status: "pending", created_at: new Date() }, + { id: "b", payment_id: "p2", event_type: "relay.send", status: "pending", created_at: new Date() }, + ]); + }); + + it("returns empty array when events array is empty", async () => { + const result = await batchInsertEvents(db, []); + expect(result).toEqual([]); + expect(db.query).not.toHaveBeenCalled(); + }); + + it("returns empty array when events is not an array", async () => { + const result = await batchInsertEvents(db, null); + expect(result).toEqual([]); + expect(db.query).not.toHaveBeenCalled(); + }); + + it("builds a single INSERT with multiple VALUE rows", async () => { + const events = [ + { payment_id: "p1", event_type: "relay.send", payload: { foo: 1 }, status: "pending" }, + { payment_id: "p2", event_type: "relay.ack", payload: null, status: "pending" }, + ]; + await batchInsertEvents(db, events); + expect(db.query).toHaveBeenCalledTimes(1); + const [sql, params] = db.query.mock.calls[0]; + expect(sql).toContain("INSERT INTO relay_events"); + // 2 events × 4 params each = 8 total + expect(params).toHaveLength(8); + expect(params[0]).toBe("p1"); + expect(params[1]).toBe("relay.send"); + expect(params[4]).toBe("p2"); + }); + + it("serializes payload objects to JSON strings", async () => { + const events = [ + { payment_id: "p1", event_type: "relay.send", payload: { amount: 100 }, status: "pending" }, + ]; + await batchInsertEvents(db, events); + const [, params] = db.query.mock.calls[0]; + expect(params[2]).toBe(JSON.stringify({ amount: 100 })); + }); + + it("defaults status to 'pending' if not provided", async () => { + const events = [{ payment_id: "p1", event_type: "relay.send", payload: null }]; + await batchInsertEvents(db, events); + const [, params] = db.query.mock.calls[0]; + expect(params[3]).toBe("pending"); + }); + + it("returns rows from the query result", async () => { + const events = [ + { payment_id: "p1", event_type: "relay.send", payload: null, status: "pending" }, + { payment_id: "p2", event_type: "relay.ack", payload: null, status: "pending" }, + ]; + const result = await batchInsertEvents(db, events); + expect(result).toHaveLength(2); + }); + + it("includes statement_timeout in the query", async () => { + const events = [{ payment_id: "p1", event_type: "e", payload: null, status: "pending" }]; + await batchInsertEvents(db, events); + const [sql] = db.query.mock.calls[0]; + expect(sql).toContain("statement_timeout"); + }); +}); + +describe("dequeueNextEvent", () => { + it("returns the first event row when queue has items", async () => { + const event = { id: "x", payment_id: "p1", event_type: "relay.send", payload: null, created_at: new Date() }; + const db = makeDb([event]); + const result = await dequeueNextEvent(db); + expect(result).toEqual(event); + }); + + it("returns null when the queue is empty", async () => { + const db = makeDb([]); + const result = await dequeueNextEvent(db); + expect(result).toBeNull(); + }); + + it("issues a SELECT FOR UPDATE SKIP LOCKED query", async () => { + const db = makeDb([]); + await dequeueNextEvent(db); + const [sql] = db.query.mock.calls[0]; + expect(sql).toContain("FOR UPDATE SKIP LOCKED"); + }); + + it("updates status to processing in the same query", async () => { + const db = makeDb([]); + await dequeueNextEvent(db); + const [sql] = db.query.mock.calls[0]; + expect(sql).toContain("status = 'processing'"); + }); + + it("includes statement_timeout guard", async () => { + const db = makeDb([]); + await dequeueNextEvent(db); + const [sql] = db.query.mock.calls[0]; + expect(sql).toContain("statement_timeout"); + }); +}); diff --git a/backend/src/lib/websocket-relay-recovery.js b/backend/src/lib/websocket-relay-recovery.js new file mode 100644 index 00000000..7a7248dd --- /dev/null +++ b/backend/src/lib/websocket-relay-recovery.js @@ -0,0 +1,249 @@ +/** + * websocket-relay-recovery.js + * + * Error-recovery primitives for the WebSocket relay: + * - Exponential backoff reconnection + * - Circuit breaker (CLOSED → OPEN → HALF_OPEN) + * - Dead letter queue (DLQ) for undeliverable messages + * - Health snapshot + */ + +// ─── Circuit breaker states ─────────────────────────────────────────────────── + +const CircuitState = Object.freeze({ + CLOSED: "CLOSED", + OPEN: "OPEN", + HALF_OPEN: "HALF_OPEN", +}); + +// ─── Exponential backoff reconnect ─────────────────────────────────────────── + +/** + * Attempt to (re)connect using an exponential backoff schedule. + * + * Delays are calculated as: + * delay = min(baseDelayMs * 2^attempt + jitter, maxDelayMs) + * + * A small random jitter (±10 % of the computed delay) is added so that + * multiple relay instances don't stampede the server at the same moment. + * + * @param {() => Promise} connectFn - Async function that establishes the connection. + * Resolves on success, throws on failure. + * @param {object} [opts] + * @param {number} [opts.maxRetries=5] - Maximum number of retries (0 = try once) + * @param {number} [opts.baseDelayMs=1000] - Initial delay in milliseconds + * @param {number} [opts.maxDelayMs=30000] - Maximum delay cap in milliseconds + * @param {Function}[opts.sleep] - Injectable sleep (defaults to setTimeout promise); useful for testing + * @returns {Promise} Resolves with the value returned by connectFn + * @throws {Error} Re-throws the last error when all retries are exhausted + */ +async function reconnectWithBackoff(connectFn, opts = {}) { + const { + maxRetries = 5, + baseDelayMs = 1000, + maxDelayMs = 30000, + sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)), + } = opts; + + let lastError; + + for (let attempt = 0; attempt <= maxRetries; attempt++) { + try { + return await connectFn(); + } catch (err) { + lastError = err; + + if (attempt === maxRetries) { + break; + } + + const exponential = baseDelayMs * Math.pow(2, attempt); + const capped = Math.min(exponential, maxDelayMs); + // ±10 % jitter + const jitter = capped * 0.1 * (Math.random() * 2 - 1); + const delay = Math.max(0, Math.round(capped + jitter)); + + await sleep(delay); + } + } + + throw lastError; +} + +// ─── Circuit Breaker ───────────────────────────────────────────────────────── + +/** + * A three-state circuit breaker that protects downstream dependencies. + * + * CLOSED → normal operation; failures are counted. + * OPEN → fast-fail; calls are rejected immediately. + * HALF_OPEN → one probe call is allowed through to test recovery. + * + * @param {object} [opts] + * @param {number} [opts.failureThreshold=5] - Consecutive failures before the circuit opens + * @param {number} [opts.resetTimeoutMs=30000] - Time (ms) before trying HALF_OPEN + * @param {Function}[opts.now] - Injectable clock (defaults to Date.now); useful for testing + */ +class CircuitBreaker { + constructor(opts = {}) { + this._failureThreshold = opts.failureThreshold ?? 5; + this._resetTimeoutMs = opts.resetTimeoutMs ?? 30000; + this._now = opts.now ?? (() => Date.now()); + + this._state = CircuitState.CLOSED; + this._failureCount = 0; + this._openedAt = null; + this._lastError = null; + } + + get state() { + return this._state; + } + + /** + * Execute `fn` through the circuit breaker. + * + * @param {() => Promise} fn + * @returns {Promise} + * @throws {Error} When the circuit is OPEN, or when `fn` fails + */ + async call(fn) { + if (this._state === CircuitState.OPEN) { + // Check whether reset timeout has elapsed + if (this._now() - this._openedAt >= this._resetTimeoutMs) { + this._state = CircuitState.HALF_OPEN; + } else { + const err = new Error("Circuit breaker is OPEN — call rejected"); + err.circuitOpen = true; + throw err; + } + } + + try { + const result = await fn(); + this._onSuccess(); + return result; + } catch (err) { + this._onFailure(err); + throw err; + } + } + + _onSuccess() { + this._failureCount = 0; + this._lastError = null; + this._state = CircuitState.CLOSED; + } + + _onFailure(err) { + this._lastError = err; + this._failureCount += 1; + + if ( + this._state === CircuitState.HALF_OPEN || + this._failureCount >= this._failureThreshold + ) { + this._state = CircuitState.OPEN; + this._openedAt = this._now(); + } + } + + /** Return a plain-object snapshot suitable for health checks. */ + getStatus() { + return { + state: this._state, + failureCount: this._failureCount, + lastError: this._lastError ? this._lastError.message : null, + }; + } +} + +// ─── Dead Letter Queue ──────────────────────────────────────────────────────── + +/** + * In-memory dead letter queue for relay messages that could not be delivered. + * + * Each entry records the original message, the error that caused the failure, + * the retry count, and a timestamp for TTL-based cleanup. + */ +class DeadLetterQueue { + constructor() { + this._entries = []; + } + + /** + * Push a failed message onto the DLQ. + * + * @param {any} message - The original relay message + * @param {Error} error - The error that caused the failure + * @param {number}[retryCount=0] - How many delivery attempts have been made + */ + push(message, error, retryCount = 0) { + this._entries.push({ + message, + errorMessage: error instanceof Error ? error.message : String(error), + retryCount, + failedAt: new Date().toISOString(), + }); + } + + /** + * Return a shallow copy of all current DLQ entries. + * + * @returns {object[]} + */ + getAll() { + return [...this._entries]; + } + + /** + * Remove and return the oldest DLQ entry, or null if the queue is empty. + * + * @returns {object|null} + */ + shift() { + return this._entries.shift() || null; + } + + /** Number of entries currently in the queue. */ + get size() { + return this._entries.length; + } + + /** Remove all entries. */ + clear() { + this._entries = []; + } +} + +// ─── Relay Health ───────────────────────────────────────────────────────────── + +/** + * Aggregate health snapshot from the relay's circuit breaker and DLQ. + * + * @param {object} opts + * @param {CircuitBreaker} opts.circuitBreaker - The relay's CircuitBreaker instance + * @param {DeadLetterQueue} opts.dlq - The relay's DeadLetterQueue instance + * @param {number} [opts.reconnectCount=0] - How many reconnections have occurred + * @returns {{ status: string, reconnectCount: number, circuitState: string, dlqSize: number, lastError: string|null }} + */ +function getRelayHealth({ circuitBreaker, dlq, reconnectCount = 0 }) { + const cbStatus = circuitBreaker.getStatus(); + const isHealthy = cbStatus.state === CircuitState.CLOSED && dlq.size === 0; + + return { + status: isHealthy ? "healthy" : "degraded", + reconnectCount, + circuitState: cbStatus.state, + dlqSize: dlq.size, + lastError: cbStatus.lastError, + }; +} + +export { + CircuitState, + CircuitBreaker, + DeadLetterQueue, + getRelayHealth, + reconnectWithBackoff, +}; diff --git a/backend/src/lib/websocket-relay-recovery.test.js b/backend/src/lib/websocket-relay-recovery.test.js new file mode 100644 index 00000000..760a921c --- /dev/null +++ b/backend/src/lib/websocket-relay-recovery.test.js @@ -0,0 +1,278 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { + CircuitBreaker, + CircuitState, + DeadLetterQueue, + getRelayHealth, + reconnectWithBackoff, +} from "./websocket-relay-recovery.js"; + +// ─── reconnectWithBackoff ────────────────────────────────────────────────────── + +describe("reconnectWithBackoff", () => { + it("resolves immediately when connectFn succeeds on first try", async () => { + const connectFn = vi.fn().mockResolvedValue("connected"); + const result = await reconnectWithBackoff(connectFn, { + maxRetries: 3, + sleep: vi.fn().mockResolvedValue(undefined), + }); + expect(result).toBe("connected"); + expect(connectFn).toHaveBeenCalledTimes(1); + }); + + it("retries after failure and succeeds on a later attempt", async () => { + const connectFn = vi + .fn() + .mockRejectedValueOnce(new Error("fail 1")) + .mockRejectedValueOnce(new Error("fail 2")) + .mockResolvedValue("connected"); + + const sleep = vi.fn().mockResolvedValue(undefined); + + const result = await reconnectWithBackoff(connectFn, { + maxRetries: 5, + baseDelayMs: 100, + maxDelayMs: 10000, + sleep, + }); + + expect(result).toBe("connected"); + expect(connectFn).toHaveBeenCalledTimes(3); + expect(sleep).toHaveBeenCalledTimes(2); + }); + + it("sleep delays increase across attempts (exponential backoff)", async () => { + const delays = []; + const sleep = vi.fn().mockImplementation((ms) => { + delays.push(ms); + return Promise.resolve(); + }); + + const connectFn = vi + .fn() + .mockRejectedValueOnce(new Error("f1")) + .mockRejectedValueOnce(new Error("f2")) + .mockRejectedValueOnce(new Error("f3")) + .mockResolvedValue("ok"); + + await reconnectWithBackoff(connectFn, { + maxRetries: 5, + baseDelayMs: 100, + maxDelayMs: 10000, + sleep, + }); + + // Each delay should be >= the previous (jitter may cause slight variance; + // we just assert delays[1] >= delays[0] on average with reasonable tolerance) + expect(delays).toHaveLength(3); + // Second delay should be at least roughly double the first (before jitter) + expect(delays[1]).toBeGreaterThanOrEqual(delays[0] * 0.8); + expect(delays[2]).toBeGreaterThanOrEqual(delays[1] * 0.8); + }); + + it("throws the last error when all retries are exhausted", async () => { + const error = new Error("permanent failure"); + const connectFn = vi.fn().mockRejectedValue(error); + const sleep = vi.fn().mockResolvedValue(undefined); + + await expect( + reconnectWithBackoff(connectFn, { maxRetries: 2, sleep }), + ).rejects.toThrow("permanent failure"); + + // Initial attempt + 2 retries = 3 total calls + expect(connectFn).toHaveBeenCalledTimes(3); + }); + + it("does not sleep after the final failed attempt", async () => { + const sleep = vi.fn().mockResolvedValue(undefined); + const connectFn = vi.fn().mockRejectedValue(new Error("fail")); + + await reconnectWithBackoff(connectFn, { maxRetries: 2, sleep }).catch(() => {}); + + // maxRetries=2 → 3 attempts → sleep called 2 times (not after the last) + expect(sleep).toHaveBeenCalledTimes(2); + }); +}); + +// ─── CircuitBreaker ─────────────────────────────────────────────────────────── + +describe("CircuitBreaker", () => { + let cb; + + beforeEach(() => { + cb = new CircuitBreaker({ failureThreshold: 3, resetTimeoutMs: 5000 }); + }); + + it("starts in CLOSED state", () => { + expect(cb.state).toBe(CircuitState.CLOSED); + }); + + it("stays CLOSED after fewer failures than the threshold", async () => { + const fn = vi.fn().mockRejectedValue(new Error("fail")); + await cb.call(fn).catch(() => {}); + await cb.call(fn).catch(() => {}); + expect(cb.state).toBe(CircuitState.CLOSED); + }); + + it("opens after reaching the failure threshold", async () => { + const fn = vi.fn().mockRejectedValue(new Error("fail")); + for (let i = 0; i < 3; i++) { + await cb.call(fn).catch(() => {}); + } + expect(cb.state).toBe(CircuitState.OPEN); + }); + + it("rejects calls immediately when OPEN without calling fn", async () => { + const fn = vi.fn().mockRejectedValue(new Error("fail")); + // Trip the breaker + for (let i = 0; i < 3; i++) { + await cb.call(fn).catch(() => {}); + } + fn.mockReset(); + + await expect(cb.call(fn)).rejects.toThrow("Circuit breaker is OPEN"); + expect(fn).not.toHaveBeenCalled(); + }); + + it("transitions to HALF_OPEN after reset timeout elapses", async () => { + let fakeNow = Date.now(); + const cb2 = new CircuitBreaker({ + failureThreshold: 1, + resetTimeoutMs: 1000, + now: () => fakeNow, + }); + + const fn = vi.fn().mockRejectedValue(new Error("fail")); + await cb2.call(fn).catch(() => {}); + expect(cb2.state).toBe(CircuitState.OPEN); + + // Advance fake clock past reset timeout + fakeNow += 2000; + + // Next call should probe (HALF_OPEN) + const successFn = vi.fn().mockResolvedValue("ok"); + await cb2.call(successFn); + expect(cb2.state).toBe(CircuitState.CLOSED); + }); + + it("closes after a successful call", async () => { + const failFn = vi.fn().mockRejectedValue(new Error("fail")); + const successFn = vi.fn().mockResolvedValue("ok"); + + await cb.call(failFn).catch(() => {}); + await cb.call(failFn).catch(() => {}); + await cb.call(successFn); + + expect(cb.state).toBe(CircuitState.CLOSED); + }); + + it("getStatus returns expected shape", () => { + const status = cb.getStatus(); + expect(status).toMatchObject({ + state: CircuitState.CLOSED, + failureCount: 0, + lastError: null, + }); + }); +}); + +// ─── DeadLetterQueue ────────────────────────────────────────────────────────── + +describe("DeadLetterQueue", () => { + let dlq; + + beforeEach(() => { + dlq = new DeadLetterQueue(); + }); + + it("starts empty", () => { + expect(dlq.size).toBe(0); + expect(dlq.getAll()).toEqual([]); + }); + + it("stores a failed message with error metadata", () => { + const msg = { type: "payment.confirmed", payload: {} }; + const err = new Error("delivery failed"); + dlq.push(msg, err, 1); + + expect(dlq.size).toBe(1); + const entry = dlq.getAll()[0]; + expect(entry.message).toEqual(msg); + expect(entry.errorMessage).toBe("delivery failed"); + expect(entry.retryCount).toBe(1); + expect(typeof entry.failedAt).toBe("string"); + }); + + it("accumulates multiple entries", () => { + dlq.push({ id: 1 }, new Error("e1"), 0); + dlq.push({ id: 2 }, new Error("e2"), 1); + expect(dlq.size).toBe(2); + }); + + it("shift removes and returns the oldest entry", () => { + dlq.push({ id: 1 }, new Error("e1")); + dlq.push({ id: 2 }, new Error("e2")); + const first = dlq.shift(); + expect(first.message).toEqual({ id: 1 }); + expect(dlq.size).toBe(1); + }); + + it("shift returns null when empty", () => { + expect(dlq.shift()).toBeNull(); + }); + + it("getAll returns a copy, not the internal array", () => { + dlq.push({ id: 1 }, new Error("e")); + const copy = dlq.getAll(); + copy.pop(); + expect(dlq.size).toBe(1); + }); + + it("clear empties the queue", () => { + dlq.push({ id: 1 }, new Error("e")); + dlq.clear(); + expect(dlq.size).toBe(0); + }); +}); + +// ─── getRelayHealth ─────────────────────────────────────────────────────────── + +describe("getRelayHealth", () => { + it("returns healthy when circuit is CLOSED and DLQ is empty", () => { + const cb = new CircuitBreaker(); + const dlq = new DeadLetterQueue(); + const health = getRelayHealth({ circuitBreaker: cb, dlq, reconnectCount: 0 }); + + expect(health).toMatchObject({ + status: "healthy", + reconnectCount: 0, + circuitState: CircuitState.CLOSED, + dlqSize: 0, + lastError: null, + }); + }); + + it("returns degraded when DLQ has entries", () => { + const cb = new CircuitBreaker(); + const dlq = new DeadLetterQueue(); + dlq.push({ id: 1 }, new Error("fail")); + + const health = getRelayHealth({ circuitBreaker: cb, dlq, reconnectCount: 2 }); + expect(health.status).toBe("degraded"); + expect(health.dlqSize).toBe(1); + expect(health.reconnectCount).toBe(2); + }); + + it("returns degraded when circuit is OPEN", async () => { + const cb = new CircuitBreaker({ failureThreshold: 1 }); + const dlq = new DeadLetterQueue(); + + const failFn = vi.fn().mockRejectedValue(new Error("down")); + await cb.call(failFn).catch(() => {}); + + const health = getRelayHealth({ circuitBreaker: cb, dlq, reconnectCount: 5 }); + expect(health.status).toBe("degraded"); + expect(health.circuitState).toBe(CircuitState.OPEN); + expect(health.lastError).toBe("down"); + }); +}); diff --git a/backend/src/lib/websocket-relay-security.js b/backend/src/lib/websocket-relay-security.js new file mode 100644 index 00000000..e0bd6406 --- /dev/null +++ b/backend/src/lib/websocket-relay-security.js @@ -0,0 +1,199 @@ +/** + * websocket-relay-security.js + * + * Security primitives for the WebSocket relay: + * - Origin validation (whitelist) + * - Message size enforcement + * - JWT verification for relay connections + * - Input sanitization (unknown-field stripping + shape validation) + * - Audit logging + */ + +import jwt from "jsonwebtoken"; + +// ─── Allowed message fields ─────────────────────────────────────────────────── + +/** + * Fields that are permitted in an inbound relay message. + * Any key not in this set is stripped by sanitizeRelayMessage(). + */ +const ALLOWED_MESSAGE_FIELDS = new Set([ + "type", + "payment_id", + "event_type", + "payload", + "timestamp", + "version", +]); + +// ─── Origin validation ──────────────────────────────────────────────────────── + +/** + * Validate that `origin` is present in `allowedOrigins`. + * + * Performs an exact, case-sensitive string match. Wildcards are intentionally + * not supported; each allowed origin must be listed explicitly to prevent + * subdomain-takeover bypasses. + * + * @param {string} origin - The "Origin" header value from the WebSocket handshake + * @param {string[]} allowedOrigins - Whitelist of permitted origins + * @returns {{ valid: boolean, reason?: string }} + */ +function validateOrigin(origin, allowedOrigins) { + if (typeof origin !== "string" || origin.trim() === "") { + return { valid: false, reason: "Origin header is missing or empty" }; + } + + if (!Array.isArray(allowedOrigins) || allowedOrigins.length === 0) { + return { valid: false, reason: "No allowed origins configured" }; + } + + if (allowedOrigins.includes(origin)) { + return { valid: true }; + } + + return { valid: false, reason: `Origin '${origin}' is not whitelisted` }; +} + +// ─── Message size limit ─────────────────────────────────────────────────────── + +/** + * Throw if the serialised byte-length of `msg` exceeds `maxBytes`. + * + * WebSocket frames can be arbitrarily large; enforcing a maximum prevents + * memory-exhaustion attacks from a single oversized message. + * + * @param {string|Buffer|object} msg - The raw WebSocket message + * @param {number} maxBytes - Maximum allowed size in bytes (default 64 KiB) + * @throws {Error} When the message exceeds the size limit + */ +function enforceMessageSizeLimit(msg, maxBytes = 65536) { + let byteLength; + + if (Buffer.isBuffer(msg)) { + byteLength = msg.length; + } else if (typeof msg === "string") { + byteLength = Buffer.byteLength(msg, "utf8"); + } else { + // Serialise objects so we measure the on-wire size + try { + byteLength = Buffer.byteLength(JSON.stringify(msg), "utf8"); + } catch { + throw new Error("Message cannot be serialised for size check"); + } + } + + if (byteLength > maxBytes) { + const err = new Error( + `Message size ${byteLength} bytes exceeds limit of ${maxBytes} bytes`, + ); + err.code = "MESSAGE_TOO_LARGE"; + err.byteLength = byteLength; + err.maxBytes = maxBytes; + throw err; + } +} + +// ─── JWT verification ───────────────────────────────────────────────────────── + +/** + * Verify a relay connection JWT. + * + * Wraps jsonwebtoken.verify() in a promise-friendly, error-normalising helper. + * The algorithm is fixed to HS256 to prevent algorithm-confusion attacks + * (e.g. 'none' or RS256-with-HMAC-public-key). + * + * @param {string} token - The raw JWT string from the WebSocket sub-protocol or query param + * @param {string} secret - HMAC secret used to sign the token + * @returns {{ valid: boolean, payload?: object, reason?: string }} + */ +function verifyRelayToken(token, secret) { + if (typeof token !== "string" || token.trim() === "") { + return { valid: false, reason: "Token is missing or empty" }; + } + + if (typeof secret !== "string" || secret.trim() === "") { + return { valid: false, reason: "Secret is missing or empty" }; + } + + try { + const payload = jwt.verify(token, secret, { algorithms: ["HS256"] }); + return { valid: true, payload }; + } catch (err) { + return { valid: false, reason: err.message }; + } +} + +// ─── Input sanitization ─────────────────────────────────────────────────────── + +/** + * Strip unknown fields and validate the shape of an inbound relay message. + * + * Returns a clean copy of the message containing only the fields listed in + * ALLOWED_MESSAGE_FIELDS, so that unexpected keys never propagate further + * into the relay pipeline. + * + * @param {any} msg - The parsed WebSocket message object + * @returns {{ sanitized: object, warnings: string[] }} + * @throws {Error} When `msg` is not a non-null object, or when required fields are missing + */ +function sanitizeRelayMessage(msg) { + if (msg === null || typeof msg !== "object" || Array.isArray(msg)) { + throw new Error("Relay message must be a non-null object"); + } + + const warnings = []; + const sanitized = {}; + + // Copy only allowed fields + for (const [key, value] of Object.entries(msg)) { + if (ALLOWED_MESSAGE_FIELDS.has(key)) { + sanitized[key] = value; + } else { + warnings.push(`Unknown field stripped: '${key}'`); + } + } + + // Require at minimum a 'type' field + if (typeof sanitized.type !== "string" || sanitized.type.trim() === "") { + throw new Error("Relay message must include a non-empty 'type' field"); + } + + return { sanitized, warnings }; +} + +// ─── Audit logging ──────────────────────────────────────────────────────────── + +/** + * Write a structured audit log entry for a relay security event. + * + * In production this would write to a dedicated audit sink (e.g. a DB table, + * a remote log aggregator, or a write-only append log). For now it emits a + * structured JSON line to stdout so that log shippers can ingest it. + * + * @param {string} event - Short event identifier, e.g. "connection.rejected" + * @param {object} [metadata] - Additional context to include in the log entry + * @param {object} [opts] + * @param {Function} [opts.emit] - Injectable emitter; defaults to console.log (for testing) + */ +function auditRelayEvent(event, metadata = {}, opts = {}) { + const emit = opts.emit || console.log; + + const entry = { + audit: true, + ts: new Date().toISOString(), + event, + ...metadata, + }; + + emit(JSON.stringify(entry)); +} + +export { + validateOrigin, + enforceMessageSizeLimit, + verifyRelayToken, + sanitizeRelayMessage, + auditRelayEvent, + ALLOWED_MESSAGE_FIELDS, +}; diff --git a/backend/src/lib/websocket-relay-security.test.js b/backend/src/lib/websocket-relay-security.test.js new file mode 100644 index 00000000..dc8857ec --- /dev/null +++ b/backend/src/lib/websocket-relay-security.test.js @@ -0,0 +1,209 @@ +import { describe, it, expect, vi } from "vitest"; +import jwt from "jsonwebtoken"; +import { + validateOrigin, + enforceMessageSizeLimit, + verifyRelayToken, + sanitizeRelayMessage, + auditRelayEvent, +} from "./websocket-relay-security.js"; + +// ─── validateOrigin ─────────────────────────────────────────────────────────── + +describe("validateOrigin", () => { + const allowed = ["https://app.example.com", "https://dashboard.example.com"]; + + it("accepts an origin that is in the whitelist", () => { + expect(validateOrigin("https://app.example.com", allowed)).toMatchObject({ + valid: true, + }); + }); + + it("rejects an origin not in the whitelist", () => { + const result = validateOrigin("https://evil.example.com", allowed); + expect(result.valid).toBe(false); + expect(result.reason).toContain("not whitelisted"); + }); + + it("rejects an empty origin string", () => { + expect(validateOrigin("", allowed).valid).toBe(false); + }); + + it("rejects a null/undefined origin", () => { + expect(validateOrigin(null, allowed).valid).toBe(false); + expect(validateOrigin(undefined, allowed).valid).toBe(false); + }); + + it("rejects when allowedOrigins is empty", () => { + expect(validateOrigin("https://app.example.com", []).valid).toBe(false); + }); + + it("is case-sensitive — does not accept wrong casing", () => { + expect(validateOrigin("HTTPS://APP.EXAMPLE.COM", allowed).valid).toBe(false); + }); + + it("rejects a subdomain not on the whitelist", () => { + expect(validateOrigin("https://sub.app.example.com", allowed).valid).toBe(false); + }); +}); + +// ─── enforceMessageSizeLimit ────────────────────────────────────────────────── + +describe("enforceMessageSizeLimit", () => { + it("does not throw for a message within the limit", () => { + expect(() => enforceMessageSizeLimit("hello", 100)).not.toThrow(); + }); + + it("throws when a string message exceeds the limit", () => { + const big = "x".repeat(10); + expect(() => enforceMessageSizeLimit(big, 5)).toThrow("exceeds limit"); + }); + + it("throws when a Buffer message exceeds the limit", () => { + const buf = Buffer.alloc(200); + expect(() => enforceMessageSizeLimit(buf, 100)).toThrow("exceeds limit"); + }); + + it("throws when a JSON-serialised object exceeds the limit", () => { + const obj = { data: "x".repeat(200) }; + expect(() => enforceMessageSizeLimit(obj, 10)).toThrow("exceeds limit"); + }); + + it("does not throw for exactly the limit size", () => { + const msg = "x".repeat(5); + expect(() => enforceMessageSizeLimit(msg, 5)).not.toThrow(); + }); + + it("attaches code and byteLength to the thrown error", () => { + try { + enforceMessageSizeLimit("hello world", 3); + } catch (err) { + expect(err.code).toBe("MESSAGE_TOO_LARGE"); + expect(err.byteLength).toBeGreaterThan(3); + expect(err.maxBytes).toBe(3); + } + }); + + it("uses 65536 as the default limit", () => { + const small = "small"; + expect(() => enforceMessageSizeLimit(small)).not.toThrow(); + }); +}); + +// ─── verifyRelayToken ───────────────────────────────────────────────────────── + +describe("verifyRelayToken", () => { + const secret = "test-secret-key"; + + it("returns valid: true with payload for a good token", () => { + const token = jwt.sign({ sub: "relay-client", role: "relay" }, secret, { + algorithm: "HS256", + expiresIn: "1h", + }); + const result = verifyRelayToken(token, secret); + expect(result.valid).toBe(true); + expect(result.payload.sub).toBe("relay-client"); + }); + + it("returns valid: false for a token signed with a wrong secret", () => { + const token = jwt.sign({ sub: "attacker" }, "wrong-secret", { algorithm: "HS256" }); + const result = verifyRelayToken(token, secret); + expect(result.valid).toBe(false); + expect(result.reason).toBeTruthy(); + }); + + it("returns valid: false for an expired token", () => { + const token = jwt.sign({ sub: "relay-client" }, secret, { + algorithm: "HS256", + expiresIn: -1, + }); + const result = verifyRelayToken(token, secret); + expect(result.valid).toBe(false); + expect(result.reason).toMatch(/expired/i); + }); + + it("returns valid: false for a malformed token string", () => { + const result = verifyRelayToken("not.a.real.jwt", secret); + expect(result.valid).toBe(false); + }); + + it("returns valid: false when token is empty", () => { + expect(verifyRelayToken("", secret).valid).toBe(false); + }); + + it("returns valid: false when secret is empty", () => { + const token = jwt.sign({ sub: "x" }, secret, { algorithm: "HS256" }); + expect(verifyRelayToken(token, "").valid).toBe(false); + }); +}); + +// ─── sanitizeRelayMessage ───────────────────────────────────────────────────── + +describe("sanitizeRelayMessage", () => { + it("returns the message unchanged when all fields are allowed", () => { + const msg = { type: "payment.confirmed", payment_id: "abc123", payload: {} }; + const { sanitized, warnings } = sanitizeRelayMessage(msg); + expect(sanitized).toMatchObject(msg); + expect(warnings).toHaveLength(0); + }); + + it("strips unknown fields", () => { + const msg = { + type: "payment.confirmed", + payment_id: "abc", + __proto__: "attack", + injected: "bad", + }; + const { sanitized, warnings } = sanitizeRelayMessage(msg); + expect(sanitized).not.toHaveProperty("injected"); + expect(warnings.some((w) => w.includes("injected"))).toBe(true); + }); + + it("throws when the message is not an object", () => { + expect(() => sanitizeRelayMessage("string")).toThrow(); + expect(() => sanitizeRelayMessage(42)).toThrow(); + expect(() => sanitizeRelayMessage(null)).toThrow(); + expect(() => sanitizeRelayMessage([])).toThrow(); + }); + + it("throws when the 'type' field is missing", () => { + expect(() => sanitizeRelayMessage({ payment_id: "abc" })).toThrow( + /type/i, + ); + }); + + it("throws when 'type' is an empty string", () => { + expect(() => sanitizeRelayMessage({ type: "" })).toThrow(); + }); + + it("includes warning messages for each stripped field", () => { + const msg = { type: "relay.send", unknownA: 1, unknownB: 2 }; + const { warnings } = sanitizeRelayMessage(msg); + expect(warnings).toHaveLength(2); + expect(warnings[0]).toContain("unknownA"); + expect(warnings[1]).toContain("unknownB"); + }); +}); + +// ─── auditRelayEvent ────────────────────────────────────────────────────────── + +describe("auditRelayEvent", () => { + it("calls the emit function with a JSON string", () => { + const emit = vi.fn(); + auditRelayEvent("connection.rejected", { origin: "https://evil.com" }, { emit }); + expect(emit).toHaveBeenCalledTimes(1); + const logLine = emit.mock.calls[0][0]; + const parsed = JSON.parse(logLine); + expect(parsed.audit).toBe(true); + expect(parsed.event).toBe("connection.rejected"); + expect(parsed.origin).toBe("https://evil.com"); + expect(typeof parsed.ts).toBe("string"); + }); + + it("works without metadata", () => { + const emit = vi.fn(); + auditRelayEvent("token.verified", undefined, { emit }); + const parsed = JSON.parse(emit.mock.calls[0][0]); + expect(parsed.event).toBe("token.verified"); + }); +}); diff --git a/backend/src/middleware/payment-rate-limiter.js b/backend/src/middleware/payment-rate-limiter.js new file mode 100644 index 00000000..31f3ea0b --- /dev/null +++ b/backend/src/middleware/payment-rate-limiter.js @@ -0,0 +1,94 @@ +/** + * payment-rate-limiter.js + * + * In-memory sliding-window rate limiter for the payment processor. + * + * Designed to be a zero-dependency fallback when Redis is not available. + * Each key (e.g. client IP) maintains a { count, windowStart } record. + * Stale entries from previous windows are evicted on every request so the + * Map does not grow unboundedly in long-running processes. + * + * Usage example: + * + * const { createPaymentRateLimiter } = require('./middleware/payment-rate-limiter') + * + * // Mount on the router (uncomment to activate): + * // router.use(createPaymentRateLimiter({ windowMs: 60000, maxRequests: 100 })) + * + * // Or on a specific route: + * // router.post('/create-payment', createPaymentRateLimiter({ windowMs: 60000, maxRequests: 10 }), handler) + */ + +/** + * Create an Express middleware that enforces a sliding-window rate limit. + * + * @param {object} [opts] + * @param {number} [opts.windowMs=60000] - Time window length in milliseconds (default: 1 minute) + * @param {number} [opts.maxRequests=100] - Maximum number of requests allowed per key per window + * @param {Function} [opts.keyFn] - Extract the rate-limit key from the request. + * Defaults to req.ip. Return null/undefined to skip limiting. + * @param {Function} [opts.now] - Injectable clock function returning a timestamp in ms; defaults to Date.now + * @returns {Function} Express middleware (req, res, next) + */ +function createPaymentRateLimiter(opts = {}) { + const { + windowMs = 60_000, + maxRequests = 100, + keyFn = (req) => req.ip, + now = Date.now, + } = opts; + + /** @type {Map} */ + const store = new Map(); + + return function paymentRateLimiterMiddleware(req, res, next) { + const key = keyFn(req); + + // If the key function returns nothing, skip limiting for this request + if (key == null) { + return next(); + } + + const currentTime = now(); + + // Evict all expired entries on every request to bound memory growth + for (const [k, entry] of store) { + if (currentTime - entry.windowStart >= windowMs) { + store.delete(k); + } + } + + const entry = store.get(key); + const windowStart = entry && currentTime - entry.windowStart < windowMs + ? entry.windowStart + : currentTime; + + const count = entry && currentTime - entry.windowStart < windowMs + ? entry.count + : 0; + + const newCount = count + 1; + store.set(key, { count: newCount, windowStart }); + + const remaining = Math.max(0, maxRequests - newCount); + const resetMs = windowStart + windowMs; + const resetSec = Math.ceil(resetMs / 1000); + + res.setHeader("X-RateLimit-Limit", String(maxRequests)); + res.setHeader("X-RateLimit-Remaining", String(remaining)); + res.setHeader("X-RateLimit-Reset", String(resetSec)); + + if (newCount > maxRequests) { + const retryAfterSec = Math.ceil((resetMs - currentTime) / 1000); + res.setHeader("Retry-After", String(Math.max(0, retryAfterSec))); + return res.status(429).json({ + error: "Too many requests", + retryAfter: Math.max(0, retryAfterSec), + }); + } + + return next(); + }; +} + +export { createPaymentRateLimiter }; diff --git a/backend/src/middleware/payment-rate-limiter.test.js b/backend/src/middleware/payment-rate-limiter.test.js new file mode 100644 index 00000000..f4db8867 --- /dev/null +++ b/backend/src/middleware/payment-rate-limiter.test.js @@ -0,0 +1,201 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { createPaymentRateLimiter } from "./payment-rate-limiter.js"; + +/** + * Build a minimal mock Express request. + * + * @param {string} [ip="127.0.0.1"] + * @returns {object} + */ +function makeReq(ip = "127.0.0.1") { + return { ip }; +} + +/** + * Build a minimal mock Express response that captures status, headers, and json. + * + * @returns {{ status: Function, setHeader: Function, json: Function, _status: number|null, _headers: object, _body: any }} + */ +function makeRes() { + const res = { + _status: null, + _headers: {}, + _body: null, + }; + + res.status = vi.fn((code) => { + res._status = code; + return res; + }); + + res.setHeader = vi.fn((name, value) => { + res._headers[name] = value; + }); + + res.json = vi.fn((body) => { + res._body = body; + return res; + }); + + return res; +} + +describe("createPaymentRateLimiter", () => { + let fakeNow; + let now; + let middleware; + let next; + + beforeEach(() => { + fakeNow = Date.now(); + now = vi.fn(() => fakeNow); + next = vi.fn(); + middleware = createPaymentRateLimiter({ windowMs: 60_000, maxRequests: 3, now }); + }); + + // ── Basic pass-through ────────────────────────────────────────────────────── + + it("calls next() for requests within the limit", () => { + middleware(makeReq(), makeRes(), next); + expect(next).toHaveBeenCalledTimes(1); + }); + + it("allows exactly maxRequests requests before blocking", () => { + const ip = "10.0.0.1"; + for (let i = 0; i < 3; i++) { + middleware(makeReq(ip), makeRes(), next); + } + expect(next).toHaveBeenCalledTimes(3); + }); + + // ── 429 on limit exceeded ─────────────────────────────────────────────────── + + it("returns 429 on the request that exceeds maxRequests", () => { + const ip = "10.0.0.2"; + const res = makeRes(); + for (let i = 0; i < 3; i++) { + middleware(makeReq(ip), makeRes(), next); + } + // 4th request should be blocked + middleware(makeReq(ip), res, next); + + expect(res._status).toBe(429); + expect(res._body).toMatchObject({ error: "Too many requests" }); + expect(typeof res._body.retryAfter).toBe("number"); + // next should not be called for the blocked request + expect(next).toHaveBeenCalledTimes(3); + }); + + it("includes retryAfter > 0 in the 429 response body", () => { + const ip = "10.0.0.3"; + for (let i = 0; i < 3; i++) { + middleware(makeReq(ip), makeRes(), next); + } + const res = makeRes(); + middleware(makeReq(ip), res, next); + expect(res._body.retryAfter).toBeGreaterThan(0); + }); + + // ── Headers ───────────────────────────────────────────────────────────────── + + it("sets X-RateLimit-Limit header on every response", () => { + const res = makeRes(); + middleware(makeReq("10.0.0.4"), res, next); + expect(res._headers["X-RateLimit-Limit"]).toBe("3"); + }); + + it("decrements X-RateLimit-Remaining correctly", () => { + const ip = "10.0.0.5"; + const res1 = makeRes(); + middleware(makeReq(ip), res1, next); + expect(res1._headers["X-RateLimit-Remaining"]).toBe("2"); + + const res2 = makeRes(); + middleware(makeReq(ip), res2, next); + expect(res2._headers["X-RateLimit-Remaining"]).toBe("1"); + }); + + it("sets X-RateLimit-Reset header as a Unix timestamp in seconds", () => { + const res = makeRes(); + middleware(makeReq("10.0.0.6"), res, next); + const reset = Number(res._headers["X-RateLimit-Reset"]); + const expectedReset = Math.ceil((fakeNow + 60_000) / 1000); + expect(reset).toBe(expectedReset); + }); + + it("sets Retry-After header on 429 responses", () => { + const ip = "10.0.0.7"; + for (let i = 0; i < 3; i++) { + middleware(makeReq(ip), makeRes(), next); + } + const res = makeRes(); + middleware(makeReq(ip), res, next); + expect(res._headers["Retry-After"]).toBeDefined(); + expect(Number(res._headers["Retry-After"])).toBeGreaterThanOrEqual(0); + }); + + // ── Window reset ──────────────────────────────────────────────────────────── + + it("resets the counter after the window expires", () => { + const ip = "10.0.0.8"; + for (let i = 0; i < 3; i++) { + middleware(makeReq(ip), makeRes(), next); + } + + // Advance the fake clock past the window + fakeNow += 60_001; + now.mockReturnValue(fakeNow); + + const res = makeRes(); + middleware(makeReq(ip), res, next); + // Should be allowed again — counter reset + expect(res._status).toBeNull(); + expect(next).toHaveBeenCalledTimes(4); + }); + + // ── Independent key tracking ──────────────────────────────────────────────── + + it("tracks different IPs independently", () => { + const ipA = "192.168.1.1"; + const ipB = "192.168.1.2"; + + // Exhaust ipA's quota + for (let i = 0; i < 3; i++) { + middleware(makeReq(ipA), makeRes(), next); + } + + // ipB should still be allowed + const res = makeRes(); + middleware(makeReq(ipB), res, next); + expect(res._status).toBeNull(); + expect(next).toHaveBeenCalledTimes(4); // 3 for ipA + 1 for ipB + }); + + // ── keyFn returning null skips limiting ───────────────────────────────────── + + it("skips rate limiting when keyFn returns null", () => { + const noKeyMiddleware = createPaymentRateLimiter({ + windowMs: 60_000, + maxRequests: 1, + keyFn: () => null, + now, + }); + + for (let i = 0; i < 10; i++) { + noKeyMiddleware(makeReq(), makeRes(), next); + } + expect(next).toHaveBeenCalledTimes(10); + }); + + // ── X-RateLimit-Remaining never goes below 0 ─────────────────────────────── + + it("X-RateLimit-Remaining is 0 when limit is exceeded, not negative", () => { + const ip = "10.0.0.9"; + for (let i = 0; i < 3; i++) { + middleware(makeReq(ip), makeRes(), next); + } + const res = makeRes(); + middleware(makeReq(ip), res, next); + expect(Number(res._headers["X-RateLimit-Remaining"])).toBe(0); + }); +});