diff --git a/k8s/exceptionless/templates/api.yaml b/k8s/exceptionless/templates/api.yaml index eb6979280a..37016f2779 100644 --- a/k8s/exceptionless/templates/api.yaml +++ b/k8s/exceptionless/templates/api.yaml @@ -28,6 +28,11 @@ spec: annotations: checksum/config: {{ include (print $.Template.BasePath "/config.yaml") . | sha256sum }} spec: + # SSE connections are long-lived; give the pod enough time to drain before SIGTERM. + # The preStop sleep lets the ALB/ingress controller deregister the pod before traffic stops, + # then the remaining window allows ASP.NET Core to cancel RequestAborted tokens and clean up. + # When push is eventually enabled behind a Gateway API RoutePolicy, revisit this value. + terminationGracePeriodSeconds: 60 topologySpreadConstraints: - maxSkew: 1 topologyKey: kubernetes.io/hostname @@ -39,6 +44,13 @@ spec: - name: {{ template "exceptionless.name" . }}-api image: "{{ .Values.api.image.repository }}:{{ .Values.version }}" imagePullPolicy: {{ .Values.api.image.pullPolicy }} + lifecycle: + preStop: + # Give the ALB ~15s to deregister this pod before SIGTERM fires. + # The total graceful window is terminationGracePeriodSeconds (60s) minus + # this sleep, leaving ~45s for ASP.NET Core to drain active SSE connections. + exec: + command: ["sleep", "15"] livenessProbe: httpGet: path: /health @@ -82,7 +94,10 @@ spec: {{- include "exceptionless.otel-env" . | indent 12 }} - name: RunJobsInProcess value: 'false' - - name: EnableWebSockets + # SSE rollout prerequisite: Azure Application Gateway for Containers Ingress API + # does not support the routeTimeout=0s override required for long-lived SSE streams. + # Keep push disabled here until this route moves to Gateway API + RoutePolicy. + - name: EnablePush value: 'false' {{- if (empty .Values.storage.connectionString) }} volumeMounts: @@ -163,6 +178,8 @@ metadata: alb.networking.azure.io/alb-namespace: {{ .Values.ingress.albNamespace }} alb.networking.azure.io/alb-frontend: {{ template "exceptionless.fullname" . }}-fe cert-manager.io/cluster-issuer: {{ .Values.ingress.clusterIssuer }} + # SSE is not safe to enable behind the current AGC Ingress API path. + # Migrate to Gateway API and attach a RoutePolicy with routeTimeout: 0s before enabling push. spec: ingressClassName: azure-alb-external tls: diff --git a/src/Exceptionless.Core/Bootstrapper.cs b/src/Exceptionless.Core/Bootstrapper.cs index 751d8e0880..b73401f988 100644 --- a/src/Exceptionless.Core/Bootstrapper.cs +++ b/src/Exceptionless.Core/Bootstrapper.cs @@ -225,8 +225,8 @@ public static void LogConfiguration(IServiceProvider serviceProvider, AppOptions if (String.IsNullOrEmpty(appOptions.StorageOptions.Provider)) logger.LogWarning("Distributed storage is NOT enabled on {MachineName}", Environment.MachineName); - if (!appOptions.EnableWebSockets) - logger.LogWarning("Web Sockets is NOT enabled on {MachineName}", Environment.MachineName); + if (!appOptions.EnablePush) + logger.LogWarning("Real-time push (SSE) is NOT enabled on {MachineName}", Environment.MachineName); if (String.IsNullOrEmpty(appOptions.EmailOptions.SmtpHost)) logger.LogWarning("Emails will NOT be sent until the SmtpHost is configured on {MachineName}", Environment.MachineName); diff --git a/src/Exceptionless.Core/Configuration/AppOptions.cs b/src/Exceptionless.Core/Configuration/AppOptions.cs index 818b27951d..05dd0f7c1b 100644 --- a/src/Exceptionless.Core/Configuration/AppOptions.cs +++ b/src/Exceptionless.Core/Configuration/AppOptions.cs @@ -55,7 +55,11 @@ public class AppOptions public bool EnableRepositoryNotifications { get; internal set; } - public bool EnableWebSockets { get; internal set; } + /// + /// Controls whether real-time push (SSE) is enabled. Reads from either 'EnablePush' + /// or legacy 'EnableWebSockets' config key for backward compatibility. + /// + public bool EnablePush { get; internal set; } public string? Version { get; internal set; } @@ -111,7 +115,8 @@ public static AppOptions ReadFromConfiguration(IConfiguration config) options.BulkBatchSize = config.GetValue(nameof(options.BulkBatchSize), 1000); options.EnableRepositoryNotifications = config.GetValue(nameof(options.EnableRepositoryNotifications), true); - options.EnableWebSockets = config.GetValue(nameof(options.EnableWebSockets), true); + // Support both new 'EnablePush' and legacy 'EnableWebSockets' config keys + options.EnablePush = config.GetValue(nameof(options.EnablePush), config.GetValue("EnableWebSockets", true)); try { diff --git a/src/Exceptionless.Core/Utility/AppDiagnostics.cs b/src/Exceptionless.Core/Utility/AppDiagnostics.cs index c081925b6c..9545fe4c74 100644 --- a/src/Exceptionless.Core/Utility/AppDiagnostics.cs +++ b/src/Exceptionless.Core/Utility/AppDiagnostics.cs @@ -125,6 +125,11 @@ public GaugeInfo(Meter meter, string name) internal static readonly Counter SavedViewsSize = Meter.CreateCounter("ex.savedviews.size", description: "Size of user saved views"); internal static readonly Counter SavedViewsViewTypeSize = Meter.CreateCounter("ex.savedviews.viewtype.size", description: "Size of user saved views by view type"); + + internal static readonly Counter PushSseConnectionsOpened = Meter.CreateCounter("ex.push.connections.sse.opened", description: "SSE push connections opened"); + internal static readonly Counter PushSseConnectionsClosed = Meter.CreateCounter("ex.push.connections.sse.closed", description: "SSE push connections closed"); + internal static readonly Counter PushWebSocketConnectionsOpened = Meter.CreateCounter("ex.push.connections.websocket.opened", description: "WebSocket push connections opened"); + internal static readonly Counter PushWebSocketConnectionsClosed = Meter.CreateCounter("ex.push.connections.websocket.closed", description: "WebSocket push connections closed"); } public static class MetricsClientExtensions diff --git a/src/Exceptionless.Core/Utility/IConnectionMapping.cs b/src/Exceptionless.Core/Utility/IConnectionMapping.cs index e8a22dcdb9..7adabf2b62 100644 --- a/src/Exceptionless.Core/Utility/IConnectionMapping.cs +++ b/src/Exceptionless.Core/Utility/IConnectionMapping.cs @@ -104,6 +104,7 @@ public static class ConnectionMappingExtensions { public const string UserIdPrefix = "u-"; public const string GroupPrefix = "g-"; + public const string ConnectionGroupPrefix = "cg-"; public static Task GroupAddAsync(this IConnectionMapping map, string group, string connectionId) { @@ -125,6 +126,21 @@ public static Task GetGroupConnectionCountAsync(this IConnectionMapping map return map.GetConnectionCountAsync(GroupPrefix + group); } + public static Task ConnectionGroupAddAsync(this IConnectionMapping map, string connectionId, string group) + { + return map.AddAsync(ConnectionGroupPrefix + connectionId, group); + } + + public static Task ConnectionGroupRemoveAsync(this IConnectionMapping map, string connectionId, string group) + { + return map.RemoveAsync(ConnectionGroupPrefix + connectionId, group); + } + + public static Task> GetConnectionGroupsAsync(this IConnectionMapping map, string connectionId) + { + return map.GetConnectionsAsync(ConnectionGroupPrefix + connectionId); + } + public static Task UserIdAddAsync(this IConnectionMapping map, string userId, string connectionId) { return map.AddAsync(UserIdPrefix + userId, connectionId); diff --git a/src/Exceptionless.Web/Bootstrapper.cs b/src/Exceptionless.Web/Bootstrapper.cs index f83edd3615..07d368d522 100644 --- a/src/Exceptionless.Web/Bootstrapper.cs +++ b/src/Exceptionless.Web/Bootstrapper.cs @@ -14,6 +14,7 @@ public class Bootstrapper { public static void RegisterServices(IServiceCollection services, AppOptions appOptions, ILoggerFactory loggerFactory) { + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Exceptionless.Web/ClientApp.angular/components/websocket/websocket-service.js b/src/Exceptionless.Web/ClientApp.angular/components/websocket/websocket-service.js index 0cee86b9a6..77fc55f16c 100644 --- a/src/Exceptionless.Web/ClientApp.angular/components/websocket/websocket-service.js +++ b/src/Exceptionless.Web/ClientApp.angular/components/websocket/websocket-service.js @@ -1,6 +1,7 @@ (function () { "use strict"; + // Deprecated: keep the legacy Angular client on WebSocket during the SSE rollout. angular .module("exceptionless.websocket", ["app.config", "exceptionless", "exceptionless.auth"]) .factory("websocketService", function ($ExceptionlessClient, $rootScope, $timeout, authService, BASE_URL) { diff --git a/src/Exceptionless.Web/ClientApp/src/lib/features/auth/api.test.ts b/src/Exceptionless.Web/ClientApp/src/lib/features/auth/api.test.ts index 862ebb85e1..fb2ab82661 100644 --- a/src/Exceptionless.Web/ClientApp/src/lib/features/auth/api.test.ts +++ b/src/Exceptionless.Web/ClientApp/src/lib/features/auth/api.test.ts @@ -3,6 +3,17 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import { logout } from './api.svelte'; +vi.mock('@exceptionless/browser', () => ({ + Exceptionless: { + config: { + setUserIdentity: vi.fn() + }, + submitFeatureUsage: vi.fn(), + submitSessionEnd: vi.fn().mockResolvedValue(undefined), + submitSessionStart: vi.fn().mockResolvedValue(undefined) + } +})); + describe('logout', () => { beforeEach(() => { // Mock localStorage for server-side tests diff --git a/src/Exceptionless.Web/ClientApp/src/lib/features/saved-views/use-saved-views.test.ts b/src/Exceptionless.Web/ClientApp/src/lib/features/saved-views/use-saved-views.test.ts index fd59c89f0d..05806c5699 100644 --- a/src/Exceptionless.Web/ClientApp/src/lib/features/saved-views/use-saved-views.test.ts +++ b/src/Exceptionless.Web/ClientApp/src/lib/features/saved-views/use-saved-views.test.ts @@ -7,6 +7,17 @@ import type { SavedView } from './models'; import { invalidateSavedViewQueries, queryKeys, removeSavedViewFromCaches, SAVED_VIEW_REFRESH_DELAY_MS, syncSavedViewCaches } from './api.svelte'; import { type SavedViewQueryParams, setSortQueryParam, setTimeQueryParam, supportsSortQueryParam, supportsTimeQueryParam } from './use-saved-views.svelte'; +vi.mock('@exceptionless/browser', () => ({ + Exceptionless: { + config: { + setUserIdentity: vi.fn() + }, + submitFeatureUsage: vi.fn(), + submitSessionEnd: vi.fn().mockResolvedValue(undefined), + submitSessionStart: vi.fn().mockResolvedValue(undefined) + } +})); + const TEST_ORG_ID = '507f1f77bcf86cd799439011'; const TEST_USER_ID = '66a1b2c3d4e5f6a7b8c9d0e1'; diff --git a/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.svelte.test.ts b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.svelte.test.ts new file mode 100644 index 0000000000..c8b0958181 --- /dev/null +++ b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.svelte.test.ts @@ -0,0 +1,434 @@ +// @vitest-environment jsdom + +import { render } from '@testing-library/svelte'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { SSE_CLOSED, SSE_CONNECTING, SSE_OPEN, SseClient, type SseClientOptions } from './sse-client.svelte'; +import SseClientTestHarness from './sse-client.test-harness.svelte'; + +// Mock the auth module +const mockAccessToken = vi.hoisted(() => ({ + current: 'test-token-123' as null | string +})); + +vi.mock('../auth/index.svelte', () => ({ + accessToken: mockAccessToken +})); + +// Creates a response whose stream stays open indefinitely (for testing open connections) +function createOpenSseResponse(initialEvents: string[] = []) { + return new Response( + new ReadableStream({ + start(controller) { + for (const event of initialEvents) { + controller.enqueue(new TextEncoder().encode(event)); + } + // intentionally never close + } + }), + { + headers: { 'Content-Type': 'text/event-stream' }, + status: 200 + } + ); +} + +// Helper to create a mock fetch response that streams SSE data +function createSseResponse(events: string[] = [], options: { delay?: number; status?: number } = {}) { + const { delay = 0, status = 200 } = options; + + return new Response( + new ReadableStream({ + async start(controller) { + for (const event of events) { + if (delay > 0) { + await new Promise((resolve) => setTimeout(resolve, delay)); + } + + controller.enqueue(new TextEncoder().encode(event)); + } + + controller.close(); + } + }), + { + headers: { 'Content-Type': 'text/event-stream' }, + status + } + ); +} + +function setDocumentHidden(hidden: boolean) { + Object.defineProperty(document, 'hidden', { + configurable: true, + get: () => hidden + }); + + document.dispatchEvent(new Event('visibilitychange')); +} + +describe('SseClient', () => { + let fetchMock: ReturnType>; + let activeUnmounts: Array<() => void> = []; + + beforeEach(() => { + fetchMock = vi.fn(); + global.fetch = fetchMock as typeof fetch; + mockAccessToken.current = 'test-token-123'; + setDocumentHidden(false); + }); + + afterEach(() => { + for (const unmount of activeUnmounts) { + unmount(); + } + + activeUnmounts = []; + Reflect.deleteProperty(document, 'hidden'); + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + function trackedClient(options?: SseClientOptions): SseClient { + const currentToken = mockAccessToken.current; + mockAccessToken.current = null; + + let client: SseClient | undefined; + const { unmount } = render(SseClientTestHarness, { + props: { + onClient: (value: SseClient) => { + client = value; + }, + options: { + baseUrl: 'http://localhost:5200', + reconnectDelay: () => 50, + ...options + } + } + }); + + mockAccessToken.current = currentToken; + activeUnmounts.push(unmount); + + if (!client) { + throw new Error('Expected test harness to provide an SseClient instance'); + } + + return client; + } + + describe('Connection Lifecycle', () => { + it('should connect successfully and call onOpen', async () => { + const onOpen = vi.fn(); + fetchMock.mockImplementation(() => Promise.resolve(createOpenSseResponse([': keepalive\n\n']))); + + const client = trackedClient(); + client.onOpen = onOpen; + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 50)); + + expect(fetchMock).toHaveBeenCalledWith( + 'http://localhost:5200/api/v2/push', + expect.objectContaining({ + headers: expect.objectContaining({ + Accept: 'text/event-stream', + Authorization: 'Bearer test-token-123' + }) + }) + ); + expect(onOpen).toHaveBeenCalledWith(false); + + client.close(); + }); + + it('should set readyState to CONNECTING then OPEN', async () => { + fetchMock.mockImplementation(() => Promise.resolve(createOpenSseResponse([': keepalive\n\n']))); + + const client = trackedClient(); + client.connect(); + + expect(client.readyState).toBe(SSE_CONNECTING); + + await new Promise((resolve) => setTimeout(resolve, 50)); + expect(client.readyState).toBe(SSE_OPEN); + + client.close(); + }); + + it('should call onConnecting with isReconnect=false on first connection', async () => { + const onConnecting = vi.fn(); + fetchMock.mockImplementation(() => Promise.resolve(createSseResponse([]))); + + const client = trackedClient(); + client.onConnecting = onConnecting; + client.connect(); + + expect(onConnecting).toHaveBeenCalledWith(false); + client.close(); + }); + }); + + describe('Disconnection', () => { + it('should close when close() is called', async () => { + // Create a response that never closes + fetchMock.mockResolvedValue( + new Response( + new ReadableStream({ + start() { + // Never close - simulate long-lived connection + } + }), + { headers: { 'Content-Type': 'text/event-stream' }, status: 200 } + ) + ); + + const client = trackedClient(); + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 50)); + const result = client.close(); + + expect(result).toBe(true); + expect(client.readyState).toBe(SSE_CLOSED); + }); + + it('should return false when closing already closed connection', () => { + const client = trackedClient(); + const result = client.close(); + + expect(result).toBe(false); + }); + + it('should NOT reconnect after manual close', async () => { + // Use a stream that stays open (never closes) so we can test manual close + fetchMock.mockResolvedValue( + new Response( + new ReadableStream({ + start() { + // intentionally never close - stream stays open + } + }), + { headers: { 'Content-Type': 'text/event-stream' }, status: 200 } + ) + ); + + const client = trackedClient(); + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 50)); + client.close(); + + await new Promise((resolve) => setTimeout(resolve, 100)); + expect(client.readyState).toBe(SSE_CLOSED); + // fetch should only be called once (no reconnect) + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('should allow reconnect after internal close', async () => { + fetchMock.mockImplementation(() => Promise.resolve(createOpenSseResponse([': keepalive\n\n']))); + + const client = trackedClient(); + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 50)); + expect(client.close(false)).toBe(true); + + client.connect(); + await new Promise((resolve) => setTimeout(resolve, 50)); + + expect(client.readyState).toBe(SSE_OPEN); + expect(fetchMock).toHaveBeenCalledTimes(2); + }); + }); + + describe('Auth Failure Handling', () => { + it('should NOT reconnect on 401', async () => { + fetchMock.mockImplementation(() => Promise.resolve(new Response(null, { status: 401 }))); + + const onClose = vi.fn(); + const client = trackedClient(); + client.onClose = onClose; + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(onClose).toHaveBeenCalledTimes(1); + expect(client.readyState).toBe(SSE_CLOSED); + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('should NOT reconnect on 403', async () => { + fetchMock.mockImplementation(() => Promise.resolve(new Response(null, { status: 403 }))); + + const client = trackedClient(); + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(client.readyState).toBe(SSE_CLOSED); + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('should retry slowly when push endpoint is unavailable', async () => { + vi.useFakeTimers(); + const onClose = vi.fn(); + fetchMock.mockImplementation(() => Promise.resolve(new Response(null, { status: 404 }))); + + const client = trackedClient(); + client.onClose = onClose; + client.connect(); + + await vi.advanceTimersByTimeAsync(100); + + expect(onClose).toHaveBeenCalledTimes(1); + expect(client.readyState).toBe(SSE_CONNECTING); + expect(fetchMock).toHaveBeenCalledTimes(1); + + await vi.advanceTimersByTimeAsync(59000); + expect(fetchMock).toHaveBeenCalledTimes(1); + + await vi.advanceTimersByTimeAsync(1000); + expect(fetchMock).toHaveBeenCalledTimes(2); + }); + }); + + describe('Reconnection Logic', () => { + it('should reconnect when stream ends normally', async () => { + let callCount = 0; + fetchMock.mockImplementation(() => { + callCount++; + return Promise.resolve(createSseResponse([': keepalive\n\n'])); + }); + + const client = trackedClient({ baseUrl: 'http://localhost:5200', reconnectDelay: () => 10 }); + client.connect(); + + // Wait for initial connection + stream end + reconnect + await new Promise((resolve) => setTimeout(resolve, 200)); + + expect(callCount).toBeGreaterThan(1); + client.close(); + }); + + it('should use custom reconnectDelay', async () => { + const reconnectDelay = vi.fn(() => 50); + fetchMock.mockImplementation(() => Promise.resolve(createSseResponse([]))); + + const client = trackedClient({ baseUrl: 'http://localhost:5200', reconnectDelay }); + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 150)); + + expect(reconnectDelay).toHaveBeenCalled(); + client.close(); + }); + + it('should reconnect when tab becomes visible after being hidden during a pending reconnect', async () => { + vi.useFakeTimers(); + fetchMock.mockImplementation(() => Promise.resolve(createSseResponse([': keepalive\n\n']))); + + const client = trackedClient({ baseUrl: 'http://localhost:5200', reconnectDelay: () => 1000 }); + client.connect(); + + await vi.advanceTimersByTimeAsync(100); + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(client.readyState).toBe(SSE_CONNECTING); + + setDocumentHidden(true); + await Promise.resolve(); + expect(client.readyState).toBe(SSE_CLOSED); + + await vi.advanceTimersByTimeAsync(1000); + expect(fetchMock).toHaveBeenCalledTimes(1); + + setDocumentHidden(false); + await Promise.resolve(); + await Promise.resolve(); + + expect(fetchMock).toHaveBeenCalledTimes(2); + }); + }); + + describe('Message Handling', () => { + it('should parse SSE data messages and call onMessage', async () => { + const onMessage = vi.fn(); + const sseData = 'data: {"type":"StackChanged","message":{"id":"123"}}\n\n'; + fetchMock.mockImplementation(() => Promise.resolve(createOpenSseResponse([sseData]))); + + const client = trackedClient(); + client.onMessage = onMessage; + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(onMessage).toHaveBeenCalledWith( + expect.objectContaining({ + data: '{"type":"StackChanged","message":{"id":"123"}}' + }) + ); + client.close(); + }); + + it('should ignore keep-alive comments', async () => { + const onMessage = vi.fn(); + const sseData = ': keepalive\n\ndata: {"type":"test","message":{}}\n\n'; + fetchMock.mockImplementation(() => Promise.resolve(createOpenSseResponse([sseData]))); + + const client = trackedClient(); + client.onMessage = onMessage; + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Should only get the data message, not the keepalive + expect(onMessage).toHaveBeenCalledTimes(1); + expect(onMessage).toHaveBeenCalledWith( + expect.objectContaining({ + data: '{"type":"test","message":{}}' + }) + ); + client.close(); + }); + + it('should handle multiple messages in one chunk', async () => { + const onMessage = vi.fn(); + const sseData = 'data: {"type":"A","message":{}}\n\ndata: {"type":"B","message":{}}\n\n'; + fetchMock.mockImplementation(() => Promise.resolve(createOpenSseResponse([sseData]))); + + const client = trackedClient(); + client.onMessage = onMessage; + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(onMessage).toHaveBeenCalledTimes(2); + client.close(); + }); + + it('should handle messages split across chunks', async () => { + const onMessage = vi.fn(); + fetchMock.mockImplementation(() => Promise.resolve(createOpenSseResponse(['data: {"type":"Sp', 'lit","message":{}}\n\n']))); + + const client = trackedClient(); + client.onMessage = onMessage; + client.connect(); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(onMessage).toHaveBeenCalledWith( + expect.objectContaining({ + data: '{"type":"Split","message":{}}' + }) + ); + client.close(); + }); + }); + + describe('URL Construction', () => { + it('should construct correct SSE URL with base URL', () => { + const client = trackedClient({ baseUrl: 'http://localhost:5200' }); + expect(client.url).toBe('http://localhost:5200/api/v2/push'); + }); + }); +}); diff --git a/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.svelte.ts b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.svelte.ts new file mode 100644 index 0000000000..45af5f613d --- /dev/null +++ b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.svelte.ts @@ -0,0 +1,326 @@ +import { DocumentVisibility } from '$shared/document-visibility.svelte'; + +import { accessToken } from '../auth/index.svelte'; + +export interface SseClientOptions { + /** + * Base URL for SSE connection (e.g., 'http://localhost:5200') + * If not provided, constructs from window.location + */ + baseUrl?: string; + /** + * Connection timeout in milliseconds + * Default: 10000ms (10 seconds) + */ + connectionTimeout?: number; + /** + * Custom reconnection delay calculator + * Default uses exponential backoff: 1s, 2s, 4s, 8s, 16s, max 30s + * For testing, can return 0 to reconnect immediately + */ + reconnectDelay?: (attempt: number) => number; +} + +// SSE connection state constants (same values as EventSource.*) +export const SSE_CONNECTING = 0; +export const SSE_OPEN = 1; +export const SSE_CLOSED = 2; +const ENDPOINT_UNAVAILABLE_RETRY_DELAY_MS = 60000; + +// EventSource does not support custom Authorization headers, so the app uses fetch + +// ReadableStream to keep bearer tokens out of the query string. +export class SseClient { + public readyState = $state(SSE_CLOSED); + + /** + * Lazy getter for SSE URL. + */ + public get url(): string { + if (this._url === null) { + if (this._options.baseUrl) { + this._url = `${this._options.baseUrl}${this._path}`; + } else { + const { host, protocol } = window.location; + this._url = `${protocol}//${host}${this._path}`; + } + } + + return this._url; + } + + private _options: SseClientOptions; + private _path: string; + private _url: null | string = null; + private abortController: AbortController | null = null; + private accessToken: null | string = null; + private authFailed: boolean = false; + private connectionTimeoutId: null | ReturnType = null; + private forcedClose: boolean = false; + private hasConnectedBefore: boolean = false; + private pausedForVisibility: boolean = false; + private reconnectAttempts: number = 0; + private reconnectTimeoutId: null | ReturnType = null; + + private streamGeneration: number = 0; + + /** + * @param path - SSE endpoint path (default: '/api/v2/push') + * @param options - Optional configuration + */ + constructor(path: string = '/api/v2/push', options: SseClientOptions = {}) { + this._path = path; + this._options = options; + + const visibility = new DocumentVisibility(); + + $effect(() => { + if (this.accessToken !== accessToken.current) { + this.accessToken = accessToken.current; + this.reconnectAttempts = 0; + this.authFailed = false; + this.pausedForVisibility = false; + this.close(false); + } else if (!visibility.visible) { + this.pausedForVisibility = true; + this.close(false); + } else { + this.pausedForVisibility = false; + } + + if ( + this.accessToken && + visibility.visible && + this.readyState === SSE_CLOSED && + this.reconnectTimeoutId === null && + !this.authFailed && + !this.forcedClose + ) { + this.connect(); + } + }); + } + + public close(isManual: boolean = true): boolean { + const hadPendingReconnect = this.reconnectTimeoutId !== null; + const hadPendingConnectionTimeout = this.connectionTimeoutId !== null; + const hadActiveStream = this.abortController !== null; + + clearTimeout(this.reconnectTimeoutId!); + this.reconnectTimeoutId = null; + clearTimeout(this.connectionTimeoutId!); + this.connectionTimeoutId = null; + this.forcedClose = isManual; + + if (this.abortController) { + this.streamGeneration++; + this.abortController.abort(); + this.abortController = null; + } + + this.readyState = SSE_CLOSED; + return hadPendingReconnect || hadPendingConnectionTimeout || hadActiveStream; + } + + public connect() { + const isReconnect: boolean = this.hasConnectedBefore; + const generation = ++this.streamGeneration; + + this.readyState = SSE_CONNECTING; + this.forcedClose = false; + + this.abortController = new AbortController(); + const { signal } = this.abortController; + + this.onConnecting(isReconnect); + + // Connection timeout + clearTimeout(this.connectionTimeoutId!); + const timeout = this._options.connectionTimeout ?? 10000; + this.connectionTimeoutId = setTimeout(() => { + this.connectionTimeoutId = null; + if (this.readyState === SSE_CONNECTING) { + console.warn(`[SseClient] Connection timeout after ${timeout}ms`); + this.abortController?.abort(); + } + }, timeout); + + this.startStream(signal, isReconnect, generation); + } + + public onClose: () => void = () => {}; + public onConnecting: (isReconnect: boolean) => void = () => {}; + public onError: (error: unknown) => void = () => {}; + public onMessage: (ev: MessageEvent) => void = () => {}; + public onOpen: (isReconnect: boolean) => void = () => {}; + + /** + * Calculate reconnection delay using exponential backoff + */ + private getReconnectDelay(attempt: number): number { + if (this._options.reconnectDelay) { + return this._options.reconnectDelay(attempt); + } + + // Default: exponential backoff 1s, 2s, 4s, 8s, 16s, max 30s + return Math.min(1000 * Math.pow(2, attempt - 1), 30000); + } + + private scheduleReconnect(delayOverrideMs?: number, incrementAttempts: boolean = true) { + if (this.reconnectTimeoutId !== null || this.authFailed || this.forcedClose || this.pausedForVisibility || !(this.accessToken ?? accessToken.current)) { + this.readyState = SSE_CLOSED; + return; + } + + if (incrementAttempts) { + this.reconnectAttempts++; + } + + const delay = delayOverrideMs ?? this.getReconnectDelay(this.reconnectAttempts); + + this.readyState = SSE_CONNECTING; + this.onConnecting(true); + this.onClose(); + + clearTimeout(this.reconnectTimeoutId!); + this.reconnectTimeoutId = setTimeout(() => { + this.reconnectTimeoutId = null; + this.connect(); + }, delay); + } + + private async startStream(signal: AbortSignal, isReconnect: boolean, generation: number) { + try { + const token = this.accessToken ?? accessToken.current; + const response = await fetch(this.url, { + headers: { + Accept: 'text/event-stream', + Authorization: `Bearer ${token}` + }, + signal + }); + + clearTimeout(this.connectionTimeoutId!); + this.connectionTimeoutId = null; + + if (!response.ok) { + // Auth failures - don't reconnect + if (response.status === 401 || response.status === 403) { + console.warn('[SseClient] Auth failure, not reconnecting', { status: response.status }); + this.authFailed = true; + this.readyState = SSE_CLOSED; + this.onClose(); + return; + } + + if (response.status === 404) { + console.info('[SseClient] Push endpoint unavailable, retrying later'); + this.scheduleReconnect(ENDPOINT_UNAVAILABLE_RETRY_DELAY_MS, false); + return; + } + + // Rate limited + if (response.status === 429) { + console.warn('[SseClient] Rate limited, will retry'); + this.scheduleReconnect(); + return; + } + + throw new Error(`SSE connection failed: ${response.status}`); + } + + if (!response.body) { + throw new Error('SSE response has no body'); + } + + if (generation !== this.streamGeneration) { + this.readyState = SSE_CLOSED; + return; + } + + this.readyState = SSE_OPEN; + this.reconnectAttempts = 0; + this.hasConnectedBefore = true; + this.onOpen(isReconnect); + + // Read the stream + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + if (generation !== this.streamGeneration) { + this.readyState = SSE_CLOSED; + return; + } + + buffer += decoder.decode(value, { stream: true }); + + // Process complete SSE messages (separated by double newline) + const messages = buffer.split('\n\n'); + buffer = messages.pop() ?? ''; + + for (const message of messages) { + if (!message.trim()) { + continue; + } + + // Parse SSE format + const lines = message.split('\n'); + let data = ''; + + for (const line of lines) { + if (line.startsWith('data: ')) { + data += line.slice(6); + } else if (line.startsWith('data:')) { + data += line.slice(5); + } else if (line.startsWith(':')) { + // Comment (keep-alive), ignore + continue; + } + } + + if (data) { + // Create a MessageEvent-like object for compatibility + const event = new MessageEvent('message', { data }); + this.onMessage(event); + } + } + } + } catch (error: unknown) { + clearTimeout(this.connectionTimeoutId!); + this.connectionTimeoutId = null; + + if (generation !== this.streamGeneration) { + this.readyState = SSE_CLOSED; + return; + } + + if (signal.aborted && (this.forcedClose || this.pausedForVisibility)) { + // Intentional close - don't reconnect + this.readyState = SSE_CLOSED; + this.onClose(); + return; + } + + if (signal.aborted) { + // Timeout or other abort - try reconnect + this.scheduleReconnect(); + return; + } + + console.error('[SseClient] Stream error', error); + this.onError(error); + } + + // Stream ended (server closed connection) - reconnect + if (generation === this.streamGeneration && !this.forcedClose && !this.pausedForVisibility) { + this.scheduleReconnect(); + } + } +} diff --git a/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.test-harness.svelte b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.test-harness.svelte new file mode 100644 index 0000000000..aeb2633c54 --- /dev/null +++ b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/sse-client.test-harness.svelte @@ -0,0 +1,19 @@ + diff --git a/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/web-socket-client.svelte.ts b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/web-socket-client.svelte.ts deleted file mode 100644 index c816c5bcb5..0000000000 --- a/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/web-socket-client.svelte.ts +++ /dev/null @@ -1,218 +0,0 @@ -import { DocumentVisibility } from '$shared/document-visibility.svelte'; - -import { accessToken } from '../auth/index.svelte'; - -export interface WebSocketClientOptions { - /** - * Base URL for WebSocket connection (e.g., 'ws://localhost:1234') - * If not provided, constructs from window.location - */ - baseUrl?: string; - /** - * Connection timeout in milliseconds - * Default: 10000ms (10 seconds) - */ - connectionTimeout?: number; - /** - * Custom reconnection delay calculator - * Default uses exponential backoff: 1s, 2s, 4s, 8s, 16s, max 30s - * For testing, can return 0 to reconnect immediately - */ - reconnectDelay?: (attempt: number) => number; -} - -export class WebSocketClient { - public readyState = $state(WebSocket.CLOSED); - - /** - * Lazy getter for WebSocket URL. - * Constructed on first access. Uses baseUrl from options if provided, otherwise constructs from window.location. - */ - public get url(): string { - if (this._url === null) { - if (this._options.baseUrl) { - this._url = `${this._options.baseUrl}${this._path}`; - } else { - const { host, protocol } = window.location; - const wsProtocol = protocol === 'https:' ? 'wss://' : 'ws://'; - this._url = `${wsProtocol}${host}${this._path}`; - } - } - - return this._url; - } - - private _options: WebSocketClientOptions; - private _path: string; - private _url: null | string = null; - private accessToken: null | string = null; - private connectionTimeoutId: null | ReturnType = null; - private forcedClose: boolean = false; - private hasConnectedBefore: boolean = false; - private reconnectAttempts: number = 0; - private reconnectTimeoutId: null | ReturnType = null; - - private ws: null | WebSocket = null; - - /** - * @param path - WebSocket path (default: '/api/v2/push') - * @param options - Optional configuration - */ - constructor(path: string = '/api/v2/push', options: WebSocketClientOptions = {}) { - this._path = path; - this._options = options; - - const visibility = new DocumentVisibility(); - - $effect(() => { - if (this.accessToken !== accessToken.current) { - this.accessToken = accessToken.current; - this.reconnectAttempts = 0; // Reset backoff on token change - this.close(); - } else if (!visibility.visible) { - this.close(); - } - - // Only auto-connect if we're fully closed and don't have a pending reconnect attempt - // Don't try to connect if we're CONNECTING, OPEN, or CLOSING - if (this.accessToken && visibility.visible && this.readyState === WebSocket.CLOSED && this.reconnectTimeoutId === null) { - this.connect(); - } - }); - } - - public close(): boolean { - clearTimeout(this.reconnectTimeoutId!); - this.reconnectTimeoutId = null; - clearTimeout(this.connectionTimeoutId!); - this.connectionTimeoutId = null; - - if (this.ws) { - this.forcedClose = true; - this.ws.close(); - return true; - } - - return false; - } - - public connect() { - // isReconnect means: have we successfully connected before? - const isReconnect: boolean = this.hasConnectedBefore; - - // Reset state - this.readyState = WebSocket.CONNECTING; - this.forcedClose = false; - - try { - this.ws = new WebSocket(`${this.url}?access_token=${this.accessToken}`); - this.onConnecting(isReconnect); - } catch (error) { - console.error('[WebSocketClient] Failed to create WebSocket', error); - throw error; - } - - // Connection timeout: if we don't connect within configured timeout, force close - clearTimeout(this.connectionTimeoutId!); - const timeout = this._options.connectionTimeout ?? 10000; - this.connectionTimeoutId = setTimeout(() => { - this.connectionTimeoutId = null; - if (this.ws && this.readyState === WebSocket.CONNECTING) { - console.warn(`[WebSocketClient] Connection timeout after ${timeout}ms`); - this.ws.close(); - } - }, timeout); - - this.ws.onopen = (event: Event) => { - clearTimeout(this.connectionTimeoutId!); - this.connectionTimeoutId = null; - this.readyState = WebSocket.OPEN; - this.reconnectAttempts = 0; // Reset backoff on successful connection - this.hasConnectedBefore = true; // Mark that we've connected successfully - this.onOpen(event, isReconnect); - }; - - this.ws.onclose = (event: CloseEvent) => { - clearTimeout(this.connectionTimeoutId!); - this.connectionTimeoutId = null; - this.ws = null; - - if (this.forcedClose) { - this.readyState = WebSocket.CLOSED; - this.onClose(event); - return; - } - - // Don't retry on authentication/authorization failures - // Code 1008 (Policy Violation) is explicit auth failure - // Code 1006 (Abnormal Closure) during handshake could be 401/403 - // Codes 4xxx are custom application codes (e.g., 4401=401, 4403=403) - const isAuthFailure = event.code === 1008 || (event.code === 1006 && event.wasClean === false) || (event.code >= 4400 && event.code < 4500); - if (isAuthFailure) { - console.warn('[WebSocketClient] Auth failure detected, not reconnecting', { - code: event.code, - reason: event.reason - }); - this.readyState = WebSocket.CLOSED; - this.onClose(event); - return; // Let the auth system handle redirect to login - } - - // Calculate reconnection delay with exponential backoff - this.reconnectAttempts++; - const delay = this.getReconnectDelay(this.reconnectAttempts); - - this.onConnecting(true); // Always true when reconnecting after close - this.onClose(event); - - // Schedule reconnect - clear any existing timeout first - clearTimeout(this.reconnectTimeoutId!); - this.reconnectTimeoutId = setTimeout(() => { - this.reconnectTimeoutId = null; - this.connect(); - }, delay); - }; - - this.ws.onmessage = (event) => { - this.onMessage(event); - }; - - this.ws.onerror = (event) => { - console.error('[WebSocketClient] onerror triggered', { - event, - readyState: this.readyState, - reconnectAttempts: this.reconnectAttempts - }); - this.onError(event); - }; - } - - public onClose: (ev: CloseEvent) => void = () => {}; - - public onConnecting: (isReconnect: boolean) => void = () => {}; - public onError: (ev: Event) => void = () => {}; - public onMessage: (ev: MessageEvent) => void = () => {}; - - public onOpen: (ev: Event, isReconnect: boolean) => void = () => {}; - - public send(data: Parameters[0]) { - if (this.ws) { - return this.ws.send(data); - } else { - throw new Error('INVALID_STATE_ERR : Pausing to reconnect websocket'); - } - } - - /** - * Calculate reconnection delay using exponential backoff - * Can be overridden via options for testing - */ - private getReconnectDelay(attempt: number): number { - if (this._options.reconnectDelay) { - return this._options.reconnectDelay(attempt); - } - - // Default: exponential backoff 1s, 2s, 4s, 8s, 16s, max 30s - return Math.min(1000 * Math.pow(2, attempt - 1), 30000); - } -} diff --git a/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/web-socket-client.test.ts b/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/web-socket-client.test.ts deleted file mode 100644 index c7138d1206..0000000000 --- a/src/Exceptionless.Web/ClientApp/src/lib/features/websockets/web-socket-client.test.ts +++ /dev/null @@ -1,494 +0,0 @@ -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import WS from 'vitest-websocket-mock'; - -import { WebSocketClient, type WebSocketClientOptions } from './web-socket-client.svelte'; - -// Mock the auth module -vi.mock('../auth/index.svelte', () => ({ - accessToken: { - current: 'test-token-123' - } -})); - -// Mock DocumentVisibility to always return visible -vi.mock('$shared/document-visibility.svelte', () => { - return { - DocumentVisibility: class { - visible = true; - } - }; -}); - -let server: WS; - -beforeEach(() => { - server = new WS('ws://localhost:1234/api/v2/push'); -}); - -afterEach(() => { - WS.clean(); -}); - -function createClient(path?: string, options?: WebSocketClientOptions): WebSocketClient { - return new WebSocketClient(path, { - baseUrl: 'ws://localhost:1234', - reconnectDelay: () => 0, - ...options - }); -} - -describe('WebSocketClient', () => { - describe('Connection Lifecycle', () => { - it('should connect successfully', async () => { - const client = createClient(); - client.connect(); - await server.connected; - - expect(client.readyState).toBe(WebSocket.OPEN); - client.close(); - }); - - it('should set readyState to CONNECTING then OPEN', async () => { - const client = createClient(); - client.connect(); - - expect(client.readyState).toBe(WebSocket.CONNECTING); - await server.connected; - expect(client.readyState).toBe(WebSocket.OPEN); - - client.close(); - }); - - it('should call onConnecting with isReconnect=false on first connection', async () => { - const onConnecting = vi.fn(); - const client = createClient(); - client.onConnecting = onConnecting; - - client.connect(); - expect(onConnecting).toHaveBeenCalledWith(false); - await server.connected; - - client.close(); - }); - - it('should call onOpen with isReconnect=false on first connection', async () => { - const onOpen = vi.fn(); - const client = createClient(); - client.onOpen = onOpen; - - client.connect(); - await server.connected; - - expect(onOpen).toHaveBeenCalledWith(expect.anything(), false); - client.close(); - }); - - it('should handle multiple connect calls gracefully', async () => { - const client = createClient(); - - client.connect(); - client.connect(); - client.connect(); - - await server.connected; - expect(client.readyState).toBe(WebSocket.OPEN); - - client.close(); - }); - - it('should use custom connectionTimeout option', async () => { - const onConnecting = vi.fn(); - const client = new WebSocketClient('/api/v2/push', { - baseUrl: 'ws://localhost:9999', - connectionTimeout: 75, // Very short timeout - reconnectDelay: () => 1000 // Prevent immediate reconnect - }); - client.onConnecting = onConnecting; - - client.connect(); - expect(client.readyState).toBe(WebSocket.CONNECTING); - - // Wait for custom timeout to expire and close to be triggered - await new Promise((resolve) => setTimeout(resolve, 150)); - - // onConnecting was called with isReconnect=false for initial connect - expect(onConnecting).toHaveBeenCalledWith(false); - }); - }); - - describe('Disconnection', () => { - it('should close WebSocket when close() is called', async () => { - const client = createClient(); - client.connect(); - await server.connected; - - const result = client.close(); - await new Promise((resolve) => setTimeout(resolve, 10)); - - expect(result).toBe(true); - expect(client.readyState).toBe(WebSocket.CLOSED); - }); - - it('should return false when closing already closed connection', () => { - const client = createClient(); - client.close(); - const result = client.close(); - - expect(result).toBe(false); - }); - - it('should call onClose callback', async () => { - const onClose = vi.fn(); - const client = createClient(); - client.onClose = onClose; - - client.connect(); - await server.connected; - - server.close({ code: 1000, reason: 'Test', wasClean: true }); - await new Promise((resolve) => setTimeout(resolve, 10)); - - expect(onClose).toHaveBeenCalledWith( - expect.objectContaining({ - code: 1000, - reason: 'Test', - wasClean: true - }) - ); - }); - - it('should NOT reconnect after manual close', async () => { - const client = createClient(); - client.connect(); - await server.connected; - - client.close(); - await new Promise((resolve) => setTimeout(resolve, 50)); - - expect(client.readyState).toBe(WebSocket.CLOSED); - }); - }); - - describe('Reconnection Logic', () => { - it('should NOT reconnect on policy violation (code 1008) - auth failure', async () => { - const client = createClient(); - const onClose = vi.fn(); - client.onClose = onClose; - - client.connect(); - await server.connected; - - server.close({ code: 1008, reason: 'Policy Violation', wasClean: false }); - await new Promise((resolve) => setTimeout(resolve, 50)); - - expect(onClose).toHaveBeenCalledWith( - expect.objectContaining({ - code: 1008, - reason: 'Policy Violation', - wasClean: false - }) - ); - expect(client.readyState).toBe(WebSocket.CLOSED); - }); - - it('should NOT reconnect on abnormal closure (code 1006, wasClean=false) - connection lost unexpectedly', async () => { - const client = createClient(); - const onClose = vi.fn(); - client.onClose = onClose; - - client.connect(); - await server.connected; - - server.close({ code: 1006, reason: 'Abnormal Closure', wasClean: false }); - await new Promise((resolve) => setTimeout(resolve, 50)); - - expect(onClose).toHaveBeenCalledWith( - expect.objectContaining({ - code: 1006, - reason: 'Abnormal Closure', - wasClean: false - }) - ); - expect(client.readyState).toBe(WebSocket.CLOSED); - }); - - it('should NOT reconnect on unauthorized (code 4401) - 401 HTTP equivalent', async () => { - const client = createClient(); - const onClose = vi.fn(); - client.onClose = onClose; - - client.connect(); - await server.connected; - - server.close({ code: 4401, reason: 'Unauthorized', wasClean: false }); - await new Promise((resolve) => setTimeout(resolve, 50)); - - expect(onClose).toHaveBeenCalledWith( - expect.objectContaining({ - code: 4401, - reason: 'Unauthorized', - wasClean: false - }) - ); - expect(client.readyState).toBe(WebSocket.CLOSED); - }); - - it('should NOT reconnect on forbidden (code 4403) - 403 HTTP equivalent', async () => { - const client = createClient(); - const onClose = vi.fn(); - client.onClose = onClose; - - client.connect(); - await server.connected; - - server.close({ code: 4403, reason: 'Forbidden', wasClean: false }); - await new Promise((resolve) => setTimeout(resolve, 50)); - - expect(onClose).toHaveBeenCalledWith( - expect.objectContaining({ - code: 4403, - reason: 'Forbidden', - wasClean: false - }) - ); - expect(client.readyState).toBe(WebSocket.CLOSED); - }); - - it('should reconnect on normal closure (code 1000) - server initiated graceful close', async () => { - const onConnecting = vi.fn(); - const client = createClient(); - client.onConnecting = onConnecting; - - client.connect(); - await server.connected; - onConnecting.mockClear(); - - server.close({ code: 1000, reason: 'Normal Closure', wasClean: true }); - await new Promise((resolve) => setTimeout(resolve, 10)); - await server.connected; - - expect(onConnecting).toHaveBeenCalledWith(true); - client.close(); - }); - - it('should reconnect on going away (code 1001) - server restart', async () => { - const onConnecting = vi.fn(); - const client = createClient(); - client.onConnecting = onConnecting; - - client.connect(); - await server.connected; - onConnecting.mockClear(); - - server.close({ code: 1001, reason: 'Going Away', wasClean: true }); - await new Promise((resolve) => setTimeout(resolve, 10)); - await server.connected; - - expect(onConnecting).toHaveBeenCalledWith(true); - client.close(); - }); - - it('should call onConnecting with isReconnect=true on reconnection', async () => { - const onConnecting = vi.fn(); - const client = createClient(); - client.onConnecting = onConnecting; - - client.connect(); - await server.connected; - expect(onConnecting).toHaveBeenCalledWith(false); - onConnecting.mockClear(); - - server.close({ code: 1000, reason: 'Test', wasClean: true }); - await new Promise((resolve) => setTimeout(resolve, 10)); - - expect(onConnecting).toHaveBeenCalledWith(true); - await server.connected; - client.close(); - }); - }); - - describe('Message Handling', () => { - it('should send messages when connected', async () => { - const client = createClient(); - client.connect(); - await server.connected; - - client.send('test message'); - await expect(server).toReceiveMessage('test message'); - - client.close(); - }); - - it('should throw error when sending while disconnected', () => { - const client = createClient(); - - expect(() => client.send('test')).toThrow('INVALID_STATE_ERR'); - }); - - it('should call onMessage callback when receiving messages', async () => { - const onMessage = vi.fn(); - const client = createClient(); - client.onMessage = onMessage; - - client.connect(); - await server.connected; - - server.send('test data'); - await new Promise((resolve) => setTimeout(resolve, 10)); - - expect(onMessage).toHaveBeenCalledWith( - expect.objectContaining({ - data: 'test data' - }) - ); - - client.close(); - }); - - it('should receive JSON messages', async () => { - const onMessage = vi.fn(); - const client = createClient(); - client.onMessage = onMessage; - - client.connect(); - await server.connected; - - const message = JSON.stringify({ data: 'hello', type: 'test' }); - server.send(message); - await new Promise((resolve) => setTimeout(resolve, 10)); - - expect(onMessage).toHaveBeenCalledWith( - expect.objectContaining({ - data: message - }) - ); - - client.close(); - }); - }); - - describe('Error Handling', () => { - it('should call onError callback', async () => { - const onError = vi.fn(); - const client = createClient(); - client.onError = onError; - - client.connect(); - await server.connected; - - server.error(); - await new Promise((resolve) => setTimeout(resolve, 10)); - - expect(onError).toHaveBeenCalled(); - client.close(); - }); - }); - - describe('URL Construction', () => { - it('should construct correct WebSocket URL', () => { - const client = createClient('/api/v2/push'); - - expect(client.url).toBe('ws://localhost:1234/api/v2/push'); - }); - - it('should use custom base URL', async () => { - const customClient = new WebSocketClient('/api/v2/push', { - baseUrl: 'ws://custom-host:5000', - reconnectDelay: () => 0 - }); - - const customServer = new WS('ws://custom-host:5000/api/v2/push'); - customClient.connect(); - await customServer.connected; - - expect(customClient.readyState).toBe(WebSocket.OPEN); - - customClient.close(); - customServer.close(); - }); - - it('should handle custom path', async () => { - const client = createClient('/custom/path'); - const customServer = new WS('ws://localhost:1234/custom/path'); - - client.connect(); - await customServer.connected; - - expect(client.readyState).toBe(WebSocket.OPEN); - - client.close(); - customServer.close(); - }); - }); - - describe('Options - getReconnectDelay', () => { - it('should use custom getReconnectDelay from options', async () => { - const getReconnectDelay = vi.fn(() => 100); - const client = new WebSocketClient('/api/v2/push', { - baseUrl: 'ws://localhost:1234', - reconnectDelay: getReconnectDelay - }); - - client.connect(); - await server.connected; - - server.close({ code: 1000, reason: 'Test', wasClean: true }); - await new Promise((resolve) => setTimeout(resolve, 10)); - await server.connected; - - expect(getReconnectDelay).toHaveBeenCalled(); - client.close(); - }); - - it('should use immediate reconnection with getReconnectDelay: () => 0', async () => { - const onConnecting = vi.fn(); - const client = createClient(); - client.onConnecting = onConnecting; - - client.connect(); - await server.connected; - onConnecting.mockClear(); - - const start = Date.now(); - server.close({ code: 1000, reason: 'Test', wasClean: true }); - - // Wait for reconnection attempt - await new Promise((resolve) => setTimeout(resolve, 50)); - - // Verify reconnection happened quickly (within 50ms) - const elapsed = Date.now() - start; - expect(onConnecting).toHaveBeenCalledWith(true); - expect(elapsed).toBeLessThan(100); - - client.close(); - }); - }); - - describe('Edge Cases', () => { - it('should handle rapid connect/disconnect cycles', async () => { - const client = createClient(); - - client.connect(); - client.close(); - client.connect(); - await server.connected; - - expect(client.readyState).toBe(WebSocket.OPEN); - client.close(); - }); - - it('should maintain connection state correctly', async () => { - const client = createClient(); - - expect(client.readyState).toBe(WebSocket.CLOSED); - - client.connect(); - await server.connected; - expect(client.readyState).toBe(WebSocket.OPEN); - - client.close(); - await new Promise((resolve) => setTimeout(resolve, 10)); - expect(client.readyState).toBe(WebSocket.CLOSED); - }); - }); -}); diff --git a/src/Exceptionless.Web/ClientApp/src/routes/(app)/+layout.svelte b/src/Exceptionless.Web/ClientApp/src/routes/(app)/+layout.svelte index dd60f5caf5..98afcf8838 100644 --- a/src/Exceptionless.Web/ClientApp/src/routes/(app)/+layout.svelte +++ b/src/Exceptionless.Web/ClientApp/src/routes/(app)/+layout.svelte @@ -29,8 +29,8 @@ import { getMeQuery, invalidateUserQueries } from '$features/users/api.svelte'; import { getGravatarFromCurrentUser } from '$features/users/gravatar.svelte'; import { invalidateWebhookQueries } from '$features/webhooks/api.svelte'; - import { isEntityChangedType, type WebSocketMessageType } from '$features/websockets/models'; - import { WebSocketClient } from '$features/websockets/web-socket-client.svelte'; + import { type EntityChanged, isEntityChangedType, type UserMembershipChanged, type WebSocketMessageType } from '$features/websockets/models'; + import { SseClient } from '$features/websockets/sse-client.svelte'; import { Telemetry } from '$lib/telemetry'; import { useMiddleware } from '@exceptionless/fetchclient'; import { useQueryClient } from '@tanstack/svelte-query'; @@ -155,11 +155,29 @@ } } - // This event is fired when a user is added or removed from an organization. - // if (data.type === "UserMembershipChanged" && data.message?.organization_id) { - // $rootScope.$emit("OrganizationChanged", data.message); - // $rootScope.$emit("ProjectChanged", data.message); - // } + // When a user is added or removed from an organization, invalidate org/project caches + // so the UI reflects the membership change without a manual reload. + if (data.type === 'UserMembershipChanged') { + const membershipMessage = data.message as UserMembershipChanged; + if (membershipMessage.organization_id) { + const organizationChangedMessage: EntityChanged = { + change_type: membershipMessage.change_type, + data: {}, + id: membershipMessage.organization_id, + organization_id: membershipMessage.organization_id, + type: 'Organization' + }; + const projectChangedMessage: EntityChanged = { + change_type: membershipMessage.change_type, + data: {}, + organization_id: membershipMessage.organization_id, + type: 'Project' + }; + + await invalidateOrganizationQueries(queryClient, organizationChangedMessage); + await invalidateProjectQueries(queryClient, projectChangedMessage); + } + } } // Close Sidebar on page change on mobile @@ -187,7 +205,7 @@ } }); - // WebSocket + keyboard shortcuts — only depends on token, not navigation + // SSE + keyboard shortcuts — only depends on token, not navigation $effect(() => { const currentToken = accessToken.current; @@ -248,15 +266,15 @@ document.addEventListener('keydown', handleKeydown, { capture: true }); - const ws = new WebSocketClient(); - ws.onMessage = onMessage; - ws.onOpen = (_, isReconnect) => { + const sse = new SseClient(); + sse.onMessage = onMessage; + sse.onOpen = (isReconnect) => { if (isReconnect) { queryClient.invalidateQueries(); document.dispatchEvent( new CustomEvent('refresh', { bubbles: true, - detail: 'WebSocket Connected' + detail: 'SSE Connected' }) ); } @@ -264,7 +282,7 @@ return () => { document.removeEventListener('keydown', handleKeydown, { capture: true }); - ws?.close(); + sse?.close(); }; }); diff --git a/src/Exceptionless.Web/Hubs/MessageBusBroker.cs b/src/Exceptionless.Web/Hubs/MessageBusBroker.cs index fe059935db..27689cd05b 100644 --- a/src/Exceptionless.Web/Hubs/MessageBusBroker.cs +++ b/src/Exceptionless.Web/Hubs/MessageBusBroker.cs @@ -13,15 +13,17 @@ public sealed class MessageBusBroker : IStartupAction { private static readonly string TokenTypeName = nameof(Token); private static readonly string UserTypeName = nameof(User); - private readonly WebSocketConnectionManager _connectionManager; + private readonly SseConnectionManager _sseConnectionManager; + private readonly WebSocketConnectionManager _webSocketConnectionManager; private readonly IConnectionMapping _connectionMapping; private readonly IMessageSubscriber _subscriber; private readonly AppOptions _options; private readonly ILogger _logger; - public MessageBusBroker(WebSocketConnectionManager connectionManager, IConnectionMapping connectionMapping, IMessageSubscriber subscriber, AppOptions options, ILogger logger) + public MessageBusBroker(SseConnectionManager sseConnectionManager, WebSocketConnectionManager webSocketConnectionManager, IConnectionMapping connectionMapping, IMessageSubscriber subscriber, AppOptions options, ILogger logger) { - _connectionManager = connectionManager; + _sseConnectionManager = sseConnectionManager; + _webSocketConnectionManager = webSocketConnectionManager; _connectionMapping = connectionMapping; _subscriber = subscriber; _options = options; @@ -30,7 +32,7 @@ public MessageBusBroker(WebSocketConnectionManager connectionManager, IConnectio public async Task RunAsync(CancellationToken shutdownToken = default) { - if (!_options.EnableWebSockets) + if (!_options.EnablePush) return; _logger.LogDebug("Subscribing to message bus notifications"); @@ -56,12 +58,21 @@ private async Task OnUserMembershipChangedAsync(UserMembershipChanged userMember // manage user organization group membership var userConnectionIds = await _connectionMapping.GetUserIdConnectionsAsync(userMembershipChanged.UserId); _logger.LogTrace("Attempting to update user {User} active groups for {UserConnectionCount} connections", userMembershipChanged.UserId, userConnectionIds.Count); + if (userMembershipChanged.ChangeType is ChangeType.Removed && userConnectionIds.Count > 0) + TypedSend(userConnectionIds, userMembershipChanged); + foreach (string connectionId in userConnectionIds) { if (userMembershipChanged.ChangeType is ChangeType.Added) + { await _connectionMapping.GroupAddAsync(userMembershipChanged.OrganizationId, connectionId); + await _connectionMapping.ConnectionGroupAddAsync(connectionId, userMembershipChanged.OrganizationId); + } else if (userMembershipChanged.ChangeType is ChangeType.Removed) + { await _connectionMapping.GroupRemoveAsync(userMembershipChanged.OrganizationId, connectionId); + await _connectionMapping.ConnectionGroupRemoveAsync(connectionId, userMembershipChanged.OrganizationId); + } } await GroupSendAsync(userMembershipChanged.OrganizationId, userMembershipChanged); @@ -91,7 +102,7 @@ internal async Task OnEntityChangedAsync(EntityChanged ec, CancellationToken can var userConnectionIds = await _connectionMapping.GetUserIdConnectionsAsync(entityChanged.Id); _logger.LogTrace("Sending {UserTypeName} message to user: {UserId} (to {UserConnectionCount} connections)", UserTypeName, entityChanged.Id, userConnectionIds.Count); foreach (string connectionId in userConnectionIds) - await TypedSendAsync(connectionId, entityChanged); + TypedSend(connectionId, entityChanged); return; } @@ -106,19 +117,26 @@ internal async Task OnEntityChangedAsync(EntityChanged ec, CancellationToken can { var userConnectionIds = await _connectionMapping.GetUserIdConnectionsAsync(userId); - // Auth token removed = logout. Close sockets immediately without sending; + // Auth token removed = logout. Close connections immediately without sending; // there is no point delivering a message to a connection we are about to tear down. if (isAuthToken && entityChanged.ChangeType is ChangeType.Removed) { - _logger.LogTrace("Auth token removed for user {UserId}; closing {ConnectionCount} WebSocket connection(s)", userId, userConnectionIds.Count); - string? organizationId = entityChanged.OrganizationId; + _logger.LogTrace("Auth token removed for user {UserId}; closing {ConnectionCount} push connection(s)", userId, userConnectionIds.Count); foreach (string connectionId in userConnectionIds) { - if (organizationId is { Length: > 0 }) + var organizationIds = await _connectionMapping.GetConnectionGroupsAsync(connectionId); + if (organizationIds.Count is 0 && entityChanged.OrganizationId is { Length: > 0 } fallbackOrganizationId) + organizationIds = [fallbackOrganizationId]; + + foreach (string organizationId in organizationIds) + { await _connectionMapping.GroupRemoveAsync(organizationId, connectionId); + await _connectionMapping.ConnectionGroupRemoveAsync(connectionId, organizationId); + } await _connectionMapping.UserIdRemoveAsync(userId, connectionId); - await _connectionManager.RemoveWebSocketAsync(connectionId); + await _sseConnectionManager.RemoveConnectionAsync(connectionId); + await _webSocketConnectionManager.RemoveConnectionAsync(connectionId); } return; @@ -126,7 +144,7 @@ internal async Task OnEntityChangedAsync(EntityChanged ec, CancellationToken can _logger.LogTrace("Sending {TokenTypeName} message for user: {UserId} (to {UserConnectionCount} connections)", TokenTypeName, userId, userConnectionIds.Count); foreach (string connectionId in userConnectionIds) - await TypedSendAsync(connectionId, entityChanged); + TypedSend(connectionId, entityChanged); return; } @@ -172,13 +190,15 @@ private Task OnPlanChangedAsync(PlanChanged planChanged, CancellationToken cance private Task OnReleaseNotificationAsync(ReleaseNotification notification, CancellationToken cancellationToken = default) { _logger.LogTrace("Sending release notification message: {Message}", notification.Message); - return TypedBroadcastAsync(notification); + TypedBroadcast(notification); + return Task.CompletedTask; } private Task OnSystemNotificationAsync(SystemNotification notification, CancellationToken cancellationToken = default) { _logger.LogTrace("Sending system notification message: {Message}", notification.Message); - return TypedBroadcastAsync(notification); + TypedBroadcast(notification); + return Task.CompletedTask; } private async Task GroupSendAsync(string group, object value) @@ -190,22 +210,31 @@ private async Task GroupSendAsync(string group, object value) return; } - await TypedSendAsync(connectionIds.ToList(), value); + TypedSend(connectionIds, value); } - public Task TypedSendAsync(string connectionId, object value) + public void TypedSend(string connectionId, object value) { - return _connectionManager.SendMessageAsync(connectionId, new TypedMessage { Type = GetMessageType(value), Message = value }); + var message = new TypedMessage { Type = GetMessageType(value), Message = value }; + bool canDrop = CanDrop(value); + _sseConnectionManager.SendMessage(connectionId, message, canDrop); + _webSocketConnectionManager.SendMessage(connectionId, message); } - public Task TypedSendAsync(IList connectionIds, object value) + public void TypedSend(IEnumerable connectionIds, object value) { - return _connectionManager.SendMessageAsync(connectionIds, new TypedMessage { Type = GetMessageType(value), Message = value }); + var message = new TypedMessage { Type = GetMessageType(value), Message = value }; + bool canDrop = CanDrop(value); + _sseConnectionManager.SendMessage(connectionIds, message, canDrop); + _webSocketConnectionManager.SendMessage(connectionIds, message); } - public Task TypedBroadcastAsync(object value) + public void TypedBroadcast(object value) { - return _connectionManager.SendMessageToAllAsync(new TypedMessage { Type = GetMessageType(value), Message = value }); + var message = new TypedMessage { Type = GetMessageType(value), Message = value }; + bool canDrop = CanDrop(value); + _sseConnectionManager.SendMessageToAll(message, canDrop); + _webSocketConnectionManager.SendMessageToAll(message); } private static string GetMessageType(object value) @@ -215,6 +244,11 @@ private static string GetMessageType(object value) return value.GetType().Name; } + + private static bool CanDrop(object value) + { + return value is not (PlanOverage or ReleaseNotification or SystemNotification); + } } public record TypedMessage diff --git a/src/Exceptionless.Web/Hubs/MessageBusBrokerMiddleware.cs b/src/Exceptionless.Web/Hubs/MessageBusBrokerMiddleware.cs deleted file mode 100644 index 227a5372ae..0000000000 --- a/src/Exceptionless.Web/Hubs/MessageBusBrokerMiddleware.cs +++ /dev/null @@ -1,154 +0,0 @@ -using System.Net.WebSockets; -using System.Text; -using Exceptionless.Core.Extensions; -using Exceptionless.Core.Utility; - -namespace Exceptionless.Web.Hubs; - -public class MessageBusBrokerMiddleware -{ - private readonly ILogger _logger; - private readonly WebSocketConnectionManager _connectionManager; - private readonly IConnectionMapping _connectionMapping; - private readonly RequestDelegate _next; - - public MessageBusBrokerMiddleware(RequestDelegate next, WebSocketConnectionManager connectionManager, IConnectionMapping connectionMapping, ILogger logger) - { - _next = next; - _connectionManager = connectionManager; - _connectionMapping = connectionMapping; - _logger = logger; - } - - public async Task Invoke(HttpContext context) - { - if (!context.WebSockets.IsWebSocketRequest || !context.User.IsAuthenticated()) - { - await _next(context); - return; - } - - using var socket = await context.WebSockets.AcceptWebSocketAsync(); - string connectionId = _connectionManager.AddWebSocket(socket); - await OnConnected(context, socket, connectionId); - bool disconnected = false; - - try - { - await ReceiveAsync(socket, async (result, data) => - { - switch (result.MessageType) - { - case WebSocketMessageType.Text: - _logger.LogTrace("WebSocket got message {ConnectionId}", connectionId); - // ignore incoming messages - return; - case WebSocketMessageType.Close: - await OnDisconnected(context, socket, connectionId); - await _connectionManager.RemoveWebSocketAsync(connectionId); - disconnected = true; - return; - } - }); - } - catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) { } - - // This will be hit when the connection is lost. - if (!disconnected) - { - await OnDisconnected(context, socket, connectionId); - await _connectionManager.RemoveWebSocketAsync(connectionId); - } - } - - private async Task OnConnected(HttpContext context, WebSocket socket, string connectionId) - { - _logger.LogTrace("WebSocket connected {ConnectionId} ({State})", connectionId, socket.State); - - try - { - foreach (string organizationId in context.User.GetOrganizationIds()) - await _connectionMapping.GroupAddAsync(organizationId, connectionId); - - string? userId = context.User.GetUserId(); - if (!String.IsNullOrEmpty(userId)) - await _connectionMapping.UserIdAddAsync(userId, connectionId); - } - catch (Exception ex) - { - _logger.LogError(ex, "OnConnected Error: {Message}", ex.Message); - throw; - } - } - - private async Task OnDisconnected(HttpContext context, WebSocket socket, string connectionId) - { - _logger.LogTrace("WebSocket disconnected {ConnectionId} ({State})", connectionId, socket.State); - - try - { - foreach (string organizationId in context.User.GetOrganizationIds()) - await _connectionMapping.GroupRemoveAsync(organizationId, connectionId); - - string? userId = context.User.GetUserId(); - if (!String.IsNullOrEmpty(userId)) - await _connectionMapping.UserIdRemoveAsync(userId, connectionId); - } - catch (Exception ex) - { - _logger.LogError(ex, "OnDisconnected Error: {Message}", ex.Message); - throw; - } - } - - private async Task ReceiveAsync(WebSocket socket, Func handleMessage) - { - var buffer = new ArraySegment(new byte[1024 * 4]); - var result = await socket.ReceiveAsync(buffer, CancellationToken.None); - LogFrame(result, buffer.Array); - - while (!result.CloseStatus.HasValue) - { - string data; - - using (var ms = new MemoryStream()) - { - do - { - result = await socket.ReceiveAsync(buffer, CancellationToken.None); - LogFrame(result, buffer.Array); - - if (buffer.Array is not null) - await ms.WriteAsync(buffer.Array, buffer.Offset, result.Count); - } while (!result.EndOfMessage); - - ms.Seek(0, SeekOrigin.Begin); - - using (var reader = new StreamReader(ms, Encoding.UTF8)) - data = await reader.ReadToEndAsync(); - } - - await handleMessage(result, data); - } - } - - private void LogFrame(WebSocketReceiveResult frame, byte[]? buffer) - { - if (!_logger.IsEnabled(LogLevel.Debug)) - return; - - if (frame.CloseStatus.HasValue) - { - _logger.LogDebug("Close: {CloseStatus} {CloseStatusDescription}", frame.CloseStatus.Value, frame.CloseStatusDescription); - } - else - { - string? content = "<>"; - if (frame.MessageType == WebSocketMessageType.Text) - content = buffer is not null ? Encoding.UTF8.GetString(buffer, 0, frame.Count) : null; - - _logger.LogDebug("Received Frame {MessageType}: length={FrameCount}, end={FrameEndOfMessage}: {Content}", frame.MessageType, frame.Count, frame.EndOfMessage, content); - } - - } -} diff --git a/src/Exceptionless.Web/Hubs/PushDisconnectCleanup.cs b/src/Exceptionless.Web/Hubs/PushDisconnectCleanup.cs new file mode 100644 index 0000000000..1442e5e942 --- /dev/null +++ b/src/Exceptionless.Web/Hubs/PushDisconnectCleanup.cs @@ -0,0 +1,32 @@ +using System.Security.Claims; +using Exceptionless.Core.Extensions; +using Exceptionless.Core.Models; +using Exceptionless.Core.Utility; +using Microsoft.Extensions.Logging; + +namespace Exceptionless.Web.Hubs; + +internal static class PushDisconnectCleanup +{ + public static async Task> GetOrganizationIdsAsync(ClaimsPrincipal user, string connectionId, IConnectionMapping connectionMapping, Func> getCurrentUserAsync, ILogger logger) + { + var organizationIds = new HashSet(await connectionMapping.GetConnectionGroupsAsync(connectionId).ConfigureAwait(false)); + organizationIds.UnionWith(user.GetOrganizationIds()); + string? userId = user.GetUserId(); + if (String.IsNullOrEmpty(userId)) + return organizationIds; + + try + { + var currentUser = await getCurrentUserAsync().ConfigureAwait(false); + if (currentUser is not null) + organizationIds.UnionWith(currentUser.OrganizationIds); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Falling back to tracked push disconnect cleanup for user {UserId}", userId); + } + + return organizationIds; + } +} diff --git a/src/Exceptionless.Web/Hubs/SseConnection.cs b/src/Exceptionless.Web/Hubs/SseConnection.cs new file mode 100644 index 0000000000..e5398c8c9f --- /dev/null +++ b/src/Exceptionless.Web/Hubs/SseConnection.cs @@ -0,0 +1,304 @@ +using Foundatio.Serializer; + +namespace Exceptionless.Web.Hubs; + +/// +/// Represents a single SSE client connection. Owns a write loop that serializes +/// all sends through a bounded dedup queue, preventing concurrent writes to the +/// underlying HttpResponse stream. +/// +/// Design: delivery is best-effort. Under burst load, oldest unwritten events are +/// dropped. This is intentional — SSE push messages trigger client-side cache +/// invalidation refetches, so a dropped message results in stale cache until the +/// next push or manual refresh, not data loss. +/// +/// Deduplication: messages with the same serialized payload are coalesced — if an +/// identical message is already queued, the newer duplicate is skipped. This reduces +/// redundant client refreshes during burst scenarios (e.g., rapid entity updates). +/// +public sealed class SseConnection : IAsyncDisposable +{ + private static readonly byte[] KeepAliveBytes = ": keepalive\n\n"u8.ToArray(); + private readonly HttpResponse _response; + private readonly ITextSerializer _serializer; + private readonly DedupQueue _queue; + private readonly CancellationTokenSource _cts; + private readonly CancellationToken _connectionAborted; + private readonly Task _writeLoop; + private readonly ILogger _logger; + private long _droppedMessages; + private long _dedupedMessages; + private int _disposeState; + + public string ConnectionId { get; } + public CancellationToken ConnectionAborted => _connectionAborted; + + /// Number of messages dropped due to backpressure (queue full). + public long DroppedMessages => Interlocked.Read(ref _droppedMessages); + + /// Number of messages skipped due to deduplication. + public long DedupedMessages => Interlocked.Read(ref _dedupedMessages); + + public SseConnection(string connectionId, HttpResponse response, ITextSerializer serializer, CancellationToken requestAborted, ILogger logger, int capacity = 64) + { + ConnectionId = connectionId; + _response = response; + _serializer = serializer; + _logger = logger; + _queue = new DedupQueue(capacity); + + _cts = CancellationTokenSource.CreateLinkedTokenSource(requestAborted); + _connectionAborted = _cts.Token; + _writeLoop = Task.Run(() => WriteLoopAsync(_cts.Token)); + } + + /// + /// Enqueue a message to be written. Returns false if the connection is closed. + /// If an identical message (same serialized payload) is already queued, the new + /// one is skipped (deduped) and this returns true. + /// + public bool TryWrite(object message, bool canDrop = true) + { + if (_cts.IsCancellationRequested) + return false; + + string data = _serializer.SerializeToString(message); + var result = _queue.TryEnqueue(new SseEvent { Data = data, DedupeKey = canDrop ? data : null, CanDrop = canDrop }); + + if (result == EnqueueResult.Deduped) + { + Interlocked.Increment(ref _dedupedMessages); + return true; + } + + if (result == EnqueueResult.DroppedQueuedMessage) + Interlocked.Increment(ref _droppedMessages); + + return result != EnqueueResult.Skipped; + } + + /// + /// Send a keep-alive comment to prevent proxy/LB timeouts. + /// Keep-alives bypass dedup (always enqueued). + /// + public bool TryWriteKeepAlive() + { + if (_cts.IsCancellationRequested) + return false; + + return _queue.TryEnqueue(SseEvent.KeepAlive) != EnqueueResult.Skipped; + } + + /// + /// Abort the connection. The write loop will complete and the middleware will clean up. + /// + public void Abort() + { + try { _cts.Cancel(); } + catch (ObjectDisposedException ex) + { + _logger.LogDebug(ex, "SSE cancellation token source was already disposed for {ConnectionId}", ConnectionId); + } + + _queue.Complete(); + } + + public async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref _disposeState, 1) != 0) + return; + Abort(); + using (_queue) + using (_cts) + { + try + { + await _writeLoop.ConfigureAwait(false); + } + catch (OperationCanceledException ex) + { + _logger.LogDebug(ex, "SSE dispose canceled for {ConnectionId}", ConnectionId); + } + } + } + + private async Task WriteLoopAsync(CancellationToken ct) + { + try + { + while (!ct.IsCancellationRequested) + { + var evt = await _queue.DequeueAsync(ct); + if (evt is null) + break; // Queue completed + + var bytes = evt.Value.IsKeepAlive + ? KeepAliveBytes + : System.Text.Encoding.UTF8.GetBytes($"data: {evt.Value.Data}\n\n"); + + await _response.Body.WriteAsync(bytes, ct); + await _response.Body.FlushAsync(ct); + } + } + catch (OperationCanceledException ex) + { + _logger.LogDebug(ex, "SSE write loop canceled for {ConnectionId}", ConnectionId); + } + catch (ObjectDisposedException ex) + { + _logger.LogDebug(ex, "SSE response was disposed for {ConnectionId}", ConnectionId); + } + catch (IOException ex) + { + _logger.LogDebug(ex, "SSE write failed for {ConnectionId}", ConnectionId); + } + finally + { + // Always signal ConnectionAborted so the middleware's Task.Delay unblocks + // and cleanup (IConnectionMapping removal) happens reliably. + _queue.Complete(); + if (!_cts.IsCancellationRequested) + { + try + { + _cts.Cancel(); + } + catch (ObjectDisposedException ex) + { + _logger.LogDebug(ex, "SSE cancellation token source was already disposed for {ConnectionId}", ConnectionId); + } + } + } + } + + internal readonly record struct SseEvent + { + public string? Data { get; init; } + + /// + /// Key used for deduplication. If null, no dedup is applied (e.g., keep-alive). + /// For data messages, this is the serialized payload — identical payloads trigger + /// the same client-side cache invalidation, so coalescing is safe. + /// + public string? DedupeKey { get; init; } + public bool CanDrop { get; init; } + + public bool IsKeepAlive { get; init; } + public static SseEvent KeepAlive => new() { IsKeepAlive = true, CanDrop = true }; + } + + internal enum EnqueueResult + { + Enqueued, + Deduped, + DroppedQueuedMessage, + BackpressureSkipped, + Skipped + } + + /// + /// Bounded FIFO queue with deduplication. Thread-safe for multiple writers and a single reader. + /// When full, drops the oldest item to make room (like BoundedChannelFullMode.DropOldest). + /// If an item with the same DedupeKey is already queued, the new item is skipped. + /// + internal sealed class DedupQueue : IDisposable + { + private readonly object _lock = new(); + private readonly LinkedList _list = new(); + private readonly Dictionary> _index = new(); + private readonly SemaphoreSlim _signal = new(0); + private readonly int _capacity; + private bool _completed; + + public DedupQueue(int capacity) + { + _capacity = capacity; + } + + public EnqueueResult TryEnqueue(SseEvent evt) + { + lock (_lock) + { + if (_completed) + return EnqueueResult.Skipped; + + // Dedup check: if same key is already queued, skip + if (evt.DedupeKey is not null && _index.ContainsKey(evt.DedupeKey)) + return EnqueueResult.Deduped; + + var result = EnqueueResult.Enqueued; + + // Enforce capacity: drop the oldest droppable message first so direct user + // notifications do not get crowded out by stale cache invalidations. + if (_list.Count >= _capacity) + { + if (evt.IsKeepAlive) + return EnqueueResult.BackpressureSkipped; + + var queuedToDrop = FindFirstDroppableNode(); + RemoveNode(queuedToDrop ?? _list.First!); + result = EnqueueResult.DroppedQueuedMessage; + } + + var node = _list.AddLast(evt); + if (evt.DedupeKey is not null) + _index[evt.DedupeKey] = node; + + _signal.Release(); + return result; + } + } + + public async Task DequeueAsync(CancellationToken ct) + { + await _signal.WaitAsync(ct); + + lock (_lock) + { + if (_list.Count == 0) + return null; // Completed + + var node = _list.First!; + RemoveNode(node); + return node.Value; + } + } + + public void Complete() + { + lock (_lock) + { + if (_completed) + return; + _completed = true; + _signal.Release(); // Wake up the reader so it sees null + } + } + + public void Dispose() + { + _signal.Dispose(); + } + + private LinkedListNode? FindFirstDroppableNode() + { + var current = _list.First; + while (current is not null) + { + if (current.Value.CanDrop) + return current; + + current = current.Next; + } + + return null; + } + + private void RemoveNode(LinkedListNode node) + { + _list.Remove(node); + if (node.Value.DedupeKey is not null) + _index.Remove(node.Value.DedupeKey); + } + } +} diff --git a/src/Exceptionless.Web/Hubs/SseConnectionManager.cs b/src/Exceptionless.Web/Hubs/SseConnectionManager.cs new file mode 100644 index 0000000000..df68d3fe67 --- /dev/null +++ b/src/Exceptionless.Web/Hubs/SseConnectionManager.cs @@ -0,0 +1,211 @@ +using System.Collections.Concurrent; +using Exceptionless.Core; +using Foundatio.Serializer; + +namespace Exceptionless.Web.Hubs; + +/// +/// Manages active SSE connections. Replaces WebSocketConnectionManager. +/// Sends keep-alive comments every 15 seconds to prevent proxy/LB disconnects. +/// Proactively prunes dead connections during keep-alive sweeps. +/// +public sealed class SseConnectionManager : IDisposable, IAsyncDisposable +{ + private readonly ConcurrentDictionary _connections = new(); + private readonly ConcurrentDictionary> _pendingDisposals = new(); + private readonly Timer? _timer; + private readonly ITextSerializer _serializer; + private readonly ILogger _logger; + + /// + /// Maximum number of concurrent connections per user to prevent resource exhaustion. + /// This is a soft limit — under concurrent connection bursts, a few extra connections + /// may be admitted briefly. This is acceptable because the alternative (distributed + /// locking) would add latency to every SSE connect without meaningful security benefit. + /// + public int MaxConnectionsPerUser { get; init; } = 10; + + public SseConnectionManager(AppOptions options, ITextSerializer serializer, ILoggerFactory loggerFactory) + { + _serializer = serializer; + _logger = loggerFactory.CreateLogger(); + + if (!options.EnablePush) + return; + + _timer = new Timer(SendKeepAlive, null, TimeSpan.FromSeconds(15), TimeSpan.FromSeconds(15)); + } + + private void SendKeepAlive(object? state) + { + if (_connections.IsEmpty) + return; + + int sent = 0; + int pruned = 0; + + foreach (var (connectionId, connection) in _connections) + { + if (connection.ConnectionAborted.IsCancellationRequested) + { + TryRemove(connectionId); + pruned++; + continue; + } + + if (!connection.TryWriteKeepAlive()) + { + // Write failed — connection is dead, prune it + TryRemove(connectionId); + pruned++; + } + else + { + sent++; + } + } + + if (_logger.IsEnabled(LogLevel.Trace)) + _logger.LogTrace("SSE keep-alive: sent={SentCount}, pruned={PrunedCount}, active={ActiveCount}", sent, pruned, _connections.Count); + } + + public SseConnection? GetConnectionById(string connectionId) + { + return _connections.TryGetValue(connectionId, out var connection) ? connection : null; + } + + public ICollection GetAll() + { + return _connections.Values; + } + + public int ConnectionCount => _connections.Count; + + public SseConnection AddConnection(string connectionId, HttpResponse response, CancellationToken requestAborted) + { + var connection = new SseConnection(connectionId, response, _serializer, requestAborted, _logger); + _connections.TryAdd(connectionId, connection); + AppDiagnostics.PushSseConnectionsOpened.Add(1); + AppDiagnostics.Gauge("push.connections.sse.active", _connections.Count); + return connection; + } + + public async Task RemoveConnectionAsync(string connectionId) + { + if (_connections.TryRemove(connectionId, out var connection)) + { + await DisposeConnectionAsync(connectionId, connection).ConfigureAwait(false); + return; + } + + if (_pendingDisposals.TryGetValue(connectionId, out var pendingDisposal)) + await pendingDisposal.Value.ConfigureAwait(false); + } + + private void TryRemove(string connectionId) + { + if (_connections.TryRemove(connectionId, out var connection)) + _ = ObserveDisposeAsync(connectionId, DisposeConnectionAsync(connectionId, connection)); + } + + public bool SendMessage(string connectionId, object message, bool canDrop = true) + { + if (!_connections.TryGetValue(connectionId, out var connection)) + return false; + + if (connection.ConnectionAborted.IsCancellationRequested) + { + TryRemove(connectionId); + return false; + } + + return connection.TryWrite(message, canDrop); + } + + public void SendMessage(IEnumerable connectionIds, object message, bool canDrop = true) + { + foreach (string connectionId in connectionIds) + SendMessage(connectionId, message, canDrop); + } + + public void SendMessageToAll(object message, bool canDrop = true) + { + foreach (var (connectionId, connection) in _connections) + { + if (connection.ConnectionAborted.IsCancellationRequested) + { + TryRemove(connectionId); + continue; + } + + connection.TryWrite(message, canDrop); + } + } + + public void Dispose() + { + // Synchronous disposal: used by test hosts and non-async disposal paths. + // For production host shutdown, the DI container will prefer DisposeAsync(). + DisposeAsync().AsTask().GetAwaiter().GetResult(); + } + + public async ValueTask DisposeAsync() + { + _timer?.Dispose(); + + var disposeTasks = new List(); + + foreach (var (connectionId, connection) in _connections) + { + if (_connections.TryRemove(connectionId, out var activeConnection)) + disposeTasks.Add(DisposeConnectionAsync(connectionId, activeConnection)); + } + + foreach (var pendingDisposal in _pendingDisposals.Values) + disposeTasks.Add(pendingDisposal.Value); + + if (disposeTasks.Count > 0) + await Task.WhenAll(disposeTasks).ConfigureAwait(false); + } + + private Task DisposeConnectionAsync(string connectionId, SseConnection connection) + { + var pendingDisposal = _pendingDisposals.GetOrAdd(connectionId, _ => new Lazy(() => DisposeConnectionCoreAsync(connectionId, connection))); + return pendingDisposal.Value; + } + + private async Task DisposeConnectionCoreAsync(string connectionId, SseConnection connection) + { + try + { + connection.Abort(); + await connection.DisposeAsync().ConfigureAwait(false); + } + finally + { + _pendingDisposals.TryRemove(connectionId, out _); + AppDiagnostics.PushSseConnectionsClosed.Add(1); + AppDiagnostics.Gauge("push.connections.sse.active", _connections.Count); + } + } + + private async Task ObserveDisposeAsync(string connectionId, Task disposeTask) + { + try + { + await disposeTask.ConfigureAwait(false); + } + catch (OperationCanceledException ex) + { + _logger.LogDebug(ex, "SSE connection cleanup canceled for {ConnectionId}", connectionId); + } + catch (ObjectDisposedException ex) + { + _logger.LogDebug(ex, "SSE connection cleanup raced with disposal for {ConnectionId}", connectionId); + } + catch (InvalidOperationException ex) + { + _logger.LogDebug(ex, "SSE connection cleanup failed for {ConnectionId}", connectionId); + } + } +} diff --git a/src/Exceptionless.Web/Hubs/SseMiddleware.cs b/src/Exceptionless.Web/Hubs/SseMiddleware.cs new file mode 100644 index 0000000000..3c1d767aa0 --- /dev/null +++ b/src/Exceptionless.Web/Hubs/SseMiddleware.cs @@ -0,0 +1,145 @@ +using Exceptionless.Core.Extensions; +using Exceptionless.Core.Repositories; +using Exceptionless.Core.Utility; + +namespace Exceptionless.Web.Hubs; + +/// +/// Handles SSE connections at /api/v2/push. Replaces MessageBusBrokerMiddleware (WebSocket). +/// Accepts authenticated GET requests, sets SSE response headers, registers the connection +/// with IConnectionMapping, and holds the response open until the client disconnects. +/// +public class SseMiddleware +{ + private static readonly PathString _sseEndpoint = new("/api/v2/push"); + private readonly ILogger _logger; + private readonly SseConnectionManager _connectionManager; + private readonly IConnectionMapping _connectionMapping; + private readonly IUserRepository _userRepository; + private readonly RequestDelegate _next; + + public SseMiddleware(RequestDelegate next, SseConnectionManager connectionManager, IConnectionMapping connectionMapping, IUserRepository userRepository, ILogger logger) + { + _next = next; + _connectionManager = connectionManager; + _connectionMapping = connectionMapping; + _userRepository = userRepository; + _logger = logger; + } + + public async Task Invoke(HttpContext context) + { + if (!context.Request.Path.StartsWithSegments(_sseEndpoint, StringComparison.Ordinal) + || !HttpMethods.IsGet(context.Request.Method) + || context.WebSockets.IsWebSocketRequest) + { + await _next(context); + return; + } + + if (!context.User.IsAuthenticated()) + { + context.Response.StatusCode = StatusCodes.Status401Unauthorized; + return; + } + + string? userId = context.User.GetUserId(); + if (String.IsNullOrEmpty(userId)) + { + context.Response.StatusCode = StatusCodes.Status401Unauthorized; + return; + } + + // Enforce per-user connection limit + var existingConnections = await _connectionMapping.GetUserIdConnectionsAsync(userId); + if (existingConnections.Count >= _connectionManager.MaxConnectionsPerUser) + { + _logger.LogWarning("User {UserId} exceeded max SSE connections ({Max})", userId, _connectionManager.MaxConnectionsPerUser); + context.Response.StatusCode = StatusCodes.Status429TooManyRequests; + return; + } + + // Set SSE response headers + context.Response.Headers.ContentType = "text/event-stream"; + context.Response.Headers.CacheControl = "no-cache, no-store"; + context.Response.Headers["X-Accel-Buffering"] = "no"; // nginx + + // Disable response buffering + var bufferingFeature = context.Features.Get(); + bufferingFeature?.DisableBuffering(); + + string connectionId = Guid.NewGuid().ToString("N"); + SseConnection? connection = null; + + try + { + connection = _connectionManager.AddConnection(connectionId, context.Response, context.RequestAborted); + await OnConnected(context, connectionId).ConfigureAwait(false); + + // Send initial connected event + connection.TryWrite(new { type = "Connected", message = new { connection_id = connectionId } }); + + // Hold the response open until the client disconnects or the connection is aborted + await Task.Delay(Timeout.Infinite, connection.ConnectionAborted).ConfigureAwait(false); + } + catch (OperationCanceledException ex) + { + _logger.LogDebug(ex, "SSE request ended for {ConnectionId}", connectionId); + } + finally + { + if (connection is not null) + { + try + { + await OnDisconnected(context, connectionId).ConfigureAwait(false); + } + finally + { + await _connectionManager.RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + } + } + } + + private async Task OnConnected(HttpContext context, string connectionId) + { + _logger.LogTrace("SSE connected {ConnectionId}", connectionId); + foreach (string organizationId in context.User.GetOrganizationIds()) + { + await _connectionMapping.GroupAddAsync(organizationId, connectionId).ConfigureAwait(false); + await _connectionMapping.ConnectionGroupAddAsync(connectionId, organizationId).ConfigureAwait(false); + } + + string? userId = context.User.GetUserId(); + if (!String.IsNullOrEmpty(userId)) + await _connectionMapping.UserIdAddAsync(userId, connectionId).ConfigureAwait(false); + } + + private async Task OnDisconnected(HttpContext context, string connectionId) + { + _logger.LogTrace("SSE disconnected {ConnectionId}", connectionId); + + try + { + foreach (string organizationId in await PushDisconnectCleanup.GetOrganizationIdsAsync(context.User, connectionId, _connectionMapping, () => _userRepository.GetByIdAsync(context.User.GetUserId()!), _logger).ConfigureAwait(false)) + { + await _connectionMapping.GroupRemoveAsync(organizationId, connectionId).ConfigureAwait(false); + await _connectionMapping.ConnectionGroupRemoveAsync(connectionId, organizationId).ConfigureAwait(false); + } + + string? userId = context.User.GetUserId(); + if (!String.IsNullOrEmpty(userId)) + await _connectionMapping.UserIdRemoveAsync(userId, connectionId).ConfigureAwait(false); + } + catch (OperationCanceledException ex) + { + _logger.LogDebug(ex, "SSE disconnect was canceled for {ConnectionId}", connectionId); + } + catch (ObjectDisposedException ex) + { + _logger.LogDebug(ex, "SSE disconnect raced with disposal for {ConnectionId}", connectionId); + } + } + +} diff --git a/src/Exceptionless.Web/Hubs/WebSocketConnectionManager.cs b/src/Exceptionless.Web/Hubs/WebSocketConnectionManager.cs index 662aabff39..08d87d0bb7 100644 --- a/src/Exceptionless.Web/Hubs/WebSocketConnectionManager.cs +++ b/src/Exceptionless.Web/Hubs/WebSocketConnectionManager.cs @@ -1,4 +1,4 @@ -using System.Collections.Concurrent; +using System.Collections.Concurrent; using System.Net.WebSockets; using System.Text; using Exceptionless.Core; @@ -6,188 +6,293 @@ namespace Exceptionless.Web.Hubs; -public class WebSocketConnectionManager : IDisposable +/// +/// Temporary WebSocket compatibility layer for the Angular rollout. Remove once the +/// SSE rollout is complete and the websocket active-connection gauge remains at zero. +/// +public sealed class WebSocketConnectionManager : IDisposable { - private static readonly ArraySegment _keepAliveMessage = new(Encoding.ASCII.GetBytes("{}"), 0, 2); - private readonly ConcurrentDictionary _connections = new(); + private static readonly ArraySegment KeepAliveMessage = new(Encoding.ASCII.GetBytes("{}"), 0, 2); + private readonly ConcurrentDictionary _connections = new(); private readonly Timer? _timer; private readonly ITextSerializer _serializer; private readonly ILogger _logger; + public int MaxConnectionsPerUser { get; init; } = 10; + public int ConnectionCount => _connections.Count; + public WebSocketConnectionManager(AppOptions options, ITextSerializer serializer, ILoggerFactory loggerFactory) { _serializer = serializer; _logger = loggerFactory.CreateLogger(); - if (!options.EnableWebSockets) + if (!options.EnablePush) return; - _timer = new Timer(KeepAlive, null, TimeSpan.FromSeconds(10), TimeSpan.FromSeconds(10)); + _timer = new Timer(SendKeepAlive, null, TimeSpan.FromSeconds(15), TimeSpan.FromSeconds(15)); } - private void KeepAlive(object? state) + private void SendKeepAlive(object? state) { - if (_connections is { IsEmpty: true, Count: 0 }) + if (_connections.IsEmpty) return; - Task.Factory.StartNew(async () => + foreach (var (connectionId, connection) in _connections) { - var sockets = GetAll(); - var openSockets = sockets.Where(s => s.State == WebSocketState.Open).ToArray(); - _logger.LogTrace("Sending web socket keep alive to {OpenSocketsCount} open connections of {SocketCount} total connections", openSockets.Length, sockets.Count); - - foreach (var socket in openSockets) + if (!CanSend(connection.Socket)) { - try - { - await socket.SendAsync(buffer: _keepAliveMessage, - messageType: WebSocketMessageType.Text, - endOfMessage: true, - cancellationToken: CancellationToken.None); - } - catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) - { - // NOTE: This will not remove it from the ConnectionMappings. - await RemoveWebSocketAsync(socket); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error sending keep alive socket message: {Message}", ex.Message); - } + _ = RemoveConnectionAsync(connectionId); + continue; } - }); + + _ = SendKeepAliveAsync(connectionId, connection); + } + } + + public WebSocket? GetConnectionById(string connectionId) + { + return _connections.TryGetValue(connectionId, out var connection) ? connection.Socket : null; } public WebSocket? GetWebSocketById(string connectionId) { - return _connections.TryGetValue(connectionId, out var socket) ? socket : null; + return GetConnectionById(connectionId); } public ICollection GetAll() { - return _connections.Values; + return _connections.Values.Select(static connection => connection.Socket).ToArray(); } public string GetConnectionId(WebSocket socket) { - return _connections.FirstOrDefault(p => p.Value == socket).Key; + return _connections.FirstOrDefault(pair => pair.Value.Socket == socket).Key; } - public string AddWebSocket(WebSocket socket) + public string AddConnection(WebSocket socket) { string connectionId = Guid.NewGuid().ToString("N"); - _connections.TryAdd(connectionId, socket); + _connections.TryAdd(connectionId, new ManagedConnection(socket)); + AppDiagnostics.PushWebSocketConnectionsOpened.Add(1); + AppDiagnostics.Gauge("push.connections.websocket.active", _connections.Count); return connectionId; } - private Task RemoveWebSocketAsync(WebSocket socket) - { - string id = GetConnectionId(socket); - if (String.IsNullOrEmpty(id) || !_connections.TryRemove(id, out var _)) - return Task.CompletedTask; - - return CloseWebSocketAsync(socket); - } - - public Task RemoveWebSocketAsync(string id) + public string AddWebSocket(WebSocket socket) { - if (!_connections.TryRemove(id, out var socket)) - return Task.CompletedTask; - - return CloseWebSocketAsync(socket); + return AddConnection(socket); } - private async Task CloseWebSocketAsync(WebSocket socket) + public async Task RemoveConnectionAsync(string connectionId) { - if (!CanSendWebSocketMessage(socket)) + if (!_connections.TryRemove(connectionId, out var connection)) return; try { - await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by manager", CancellationToken.None); + await connection.CloseAsync(CancellationToken.None).ConfigureAwait(false); } catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) { - // Ignored + _logger.LogDebug(ex, "Websocket {ConnectionId} closed before manager shutdown completed", connectionId); + } + catch (ObjectDisposedException ex) + { + _logger.LogDebug(ex, "Websocket {ConnectionId} was already disposed during shutdown", connectionId); } catch (Exception ex) { - _logger.LogError(ex, "Error closing web socket: {Message}", ex.Message); + _logger.LogDebug(ex, "Error closing websocket {ConnectionId}", connectionId); } + finally + { + AppDiagnostics.PushWebSocketConnectionsClosed.Add(1); + AppDiagnostics.Gauge("push.connections.websocket.active", _connections.Count); + } + } + + public Task RemoveWebSocketAsync(string connectionId) + { + return RemoveConnectionAsync(connectionId); } - private Task SendMessageAsync(WebSocket socket, object message) + public bool SendMessage(string connectionId, object message) { - if (!CanSendWebSocketMessage(socket)) - return Task.CompletedTask; + if (!_connections.TryGetValue(connectionId, out var connection)) + return false; - string serializedMessage = _serializer.SerializeToString(message); - Task.Factory.StartNew(async () => + if (!CanSend(connection.Socket)) { - if (!CanSendWebSocketMessage(socket)) - return; + _ = RemoveConnectionAsync(connectionId); + return false; + } - try - { - await socket.SendAsync(buffer: new ArraySegment(Encoding.ASCII.GetBytes(serializedMessage), 0, serializedMessage.Length), - messageType: WebSocketMessageType.Text, - endOfMessage: true, - cancellationToken: CancellationToken.None); - } - catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) - { - // Ignored - } - catch (Exception ex) - { - _logger.LogError(ex, "Error sending socket message: {Message}", ex.Message); - } - }); + _ = SendMessageAsync(connectionId, connection, message); + return true; + } + public Task SendMessageAsync(string connectionId, object message) + { + SendMessage(connectionId, message); return Task.CompletedTask; } - public Task SendMessageAsync(string connectionId, object message) + public void SendMessage(IEnumerable connectionIds, object message) { - var socket = GetWebSocketById(connectionId); - return socket is not null ? SendMessageAsync(socket, message) : Task.CompletedTask; + foreach (var connectionId in connectionIds) + SendMessage(connectionId, message); } public Task SendMessageAsync(IEnumerable connectionIds, object message) { - return Task.WhenAll(connectionIds.Select(id => - { - var socket = GetWebSocketById(id); - return socket is not null ? SendMessageAsync(socket, message) : Task.CompletedTask; - })); + SendMessage(connectionIds, message); + return Task.CompletedTask; } - public async Task SendMessageToAllAsync(object message, bool throwOnError = true) + public void SendMessageToAll(object message) { - foreach (var socket in GetAll()) + foreach (var (connectionId, connection) in _connections) { - if (!CanSendWebSocketMessage(socket)) - continue; - - try + if (!CanSend(connection.Socket)) { - await SendMessageAsync(socket, message); - } - catch (Exception) - { - if (throwOnError) - throw; + _ = RemoveConnectionAsync(connectionId); + continue; } + + _ = SendMessageAsync(connectionId, connection, message); } } - private bool CanSendWebSocketMessage(WebSocket socket) + public Task SendMessageToAllAsync(object message, bool throwOnError = true) { - return socket.State != WebSocketState.Aborted && socket.State != WebSocketState.Closed && socket.State != WebSocketState.CloseSent; + SendMessageToAll(message); + return Task.CompletedTask; } public void Dispose() { _timer?.Dispose(); } + + private async Task SendKeepAliveAsync(string connectionId, ManagedConnection connection) + { + try + { + if (!await connection.SendAsync(KeepAliveMessage, WebSocketMessageType.Text, CancellationToken.None).ConfigureAwait(false)) + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) + { + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + catch (ObjectDisposedException) + { + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + catch (WebSocketException ex) + { + _logger.LogDebug(ex, "Error sending websocket keepalive for {ConnectionId}", connectionId); + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + catch (InvalidOperationException ex) + { + _logger.LogDebug(ex, "Error sending websocket keepalive for {ConnectionId}", connectionId); + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + } + + private async Task SendMessageAsync(string connectionId, ManagedConnection connection, object message) + { + try + { + string serializedMessage = _serializer.SerializeToString(message); + byte[] bytes = Encoding.UTF8.GetBytes(serializedMessage); + if (!await connection.SendAsync(new ArraySegment(bytes, 0, bytes.Length), WebSocketMessageType.Text, CancellationToken.None).ConfigureAwait(false)) + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) + { + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + catch (ObjectDisposedException) + { + await RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + catch (WebSocketException ex) + { + _logger.LogDebug(ex, "Error sending websocket message for {ConnectionId}", connectionId); + } + catch (InvalidOperationException ex) + { + _logger.LogDebug(ex, "Error sending websocket message for {ConnectionId}", connectionId); + } + catch (OperationCanceledException ex) + { + _logger.LogDebug(ex, "Error sending websocket message for {ConnectionId}", connectionId); + } + catch (NotSupportedException ex) + { + _logger.LogDebug(ex, "Error sending websocket message for {ConnectionId}", connectionId); + } + catch (EncoderFallbackException ex) + { + _logger.LogDebug(ex, "Error sending websocket message for {ConnectionId}", connectionId); + } + } + + private static bool CanSend(WebSocket socket) + { + return socket.State is WebSocketState.Open; + } + + private static bool CanClose(WebSocket socket) + { + return socket.State is WebSocketState.Open or WebSocketState.CloseReceived; + } + + private sealed class ManagedConnection + { + private readonly SemaphoreSlim _sendLock = new(1, 1); + + public ManagedConnection(WebSocket socket) + { + Socket = socket; + } + + public WebSocket Socket { get; } + + public async Task CloseAsync(CancellationToken cancellationToken) + { + await _sendLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + if (!CanClose(Socket)) + return false; + + await Socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by manager", cancellationToken).ConfigureAwait(false); + return true; + } + finally + { + _sendLock.Release(); + } + } + + public async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, CancellationToken cancellationToken) + { + await _sendLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + if (!CanSend(Socket)) + return false; + + await Socket.SendAsync(buffer, messageType, true, cancellationToken).ConfigureAwait(false); + return true; + } + finally + { + _sendLock.Release(); + } + } + } } diff --git a/src/Exceptionless.Web/Hubs/WebSocketPushMiddleware.cs b/src/Exceptionless.Web/Hubs/WebSocketPushMiddleware.cs new file mode 100644 index 0000000000..bd10f3577d --- /dev/null +++ b/src/Exceptionless.Web/Hubs/WebSocketPushMiddleware.cs @@ -0,0 +1,128 @@ +using System.Net.WebSockets; +using Exceptionless.Core.Extensions; +using Exceptionless.Core.Repositories; +using Exceptionless.Core.Utility; + +namespace Exceptionless.Web.Hubs; + +/// +/// Temporary WebSocket endpoint compatibility for the Angular rollout. Keep this in place +/// until all clients are on SSE and websocket active connections stay at zero. +/// +public sealed class WebSocketPushMiddleware +{ + private static readonly PathString PushEndpoint = new("/api/v2/push"); + private readonly ILogger _logger; + private readonly WebSocketConnectionManager _connectionManager; + private readonly IConnectionMapping _connectionMapping; + private readonly IUserRepository _userRepository; + private readonly RequestDelegate _next; + + public WebSocketPushMiddleware(RequestDelegate next, WebSocketConnectionManager connectionManager, IConnectionMapping connectionMapping, IUserRepository userRepository, ILogger logger) + { + _next = next; + _connectionManager = connectionManager; + _connectionMapping = connectionMapping; + _userRepository = userRepository; + _logger = logger; + } + + public async Task Invoke(HttpContext context) + { + if (!context.Request.Path.StartsWithSegments(PushEndpoint, StringComparison.Ordinal) + || !context.WebSockets.IsWebSocketRequest) + { + await _next(context); + return; + } + + if (!context.User.IsAuthenticated()) + { + context.Response.StatusCode = StatusCodes.Status401Unauthorized; + return; + } + + string? userId = context.User.GetUserId(); + if (String.IsNullOrEmpty(userId)) + { + context.Response.StatusCode = StatusCodes.Status401Unauthorized; + return; + } + + var existingConnections = await _connectionMapping.GetUserIdConnectionsAsync(userId); + if (existingConnections.Count >= _connectionManager.MaxConnectionsPerUser) + { + _logger.LogWarning("User {UserId} exceeded max websocket push connections ({Max})", userId, _connectionManager.MaxConnectionsPerUser); + context.Response.StatusCode = StatusCodes.Status429TooManyRequests; + return; + } + + using var socket = await context.WebSockets.AcceptWebSocketAsync(); + string connectionId = _connectionManager.AddConnection(socket); + + try + { + await OnConnected(context, connectionId).ConfigureAwait(false); + await ReceiveUntilCloseAsync(socket, context.RequestAborted).ConfigureAwait(false); + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) { } + catch (OperationCanceledException) { } + finally + { + try + { + await OnDisconnected(context, connectionId).ConfigureAwait(false); + } + finally + { + await _connectionManager.RemoveConnectionAsync(connectionId).ConfigureAwait(false); + } + } + } + + private async Task OnConnected(HttpContext context, string connectionId) + { + _logger.LogTrace("WebSocket push connected {ConnectionId}", connectionId); + + foreach (string organizationId in context.User.GetOrganizationIds()) + { + await _connectionMapping.GroupAddAsync(organizationId, connectionId).ConfigureAwait(false); + await _connectionMapping.ConnectionGroupAddAsync(connectionId, organizationId).ConfigureAwait(false); + } + + string? userId = context.User.GetUserId(); + if (!String.IsNullOrEmpty(userId)) + await _connectionMapping.UserIdAddAsync(userId, connectionId).ConfigureAwait(false); + } + + private async Task OnDisconnected(HttpContext context, string connectionId) + { + _logger.LogTrace("WebSocket push disconnected {ConnectionId}", connectionId); + + foreach (string organizationId in await PushDisconnectCleanup.GetOrganizationIdsAsync(context.User, connectionId, _connectionMapping, () => _userRepository.GetByIdAsync(context.User.GetUserId()!), _logger).ConfigureAwait(false)) + { + await _connectionMapping.GroupRemoveAsync(organizationId, connectionId).ConfigureAwait(false); + await _connectionMapping.ConnectionGroupRemoveAsync(connectionId, organizationId).ConfigureAwait(false); + } + + string? userId = context.User.GetUserId(); + if (!String.IsNullOrEmpty(userId)) + await _connectionMapping.UserIdRemoveAsync(userId, connectionId).ConfigureAwait(false); + } + + private static async Task ReceiveUntilCloseAsync(WebSocket socket, CancellationToken cancellationToken) + { + var buffer = new byte[4096]; + + while (socket.State is WebSocketState.Open) + { + WebSocketReceiveResult result; + do + { + result = await socket.ReceiveAsync(new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); + if (result.MessageType is WebSocketMessageType.Close) + return; + } while (!result.EndOfMessage); + } + } +} diff --git a/src/Exceptionless.Web/Program.cs b/src/Exceptionless.Web/Program.cs index db9dd79749..5d3a9ed460 100644 --- a/src/Exceptionless.Web/Program.cs +++ b/src/Exceptionless.Web/Program.cs @@ -73,6 +73,12 @@ public static IHostBuilder CreateHostBuilder(IConfigurationRoot config, string e var builder = Host.CreateDefaultBuilder() .UseEnvironment(environment) + .ConfigureHostOptions(o => + { + // Align with k8s terminationGracePeriodSeconds (60s) minus preStop sleep (15s). + // Gives ASP.NET Core 45s to drain active SSE connections before the pod is force-killed. + o.ShutdownTimeout = TimeSpan.FromSeconds(45); + }) .ConfigureLogging(b => b.ClearProviders()) // clears .net providers since we are telling serilog to write to providers we only want it to be the otel provider .UseSerilog((ctx, sp, c) => { diff --git a/src/Exceptionless.Web/Startup.cs b/src/Exceptionless.Web/Startup.cs index 4a37df1963..6cedc2ddb0 100644 --- a/src/Exceptionless.Web/Startup.cs +++ b/src/Exceptionless.Web/Startup.cs @@ -305,10 +305,11 @@ ApplicationException applicationException when applicationException.Message.Cont // Reject event posts in organizations over their max event limits. app.UseMiddleware(); - if (options.EnableWebSockets) + if (options.EnablePush) { app.UseWebSockets(); - app.UseMiddleware(); + app.UseMiddleware(); + app.UseMiddleware(); } app.UseEndpoints(endpoints => diff --git a/src/Exceptionless.Web/Utility/Handlers/ThrottlingMiddleware.cs b/src/Exceptionless.Web/Utility/Handlers/ThrottlingMiddleware.cs index ee7722b814..4c1d69ec1a 100644 --- a/src/Exceptionless.Web/Utility/Handlers/ThrottlingMiddleware.cs +++ b/src/Exceptionless.Web/Utility/Handlers/ThrottlingMiddleware.cs @@ -21,7 +21,7 @@ public class ThrottlingMiddleware private static readonly PathString _v1ProjectConfigPath = new("/api/v1/project/config"); private static readonly PathString _v2ProjectConfigPath = new("/api/v2/projects/config"); private static readonly PathString _heartbeatPath = new("/api/v2/events/session/heartbeat"); - private static readonly PathString _webSocketPath = new("/api/v2/push"); + private static readonly PathString _ssePath = new("/api/v2/push"); public ThrottlingMiddleware(RequestDelegate next, ICacheClient cacheClient, ThrottlingOptions options, TimeProvider timeProvider) @@ -111,7 +111,7 @@ private bool IsUnthrottledRoute(HttpContext context) return context.Request.Path.StartsWithSegments(_v2ProjectConfigPath, StringComparison.Ordinal) || context.Request.Path.StartsWithSegments(_heartbeatPath, StringComparison.Ordinal) - || context.Request.Path.StartsWithSegments(_webSocketPath, StringComparison.Ordinal) + || context.Request.Path.StartsWithSegments(_ssePath, StringComparison.Ordinal) || context.Request.Path.StartsWithSegments(_v1ProjectConfigPath, StringComparison.Ordinal); } } diff --git a/tests/Exceptionless.Tests/AppWebHostFactory.cs b/tests/Exceptionless.Tests/AppWebHostFactory.cs index 19aa9d17cf..d2f1d5bf11 100644 --- a/tests/Exceptionless.Tests/AppWebHostFactory.cs +++ b/tests/Exceptionless.Tests/AppWebHostFactory.cs @@ -1,5 +1,9 @@ using System.Collections.Concurrent; using System.Net; +using System.Net.Http.Json; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; using Aspire.Hosting; using Aspire.Hosting.ApplicationModel; using Exceptionless.Insulation.Configuration; @@ -12,9 +16,11 @@ namespace Exceptionless.Tests; public class AppWebHostFactory : WebApplicationFactory, IAsyncLifetime { + private static readonly string[] s_indexPrefixes = ["events", "migrations", "organizations", "projects", "saved-views", "stacks", "tokens", "users", "webhooks"]; + private static readonly string s_runScope = CreateRunScope(); private static int s_counter = -1; private static readonly ConcurrentQueue s_pool = new(); - private static readonly Lazy> s_sharedApplication = new(StartSharedApplicationAsync, LazyThreadSafetyMode.ExecutionAndPublication); + private static readonly Lazy> s_sharedApplication = new(StartSharedApplicationAsync, LazyThreadSafetyMode.ExecutionAndPublication); private bool _sliceReleased; public AppWebHostFactory() @@ -23,7 +29,7 @@ public AppWebHostFactory() instanceId = Interlocked.Increment(ref s_counter); InstanceId = instanceId; - AppScope = instanceId == 0 ? "test" : $"test-{instanceId}"; + AppScope = instanceId == 0 ? s_runScope : $"{s_runScope}-{instanceId}"; } public string AppScope { get; } @@ -32,10 +38,11 @@ public AppWebHostFactory() public async ValueTask InitializeAsync() { - _ = await s_sharedApplication.Value; + var sharedApplication = await s_sharedApplication.Value; + await CleanupElasticsearchSliceAsync(sharedApplication.ElasticsearchUri); } - private static async Task StartSharedApplicationAsync() + private static async Task StartSharedApplicationAsync() { var options = new DistributedApplicationOptions { AssemblyName = typeof(ElasticsearchResource).Assembly.FullName, DisableDashboard = true }; var builder = DistributedApplication.CreateBuilder(options); @@ -53,22 +60,60 @@ private static async Task StartSharedApplicationAsync() var connectionString = await elasticsearch.Resource.GetConnectionStringAsync() ?? throw new InvalidOperationException("Could not resolve Elasticsearch connection string."); - await WaitForElasticsearchAsync(new Uri(connectionString)); + var elasticsearchUri = new Uri(connectionString); + await WaitForElasticsearchAsync(elasticsearchUri); - return app; + return new SharedApplicationContext(app, elasticsearchUri); + } + + private static string CreateRunScope() + { + string workspacePath = GetWorkspaceRoot(); + byte[] hash = SHA256.HashData(Encoding.UTF8.GetBytes(workspacePath)); + return $"test-{Convert.ToHexString(hash)[..8].ToLowerInvariant()}"; + } + + private static string GetWorkspaceRoot() + { + for (var directory = new DirectoryInfo(AppContext.BaseDirectory); directory is not null; directory = directory.Parent) + { + if (File.Exists(Path.Combine(directory.FullName, ".git")) || directory.EnumerateFiles("*.slnx").Any()) + return directory.FullName; + } + + return Directory.GetCurrentDirectory(); } private static async Task WaitForElasticsearchAsync(Uri elasticsearchUri) { - using var client = new HttpClient { Timeout = TimeSpan.FromSeconds(1) }; + using var client = new HttpClient + { + BaseAddress = elasticsearchUri, + Timeout = TimeSpan.FromSeconds(2) + }; var deadline = TimeProvider.System.GetUtcNow() + TimeSpan.FromSeconds(60); while (TimeProvider.System.GetUtcNow() < deadline) { try { - using var response = await client.GetAsync(elasticsearchUri); - if (response.StatusCode == HttpStatusCode.OK) + using var pingRequest = new HttpRequestMessage(HttpMethod.Head, "/"); + using var pingResponse = await client.SendAsync(pingRequest); + if (!pingResponse.IsSuccessStatusCode) + { + await Task.Delay(TimeSpan.FromMilliseconds(250)); + continue; + } + + using var healthResponse = await client.GetAsync("/_cluster/health?wait_for_status=yellow&timeout=1s"); + if (!healthResponse.IsSuccessStatusCode) + { + await Task.Delay(TimeSpan.FromMilliseconds(250)); + continue; + } + + var health = await healthResponse.Content.ReadFromJsonAsync(); + if (health?.Status is "yellow" or "green") return; } catch (HttpRequestException) @@ -84,6 +129,40 @@ private static async Task WaitForElasticsearchAsync(Uri elasticsearchUri) throw new TimeoutException("Timed out waiting for Elasticsearch test container to be ready."); } + private async Task CleanupElasticsearchSliceAsync(Uri elasticsearchUri) + { + await WaitForElasticsearchAsync(elasticsearchUri); + + using var client = new HttpClient + { + BaseAddress = elasticsearchUri, + Timeout = TimeSpan.FromSeconds(10) + }; + + foreach (string pattern in s_indexPrefixes.Select(prefix => Uri.EscapeDataString($"{AppScope}-{prefix}*"))) + { + using var listResponse = await client.GetAsync($"/_cat/indices/{pattern}?h=index&format=json&expand_wildcards=all"); + if (listResponse.StatusCode == HttpStatusCode.NotFound) + continue; + + listResponse.EnsureSuccessStatusCode(); + + string payloadJson = await listResponse.Content.ReadAsStringAsync(); + var payload = JsonSerializer.Deserialize>(payloadJson, new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }) + ?? []; + + foreach (string indexName in payload.Select(record => record.Index).Where(name => !String.IsNullOrEmpty(name)).Distinct()) + { + using var deleteResponse = await client.DeleteAsync($"/{Uri.EscapeDataString(indexName)}?ignore_unavailable=true"); + if (deleteResponse.StatusCode != HttpStatusCode.NotFound) + deleteResponse.EnsureSuccessStatusCode(); + } + } + } + protected override void ConfigureWebHost(IWebHostBuilder builder) { builder.UseSolutionRelativeContentRoot("src/Exceptionless.Web", "*.slnx"); @@ -103,14 +182,31 @@ protected override IHostBuilder CreateHostBuilder() return Web.Program.CreateHostBuilder(config, Environments.Development); } - public override ValueTask DisposeAsync() + public override async ValueTask DisposeAsync() { - if (!_sliceReleased) + try { - s_pool.Enqueue(InstanceId); - _sliceReleased = true; + await base.DisposeAsync(); } + finally + { + if (!_sliceReleased) + { + s_pool.Enqueue(InstanceId); + _sliceReleased = true; + } + } + } + + private sealed record SharedApplicationContext(DistributedApplication Application, Uri ElasticsearchUri); - return base.DisposeAsync(); + private sealed class CatIndexRecord + { + public string Index { get; set; } = String.Empty; + } + + private sealed class ClusterHealthResponse + { + public string Status { get; set; } = String.Empty; } } diff --git a/tests/Exceptionless.Tests/Hubs/FakeHttpResponse.cs b/tests/Exceptionless.Tests/Hubs/FakeHttpResponse.cs new file mode 100644 index 0000000000..3a48a8c642 --- /dev/null +++ b/tests/Exceptionless.Tests/Hubs/FakeHttpResponse.cs @@ -0,0 +1,45 @@ +using System.IO.Pipelines; +using System.Text; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; + +namespace Exceptionless.Tests.Hubs; + +/// +/// A minimal fake HttpResponse for testing SSE connections. +/// Captures written data in a MemoryStream. +/// The WriteAsync extension on HttpResponse writes to Body directly, +/// so the MemoryStream captures all output. +/// +internal sealed class FakeHttpResponse : HttpResponse, IDisposable +{ + private readonly MemoryStream _body = new(); + private readonly HeaderDictionary _headers = new(); + + public override HttpContext HttpContext => null!; + public override int StatusCode { get; set; } + public override IHeaderDictionary Headers => _headers; + public override Stream Body + { + get => _body; + set { } + } + public override long? ContentLength { get; set; } + public override string? ContentType { get; set; } + public override IResponseCookies Cookies => null!; + public override bool HasStarted => true; + + /// + /// Get all data written to this response as a string. + /// + public string WrittenData => Encoding.UTF8.GetString(_body.ToArray()); + + public override void OnCompleted(Func callback, object state) { } + public override void OnStarting(Func callback, object state) { } + public override void Redirect(string location, bool permanent) { } + + public void Dispose() + { + _body.Dispose(); + } +} diff --git a/tests/Exceptionless.Tests/Hubs/PushDisconnectCleanupTests.cs b/tests/Exceptionless.Tests/Hubs/PushDisconnectCleanupTests.cs new file mode 100644 index 0000000000..3bdaf5b1f9 --- /dev/null +++ b/tests/Exceptionless.Tests/Hubs/PushDisconnectCleanupTests.cs @@ -0,0 +1,107 @@ +using System.Security.Claims; +using Exceptionless.Core.Extensions; +using Exceptionless.Core.Models; +using Exceptionless.Core.Utility; +using Exceptionless.Web.Hubs; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Exceptionless.Tests.Hubs; + +public sealed class PushDisconnectCleanupTests +{ + [Fact] + public async Task GetOrganizationIdsAsync_UserAddedToOrganizationAfterConnect_IncludesCurrentMemberships() + { + // Arrange + var user = CreateUser("user1", "org-a"); + ClaimsPrincipal principal = new(user.ToIdentity()); + var connectionMapping = new ConnectionMapping(); + const string connectionId = "push-connection"; + await connectionMapping.ConnectionGroupAddAsync(connectionId, "org-a"); + + var currentUser = CreateUser("user1", "org-a", "org-b"); + + // Act + var organizationIds = await PushDisconnectCleanup.GetOrganizationIdsAsync(principal, connectionId, connectionMapping, () => Task.FromResult(currentUser), NullLogger.Instance); + + // Assert + Assert.Contains("org-a", organizationIds); + Assert.Contains("org-b", organizationIds); + } + + [Fact] + public async Task GetOrganizationIdsAsync_UserAddedToOrganizationAfterConnect_CleansUpAddedOrganizationMapping() + { + // Arrange + var user = CreateUser("user1", "org-a"); + ClaimsPrincipal principal = new(user.ToIdentity()); + var connectionMapping = new ConnectionMapping(); + const string connectionId = "push-connection"; + await connectionMapping.GroupAddAsync("org-a", connectionId); + await connectionMapping.ConnectionGroupAddAsync(connectionId, "org-a"); + await connectionMapping.GroupAddAsync("org-b", connectionId); + await connectionMapping.ConnectionGroupAddAsync(connectionId, "org-b"); + await connectionMapping.UserIdAddAsync(user.Id, connectionId); + + var currentUser = CreateUser("user1", "org-a", "org-b"); + + // Act + foreach (string organizationId in await PushDisconnectCleanup.GetOrganizationIdsAsync(principal, connectionId, connectionMapping, () => Task.FromResult(currentUser), NullLogger.Instance)) + { + await connectionMapping.GroupRemoveAsync(organizationId, connectionId); + await connectionMapping.ConnectionGroupRemoveAsync(connectionId, organizationId); + } + + await connectionMapping.UserIdRemoveAsync(user.Id, connectionId); + + // Assert + Assert.DoesNotContain(connectionId, await connectionMapping.GetGroupConnectionsAsync("org-a")); + Assert.DoesNotContain(connectionId, await connectionMapping.GetGroupConnectionsAsync("org-b")); + Assert.Empty(await connectionMapping.GetConnectionGroupsAsync(connectionId)); + Assert.DoesNotContain(connectionId, await connectionMapping.GetUserIdConnectionsAsync(user.Id)); + } + + [Fact] + public async Task GetOrganizationIdsAsync_WhenRepositoryLookupFails_FallsBackToTrackedConnectionGroups() + { + // Arrange + var user = CreateUser("user1", "org-a"); + ClaimsPrincipal principal = new(user.ToIdentity()); + var connectionMapping = new ConnectionMapping(); + const string connectionId = "push-connection"; + await connectionMapping.GroupAddAsync("org-a", connectionId); + await connectionMapping.ConnectionGroupAddAsync(connectionId, "org-a"); + await connectionMapping.GroupAddAsync("org-b", connectionId); + await connectionMapping.ConnectionGroupAddAsync(connectionId, "org-b"); + await connectionMapping.UserIdAddAsync(user.Id, connectionId); + + // Act + foreach (string organizationId in await PushDisconnectCleanup.GetOrganizationIdsAsync(principal, connectionId, connectionMapping, () => throw new InvalidOperationException("boom"), NullLogger.Instance)) + { + await connectionMapping.GroupRemoveAsync(organizationId, connectionId); + await connectionMapping.ConnectionGroupRemoveAsync(connectionId, organizationId); + } + + await connectionMapping.UserIdRemoveAsync(user.Id, connectionId); + + // Assert + Assert.DoesNotContain(connectionId, await connectionMapping.GetGroupConnectionsAsync("org-a")); + Assert.DoesNotContain(connectionId, await connectionMapping.GetGroupConnectionsAsync("org-b")); + Assert.Empty(await connectionMapping.GetConnectionGroupsAsync(connectionId)); + Assert.DoesNotContain(connectionId, await connectionMapping.GetUserIdConnectionsAsync(user.Id)); + } + + private static User CreateUser(string userId, params string[] organizationIds) + { + var user = new User { + Id = userId, + EmailAddress = $"{userId}@example.com" + }; + + foreach (string organizationId in organizationIds) + user.OrganizationIds.Add(organizationId); + + return user; + } +} diff --git a/tests/Exceptionless.Tests/Hubs/SseIntegrationTests.cs b/tests/Exceptionless.Tests/Hubs/SseIntegrationTests.cs new file mode 100644 index 0000000000..55f73f9bdc --- /dev/null +++ b/tests/Exceptionless.Tests/Hubs/SseIntegrationTests.cs @@ -0,0 +1,214 @@ +using System.Net; +using System.Text; +using Exceptionless.Core.Messaging.Models; +using Exceptionless.Core.Models; +using Exceptionless.Core.Utility; +using Exceptionless.Tests.Extensions; +using Exceptionless.Tests.Utility; +using Exceptionless.Web.Hubs; +using Exceptionless.Web.Models; +using Foundatio.Messaging; +using Foundatio.Repositories.Models; +using Xunit; + +namespace Exceptionless.Tests.Hubs; + +/// +/// Integration tests for the SSE endpoint (/api/v2/push). +/// These test the full HTTP pipeline including auth, middleware, and message delivery. +/// +public sealed class SseIntegrationTests : IntegrationTestsBase +{ + private readonly IMessagePublisher _messagePublisher; + + public SseIntegrationTests(ITestOutputHelper output, AppWebHostFactory factory) + : base(output, factory) + { + _messagePublisher = GetService(); + } + + protected override async Task ResetDataAsync() + { + await base.ResetDataAsync(); + await GetService().CreateDataAsync(); + } + + [Fact] + public async Task ConnectWithValidToken_ReturnsEventStream() + { + var token = await CreateTokenAsync(); + + using var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/api/v2/push"); + request.Headers.Add("Accept", "text/event-stream"); + request.Headers.Add("Authorization", $"Bearer {token}"); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + using var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cts.Token); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("text/event-stream", response.Content.Headers.ContentType?.MediaType); + } + + [Fact] + public async Task ConnectWithoutAuth_Returns401() + { + using var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/api/v2/push"); + request.Headers.Add("Accept", "text/event-stream"); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + using var response = await client.SendAsync(request, cts.Token); + + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + } + + [Fact] + public async Task ConnectWithInvalidToken_Returns401() + { + using var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/api/v2/push"); + request.Headers.Add("Accept", "text/event-stream"); + request.Headers.Add("Authorization", "Bearer invalid-token-xyz"); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + using var response = await client.SendAsync(request, cts.Token); + + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + } + + [Fact] + public async Task ConnectWithAccessTokenQueryParam_Succeeds() + { + var token = await CreateTokenAsync(); + + using var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, $"/api/v2/push?access_token={token}"); + request.Headers.Add("Accept", "text/event-stream"); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + using var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cts.Token); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task ConnectedClient_ReceivesEntityChangedMessage() + { + var token = await CreateTokenAsync(); + var orgId = SampleDataService.TEST_ORG_ID; + + using var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/api/v2/push"); + request.Headers.Add("Accept", "text/event-stream"); + request.Headers.Add("Authorization", $"Bearer {token}"); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(10)); + using var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cts.Token); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var stream = await response.Content.ReadAsStreamAsync(cts.Token); + using var reader = new StreamReader(stream, Encoding.UTF8); + + // Read the initial "Connected" event + string? connectedEvent = await ReadSseEventAsync(reader, cts.Token); + Assert.NotNull(connectedEvent); + Assert.Contains("Connected", connectedEvent); + + // Publish an EntityChanged message to the organization + var entityChanged = new EntityChanged + { + Id = "stack-123", + Type = "Stack", + ChangeType = ChangeType.Saved + }; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.OrganizationId] = orgId; +#pragma warning disable xUnit1051 + await _messagePublisher.PublishAsync(entityChanged); +#pragma warning restore xUnit1051 + + // Wait for and read the message + string? receivedEvent = await ReadSseEventAsync(reader, cts.Token); + Assert.NotNull(receivedEvent); + Assert.Contains("StackChanged", receivedEvent); + Assert.Contains("stack-123", receivedEvent); + } + + [Fact] + public async Task SseEndpoint_IsExemptFromThrottling() + { + var token = await CreateTokenAsync(); + + using var client = _server.CreateClient(); + + // Make multiple SSE connection attempts - should not be throttled + for (int i = 0; i < 5; i++) + { + using var request = new HttpRequestMessage(HttpMethod.Get, "/api/v2/push"); + request.Headers.Add("Accept", "text/event-stream"); + request.Headers.Add("Authorization", $"Bearer {token}"); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(3)); + using var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cts.Token); + + // Should never be 429 + Assert.NotEqual(HttpStatusCode.TooManyRequests, response.StatusCode); + } + } + + /// + /// Read a single SSE event (terminated by double newline) from the stream. + /// + private static async Task ReadSseEventAsync(StreamReader reader, CancellationToken ct) + { + var sb = new StringBuilder(); + int emptyLineCount = 0; + + while (!ct.IsCancellationRequested) + { + string? line = await reader.ReadLineAsync(ct); + if (line is null) + return sb.Length > 0 ? sb.ToString() : null; + + if (line.Length == 0) + { + emptyLineCount++; + if (emptyLineCount >= 1 && sb.Length > 0) + return sb.ToString(); + continue; + } + + // Skip comments (keep-alive) + if (line.StartsWith(':')) + continue; + + emptyLineCount = 0; + sb.AppendLine(line); + } + + return sb.Length > 0 ? sb.ToString() : null; + } + + private async Task CreateTokenAsync() + { + var result = await SendRequestAsAsync(r => r + .Post() + .AppendPath("auth/login") + .Content(new Login + { + Email = SampleDataService.TEST_USER_EMAIL, + Password = SampleDataService.TEST_USER_PASSWORD + }) + .StatusCodeShouldBeOk() + ); + + return result?.Token ?? throw new InvalidOperationException("Login did not return a token."); + } +} diff --git a/tests/Exceptionless.Tests/Hubs/SseTests.cs b/tests/Exceptionless.Tests/Hubs/SseTests.cs new file mode 100644 index 0000000000..f20f74b455 --- /dev/null +++ b/tests/Exceptionless.Tests/Hubs/SseTests.cs @@ -0,0 +1,683 @@ +using Exceptionless.Core; +using Exceptionless.Core.Messaging.Models; +using Exceptionless.Core.Models; +using Exceptionless.Core.Utility; +using Exceptionless.Web.Hubs; +using Foundatio.Repositories.Models; +using Foundatio.Serializer; +using System.Reflection; +using Xunit; + +namespace Exceptionless.Tests.Hubs; + +public sealed class SseConnectionManagerTests : TestWithServices +{ + public SseConnectionManagerTests(ITestOutputHelper output) : base(output) { } + + [Fact] + public void AddConnection_NewConnection_CanLookupAndEnumerate() + { + using var manager = CreateManager(); + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + + string connectionId = "test-conn-1"; + var connection = manager.AddConnection(connectionId, response, cts.Token); + + Assert.NotNull(connection); + Assert.Same(connection, manager.GetConnectionById(connectionId)); + Assert.Equal(1, manager.ConnectionCount); + Assert.Contains(connection, manager.GetAll()); + } + + [Fact] + public async Task RemoveConnectionAsync_ExistingConnection_RemovesAndAborts() + { + using var manager = CreateManager(); + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + + string connectionId = "test-conn-2"; + var connection = manager.AddConnection(connectionId, response, cts.Token); + + await manager.RemoveConnectionAsync(connectionId); + + Assert.Null(manager.GetConnectionById(connectionId)); + Assert.Equal(0, manager.ConnectionCount); + Assert.True(connection.ConnectionAborted.IsCancellationRequested); + } + + [Fact] + public async Task RemoveConnectionAsync_UnknownConnection_DoesNothing() + { + using var manager = CreateManager(); + + await manager.RemoveConnectionAsync("nonexistent"); + + Assert.Equal(0, manager.ConnectionCount); + } + + [Fact] + public void SendMessage_ValidConnection_EnqueuesMessage() + { + using var manager = CreateManager(); + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + + string connectionId = "test-conn-3"; + manager.AddConnection(connectionId, response, cts.Token); + + bool sent = manager.SendMessage(connectionId, new { type = "test", message = "hello" }); + + Assert.True(sent); + } + + [Fact] + public void SendMessage_UnknownConnection_ReturnsFalse() + { + using var manager = CreateManager(); + + bool sent = manager.SendMessage("missing", new { type = "test" }); + + Assert.False(sent); + } + + [Fact] + public async Task SendMessage_AbortedConnection_ReturnsFalseAndRemoves() + { + using var manager = CreateManager(); + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + + string connectionId = "test-conn-4"; + manager.AddConnection(connectionId, response, cts.Token); + + await cts.CancelAsync(); + + bool sent = manager.SendMessage(connectionId, new { type = "test" }); + + Assert.False(sent); + // Connection should be cleaned up + Assert.Null(manager.GetConnectionById(connectionId)); + } + + [Fact] + public void SendMessageToAll_MultipleConnections_SendsToAll() + { + using var manager = CreateManager(); + using var response1 = new FakeHttpResponse(); + using var response2 = new FakeHttpResponse(); + using var cts1 = new CancellationTokenSource(); + using var cts2 = new CancellationTokenSource(); + + manager.AddConnection("conn-1", response1, cts1.Token); + manager.AddConnection("conn-2", response2, cts2.Token); + + manager.SendMessageToAll(new { type = "broadcast" }); + + // Both connections should have received the message (enqueued) + Assert.Equal(2, manager.ConnectionCount); + } + + private SseConnectionManager CreateManager() + { + var options = new AppOptions { EnablePush = true }; + return new SseConnectionManager(options, GetService(), Log); + } +} + +/// +/// Tests for MessageBusBroker using SSE connections. +/// +public sealed class SseBrokerTests : TestWithServices +{ + private readonly MessageBusBroker _broker; + private readonly IConnectionMapping _connectionMapping; + private readonly SseConnectionManager _connectionManager; + + public SseBrokerTests(ITestOutputHelper output) : base(output) + { + _broker = GetService(); + _connectionMapping = GetService(); + _connectionManager = GetService(); + } + + [Fact] + public async Task OnEntityChangedAsync_AuthTokenRemoved_ClosesConnectionsAndClearsMapping() + { + const string userId = "test-user-id"; + const string organizationId = "test-org-id"; + using var response1 = new FakeHttpResponse(); + using var response2 = new FakeHttpResponse(); + using var unrelatedResponse = new FakeHttpResponse(); + using var cts1 = new CancellationTokenSource(); + using var cts2 = new CancellationTokenSource(); + using var ctsu = new CancellationTokenSource(); + + string connId1 = "conn-auth-1"; + string connId2 = "conn-auth-2"; + string unrelatedConnId = "conn-unrelated"; + + _connectionManager.AddConnection(connId1, response1, cts1.Token); + _connectionManager.AddConnection(connId2, response2, cts2.Token); + _connectionManager.AddConnection(unrelatedConnId, unrelatedResponse, ctsu.Token); + + try + { + await _connectionMapping.UserIdAddAsync(userId, connId1); + await _connectionMapping.UserIdAddAsync(userId, connId2); + await _connectionMapping.GroupAddAsync(organizationId, connId1); + await _connectionMapping.GroupAddAsync(organizationId, connId2); + await _connectionMapping.GroupAddAsync(organizationId, unrelatedConnId); + + var entityChanged = new EntityChanged + { + Id = "test-token-id", + Type = nameof(Token), + ChangeType = ChangeType.Removed + }; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.OrganizationId] = organizationId; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.UserId] = userId; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.IsAuthenticationToken] = true; + + await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); + + // Connections should be removed + Assert.Null(_connectionManager.GetConnectionById(connId1)); + Assert.Null(_connectionManager.GetConnectionById(connId2)); + Assert.NotNull(_connectionManager.GetConnectionById(unrelatedConnId)); + + // User mapping cleared + var remaining = await _connectionMapping.GetUserIdConnectionsAsync(userId); + Assert.Empty(remaining); + + // Org mapping only has unrelated connection + var orgConnections = await _connectionMapping.GetGroupConnectionsAsync(organizationId); + Assert.DoesNotContain(connId1, orgConnections); + Assert.DoesNotContain(connId2, orgConnections); + Assert.Contains(unrelatedConnId, orgConnections); + } + finally + { + await _connectionMapping.GroupRemoveAsync(organizationId, unrelatedConnId); + await _connectionManager.RemoveConnectionAsync(unrelatedConnId); + } + } + + [Fact] + public async Task OnEntityChangedAsync_NonAuthTokenRemoved_DoesNotCloseConnections() + { + const string userId = "test-user-id-2"; + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + + string connectionId = "conn-nonauth"; + _connectionManager.AddConnection(connectionId, response, cts.Token); + + try + { + await _connectionMapping.UserIdAddAsync(userId, connectionId); + + var entityChanged = new EntityChanged + { + Id = "test-api-token-id", + Type = nameof(Token), + ChangeType = ChangeType.Removed + }; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.UserId] = userId; + // IsAuthenticationToken intentionally omitted (defaults false) + + await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); + + // Connection should NOT be closed + Assert.NotNull(_connectionManager.GetConnectionById(connectionId)); + } + finally + { + await _connectionMapping.UserIdRemoveAsync(userId, connectionId); + await _connectionManager.RemoveConnectionAsync(connectionId); + } + } + + [Fact] + public async Task OnEntityChangedAsync_OrganizationMessage_SentToGroupOnly() + { + const string orgId = "org-1"; + const string otherOrgId = "org-2"; + using var responseInOrg = new FakeHttpResponse(); + using var responseOutOrg = new FakeHttpResponse(); + using var cts1 = new CancellationTokenSource(); + using var cts2 = new CancellationTokenSource(); + + string inOrgConn = "conn-in-org"; + string outOrgConn = "conn-out-org"; + + _connectionManager.AddConnection(inOrgConn, responseInOrg, cts1.Token); + _connectionManager.AddConnection(outOrgConn, responseOutOrg, cts2.Token); + + try + { + await _connectionMapping.GroupAddAsync(orgId, inOrgConn); + await _connectionMapping.GroupAddAsync(otherOrgId, outOrgConn); + + var entityChanged = new EntityChanged + { + Id = "stack-123", + Type = "Stack", + ChangeType = ChangeType.Saved + }; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.OrganizationId] = orgId; + + await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); + + // Give write loop a moment to process + await Task.Delay(200, TestContext.Current.CancellationToken); + + // In-org connection should receive message, out-org should not + Assert.True(responseInOrg.WrittenData.Length > 0, "In-org connection should receive message"); + Assert.Equal(0, responseOutOrg.WrittenData.Length); + } + finally + { + await _connectionMapping.GroupRemoveAsync(orgId, inOrgConn); + await _connectionMapping.GroupRemoveAsync(otherOrgId, outOrgConn); + await _connectionManager.RemoveConnectionAsync(inOrgConn); + await _connectionManager.RemoveConnectionAsync(outOrgConn); + } + } + + [Fact] + public async Task OnEntityChangedAsync_UserMessage_SentToUserOnly() + { + const string userId = "user-target"; + const string otherUserId = "user-other"; + using var responseTarget = new FakeHttpResponse(); + using var responseOther = new FakeHttpResponse(); + using var cts1 = new CancellationTokenSource(); + using var cts2 = new CancellationTokenSource(); + + string targetConn = "conn-target-user"; + string otherConn = "conn-other-user"; + + _connectionManager.AddConnection(targetConn, responseTarget, cts1.Token); + _connectionManager.AddConnection(otherConn, responseOther, cts2.Token); + + try + { + await _connectionMapping.UserIdAddAsync(userId, targetConn); + await _connectionMapping.UserIdAddAsync(otherUserId, otherConn); + + var entityChanged = new EntityChanged + { + Id = userId, + Type = nameof(User), + ChangeType = ChangeType.Saved + }; + + await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); + + await Task.Delay(200, TestContext.Current.CancellationToken); + + Assert.True(responseTarget.WrittenData.Length > 0, "Target user should receive message"); + Assert.Equal(0, responseOther.WrittenData.Length); + } + finally + { + await _connectionMapping.UserIdRemoveAsync(userId, targetConn); + await _connectionMapping.UserIdRemoveAsync(otherUserId, otherConn); + await _connectionManager.RemoveConnectionAsync(targetConn); + await _connectionManager.RemoveConnectionAsync(otherConn); + } + } + + [Fact] + public async Task OnUserMembershipChangedAsync_AddedAndRemoved_UpdatesForwardAndReverseMappings() + { + const string userId = "membership-user"; + const string organizationId = "membership-org"; + using var response1 = new FakeHttpResponse(); + using var response2 = new FakeHttpResponse(); + using var cts1 = new CancellationTokenSource(); + using var cts2 = new CancellationTokenSource(); + + string connectionId1 = "membership-conn-1"; + string connectionId2 = "membership-conn-2"; + + _connectionManager.AddConnection(connectionId1, response1, cts1.Token); + _connectionManager.AddConnection(connectionId2, response2, cts2.Token); + + try + { + await _connectionMapping.UserIdAddAsync(userId, connectionId1); + await _connectionMapping.UserIdAddAsync(userId, connectionId2); + + var addMessage = new UserMembershipChanged { + UserId = userId, + OrganizationId = organizationId, + ChangeType = ChangeType.Added + }; + + await InvokeUserMembershipChangedAsync(addMessage); + + var organizationConnections = await _connectionMapping.GetGroupConnectionsAsync(organizationId); + Assert.Contains(connectionId1, organizationConnections); + Assert.Contains(connectionId2, organizationConnections); + Assert.Contains(organizationId, await _connectionMapping.GetConnectionGroupsAsync(connectionId1)); + Assert.Contains(organizationId, await _connectionMapping.GetConnectionGroupsAsync(connectionId2)); + + var removeMessage = addMessage with { ChangeType = ChangeType.Removed }; + await InvokeUserMembershipChangedAsync(removeMessage); + + Assert.Empty(await _connectionMapping.GetGroupConnectionsAsync(organizationId)); + Assert.Empty(await _connectionMapping.GetConnectionGroupsAsync(connectionId1)); + Assert.Empty(await _connectionMapping.GetConnectionGroupsAsync(connectionId2)); + } + finally + { + await _connectionMapping.UserIdRemoveAsync(userId, connectionId1); + await _connectionMapping.UserIdRemoveAsync(userId, connectionId2); + await _connectionManager.RemoveConnectionAsync(connectionId1); + await _connectionManager.RemoveConnectionAsync(connectionId2); + } + } + + [Fact] + public async Task OnUserMembershipChangedAsync_Removed_SendsRefreshToRemovedUserAndRemainingOrganizationMembers() + { + const string removedUserId = "removed-user"; + const string remainingUserId = "remaining-user"; + const string organizationId = "shared-org"; + using var removedResponse = new FakeHttpResponse(); + using var remainingResponse = new FakeHttpResponse(); + using var removedCts = new CancellationTokenSource(); + using var remainingCts = new CancellationTokenSource(); + + string removedConnectionId = "removed-conn"; + string remainingConnectionId = "remaining-conn"; + + _connectionManager.AddConnection(removedConnectionId, removedResponse, removedCts.Token); + _connectionManager.AddConnection(remainingConnectionId, remainingResponse, remainingCts.Token); + + try + { + await _connectionMapping.UserIdAddAsync(removedUserId, removedConnectionId); + await _connectionMapping.UserIdAddAsync(remainingUserId, remainingConnectionId); + await _connectionMapping.GroupAddAsync(organizationId, removedConnectionId); + await _connectionMapping.GroupAddAsync(organizationId, remainingConnectionId); + await _connectionMapping.ConnectionGroupAddAsync(removedConnectionId, organizationId); + + var message = new UserMembershipChanged { + UserId = removedUserId, + OrganizationId = organizationId, + ChangeType = ChangeType.Removed + }; + + await InvokeUserMembershipChangedAsync(message); + await Task.Delay(200, TestContext.Current.CancellationToken); + + Assert.Contains(nameof(UserMembershipChanged), removedResponse.WrittenData); + Assert.Contains(nameof(UserMembershipChanged), remainingResponse.WrittenData); + Assert.DoesNotContain(removedConnectionId, await _connectionMapping.GetGroupConnectionsAsync(organizationId)); + Assert.Empty(await _connectionMapping.GetConnectionGroupsAsync(removedConnectionId)); + } + finally + { + await _connectionMapping.UserIdRemoveAsync(removedUserId, removedConnectionId); + await _connectionMapping.UserIdRemoveAsync(remainingUserId, remainingConnectionId); + await _connectionMapping.GroupRemoveAsync(organizationId, removedConnectionId); + await _connectionMapping.GroupRemoveAsync(organizationId, remainingConnectionId); + await _connectionMapping.ConnectionGroupRemoveAsync(removedConnectionId, organizationId); + await _connectionManager.RemoveConnectionAsync(removedConnectionId); + await _connectionManager.RemoveConnectionAsync(remainingConnectionId); + } + } + + [Fact] + public async Task OnEntityChangedAsync_AuthTokenRemoved_ClearsAllTrackedOrganizationMappings() + { + const string userId = "tracked-user"; + const string firstOrganizationId = "tracked-org-1"; + const string secondOrganizationId = "tracked-org-2"; + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + + string connectionId = "tracked-conn"; + _connectionManager.AddConnection(connectionId, response, cts.Token); + + try + { + await _connectionMapping.UserIdAddAsync(userId, connectionId); + await _connectionMapping.GroupAddAsync(firstOrganizationId, connectionId); + await _connectionMapping.GroupAddAsync(secondOrganizationId, connectionId); + await _connectionMapping.ConnectionGroupAddAsync(connectionId, firstOrganizationId); + await _connectionMapping.ConnectionGroupAddAsync(connectionId, secondOrganizationId); + + var entityChanged = new EntityChanged + { + Id = "tracked-token-id", + Type = nameof(Token), + ChangeType = ChangeType.Removed + }; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.OrganizationId] = firstOrganizationId; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.UserId] = userId; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.IsAuthenticationToken] = true; + + await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); + + Assert.Null(_connectionManager.GetConnectionById(connectionId)); + Assert.Empty(await _connectionMapping.GetUserIdConnectionsAsync(userId)); + Assert.Empty(await _connectionMapping.GetConnectionGroupsAsync(connectionId)); + Assert.Empty(await _connectionMapping.GetGroupConnectionsAsync(firstOrganizationId)); + Assert.Empty(await _connectionMapping.GetGroupConnectionsAsync(secondOrganizationId)); + } + finally + { + await _connectionMapping.UserIdRemoveAsync(userId, connectionId); + await _connectionMapping.GroupRemoveAsync(firstOrganizationId, connectionId); + await _connectionMapping.GroupRemoveAsync(secondOrganizationId, connectionId); + await _connectionMapping.ConnectionGroupRemoveAsync(connectionId, firstOrganizationId); + await _connectionMapping.ConnectionGroupRemoveAsync(connectionId, secondOrganizationId); + } + } + + private Task InvokeUserMembershipChangedAsync(UserMembershipChanged message) + { + var method = typeof(MessageBusBroker).GetMethod("OnUserMembershipChangedAsync", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + return Assert.IsAssignableFrom(method!.Invoke(_broker, [message, CancellationToken.None])); + } +} + +/// +/// Tests for the deduplication behavior of SseConnection. +/// Validates that identical messages queued in quick succession are coalesced. +/// +public sealed class SseDeduplicationTests : TestWithServices +{ + public SseDeduplicationTests(ITestOutputHelper output) : base(output) { } + + [Fact] + public async Task DuplicateMessages_AreDeduped_OnlyOneQueued() + { + var queue = new SseConnection.DedupQueue(8); + var evt = new SseConnection.SseEvent { Data = "{\"type\":\"StackChanged\",\"id\":\"stack-123\",\"change_type\":1}", DedupeKey = "stack-123" }; + int dedupedCount = 0; + + for (int i = 0; i < 5; i++) + { + if (queue.TryEnqueue(evt) == SseConnection.EnqueueResult.Deduped) + dedupedCount++; + } + + using var cts = new CancellationTokenSource(); + var queued = await queue.DequeueAsync(cts.Token); + + Assert.NotNull(queued); + Assert.Equal(evt.Data, queued!.Value.Data); + Assert.Equal(4, dedupedCount); + } + + [Fact] + public async Task DifferentMessages_AreNotDeduped() + { + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + var serializer = GetService(); + + await using var connection = new SseConnection("dedup-test-2", response, serializer, cts.Token, Log.CreateLogger()); + + // Send 3 different messages + connection.TryWrite(new { type = "StackChanged", id = "stack-1" }); + connection.TryWrite(new { type = "StackChanged", id = "stack-2" }); + connection.TryWrite(new { type = "ProjectChanged", id = "proj-1" }); + + await Task.Delay(200, TestContext.Current.CancellationToken); + connection.Abort(); + await Task.Delay(50, TestContext.Current.CancellationToken); + + string output = response.WrittenData; + int dataLineCount = output.Split("data: ").Length - 1; + Assert.Equal(3, dataLineCount); + Assert.Equal(0, connection.DedupedMessages); + } + + [Fact] + public async Task SameMessage_AfterFirstIsConsumed_IsNotDeduped() + { + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + var serializer = GetService(); + + await using var connection = new SseConnection("dedup-test-3", response, serializer, cts.Token, Log.CreateLogger()); + + var message = new { type = "StackChanged", id = "stack-repeat" }; + + // Send first message and wait for it to be consumed + connection.TryWrite(message); + await Task.Delay(200, TestContext.Current.CancellationToken); + + // Send same message again — should NOT be deduped because first was already consumed + connection.TryWrite(message); + await Task.Delay(200, TestContext.Current.CancellationToken); + + connection.Abort(); + await Task.Delay(50, TestContext.Current.CancellationToken); + + string output = response.WrittenData; + int dataLineCount = output.Split("data: ").Length - 1; + Assert.Equal(2, dataLineCount); + Assert.Equal(0, connection.DedupedMessages); + } + + [Fact] + public async Task KeepAlive_IsNeverDeduped() + { + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + var serializer = GetService(); + + await using var connection = new SseConnection("dedup-test-4", response, serializer, cts.Token, Log.CreateLogger()); + + // Send multiple keep-alives — none should be deduped + connection.TryWriteKeepAlive(); + connection.TryWriteKeepAlive(); + connection.TryWriteKeepAlive(); + + await Task.Delay(200, TestContext.Current.CancellationToken); + connection.Abort(); + await Task.Delay(50, TestContext.Current.CancellationToken); + + string output = response.WrittenData; + int keepAliveCount = output.Split(": keepalive").Length - 1; + Assert.Equal(3, keepAliveCount); + } + + [Fact] + public async Task Capacity_Exceeded_DropsOldest() + { + // Test the DedupQueue directly to avoid racing with the write loop + var queue = new SseConnection.DedupQueue(3); + + // Enqueue 5 items with unique keys — first 2 should be dropped + for (int i = 0; i < 5; i++) + { + queue.TryEnqueue(new SseConnection.SseEvent { Data = $"msg-{i}", DedupeKey = $"key-{i}" }); + } + + // Dequeue and verify we get the last 3 items (oldest 2 were dropped) + using var cts = new CancellationTokenSource(); + var item1 = await queue.DequeueAsync(cts.Token); + var item2 = await queue.DequeueAsync(cts.Token); + var item3 = await queue.DequeueAsync(cts.Token); + + Assert.Equal("msg-2", item1!.Value.Data); + Assert.Equal("msg-3", item2!.Value.Data); + Assert.Equal("msg-4", item3!.Value.Data); + } + + [Fact] + public async Task CriticalMessage_WhenQueueFull_DropsOldestDroppableMessageFirst() + { + var queue = new SseConnection.DedupQueue(2); + queue.TryEnqueue(new SseConnection.SseEvent { Data = "lossy-1", DedupeKey = "lossy-1", CanDrop = true }); + queue.TryEnqueue(new SseConnection.SseEvent { Data = "critical-1", CanDrop = false }); + + var result = queue.TryEnqueue(new SseConnection.SseEvent { Data = "critical-2", CanDrop = false }); + + using var cts = new CancellationTokenSource(); + var item1 = await queue.DequeueAsync(cts.Token); + var item2 = await queue.DequeueAsync(cts.Token); + + Assert.Equal(SseConnection.EnqueueResult.DroppedQueuedMessage, result); + Assert.Equal("critical-1", item1!.Value.Data); + Assert.Equal("critical-2", item2!.Value.Data); + } + + [Fact] + public async Task KeepAlive_WhenQueueFull_DoesNotEvictCriticalMessage() + { + var queue = new SseConnection.DedupQueue(1); + queue.TryEnqueue(new SseConnection.SseEvent { Data = "critical-1", CanDrop = false }); + + var result = queue.TryEnqueue(SseConnection.SseEvent.KeepAlive); + + using var cts = new CancellationTokenSource(); + var item = await queue.DequeueAsync(cts.Token); + + Assert.Equal(SseConnection.EnqueueResult.BackpressureSkipped, result); + Assert.True(item.HasValue); + Assert.Equal("critical-1", item.Value.Data); + Assert.False(item.Value.IsKeepAlive); + } + + [Fact] + public async Task TryWriteKeepAlive_WhenQueueIsBackpressured_ReturnsTrue() + { + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + var serializer = GetService(); + + await using var connection = new SseConnection("dedup-test-6", response, serializer, cts.Token, Log.CreateLogger(), capacity: 0); + + Assert.True(connection.TryWriteKeepAlive()); + } + + [Fact] + public async Task TryWrite_WhenQueueCompleted_ReturnsFalse() + { + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + var serializer = GetService(); + + await using var connection = new SseConnection("dedup-test-5", response, serializer, cts.Token, Log.CreateLogger()); + + var queueField = typeof(SseConnection).GetField("_queue", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(queueField); + + var queue = Assert.IsType(queueField!.GetValue(connection)); + queue.Complete(); + + Assert.False(connection.TryWrite(new { type = "StackChanged", id = "stack-race" })); + Assert.False(connection.TryWriteKeepAlive()); + } +} diff --git a/tests/Exceptionless.Tests/Hubs/TestWebSocket.cs b/tests/Exceptionless.Tests/Hubs/TestWebSocket.cs index c8343c7b3a..54aabbe351 100644 --- a/tests/Exceptionless.Tests/Hubs/TestWebSocket.cs +++ b/tests/Exceptionless.Tests/Hubs/TestWebSocket.cs @@ -6,6 +6,7 @@ namespace Exceptionless.Tests.Hubs; internal sealed class TestWebSocket : WebSocket { private WebSocketState _state; + private int _closeCount; public TestWebSocket(WebSocketState state = WebSocketState.Open) { @@ -13,7 +14,6 @@ public TestWebSocket(WebSocketState state = WebSocketState.Open) } public int CloseCount => _closeCount; - private int _closeCount; public List SentMessages { get; } = []; public override WebSocketCloseStatus? CloseStatus { get; } = WebSocketCloseStatus.NormalClosure; public override string? CloseStatusDescription { get; } = "Closed"; @@ -47,7 +47,7 @@ public override Task ReceiveAsync(ArraySegment buf public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - SentMessages.Add(Encoding.ASCII.GetString(buffer.Array!, buffer.Offset, buffer.Count)); + SentMessages.Add(Encoding.UTF8.GetString(buffer.Array!, buffer.Offset, buffer.Count)); return Task.CompletedTask; } } diff --git a/tests/Exceptionless.Tests/Hubs/WebSocketCompatibilityTests.cs b/tests/Exceptionless.Tests/Hubs/WebSocketCompatibilityTests.cs new file mode 100644 index 0000000000..09045c5259 --- /dev/null +++ b/tests/Exceptionless.Tests/Hubs/WebSocketCompatibilityTests.cs @@ -0,0 +1,156 @@ +using System.Net.WebSockets; +using Exceptionless.Core; +using Exceptionless.Core.Messaging.Models; +using Exceptionless.Core.Models; +using Exceptionless.Core.Utility; +using Exceptionless.Web.Hubs; +using Foundatio.Repositories.Models; +using Foundatio.Serializer; +using Xunit; + +namespace Exceptionless.Tests.Hubs; + +public sealed class WebSocketConnectionCompatibilityTests : TestWithServices +{ + public WebSocketConnectionCompatibilityTests(ITestOutputHelper output) : base(output) { } + + [Fact] + public void AddConnection_NewSocket_CanLookupAndEnumerateConnection() + { + using var manager = CreateManager(); + var socket = new TestWebSocket(); + + string connectionId = manager.AddConnection(socket); + + Assert.False(String.IsNullOrEmpty(connectionId)); + Assert.Same(socket, manager.GetConnectionById(connectionId)); + Assert.Same(socket, Assert.Single(manager.GetAll())); + } + + [Fact] + public async Task RemoveConnectionAsync_ExistingConnection_RemovesAndClosesSocket() + { + using var manager = CreateManager(); + var socket = new TestWebSocket(); + string connectionId = manager.AddConnection(socket); + + await manager.RemoveConnectionAsync(connectionId); + + Assert.Null(manager.GetConnectionById(connectionId)); + Assert.Empty(manager.GetAll()); + Assert.Equal(1, socket.CloseCount); + Assert.Equal(WebSocketState.Closed, socket.State); + } + + [Fact] + public void SendMessage_ClosedSocket_ReturnsFalseAndRemovesConnection() + { + using var manager = CreateManager(); + var socket = new TestWebSocket(WebSocketState.Closed); + string connectionId = manager.AddConnection(socket); + + bool sent = manager.SendMessage(connectionId, new { type = "test" }); + + Assert.False(sent); + Assert.Null(manager.GetConnectionById(connectionId)); + } + + private WebSocketConnectionManager CreateManager() + { + var options = new AppOptions { EnablePush = true }; + return new WebSocketConnectionManager(options, GetService(), Log); + } +} + +public sealed class PushCompatibilityBrokerTests : TestWithServices +{ + private readonly MessageBusBroker _broker; + private readonly IConnectionMapping _connectionMapping; + private readonly SseConnectionManager _sseConnectionManager; + private readonly WebSocketConnectionManager _webSocketConnectionManager; + + public PushCompatibilityBrokerTests(ITestOutputHelper output) : base(output) + { + _broker = GetService(); + _connectionMapping = GetService(); + _sseConnectionManager = GetService(); + _webSocketConnectionManager = GetService(); + } + + [Fact] + public async Task OnEntityChangedAsync_FansOutToSseAndWebSocketConnections() + { + const string organizationId = "compat-org"; + using var response = new FakeHttpResponse(); + using var cts = new CancellationTokenSource(); + var socket = new TestWebSocket(); + + string sseConnectionId = "compat-sse"; + string webSocketConnectionId = _webSocketConnectionManager.AddConnection(socket); + _sseConnectionManager.AddConnection(sseConnectionId, response, cts.Token); + + try + { + await _connectionMapping.GroupAddAsync(organizationId, sseConnectionId); + await _connectionMapping.GroupAddAsync(organizationId, webSocketConnectionId); + + var entityChanged = new EntityChanged + { + Id = "stack-compat", + Type = "Stack", + ChangeType = ChangeType.Saved + }; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.OrganizationId] = organizationId; + + await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); + await Task.Delay(200, TestContext.Current.CancellationToken); + + Assert.Contains("StackChanged", response.WrittenData); + Assert.Single(socket.SentMessages); + Assert.Contains("StackChanged", socket.SentMessages[0]); + } + finally + { + await _connectionMapping.GroupRemoveAsync(organizationId, sseConnectionId); + await _connectionMapping.GroupRemoveAsync(organizationId, webSocketConnectionId); + await _sseConnectionManager.RemoveConnectionAsync(sseConnectionId); + await _webSocketConnectionManager.RemoveConnectionAsync(webSocketConnectionId); + } + } + + [Fact] + public async Task OnEntityChangedAsync_AuthTokenRemoved_ClosesWebSocketConnectionsAndClearsMapping() + { + const string userId = "compat-user"; + const string organizationId = "compat-org"; + var socket = new TestWebSocket(); + string connectionId = _webSocketConnectionManager.AddConnection(socket); + + try + { + await _connectionMapping.UserIdAddAsync(userId, connectionId); + await _connectionMapping.GroupAddAsync(organizationId, connectionId); + + var entityChanged = new EntityChanged + { + Id = "compat-token", + Type = nameof(Token), + ChangeType = ChangeType.Removed + }; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.OrganizationId] = organizationId; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.UserId] = userId; + entityChanged.Data[ExtendedEntityChanged.KnownKeys.IsAuthenticationToken] = true; + + await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); + + Assert.Null(_webSocketConnectionManager.GetConnectionById(connectionId)); + Assert.Equal(1, socket.CloseCount); + Assert.Empty(await _connectionMapping.GetUserIdConnectionsAsync(userId)); + } + finally + { + await _connectionMapping.GroupRemoveAsync(organizationId, connectionId); + await _connectionMapping.UserIdRemoveAsync(userId, connectionId); + } + } +} diff --git a/tests/Exceptionless.Tests/Hubs/WebSocketConnectionManagerTests.cs b/tests/Exceptionless.Tests/Hubs/WebSocketConnectionManagerTests.cs index 9018916fdc..cb84e7e22d 100644 --- a/tests/Exceptionless.Tests/Hubs/WebSocketConnectionManagerTests.cs +++ b/tests/Exceptionless.Tests/Hubs/WebSocketConnectionManagerTests.cs @@ -1,4 +1,5 @@ using System.Net.WebSockets; +using System.Reflection; using System.Text; using Exceptionless.Core; using Exceptionless.Web.Hubs; @@ -14,14 +15,11 @@ public WebSocketConnectionManagerTests(ITestOutputHelper output) : base(output) [Fact] public void AddWebSocket_NewSocket_CanLookupAndEnumerateConnection() { - // Arrange using var manager = CreateManager(); var socket = new TestWebSocket(); - // Act string connectionId = manager.AddWebSocket(socket); - // Assert Assert.False(String.IsNullOrEmpty(connectionId)); Assert.Same(socket, manager.GetWebSocketById(connectionId)); Assert.Equal(connectionId, manager.GetConnectionId(socket)); @@ -31,15 +29,12 @@ public void AddWebSocket_NewSocket_CanLookupAndEnumerateConnection() [Fact] public async Task RemoveWebSocketAsync_ExistingConnection_RemovesAndClosesSocket() { - // Arrange using var manager = CreateManager(); var socket = new TestWebSocket(); string connectionId = manager.AddWebSocket(socket); - // Act await manager.RemoveWebSocketAsync(connectionId); - // Assert Assert.Null(manager.GetWebSocketById(connectionId)); Assert.Empty(manager.GetAll()); Assert.Equal(1, socket.CloseCount); @@ -49,15 +44,12 @@ public async Task RemoveWebSocketAsync_ExistingConnection_RemovesAndClosesSocket [Fact] public async Task RemoveWebSocketAsync_ClosedSocket_RemovesWithoutClosingAgain() { - // Arrange using var manager = CreateManager(); var socket = new TestWebSocket(WebSocketState.Closed); string connectionId = manager.AddWebSocket(socket); - // Act await manager.RemoveWebSocketAsync(connectionId); - // Assert Assert.Null(manager.GetWebSocketById(connectionId)); Assert.Empty(manager.GetAll()); Assert.Equal(0, socket.CloseCount); @@ -66,34 +58,126 @@ public async Task RemoveWebSocketAsync_ClosedSocket_RemovesWithoutClosingAgain() [Fact] public async Task RemoveWebSocketAsync_UnknownConnection_DoesNothing() { - // Arrange using var manager = CreateManager(); - // Act await manager.RemoveWebSocketAsync("missing"); - // Assert Assert.Empty(manager.GetAll()); } [Fact] public async Task SendMessageToAllAsync_ClosedSockets_DoesNotSend() { - // Arrange using var manager = CreateManager(); var socket = new TestWebSocket(WebSocketState.Closed); manager.AddWebSocket(socket); - // Act await manager.SendMessageToAllAsync(new { type = "test" }); - // Assert Assert.Empty(socket.SentMessages); } + [Fact] + public async Task SendMessage_ConcurrentKeepAlive_DoesNotOverlapSocketSends() + { + using var manager = CreateManager(); + var socket = new ConcurrentSendDetectingWebSocket(); + string connectionId = manager.AddWebSocket(socket); + + Assert.True(manager.SendMessage(connectionId, new { type = "test" })); + await socket.FirstSendStarted.WaitAsync(TestContext.Current.CancellationToken); + + var sendKeepAlive = typeof(WebSocketConnectionManager).GetMethod("SendKeepAlive", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(sendKeepAlive); + sendKeepAlive!.Invoke(manager, [null]); + + await Task.Delay(100, TestContext.Current.CancellationToken); + Assert.Equal(0, socket.ConcurrentSendAttempts); + + socket.ReleaseFirstSend(); + await Task.Delay(100, TestContext.Current.CancellationToken); + + Assert.Equal(0, socket.ConcurrentSendAttempts); + Assert.Equal(2, socket.SentMessages.Count); + Assert.Contains("\"type\":\"test\"", socket.SentMessages[0]); + Assert.Equal("{}", socket.SentMessages[1]); + } + private WebSocketConnectionManager CreateManager() { - var options = new AppOptions { EnableWebSockets = false }; + var options = new AppOptions { EnablePush = false }; return new WebSocketConnectionManager(options, GetService(), Log); } + + private sealed class ConcurrentSendDetectingWebSocket : WebSocket + { + private readonly TaskCompletionSource _firstSendStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _releaseFirstSend = new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _activeSendCount; + private int _concurrentSendAttempts; + private int _sendCount; + private WebSocketState _state = WebSocketState.Open; + + public Task FirstSendStarted => _firstSendStarted.Task; + public int ConcurrentSendAttempts => _concurrentSendAttempts; + public List SentMessages { get; } = []; + public override WebSocketCloseStatus? CloseStatus => WebSocketCloseStatus.NormalClosure; + public override string? CloseStatusDescription => "Closed"; + public override string? SubProtocol => null; + public override WebSocketState State => _state; + + public void ReleaseFirstSend() + { + _releaseFirstSend.TrySetResult(true); + } + + public override void Abort() + { + _state = WebSocketState.Aborted; + } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + _state = WebSocketState.Closed; + _releaseFirstSend.TrySetResult(true); + return Task.CompletedTask; + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + _state = WebSocketState.CloseSent; + return Task.CompletedTask; + } + + public override void Dispose() { } + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + return Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Text, true)); + } + + public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + if (Interlocked.Increment(ref _activeSendCount) != 1) + { + Interlocked.Decrement(ref _activeSendCount); + Interlocked.Increment(ref _concurrentSendAttempts); + throw new InvalidOperationException("Concurrent sends are not allowed"); + } + + try + { + SentMessages.Add(Encoding.UTF8.GetString(buffer.Array!, buffer.Offset, buffer.Count)); + if (Interlocked.Increment(ref _sendCount) == 1) + { + _firstSendStarted.TrySetResult(true); + await _releaseFirstSend.Task.WaitAsync(cancellationToken); + } + } + finally + { + Interlocked.Decrement(ref _activeSendCount); + } + } + } } diff --git a/tests/Exceptionless.Tests/Hubs/WebSocketTests.cs b/tests/Exceptionless.Tests/Hubs/WebSocketTests.cs index a34a33ddc7..1387058751 100644 --- a/tests/Exceptionless.Tests/Hubs/WebSocketTests.cs +++ b/tests/Exceptionless.Tests/Hubs/WebSocketTests.cs @@ -8,9 +8,9 @@ namespace Exceptionless.Tests.Hubs; /// -/// Tests for WebSocket behavior. Calls +/// Tests for WebSocket behavior. Calls /// directly so they do not depend on -/// message bus wiring or EnableWebSockets in test host configuration. +/// message bus wiring or EnablePush in test host configuration. /// public sealed class WebSocketTests : TestWithServices { @@ -28,7 +28,6 @@ public WebSocketTests(ITestOutputHelper output) : base(output) [Fact] public async Task OnEntityChangedAsync_AuthTokenRemoved_ClosesWebSocketsAndClearsUserMapping() { - // Arrange const string userId = "test-user-id"; const string organizationId = "test-organization-id"; var socket1 = new TestWebSocket(); @@ -57,10 +56,8 @@ public async Task OnEntityChangedAsync_AuthTokenRemoved_ClosesWebSocketsAndClear entityChanged.Data[ExtendedEntityChanged.KnownKeys.UserId] = userId; entityChanged.Data[ExtendedEntityChanged.KnownKeys.IsAuthenticationToken] = true; - // Act — call the broker directly; no message bus or EnableWebSockets dependency await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); - // Assert – sockets closed and removed from manager Assert.Null(_connectionManager.GetWebSocketById(connectionId1)); Assert.Null(_connectionManager.GetWebSocketById(connectionId2)); Assert.Same(unrelatedSocket, _connectionManager.GetWebSocketById(unrelatedConnectionId)); @@ -69,7 +66,6 @@ public async Task OnEntityChangedAsync_AuthTokenRemoved_ClosesWebSocketsAndClear Assert.Equal(1, socket2.CloseCount); Assert.Equal(0, unrelatedSocket.CloseCount); - // Assert – user-id mapping removed by broker var remaining = await _connectionMapping.GetUserIdConnectionsAsync(userId); Assert.Empty(remaining); var organizationConnections = await _connectionMapping.GetGroupConnectionsAsync(organizationId); @@ -87,7 +83,6 @@ public async Task OnEntityChangedAsync_AuthTokenRemoved_ClosesWebSocketsAndClear [Fact] public async Task OnEntityChangedAsync_NonAuthTokenRemoved_DoesNotCloseWebSockets() { - // Arrange const string userId = "test-user-id-2"; var socket = new TestWebSocket(); string connectionId = _connectionManager.AddWebSocket(socket); @@ -103,12 +98,9 @@ public async Task OnEntityChangedAsync_NonAuthTokenRemoved_DoesNotCloseWebSocket ChangeType = ChangeType.Removed }; entityChanged.Data[ExtendedEntityChanged.KnownKeys.UserId] = userId; - // IsAuthenticationToken intentionally omitted (defaults false) - // Act await _broker.OnEntityChangedAsync(entityChanged, CancellationToken.None); - // Assert – socket should NOT be closed for a non-auth token removal Assert.Equal(0, socket.CloseCount); Assert.Same(socket, _connectionManager.GetWebSocketById(connectionId)); } diff --git a/tests/Exceptionless.Tests/appsettings.yml b/tests/Exceptionless.Tests/appsettings.yml index 5bdb00cb6a..f4c6494257 100644 --- a/tests/Exceptionless.Tests/appsettings.yml +++ b/tests/Exceptionless.Tests/appsettings.yml @@ -20,7 +20,7 @@ EnableDailySummary: false # Runs the jobs in the current website process RunJobsInProcess: false -EnableWebSockets: false +EnablePush: true Serilog: MinimumLevel: Warning diff --git a/tests/http/push.http b/tests/http/push.http new file mode 100644 index 0000000000..db874cbf25 --- /dev/null +++ b/tests/http/push.http @@ -0,0 +1,29 @@ +@url = http://localhost:7110 +@apiUrl = {{url}}/api/v2 +@email = admin@exceptionless.test +@password = tester + +### login to test account +# @name login +POST {{apiUrl}}/auth/login +Content-Type: application/json + +{ + "email": "{{email}}", + "password": "{{password}}" +} + +### + +@token = {{login.response.body.$.token}} + +### SSE push via bearer token +# This request intentionally stays open. Cancel it manually after verifying headers/events. +GET {{apiUrl}}/push +Accept: text/event-stream +Authorization: Bearer {{token}} + +### SSE push via query-string token +# This request intentionally stays open. Cancel it manually after verifying headers/events. +GET {{apiUrl}}/push?access_token={{token}} +Accept: text/event-stream