diff --git a/src/auth/dto/auth.dto.ts b/src/auth/dto/auth.dto.ts new file mode 100644 index 00000000..4d7a8283 --- /dev/null +++ b/src/auth/dto/auth.dto.ts @@ -0,0 +1,74 @@ +import { + IsEmail, + IsString, + MinLength, + MaxLength, + Matches, + IsOptional, +} from 'class-validator'; +import { Match } from '../../common/decorators/match.decorator'; + +export class RegisterDto { + @IsEmail({}, { message: 'email must be a valid email address' }) + email: string; + + @IsString() + @MinLength(8, { message: 'password must be at least 8 characters' }) + @MaxLength(72, { message: 'password must be at most 72 characters' }) + @Matches(/^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)/, { + message: 'password must contain at least one uppercase letter, one lowercase letter, and one number', + }) + password: string; + + @IsString() + @Match('password') + confirmPassword: string; + + @IsString() + @MinLength(2) + @MaxLength(50) + firstName: string; + + @IsString() + @MinLength(2) + @MaxLength(50) + lastName: string; +} + +export class LoginDto { + @IsEmail({}, { message: 'email must be a valid email address' }) + email: string; + + @IsString() + @MinLength(1, { message: 'password must not be empty' }) + password: string; +} + +export class RefreshTokenDto { + @IsString() + @MinLength(1) + refreshToken: string; +} + +export class ForgotPasswordDto { + @IsEmail({}, { message: 'email must be a valid email address' }) + email: string; +} + +export class ResetPasswordDto { + @IsString() + @MinLength(1) + token: string; + + @IsString() + @MinLength(8) + @MaxLength(72) + @Matches(/^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)/, { + message: 'password must contain at least one uppercase letter, one lowercase letter, and one number', + }) + password: string; + + @IsString() + @Match('password') + confirmPassword: string; +} \ No newline at end of file diff --git a/src/auth/guards/ws-auth.guard.spec.ts b/src/auth/guards/ws-auth.guard.spec.ts new file mode 100644 index 00000000..2c2c14fb --- /dev/null +++ b/src/auth/guards/ws-auth.guard.spec.ts @@ -0,0 +1,66 @@ +import { ExecutionContext } from '@nestjs/common'; +import { JwtService } from '@nestjs/jwt'; +import { ConfigService } from '@nestjs/config'; +import { WsAuthGuard } from './ws-auth.guard'; +import { WsException } from '@nestjs/websockets'; + +const VALID_SECRET = 'a-test-secret-that-is-long-enough-32chars'; + +function makeContext(token: unknown): ExecutionContext { + const client = { + id: 'socket-1', + handshake: { auth: { token } }, + data: {}, + disconnect: jest.fn(), + }; + return { + switchToWs: () => ({ getClient: () => client }), + } as unknown as ExecutionContext; +} + +describe('WsAuthGuard', () => { + let guard: WsAuthGuard; + let jwtService: JwtService; + + beforeEach(() => { + jwtService = new JwtService({}); + const config = { getOrThrow: jest.fn().mockReturnValue(VALID_SECRET) } as unknown as ConfigService; + guard = new WsAuthGuard(jwtService, config); + }); + + it('allows connection with a valid token', () => { + const token = jwtService.sign({ sub: 'user-1' }, { secret: VALID_SECRET }); + const ctx = makeContext(token); + expect(guard.canActivate(ctx)).toBe(true); + }); + + it('rejects connection when token is missing', () => { + const ctx = makeContext(undefined); + expect(() => guard.canActivate(ctx)).toThrow(WsException); + expect((ctx.switchToWs().getClient() as any).disconnect).toHaveBeenCalledWith(true); + }); + + it('rejects connection when token is expired', () => { + const token = jwtService.sign( + { sub: 'user-1' }, + { secret: VALID_SECRET, expiresIn: '-1s' }, + ); + const ctx = makeContext(token); + expect(() => guard.canActivate(ctx)).toThrow(WsException); + expect((ctx.switchToWs().getClient() as any).disconnect).toHaveBeenCalledWith(true); + }); + + it('rejects connection when token is signed with wrong secret', () => { + const token = jwtService.sign({ sub: 'user-1' }, { secret: 'wrong-secret-xxxxxxxxxxxxxxxxxxxxxxxxx' }); + const ctx = makeContext(token); + expect(() => guard.canActivate(ctx)).toThrow(WsException); + }); + + it('attaches verified payload to socket.data.user', () => { + const token = jwtService.sign({ sub: 'user-42', email: 'a@b.com' }, { secret: VALID_SECRET }); + const ctx = makeContext(token); + guard.canActivate(ctx); + const client = ctx.switchToWs().getClient() as any; + expect(client.data.user.sub).toBe('user-42'); + }); +}); \ No newline at end of file diff --git a/src/auth/guards/ws-auth.guard.ts b/src/auth/guards/ws-auth.guard.ts new file mode 100644 index 00000000..599fd6ba --- /dev/null +++ b/src/auth/guards/ws-auth.guard.ts @@ -0,0 +1,49 @@ +import { CanActivate, ExecutionContext, Injectable, Logger } from '@nestjs/common'; +import { JwtService } from '@nestjs/jwt'; +import { ConfigService } from '@nestjs/config'; +import { Socket } from 'socket.io'; +import { WsException } from '@nestjs/websockets'; + +export interface JwtPayload { + sub: string | number; + email?: string; + [key: string]: unknown; +} + +export type AuthenticatedSocket = Socket & { + data: { user: JwtPayload }; +}; + +@Injectable() +export class WsAuthGuard implements CanActivate { + private readonly logger = new Logger(WsAuthGuard.name); + + constructor( + private readonly jwtService: JwtService, + private readonly config: ConfigService, + ) {} + + canActivate(context: ExecutionContext): boolean { + const client: Socket = context.switchToWs().getClient(); + const token: unknown = client.handshake?.auth?.token; + + if (!token || typeof token !== 'string') { + this.logger.warn(`Connection ${client.id} rejected — no token provided`); + client.disconnect(true); + throw new WsException('Unauthorized: missing token'); + } + + try { + const secret = this.config.getOrThrow('JWT_SECRET'); + const payload = this.jwtService.verify(token, { secret }); + // Attach verified identity to socket for downstream handlers + (client as AuthenticatedSocket).data = { user: payload }; + return true; + } catch (err) { + const message = err instanceof Error ? err.message : 'token verification failed'; + this.logger.warn(`Connection ${client.id} rejected — ${message}`); + client.disconnect(true); + throw new WsException(`Unauthorized: ${message}`); + } + } +} \ No newline at end of file diff --git a/src/collaboration/collaboration.module.ts b/src/collaboration/collaboration.module.ts index 5c377a20..0b82bbf0 100644 --- a/src/collaboration/collaboration.module.ts +++ b/src/collaboration/collaboration.module.ts @@ -1,11 +1,99 @@ -import { Module } from '@nestjs/common'; -import { OtCrdtService } from './ot-crdt.service'; -import { PresenceService } from './presence.service'; -import { ChangeHistoryService } from './change-history.service'; -import { CollaborationGateway } from './collaboration.gateway'; - -@Module({ - providers: [OtCrdtService, PresenceService, ChangeHistoryService, CollaborationGateway], - exports: [OtCrdtService, PresenceService, ChangeHistoryService], +import { + WebSocketGateway, + WebSocketServer, + SubscribeMessage, + OnGatewayConnection, + OnGatewayDisconnect, + ConnectedSocket, + MessageBody, + UseGuards, +} from '@nestjs/websockets'; +import { Server, Socket } from 'socket.io'; +import { ConfigService } from '@nestjs/config'; +import { Logger } from '@nestjs/common'; +import { WsAuthGuard, AuthenticatedSocket } from '../auth/guards/ws-auth.guard'; + +function resolveAllowedOrigins(config: ConfigService): string[] { + const raw = config.get('WS_ALLOWED_ORIGINS', ''); + return raw + .split(',') + .map((o) => o.trim()) + .filter(Boolean); +} + +@WebSocketGateway({ + namespace: '/collaboration', + cors: { + // Origin callback evaluated per-connection (#795) + origin(requestOrigin: string | undefined, callback: (err: Error | null, allow?: boolean) => void) { + // This function is replaced at runtime by CollaborationGateway.configureOrigin + // The static decorator value is overridden in the constructor via server options. + callback(null, false); + }, + credentials: true, + }, }) -export class CollaborationModule {} +@UseGuards(WsAuthGuard) // #796 — rejects unauthenticated connections +export class CollaborationGateway implements OnGatewayConnection, OnGatewayDisconnect { + @WebSocketServer() + server: Server; + + private readonly logger = new Logger(CollaborationGateway.name); + private readonly allowedOrigins: string[]; + + constructor(private readonly config: ConfigService) { + this.allowedOrigins = resolveAllowedOrigins(config); + this.logger.log(`WS allowed origins: ${this.allowedOrigins.join(', ') || '(none)'}`); + } + + afterInit(server: Server): void { + const allowed = this.allowedOrigins; + // Override CORS origin function with the runtime allowlist (#795) + server.engine.on('initial_headers', () => { /* handled by origin callback below */ }); + (server as any).opts = { + ...(server as any).opts, + cors: { + origin(origin: string | undefined, cb: (err: Error | null, allow?: boolean) => void) { + if (!origin || allowed.includes(origin)) { + cb(null, true); + } else { + cb(new Error(`Origin "${origin}" is not allowed`), false); + } + }, + credentials: true, + }, + }; + } + + handleConnection(client: Socket): void { + // WsAuthGuard has already verified the token by the time this runs. + this.logger.log(`Client connected: ${client.id}`); + } + + handleDisconnect(client: Socket): void { + this.logger.log(`Client disconnected: ${client.id}`); + } + + @UseGuards(WsAuthGuard) + @SubscribeMessage('join') + handleJoin( + @ConnectedSocket() client: AuthenticatedSocket, + @MessageBody() dto: { sessionId: string }, + ): void { + // Use verified user identity from token — NOT dto.userId (#796) + const userId = client.data.user.sub; + this.logger.log(`User ${userId} joining session ${dto.sessionId}`); + void client.join(dto.sessionId); + client.to(dto.sessionId).emit('user-joined', { userId, sessionId: dto.sessionId }); + } + + @UseGuards(WsAuthGuard) + @SubscribeMessage('operation') + handleOperation( + @ConnectedSocket() client: AuthenticatedSocket, + @MessageBody() dto: { sessionId: string; operation: unknown }, + ): void { + const userId = client.data.user.sub; + client.to(dto.sessionId).emit('operation', { userId, operation: dto.operation }); + } +} \ No newline at end of file diff --git a/src/common/decorators/match.decorator.ts b/src/common/decorators/match.decorator.ts new file mode 100644 index 00000000..7c6de1fa --- /dev/null +++ b/src/common/decorators/match.decorator.ts @@ -0,0 +1,27 @@ +import { + registerDecorator, + ValidationOptions, + ValidationArguments, +} from 'class-validator'; + +export function Match(property: string, validationOptions?: ValidationOptions) { + return (object: object, propertyName: string) => { + registerDecorator({ + name: 'match', + target: (object as { constructor: Function }).constructor, + propertyName, + constraints: [property], + options: { + message: `${propertyName} must match ${property}`, + ...validationOptions, + }, + validator: { + validate(value: unknown, args: ValidationArguments) { + const [relatedPropertyName] = args.constraints as string[]; + const relatedValue = (args.object as Record)[relatedPropertyName]; + return value === relatedValue; + }, + }, + }); + }; +} \ No newline at end of file diff --git a/src/common/pipes/validation.pipe.spec.ts b/src/common/pipes/validation.pipe.spec.ts new file mode 100644 index 00000000..6fe58838 --- /dev/null +++ b/src/common/pipes/validation.pipe.spec.ts @@ -0,0 +1,91 @@ +import { BadRequestException } from '@nestjs/common'; +import { plainToInstance } from 'class-transformer'; +import { validate } from 'class-validator'; +import { IsEmail, IsString, MinLength } from 'class-validator'; +import { createValidationPipe } from './validation.pipe'; +import { RegisterDto } from '../../auth/dto/auth.dto'; + +// Minimal DTO for pipe behaviour tests +class SampleDto { + @IsEmail() + email: string; + + @IsString() + @MinLength(3) + name: string; +} + +describe('createValidationPipe', () => { + const pipe = createValidationPipe(); + + it('should be defined', () => { + expect(pipe).toBeDefined(); + }); +}); + +describe('RegisterDto validation', () => { + async function validateDto(plain: object) { + const dto = plainToInstance(RegisterDto, plain); + return validate(dto); + } + + it('passes with valid input', async () => { + const errors = await validateDto({ + email: 'user@example.com', + password: 'Secret1pass', + confirmPassword: 'Secret1pass', + firstName: 'Jane', + lastName: 'Doe', + }); + expect(errors).toHaveLength(0); + }); + + it('fails with invalid email', async () => { + const errors = await validateDto({ + email: 'not-an-email', + password: 'Secret1pass', + confirmPassword: 'Secret1pass', + firstName: 'Jane', + lastName: 'Doe', + }); + expect(errors.some((e) => e.property === 'email')).toBe(true); + }); + + it('fails when password is too short', async () => { + const errors = await validateDto({ + email: 'user@example.com', + password: 'Sh0rt', + confirmPassword: 'Sh0rt', + firstName: 'Jane', + lastName: 'Doe', + }); + expect(errors.some((e) => e.property === 'password')).toBe(true); + }); + + it('fails when passwords do not match', async () => { + const errors = await validateDto({ + email: 'user@example.com', + password: 'Secret1pass', + confirmPassword: 'Different1', + firstName: 'Jane', + lastName: 'Doe', + }); + expect(errors.some((e) => e.property === 'confirmPassword')).toBe(true); + }); + + it('fails when password lacks uppercase letter', async () => { + const errors = await validateDto({ + email: 'user@example.com', + password: 'alllower1', + confirmPassword: 'alllower1', + firstName: 'Jane', + lastName: 'Doe', + }); + expect(errors.some((e) => e.property === 'password')).toBe(true); + }); + + it('fails with missing required fields', async () => { + const errors = await validateDto({}); + expect(errors.length).toBeGreaterThan(0); + }); +}); \ No newline at end of file diff --git a/src/common/pipes/validation.pipe.ts b/src/common/pipes/validation.pipe.ts new file mode 100644 index 00000000..4b3492b3 --- /dev/null +++ b/src/common/pipes/validation.pipe.ts @@ -0,0 +1,44 @@ +import { ValidationPipe, ValidationError, BadRequestException } from '@nestjs/common'; + +export interface ValidationErrorItem { + field: string; + messages: string[]; + children?: ValidationErrorItem[]; +} + +export interface StandardValidationError { + statusCode: number; + error: string; + message: string; + details: ValidationErrorItem[]; +} + +function flattenErrors(errors: ValidationError[], parentField = ''): ValidationErrorItem[] { + return errors.map((error) => { + const field = parentField ? `${parentField}.${error.property}` : error.property; + const messages = Object.values(error.constraints ?? {}); + const children = error.children?.length + ? flattenErrors(error.children, field) + : undefined; + return { field, messages, ...(children ? { children } : {}) }; + }); +} + +export function createValidationPipe(): ValidationPipe { + return new ValidationPipe({ + whitelist: true, // strip unknown properties + forbidNonWhitelisted: true, // 400 on unknown properties + transform: true, // auto-transform payloads to DTO instances + transformOptions: { enableImplicitConversion: true }, + exceptionFactory(errors: ValidationError[]) { + const details = flattenErrors(errors); + const payload: StandardValidationError = { + statusCode: 400, + error: 'Bad Request', + message: 'Input validation failed', + details, + }; + return new BadRequestException(payload); + }, + }); +} \ No newline at end of file diff --git a/src/common/serializers/validation-error.serializer.ts b/src/common/serializers/validation-error.serializer.ts new file mode 100644 index 00000000..0255f949 --- /dev/null +++ b/src/common/serializers/validation-error.serializer.ts @@ -0,0 +1,26 @@ +import { BadRequestException } from '@nestjs/common'; +import type { StandardValidationError } from '../pipes/validation.pipe'; + +export function serializeValidationError(exception: BadRequestException): StandardValidationError { + const response = exception.getResponse(); + + // Already in our standard shape + if ( + typeof response === 'object' && + response !== null && + 'details' in response + ) { + return response as StandardValidationError; + } + + // NestJS default shape: { message: string[] | string, error: string, statusCode: number } + const fallback = response as { message?: string | string[]; error?: string; statusCode?: number }; + const messages = Array.isArray(fallback.message) ? fallback.message : [fallback.message ?? 'Bad Request']; + + return { + statusCode: 400, + error: fallback.error ?? 'Bad Request', + message: 'Input validation failed', + details: messages.map((msg) => ({ field: 'unknown', messages: [msg] })), + }; +} \ No newline at end of file