diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2148a9e..7f0b2d9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Select Xcode 16 + - name: Select Xcode uses: maxim-lobanov/setup-xcode@v1 with: xcode-version: latest-stable @@ -19,12 +19,30 @@ jobs: - name: Show Swift version run: swift --version - - name: Resolve dependencies - run: swift package resolve + # mlx-swift's Metal shaders + the MLX C++ sources are expensive to + # compile; cache the SwiftPM and Xcode build products across runs. + - name: Cache build + uses: actions/cache@v4 + with: + path: | + .build + ~/Library/Caches/org.swift.swiftpm + key: ${{ runner.os }}-build-${{ hashFiles('Package.swift') }} + restore-keys: ${{ runner.os }}-build- - - name: Build - run: swift build -v + # xcodebuild (not `swift build`) compiles the Metal shaders, so this + # verifies the artifact users actually run. -skipMacroValidation is + # required for the MLXHuggingFace macro plugin in non-interactive runs. + - name: Build (xcodebuild — compiles Metal shaders) + run: | + xcodebuild -scheme mlx-server \ + -destination 'platform=macOS' \ + -configuration Debug \ + -derivedDataPath .build/xcode \ + -skipMacroValidation \ + build - # TODO(phase-1): re-enable once a test target exists. - # - name: Test - # run: swift test + # The test suite is model-free (no GPU / no weights), so SwiftPM runs + # it directly and fast. + - name: Test + run: swift test diff --git a/.gitignore b/.gitignore index aed7527..721c320 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ DerivedData/ .idea/ .vscode/ *.swp +*.profraw diff --git a/Package.swift b/Package.swift index b0d72dc..944fba7 100644 --- a/Package.swift +++ b/Package.swift @@ -18,26 +18,45 @@ let package = Package( .package(url: "https://github.com/apple/swift-log.git", from: "1.6.0"), // Metrics API (Prometheus backend wired up in Phase 2). .package(url: "https://github.com/apple/swift-metrics.git", from: "2.5.0"), - - // TODO(phase-1): add mlx-swift-lm once the upstream Package.swift either - // exposes mlx-swift as a remote URL dependency or we adopt a workspace - // / submodule strategy. mlx-swift-lm currently uses .package(path: "../mlx-swift") - // which blocks remote consumption. - // https://github.com/ekryski/mlx-swift-lm/blob/alpha/Package.swift + // MLX inference for Apple Silicon: LLMs/VLMs plus the chat-template + // tool-call parsers. Consumed remotely from the v3.32.1-alpha tag. + .package(url: "https://github.com/ekryski/mlx-swift-lm", exact: "3.32.1-alpha"), + // HuggingFace hub client + tokenizers. Required by the MLXHuggingFace + // macros that generate the model Downloader / TokenizerLoader. + .package(url: "https://github.com/huggingface/swift-transformers", from: "1.3.0"), + .package(url: "https://github.com/huggingface/swift-huggingface", from: "0.9.0"), ], targets: [ + // Thin executable: CLI parsing only. All logic lives in MLXServerKit. .executableTarget( name: "MLXServer", dependencies: [ - .product(name: "Hummingbird", package: "hummingbird"), + "MLXServerKit", .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ), + // Library target: server, routing, inference engine, OpenAI types. + // Separated from the executable so it is unit-testable. + .target( + name: "MLXServerKit", + dependencies: [ + .product(name: "Hummingbird", package: "hummingbird"), .product(name: "Logging", package: "swift-log"), .product(name: "Metrics", package: "swift-metrics"), + .product(name: "MLXLLM", package: "mlx-swift-lm"), + .product(name: "MLXLMCommon", package: "mlx-swift-lm"), + .product(name: "MLXHuggingFace", package: "mlx-swift-lm"), + .product(name: "Tokenizers", package: "swift-transformers"), + .product(name: "HuggingFace", package: "swift-huggingface"), + ] + ), + // swift-testing (ships with the Swift 6 toolchain; needs a full Xcode). + .testTarget( + name: "MLXServerTests", + dependencies: [ + "MLXServerKit", + .product(name: "HummingbirdTesting", package: "hummingbird"), ] ), - // TODO(phase-1): re-add test target once we have real handlers to test. - // Will use swift-testing (built into Swift 6+) when targeting a full - // Xcode toolchain. Command Line Tools-only installs do not ship the - // Testing module. ] ) diff --git a/README.md b/README.md index 947fd9d..e305cfc 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,10 @@ OpenAI-compatible HTTP server for [mlx-swift-lm](https://github.com/ekryski/mlx- ## Status -**Phase 0: scaffolding only.** No model loading, no inference yet. The repository exists so that the design conversation can proceed against real code. See the [roadmap](#roadmap) for what is planned. +**Phase 1 + tool calling: working.** Loads an MLX model and serves +`/v1/chat/completions` (streaming and non-streaming), `/v1/models`, and +`/health`, with OpenAI-compatible tool calling. Validated end-to-end against +Qwen3-4B and Qwen3.6-35B-A3B (MoE). See the [roadmap](#roadmap) for what is next. ## Why this exists @@ -31,25 +34,37 @@ The end goal is to be a drop-in replacement for `llama-server` in [LLMKube](http [TheTom's MLXServer](https://github.com/ekryski/mlx-swift-lm/tree/ek/tom-eric-moe-tuning/Sources/MLXServer) (abandoned in favor of vllm-swift) was the proof-of-concept that an MLX-swift HTTP server is feasible. Several design decisions here, particularly around the slot manager and longest-prefix KV cache, are informed by his approach. The decision to rebuild rather than fork is mainly because his original used hand-rolled socket code; this repo uses [Hummingbird](https://github.com/hummingbird-project/hummingbird) for the HTTP layer. -## Build +## Build and run Requires: - macOS 14 (Sonoma) or later, Apple Silicon - Swift 6.0 or later (Xcode 16+) +`swift build` compiles the project (and is what CI runs), but **SwiftPM cannot +compile mlx-swift's Metal shaders** — a binary built that way fails at runtime +with `Failed to load the default metallib`. To run the server, build with +`xcodebuild`, which compiles and bundles the Metal library next to the binary: + ```bash -swift build -.build/debug/mlx-server --help +xcodebuild -scheme mlx-server -destination 'platform=macOS,arch=arm64' \ + -configuration Debug -derivedDataPath .build/xcode -skipMacroValidation build + +.build/xcode/Build/Products/Debug/mlx-server \ + --model /path/to/mlx-model-dir --port 8080 ``` +`--model` takes a local MLX model directory or a HuggingFace id. Other flags: +`--host`, `--port`, `--max-slots`, `--tool-call-format` (e.g. `xml_function` +for Qwen3.5 / Qwen3-Coder; auto-inferred when unset). + ## Roadmap | Phase | Scope | Status | |-------|-------|--------| -| 0 | Scaffolding, CI, `/health` endpoint, dependency wiring | In progress | -| 1 | `/v1/chat/completions` (streaming + non-streaming), `/v1/models`, single-slot model loading | Pending mlx-swift-lm Tier 1 release tag | +| 0 | Scaffolding, CI, `/health` endpoint, dependency wiring | Done | +| 1 | `/v1/chat/completions` (streaming + non-streaming), `/v1/models`, single-slot model loading | Done | | 2 | Multi-slot `SlotManager`, longest-prefix prompt cache, Prometheus `/metrics`, structured logging, graceful shutdown | | -| 3 | Tool calling, thinking-model support, vision-language models, speculative decoding knobs, `/v1/embeddings` | | +| 3 | Tool calling, thinking-model support, vision-language models, speculative decoding knobs, `/v1/embeddings` | Tool calling done | | 4 | LLMKube `runtime: mlx-server` integration | | ## License diff --git a/Sources/MLXServer/MLXServerCommand.swift b/Sources/MLXServer/MLXServerCommand.swift new file mode 100644 index 0000000..48a0699 --- /dev/null +++ b/Sources/MLXServer/MLXServerCommand.swift @@ -0,0 +1,41 @@ +import ArgumentParser +import MLXServerKit + +@main +struct MLXServerCommand: AsyncParsableCommand { + static let configuration = CommandConfiguration( + commandName: "mlx-server", + abstract: "OpenAI-compatible HTTP server for mlx-swift-lm on Apple Silicon." + ) + + @Option(name: .long, help: "Model identifier (HuggingFace ID or local directory path).") + var model: String? + + @Option(name: .long, help: "Bind address.") + var host: String = "127.0.0.1" + + @Option(name: .long, help: "Bind port.") + var port: Int = 8080 + + @Option(name: .long, help: "Maximum concurrent inference slots.") + var maxSlots: Int = 4 + + @Option(name: .long, help: "Tool-call format override (e.g. xml_function, json). Auto-inferred when unset.") + var toolCallFormat: String? + + func run() async throws { + guard let model else { + throw ValidationError("--model is required (HuggingFace ID or local directory path).") + } + + let config = ServerConfig( + model: model, + host: host, + port: port, + maxSlots: maxSlots, + toolCallFormat: toolCallFormat + ) + + try await MLXServerKit.run(config: config) + } +} diff --git a/Sources/MLXServer/main.swift b/Sources/MLXServer/main.swift deleted file mode 100644 index 6889eab..0000000 --- a/Sources/MLXServer/main.swift +++ /dev/null @@ -1,52 +0,0 @@ -import ArgumentParser -import Foundation -import Hummingbird -import Logging - -@main -struct MLXServerCommand: AsyncParsableCommand { - static let configuration = CommandConfiguration( - commandName: "mlx-server", - abstract: "OpenAI-compatible HTTP server for mlx-swift-lm on Apple Silicon." - ) - - @Option(name: .long, help: "Model identifier (HuggingFace ID or local path). Not yet wired up.") - var model: String? - - @Option(name: .long, help: "Bind address.") - var host: String = "127.0.0.1" - - @Option(name: .long, help: "Bind port.") - var port: Int = 8080 - - @Option(name: .long, help: "Maximum concurrent inference slots.") - var maxSlots: Int = 4 - - func run() async throws { - let logger = Logger(label: "mlx-server") - logger.info("mlx-server starting", metadata: [ - "host": .string(host), - "port": .stringConvertible(port), - "max_slots": .stringConvertible(maxSlots), - "model": .string(model ?? ""), - ]) - - let router = Router() - - // Phase 0: smoke-test endpoint. Real OpenAI handlers land in Phase 1. - router.get("/health") { _, _ -> String in - "ok" - } - - let app = Application( - router: router, - configuration: .init( - address: .hostname(host, port: port), - serverName: "mlx-server" - ), - logger: logger - ) - - try await app.runService() - } -} diff --git a/Sources/MLXServerKit/ChatCompletion.swift b/Sources/MLXServerKit/ChatCompletion.swift new file mode 100644 index 0000000..df3886d --- /dev/null +++ b/Sources/MLXServerKit/ChatCompletion.swift @@ -0,0 +1,146 @@ +import Foundation +import MLXLMCommon + +extension InferenceEngine { + /// Run a non-streaming chat completion and collect the full result. + func complete(_ request: ChatCompletionRequest) async throws -> ChatCompletionResponse { + let input = try await prepareInput(for: request) + let parameters = ChatMapping.resolveGenerateParameters(request) + + var text = "" + var toolCalls: [ToolCallObject] = [] + var info: GenerateCompletionInfo? + + do { + let stream = try await container.generate(input: input, parameters: parameters) + for await generation in stream { + switch generation { + case .chunk(let chunk): + text += chunk + case .toolCall(let call): + toolCalls.append(Self.toolCallObject(call, index: toolCalls.count)) + case .info(let completionInfo): + info = completionInfo + } + } + } catch { + throw ServerError.inferenceFailed(String(describing: error)) + } + + let hasToolCalls = !toolCalls.isEmpty + let message = ChatCompletionResponse.ResponseMessage( + content: hasToolCalls ? (text.isEmpty ? nil : text) : text, + toolCalls: hasToolCalls ? toolCalls : nil + ) + return ChatCompletionResponse( + id: Self.completionID(), + created: Self.unixNow(), + model: modelID, + choices: [ + .init( + index: 0, + message: message, + finishReason: hasToolCalls ? "tool_calls" : "stop" + ) + ], + usage: Usage( + promptTokens: info?.promptTokenCount ?? 0, + completionTokens: info?.generationTokenCount ?? 0 + ) + ) + } + + /// Build the prepared model input (chat template applied, tools injected) + /// from an OpenAI request. + func prepareInput(for request: ChatCompletionRequest) async throws -> sending LMInput { + let messages = ChatMapping.toChatMessages(request.messages) + let tools = ChatMapping.toToolSpecs(request.tools) + do { + return try await container.prepare(input: UserInput(chat: messages, tools: tools)) + } catch { + throw ServerError.inferenceFailed(String(describing: error)) + } + } + + /// Translate an mlx-swift-lm `ToolCall` into the OpenAI `tool_calls` shape. + static func toolCallObject(_ call: MLXLMCommon.ToolCall, index: Int) -> ToolCallObject { + ToolCallObject( + id: "call_" + UUID().uuidString, + function: FunctionCall( + name: call.function.name, + arguments: ChatMapping.argumentsJSONString(call.function.arguments) + ), + index: index + ) + } + + static func completionID() -> String { + "chatcmpl-" + UUID().uuidString + } + + static func unixNow() -> Int { + Int(Date().timeIntervalSince1970) + } +} + +/// Semantic events emitted by a streaming completion, before OpenAI framing. +enum StreamEvent: Sendable { + case textDelta(String) + case toolCall(ToolCallObject) + case finished(reason: String, usage: Usage) +} + +extension InferenceEngine { + /// Run a streaming chat completion, yielding ``StreamEvent``s as tokens + /// and tool calls arrive. `nonisolated` so the stream object is returned + /// synchronously; the backing `Task` hops onto the actor to generate. + nonisolated func stream( + _ request: ChatCompletionRequest + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = Task { + do { + try await self.generateStream(request, into: continuation) + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in task.cancel() } + } + } + + private func generateStream( + _ request: ChatCompletionRequest, + into continuation: AsyncThrowingStream.Continuation + ) async throws { + let input = try await prepareInput(for: request) + let parameters = ChatMapping.resolveGenerateParameters(request) + + var toolCallCount = 0 + var info: GenerateCompletionInfo? + do { + let generations = try await container.generate(input: input, parameters: parameters) + for await generation in generations { + switch generation { + case .chunk(let chunk): + continuation.yield(.textDelta(chunk)) + case .toolCall(let call): + continuation.yield(.toolCall(Self.toolCallObject(call, index: toolCallCount))) + toolCallCount += 1 + case .info(let completionInfo): + info = completionInfo + } + } + } catch { + throw ServerError.inferenceFailed(String(describing: error)) + } + + continuation.yield( + .finished( + reason: toolCallCount > 0 ? "tool_calls" : "stop", + usage: Usage( + promptTokens: info?.promptTokenCount ?? 0, + completionTokens: info?.generationTokenCount ?? 0))) + } +} diff --git a/Sources/MLXServerKit/ChatCompletionsHandler.swift b/Sources/MLXServerKit/ChatCompletionsHandler.swift new file mode 100644 index 0000000..e3c5e02 --- /dev/null +++ b/Sources/MLXServerKit/ChatCompletionsHandler.swift @@ -0,0 +1,138 @@ +import Foundation +import HTTPTypes +import Hummingbird + +/// Handles `POST /v1/chat/completions` — streaming and non-streaming. +enum ChatCompletionsHandler { + static func handle( + request: Request, + context: BasicRequestContext, + engine: some Inferencing + ) async throws -> Response { + let completionRequest: ChatCompletionRequest + do { + completionRequest = try await request.decode( + as: ChatCompletionRequest.self, context: context) + } catch { + return errorResponse( + .badRequest("Malformed request body: \(error)"), status: .badRequest) + } + + guard !completionRequest.messages.isEmpty else { + return errorResponse(.badRequest("`messages` must not be empty"), status: .badRequest) + } + if let n = completionRequest.n, n != 1 { + return errorResponse(.unsupportedParameter("`n` must be 1"), status: .badRequest) + } + + if completionRequest.stream == true { + return streamingResponse(completionRequest, engine: engine) + } + + do { + let completion = try await engine.complete(completionRequest) + return try jsonResponse(completion) + } catch let error as ServerError { + return errorResponse(error, status: .internalServerError) + } catch { + return errorResponse( + .inferenceFailed(String(describing: error)), status: .internalServerError) + } + } + + // MARK: Streaming + + /// Build a `text/event-stream` response that relays the engine's events + /// as OpenAI `chat.completion.chunk` SSE frames. + private static func streamingResponse( + _ request: ChatCompletionRequest, + engine: some Inferencing + ) -> Response { + let id = InferenceEngine.completionID() + let created = InferenceEngine.unixNow() + let model = request.model + let events = engine.stream(request) + + let body = ResponseBody { writer in + var roleSent = false + do { + for try await event in events { + switch event { + case .textDelta(let text): + let delta = ChatCompletionChunk.Delta( + role: roleSent ? nil : "assistant", content: text, toolCalls: nil) + roleSent = true + try await writer.write( + SSE.event(chunk(id, created, model, delta, finishReason: nil))) + case .toolCall(let call): + let delta = ChatCompletionChunk.Delta( + role: roleSent ? nil : "assistant", content: nil, toolCalls: [call]) + roleSent = true + try await writer.write( + SSE.event(chunk(id, created, model, delta, finishReason: nil))) + case .finished(let reason, _): + let delta = ChatCompletionChunk.Delta( + role: nil, content: nil, toolCalls: nil) + try await writer.write( + SSE.event(chunk(id, created, model, delta, finishReason: reason))) + } + } + try await writer.write(SSE.done()) + } catch { + // Headers are already sent; surface the failure as a final + // SSE frame on a best-effort basis. + if let buffer = try? SSE.event( + OpenAIErrorResponse( + message: String(describing: error), type: "server_error")) + { + try? await writer.write(buffer) + } + } + try await writer.finish(nil) + } + + return Response( + status: .ok, + headers: [.contentType: "text/event-stream", .cacheControl: "no-cache"], + body: body) + } + + private static func chunk( + _ id: String, + _ created: Int, + _ model: String, + _ delta: ChatCompletionChunk.Delta, + finishReason: String? + ) -> ChatCompletionChunk { + ChatCompletionChunk( + id: id, + created: created, + model: model, + choices: [.init(index: 0, delta: delta, finishReason: finishReason)]) + } + + // MARK: Response helpers + + /// Encode an `Encodable` value as an `application/json` response. + static func jsonResponse( + _ value: some Encodable, + status: HTTPResponse.Status = .ok + ) throws -> Response { + var buffer = ByteBuffer() + buffer.writeBytes(try JSONEncoder().encode(value)) + return Response( + status: status, + headers: [.contentType: "application/json"], + body: .init(byteBuffer: buffer)) + } + + /// Build an OpenAI-style error response. + static func errorResponse(_ error: ServerError, status: HTTPResponse.Status) -> Response { + var buffer = ByteBuffer() + buffer.writeBytes((try? JSONEncoder().encode(OpenAIErrorResponse(error))) ?? Data()) + return Response( + status: status, + headers: [.contentType: "application/json"], + body: .init(byteBuffer: buffer)) + } +} diff --git a/Sources/MLXServerKit/ChatMapping.swift b/Sources/MLXServerKit/ChatMapping.swift new file mode 100644 index 0000000..1944c0e --- /dev/null +++ b/Sources/MLXServerKit/ChatMapping.swift @@ -0,0 +1,93 @@ +import Foundation +import MLXLMCommon + +/// Pure conversions between OpenAI wire types and mlx-swift-lm types. +/// No I/O — this is the most heavily unit-tested part of the kit. +enum ChatMapping { + static let defaultMaxTokens = 2048 + static let maxTokensCeiling = 32768 + + /// Map OpenAI messages onto mlx-swift-lm `Chat.Message`s. + /// + /// `Chat.Message` carries only role + text content. Tool-call metadata on + /// assistant messages and `tool_call_id` on tool results are not preserved + /// here; multi-turn tool-call history fidelity is refined alongside the + /// tool-calling step. + static func toChatMessages(_ messages: [ChatMessage]) -> [Chat.Message] { + messages.map { message in + let role = Chat.Message.Role(rawValue: message.role) ?? .user + return Chat.Message(role: role, content: message.content?.text ?? "") + } + } + + /// Convert OpenAI tool definitions into mlx-swift-lm `ToolSpec` dictionaries + /// (the raw schema the model's chat template renders). Returns `nil` for an + /// absent or empty tool set so the template advertises no tools. + static func toToolSpecs(_ tools: [ToolDefinition]?) -> [ToolSpec]? { + guard let tools, !tools.isEmpty else { return nil } + return tools.map { tool in + var function: [String: any Sendable] = ["name": tool.function.name] + if let description = tool.function.description { + function["description"] = description + } + if let parameters = tool.function.parameters { + function["parameters"] = jsonValueToSendable(parameters) + } + return ["type": tool.type, "function": function] + } + } + + /// Map OpenAI sampling parameters onto `GenerateParameters`, applying + /// defaults and clamping `max_tokens` to a sane ceiling. + static func resolveGenerateParameters(_ request: ChatCompletionRequest) -> GenerateParameters { + var parameters = GenerateParameters() + if let temperature = request.temperature { + parameters.temperature = Float(temperature) + } + if let topP = request.topP { + parameters.topP = Float(topP) + } + let requested = request.maxTokens ?? defaultMaxTokens + parameters.maxTokens = min(max(requested, 1), maxTokensCeiling) + return parameters + } + + /// Serialize mlx-swift-lm tool-call arguments into the JSON *string* that + /// OpenAI's `tool_calls[].function.arguments` field expects. + static func argumentsJSONString(_ arguments: [String: JSONValue]) -> String { + let object = arguments.mapValues(jsonValueToFoundation) + guard JSONSerialization.isValidJSONObject(object), + let data = try? JSONSerialization.data(withJSONObject: object), + let string = String(data: data, encoding: .utf8) + else { return "{}" } + return string + } + + /// Recursively project a `JSONValue` into `Sendable` Swift values suitable + /// as `ToolSpec` payloads. + static func jsonValueToSendable(_ value: JSONValue) -> any Sendable { + switch value { + case .null: return String?.none as any Sendable + case .bool(let bool): return bool + case .int(let int): return int + case .double(let double): return double + case .string(let string): return string + case .array(let array): return array.map(jsonValueToSendable) + case .object(let object): return object.mapValues(jsonValueToSendable) + } + } + + /// Recursively project a `JSONValue` into Foundation values for + /// `JSONSerialization`. + static func jsonValueToFoundation(_ value: JSONValue) -> Any { + switch value { + case .null: return NSNull() + case .bool(let bool): return bool + case .int(let int): return int + case .double(let double): return double + case .string(let string): return string + case .array(let array): return array.map(jsonValueToFoundation) + case .object(let object): return object.mapValues(jsonValueToFoundation) + } + } +} diff --git a/Sources/MLXServerKit/Errors.swift b/Sources/MLXServerKit/Errors.swift new file mode 100644 index 0000000..ea1ea1d --- /dev/null +++ b/Sources/MLXServerKit/Errors.swift @@ -0,0 +1,52 @@ +import Foundation + +/// Errors surfaced by the server, each carrying an OpenAI-style error `type`. +public enum ServerError: Error, Sendable { + /// Model weights or tokenizer failed to load at startup. + case modelLoadFailed(String) + /// Inference failed while generating a completion. + case inferenceFailed(String) + /// The request was malformed or semantically invalid. + case badRequest(String) + /// A request parameter is recognized but not yet supported. + case unsupportedParameter(String) + + /// Human-readable message for the error body. + public var message: String { + switch self { + case .modelLoadFailed(let detail): "Model load failed: \(detail)" + case .inferenceFailed(let detail): "Inference failed: \(detail)" + case .badRequest(let detail): detail + case .unsupportedParameter(let detail): "Unsupported parameter: \(detail)" + } + } + + /// OpenAI error `type` discriminator. + public var type: String { + switch self { + case .modelLoadFailed, .inferenceFailed: "server_error" + case .badRequest: "invalid_request_error" + case .unsupportedParameter: "invalid_request_error" + } + } +} + +/// OpenAI error envelope: `{ "error": { message, type, param?, code? } }`. +public struct OpenAIErrorResponse: Encodable, Sendable { + public struct Body: Encodable, Sendable { + public var message: String + public var type: String + public var param: String? + public var code: String? + } + + public var error: Body + + public init(message: String, type: String, param: String? = nil, code: String? = nil) { + self.error = Body(message: message, type: type, param: param, code: code) + } + + public init(_ error: ServerError) { + self.init(message: error.message, type: error.type) + } +} diff --git a/Sources/MLXServerKit/HTTP.swift b/Sources/MLXServerKit/HTTP.swift new file mode 100644 index 0000000..1bcbad7 --- /dev/null +++ b/Sources/MLXServerKit/HTTP.swift @@ -0,0 +1,6 @@ +import Hummingbird + +// Hummingbird response conformances for the OpenAI JSON types. Kept separate +// so OpenAITypes.swift stays free of the HTTP framework. +extension ModelListResponse: ResponseEncodable {} +extension ChatCompletionResponse: ResponseEncodable {} diff --git a/Sources/MLXServerKit/InferenceEngine.swift b/Sources/MLXServerKit/InferenceEngine.swift new file mode 100644 index 0000000..37a2077 --- /dev/null +++ b/Sources/MLXServerKit/InferenceEngine.swift @@ -0,0 +1,101 @@ +import Foundation +import HuggingFace +import Logging +import MLXHuggingFace +import MLXLLM +import MLXLMCommon +import Tokenizers + +/// Owns the loaded model and serializes inference through actor isolation. +/// +/// Phase 1 serves a single slot: actor isolation guarantees exactly one +/// generation runs at a time. The Phase 2 multi-slot pool replaces the single +/// `container` here without changing the public method shapes. +public actor InferenceEngine { + let container: ModelContainer + /// Model id reported by `/v1/models` and accepted in request `model` fields. + public let modelID: String + /// Resolved tool-call format, or `nil` to let the model config decide. + let toolCallFormat: ToolCallFormat? + let logger: Logger + + private init( + container: ModelContainer, + modelID: String, + toolCallFormat: ToolCallFormat?, + logger: Logger + ) { + self.container = container + self.modelID = modelID + self.toolCallFormat = toolCallFormat + self.logger = logger + } + + /// Load the model named by `config.model` — a local directory path or a + /// HuggingFace id. Blocks until weights are resident; throws + /// ``ServerError/modelLoadFailed(_:)`` so the process never comes up + /// half-ready. + public static func load(config: ServerConfig, logger: Logger) async throws -> InferenceEngine { + let format = resolveToolCallFormat(config.toolCallFormat, logger: logger) + let configuration = modelConfiguration(for: config.model, toolCallFormat: format) + let fromDirectory = directoryExists(config.model) + + logger.info( + "loading model", + metadata: [ + "model": .string(config.model), + "source": .string(fromDirectory ? "local-directory" : "huggingface-id"), + "tool_call_format": .string(format?.rawValue ?? "auto"), + ]) + + let started = ContinuousClock.now + let container: ModelContainer + do { + container = try await #huggingFaceLoadModelContainer(configuration: configuration) + } catch { + throw ServerError.modelLoadFailed(String(describing: error)) + } + logger.info( + "model loaded", + metadata: ["elapsed": .stringConvertible(started.duration(to: .now))]) + + return InferenceEngine( + container: container, + modelID: config.model, + toolCallFormat: format, + logger: logger) + } + + /// Build a ``ModelConfiguration`` from a local directory or a hub id. + private static func modelConfiguration( + for model: String, + toolCallFormat: ToolCallFormat? + ) -> ModelConfiguration { + if directoryExists(model) { + return ModelConfiguration( + directory: URL(fileURLWithPath: model, isDirectory: true), + toolCallFormat: toolCallFormat) + } + return ModelConfiguration(id: model, toolCallFormat: toolCallFormat) + } + + private static func directoryExists(_ path: String) -> Bool { + var isDirectory: ObjCBool = false + let exists = FileManager.default.fileExists(atPath: path, isDirectory: &isDirectory) + return exists && isDirectory.boolValue + } + + private static func resolveToolCallFormat( + _ raw: String?, + logger: Logger + ) -> ToolCallFormat? { + guard let raw else { return nil } + guard let format = ToolCallFormat(rawValue: raw) else { + logger.warning( + "unknown --tool-call-format; inferring from the model instead", + metadata: ["value": .string(raw)]) + return nil + } + return format + } +} diff --git a/Sources/MLXServerKit/Inferencing.swift b/Sources/MLXServerKit/Inferencing.swift new file mode 100644 index 0000000..369aba9 --- /dev/null +++ b/Sources/MLXServerKit/Inferencing.swift @@ -0,0 +1,14 @@ +/// The inference surface the HTTP layer depends on. +/// +/// `InferenceEngine` is the production implementation; tests substitute a +/// lightweight stub so routes can be exercised without loading a model. +protocol Inferencing: Sendable { + /// Model id reported by `/v1/models`. + var modelID: String { get } + /// Run a non-streaming completion. + func complete(_ request: ChatCompletionRequest) async throws -> ChatCompletionResponse + /// Run a streaming completion, yielding events as they arrive. + func stream(_ request: ChatCompletionRequest) -> AsyncThrowingStream +} + +extension InferenceEngine: Inferencing {} diff --git a/Sources/MLXServerKit/OpenAITypes.swift b/Sources/MLXServerKit/OpenAITypes.swift new file mode 100644 index 0000000..8f7be2c --- /dev/null +++ b/Sources/MLXServerKit/OpenAITypes.swift @@ -0,0 +1,291 @@ +import Foundation +import MLXLMCommon + +// OpenAI-compatible wire types for /v1/chat/completions and /v1/models. +// `JSONValue` (from MLXLMCommon) is reused as the arbitrary-JSON carrier for +// tool-call arguments and JSON-schema `parameters`. + +// MARK: - Request + +/// `POST /v1/chat/completions` request body. +public struct ChatCompletionRequest: Decodable, Sendable { + public var model: String + public var messages: [ChatMessage] + public var temperature: Double? + public var topP: Double? + public var maxTokens: Int? + public var stream: Bool? + public var tools: [ToolDefinition]? + public var toolChoice: ToolChoice? + public var n: Int? + + enum CodingKeys: String, CodingKey { + case model, messages, temperature, stream, tools, n + case topP = "top_p" + case maxTokens = "max_tokens" + case toolChoice = "tool_choice" + } +} + +/// A single message in the OpenAI `messages` array. +public struct ChatMessage: Codable, Sendable { + public var role: String + public var content: MessageContent? + public var name: String? + public var toolCalls: [ToolCallObject]? + public var toolCallId: String? + + public init( + role: String, + content: MessageContent? = nil, + name: String? = nil, + toolCalls: [ToolCallObject]? = nil, + toolCallId: String? = nil + ) { + self.role = role + self.content = content + self.name = name + self.toolCalls = toolCalls + self.toolCallId = toolCallId + } + + enum CodingKeys: String, CodingKey { + case role, content, name + case toolCalls = "tool_calls" + case toolCallId = "tool_call_id" + } +} + +/// OpenAI message `content`: a string, JSON null, or an array of content parts +/// (multimodal). Phase 1 flattens part arrays to their concatenated text. +public enum MessageContent: Codable, Sendable { + case text(String) + case null + + public var text: String? { + if case .text(let value) = self { return value } + return nil + } + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if container.decodeNil() { + self = .null + } else if let string = try? container.decode(String.self) { + self = .text(string) + } else if let parts = try? container.decode([ContentPart].self) { + self = .text(parts.compactMap(\.text).joined()) + } else { + self = .null + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .text(let value): try container.encode(value) + case .null: try container.encodeNil() + } + } +} + +/// One element of a multimodal `content` array. Only `text` parts are used. +public struct ContentPart: Codable, Sendable { + public var type: String? + public var text: String? +} + +/// An OpenAI tool definition (`tools[]` in the request). +public struct ToolDefinition: Codable, Sendable { + public var type: String + public var function: FunctionSchema + + public init(type: String = "function", function: FunctionSchema) { + self.type = type + self.function = function + } +} + +/// The `function` schema inside a ``ToolDefinition``. +public struct FunctionSchema: Codable, Sendable { + public var name: String + public var description: String? + public var parameters: JSONValue? + + public init(name: String, description: String? = nil, parameters: JSONValue? = nil) { + self.name = name + self.description = description + self.parameters = parameters + } +} + +/// OpenAI `tool_choice`: `"none"` / `"auto"` / `"required"` or a forced function. +public enum ToolChoice: Codable, Sendable { + case none + case auto + case required + case function(name: String) + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let string = try? container.decode(String.self) { + switch string { + case "none": self = .none + case "required": self = .required + default: self = .auto + } + return + } + struct Forced: Decodable { + struct Function: Decodable { let name: String } + let function: Function + } + let forced = try container.decode(Forced.self) + self = .function(name: forced.function.name) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .none: try container.encode("none") + case .auto: try container.encode("auto") + case .required: try container.encode("required") + case .function(let name): + struct Forced: Encodable { + struct Function: Encodable { let name: String } + let type = "function" + let function: Function + } + try container.encode(Forced(function: .init(name: name))) + } + } +} + +/// An OpenAI tool call (`tool_calls[]`) on a request or response message. +public struct ToolCallObject: Codable, Sendable { + public var id: String + public var type: String + public var function: FunctionCall + public var index: Int? + + public init(id: String, type: String = "function", function: FunctionCall, index: Int? = nil) { + self.id = id + self.type = type + self.function = function + self.index = index + } +} + +/// The `function` of a ``ToolCallObject``. `arguments` is a JSON *string*. +public struct FunctionCall: Codable, Sendable { + public var name: String + public var arguments: String + + public init(name: String, arguments: String) { + self.name = name + self.arguments = arguments + } +} + +// MARK: - Response + +/// `chat.completion` response (non-streaming). +public struct ChatCompletionResponse: Encodable, Sendable { + public var id: String + public var object: String = "chat.completion" + public var created: Int + public var model: String + public var choices: [Choice] + public var usage: Usage + + public struct Choice: Encodable, Sendable { + public var index: Int + public var message: ResponseMessage + public var finishReason: String + + enum CodingKeys: String, CodingKey { + case index, message + case finishReason = "finish_reason" + } + } + + public struct ResponseMessage: Encodable, Sendable { + public var role: String = "assistant" + public var content: String? + public var toolCalls: [ToolCallObject]? + + enum CodingKeys: String, CodingKey { + case role, content + case toolCalls = "tool_calls" + } + } +} + +/// Token accounting for a completion. +public struct Usage: Encodable, Sendable { + public var promptTokens: Int + public var completionTokens: Int + public var totalTokens: Int + + public init(promptTokens: Int, completionTokens: Int) { + self.promptTokens = promptTokens + self.completionTokens = completionTokens + self.totalTokens = promptTokens + completionTokens + } + + enum CodingKeys: String, CodingKey { + case promptTokens = "prompt_tokens" + case completionTokens = "completion_tokens" + case totalTokens = "total_tokens" + } +} + +/// `chat.completion.chunk` — one SSE event in a streamed completion. +public struct ChatCompletionChunk: Encodable, Sendable { + public var id: String + public var object: String = "chat.completion.chunk" + public var created: Int + public var model: String + public var choices: [ChunkChoice] + + public struct ChunkChoice: Encodable, Sendable { + public var index: Int + public var delta: Delta + public var finishReason: String? + + enum CodingKeys: String, CodingKey { + case index, delta + case finishReason = "finish_reason" + } + } + + public struct Delta: Encodable, Sendable { + public var role: String? + public var content: String? + public var toolCalls: [ToolCallObject]? + + enum CodingKeys: String, CodingKey { + case role, content + case toolCalls = "tool_calls" + } + } +} + +/// `GET /v1/models` response. +public struct ModelListResponse: Encodable, Sendable { + public var object: String = "list" + public var data: [ModelObject] +} + +/// One entry in a ``ModelListResponse``. +public struct ModelObject: Encodable, Sendable { + public var id: String + public var object: String = "model" + public var created: Int + public var ownedBy: String = "mlx-server" + + enum CodingKeys: String, CodingKey { + case id, object, created + case ownedBy = "owned_by" + } +} diff --git a/Sources/MLXServerKit/Routes.swift b/Sources/MLXServerKit/Routes.swift new file mode 100644 index 0000000..10f3e31 --- /dev/null +++ b/Sources/MLXServerKit/Routes.swift @@ -0,0 +1,22 @@ +import Foundation +import Hummingbird + +/// Register all HTTP routes on the router. +func registerRoutes(router: Router, engine: some Inferencing) { + // Liveness smoke endpoint. + router.get("/health") { _, _ -> String in + "ok" + } + + // OpenAI model list — reports the single loaded model. + router.get("/v1/models") { _, _ -> ModelListResponse in + ModelListResponse(data: [ + ModelObject(id: engine.modelID, created: Int(Date().timeIntervalSince1970)) + ]) + } + + // OpenAI chat completions. + router.post("/v1/chat/completions") { request, context -> Response in + try await ChatCompletionsHandler.handle(request: request, context: context, engine: engine) + } +} diff --git a/Sources/MLXServerKit/SSE.swift b/Sources/MLXServerKit/SSE.swift new file mode 100644 index 0000000..dccf3ef --- /dev/null +++ b/Sources/MLXServerKit/SSE.swift @@ -0,0 +1,28 @@ +import Foundation +import Hummingbird + +/// Server-Sent Events framing for streamed chat completions. +/// +/// Each event is `data: \n\n`; the stream terminates with +/// `data: [DONE]\n\n` per the OpenAI streaming convention. +enum SSE { + /// The terminating `[DONE]` line. + static let doneLine = "data: [DONE]\n\n" + + /// Wrap an already-encoded JSON payload in a `data:` frame. + static func dataLine(_ json: String) -> String { + "data: \(json)\n\n" + } + + /// Encode an event payload as a `data:` frame buffer. + static func event(_ value: some Encodable) throws -> ByteBuffer { + let data = try JSONEncoder().encode(value) + let json = String(decoding: data, as: UTF8.self) + return ByteBuffer(string: dataLine(json)) + } + + /// The terminating `[DONE]` frame buffer. + static func done() -> ByteBuffer { + ByteBuffer(string: doneLine) + } +} diff --git a/Sources/MLXServerKit/Server.swift b/Sources/MLXServerKit/Server.swift new file mode 100644 index 0000000..6d1cce4 --- /dev/null +++ b/Sources/MLXServerKit/Server.swift @@ -0,0 +1,34 @@ +import Hummingbird +import Logging + +/// Builds and runs the mlx-server HTTP service. +/// +/// The inference engine is loaded *before* the socket binds, so the process +/// never accepts traffic in a half-ready state. +public func run(config: ServerConfig) async throws { + let logger = Logger(label: "mlx-server") + logger.info( + "mlx-server starting", + metadata: [ + "host": .string(config.host), + "port": .stringConvertible(config.port), + "max_slots": .stringConvertible(config.maxSlots), + "model": .string(config.model), + ]) + + let engine = try await InferenceEngine.load(config: config, logger: logger) + + let router = Router() + registerRoutes(router: router, engine: engine) + + let app = Application( + router: router, + configuration: .init( + address: .hostname(config.host, port: config.port), + serverName: "mlx-server" + ), + logger: logger + ) + + try await app.runService() +} diff --git a/Sources/MLXServerKit/ServerConfig.swift b/Sources/MLXServerKit/ServerConfig.swift new file mode 100644 index 0000000..c6aea25 --- /dev/null +++ b/Sources/MLXServerKit/ServerConfig.swift @@ -0,0 +1,32 @@ +import Foundation + +/// Immutable server configuration assembled from CLI options by the +/// `MLXServer` executable and handed to ``run(config:)``. +public struct ServerConfig: Sendable { + /// HuggingFace model id or a local model directory path. + public var model: String + /// Bind address. + public var host: String + /// Bind port. + public var port: Int + /// Maximum concurrent inference slots. Phase 1 serves single-slot; + /// this is carried through for the Phase 2 slot pool. + public var maxSlots: Int + /// Optional tool-call format override (e.g. `xml_function`, `json`). + /// When `nil` the format is inferred from the model's `config.json`. + public var toolCallFormat: String? + + public init( + model: String, + host: String, + port: Int, + maxSlots: Int, + toolCallFormat: String? = nil + ) { + self.model = model + self.host = host + self.port = port + self.maxSlots = maxSlots + self.toolCallFormat = toolCallFormat + } +} diff --git a/Tests/MLXServerTests/ChatMappingTests.swift b/Tests/MLXServerTests/ChatMappingTests.swift new file mode 100644 index 0000000..d427db2 --- /dev/null +++ b/Tests/MLXServerTests/ChatMappingTests.swift @@ -0,0 +1,85 @@ +import Foundation +import MLXLMCommon +import Testing + +@testable import MLXServerKit + +@Suite("ChatMapping") +struct ChatMappingTests { + @Test("maps OpenAI roles to Chat.Message roles") + func toChatMessages() { + let mapped = ChatMapping.toChatMessages([ + ChatMessage(role: "system", content: .text("sys")), + ChatMessage(role: "user", content: .text("hello")), + ChatMessage(role: "assistant", content: .text("hi")), + ]) + + #expect(mapped.count == 3) + #expect(mapped[0].role == .system) + #expect(mapped[1].role == .user) + #expect(mapped[1].content == "hello") + #expect(mapped[2].role == .assistant) + } + + @Test("unknown role falls back to user") + func unknownRole() { + let mapped = ChatMapping.toChatMessages([ChatMessage(role: "weird", content: .text("x"))]) + #expect(mapped[0].role == .user) + } + + @Test("converts tool definitions to ToolSpec dictionaries") + func toToolSpecs() { + let tools = [ + ToolDefinition( + function: FunctionSchema( + name: "get_weather", + description: "look up weather", + parameters: .object(["type": .string("object")]) + ) + ) + ] + let specs = ChatMapping.toToolSpecs(tools) + + #expect(specs?.count == 1) + #expect(specs?.first?["type"] as? String == "function") + let function = specs?.first?["function"] as? [String: any Sendable] + #expect(function?["name"] as? String == "get_weather") + #expect(function?["description"] as? String == "look up weather") + } + + @Test("absent or empty tools yield nil") + func emptyTools() { + #expect(ChatMapping.toToolSpecs(nil) == nil) + #expect(ChatMapping.toToolSpecs([]) == nil) + } + + @Test("resolveGenerateParameters applies values and clamps max_tokens") + func generateParameters() { + let applied = ChatMapping.resolveGenerateParameters( + ChatCompletionRequest(model: "m", messages: [], temperature: 0.3, topP: 0.8, maxTokens: 100) + ) + #expect(applied.temperature == Float(0.3)) + #expect(applied.topP == Float(0.8)) + #expect(applied.maxTokens == 100) + + let huge = ChatMapping.resolveGenerateParameters( + ChatCompletionRequest(model: "m", messages: [], maxTokens: 9_000_000) + ) + #expect(huge.maxTokens == ChatMapping.maxTokensCeiling) + + let defaulted = ChatMapping.resolveGenerateParameters( + ChatCompletionRequest(model: "m", messages: []) + ) + #expect(defaulted.maxTokens == ChatMapping.defaultMaxTokens) + } + + @Test("argumentsJSONString produces valid JSON") + func argumentsJSON() throws { + let arguments: [String: JSONValue] = ["city": .string("Paris"), "days": .int(3)] + let string = ChatMapping.argumentsJSONString(arguments) + + let parsed = try JSONSerialization.jsonObject(with: Data(string.utf8)) as? [String: Any] + #expect(parsed?["city"] as? String == "Paris") + #expect(parsed?["days"] as? Int == 3) + } +} diff --git a/Tests/MLXServerTests/OpenAITypesTests.swift b/Tests/MLXServerTests/OpenAITypesTests.swift new file mode 100644 index 0000000..3942af9 --- /dev/null +++ b/Tests/MLXServerTests/OpenAITypesTests.swift @@ -0,0 +1,111 @@ +import Foundation +import Testing + +@testable import MLXServerKit + +@Suite("OpenAI types") +struct OpenAITypesTests { + @Test("decodes a chat completion request with tools and streaming") + func decodeRequest() throws { + let json = """ + { + "model": "qwen", + "messages": [ + {"role": "system", "content": "be brief"}, + {"role": "user", "content": "hi"} + ], + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 256, + "stream": true, + "tool_choice": "auto", + "tools": [ + {"type": "function", "function": {"name": "get_weather", "description": "w", "parameters": {"type": "object"}}} + ] + } + """ + let request = try JSONDecoder().decode(ChatCompletionRequest.self, from: Data(json.utf8)) + + #expect(request.model == "qwen") + #expect(request.messages.count == 2) + #expect(request.messages[1].content?.text == "hi") + #expect(request.temperature == 0.7) + #expect(request.topP == 0.9) + #expect(request.maxTokens == 256) + #expect(request.stream == true) + #expect(request.tools?.count == 1) + #expect(request.tools?.first?.function.name == "get_weather") + guard case .auto = request.toolChoice else { + Issue.record("expected tool_choice .auto") + return + } + } + + @Test("decodes tool_choice in forced-function object form") + func decodeForcedToolChoice() throws { + let json = """ + {"model":"m","messages":[],"tool_choice":{"type":"function","function":{"name":"f"}}} + """ + let request = try JSONDecoder().decode(ChatCompletionRequest.self, from: Data(json.utf8)) + guard case .function(let name) = request.toolChoice else { + Issue.record("expected tool_choice .function") + return + } + #expect(name == "f") + } + + @Test("decodes a tool-result message") + func decodeToolMessage() throws { + let json = #"{"role":"tool","content":"42","tool_call_id":"call_abc"}"# + let message = try JSONDecoder().decode(ChatMessage.self, from: Data(json.utf8)) + + #expect(message.role == "tool") + #expect(message.toolCallId == "call_abc") + #expect(message.content?.text == "42") + } + + @Test("decodes a null message content") + func decodeNullContent() throws { + let json = #"{"role":"assistant","content":null}"# + let message = try JSONDecoder().decode(ChatMessage.self, from: Data(json.utf8)) + #expect(message.content?.text == nil) + } + + @Test("encodes a chat completion response with snake_case keys") + func encodeResponse() throws { + let response = ChatCompletionResponse( + id: "chatcmpl-1", + created: 100, + model: "qwen", + choices: [ + .init( + index: 0, + message: .init(content: "hello", toolCalls: nil), + finishReason: "stop" + ) + ], + usage: Usage(promptTokens: 3, completionTokens: 5) + ) + let string = String(decoding: try JSONEncoder().encode(response), as: UTF8.self) + + #expect(string.contains(#""finish_reason":"stop""#)) + #expect(string.contains(#""prompt_tokens":3"#)) + #expect(string.contains(#""total_tokens":8"#)) + #expect(string.contains(#""object":"chat.completion""#)) + } + + @Test("encodes a streaming chunk as chat.completion.chunk") + func encodeChunk() throws { + let chunk = ChatCompletionChunk( + id: "chatcmpl-1", + created: 100, + model: "qwen", + choices: [ + .init(index: 0, delta: .init(role: "assistant", content: "hi"), finishReason: nil) + ] + ) + let string = String(decoding: try JSONEncoder().encode(chunk), as: UTF8.self) + #expect(string.contains(#""object":"chat.completion.chunk""#)) + #expect(string.contains(#""content":"hi""#)) + } +} diff --git a/Tests/MLXServerTests/RoutesTests.swift b/Tests/MLXServerTests/RoutesTests.swift new file mode 100644 index 0000000..5c62df7 --- /dev/null +++ b/Tests/MLXServerTests/RoutesTests.swift @@ -0,0 +1,131 @@ +import Foundation +import Hummingbird +import HummingbirdTesting +import Testing + +@testable import MLXServerKit + +/// In-memory ``Inferencing`` implementation for exercising routes without a model. +struct StubEngine: Inferencing { + var modelID: String = "stub-model" + var cannedText: String = "hello" + var cannedToolCall: ToolCallObject? + + func complete(_ request: ChatCompletionRequest) async throws -> ChatCompletionResponse { + let hasTool = cannedToolCall != nil + return ChatCompletionResponse( + id: "chatcmpl-stub", + created: 0, + model: request.model, + choices: [ + .init( + index: 0, + message: .init( + content: hasTool ? nil : cannedText, + toolCalls: cannedToolCall.map { [$0] }), + finishReason: hasTool ? "tool_calls" : "stop") + ], + usage: Usage(promptTokens: 1, completionTokens: 1)) + } + + func stream(_ request: ChatCompletionRequest) -> AsyncThrowingStream { + let text = cannedText + let tool = cannedToolCall + return AsyncThrowingStream { continuation in + if let tool { + continuation.yield(.toolCall(tool)) + } else { + continuation.yield(.textDelta(text)) + } + continuation.yield( + .finished( + reason: tool != nil ? "tool_calls" : "stop", + usage: Usage(promptTokens: 1, completionTokens: 1))) + continuation.finish() + } + } +} + +@Suite("Routes") +struct RoutesTests { + private func makeApp(engine: some Inferencing) -> some ApplicationProtocol { + let router = Router() + registerRoutes(router: router, engine: engine) + return Application(router: router) + } + + private let jsonHeaders: HTTPFields = [.contentType: "application/json"] + + @Test("GET /v1/models lists the loaded model") + func models() async throws { + let app = makeApp(engine: StubEngine(modelID: "my-model")) + try await app.test(.router) { client in + let response = try await client.execute(uri: "/v1/models", method: .get) + #expect(response.status == .ok) + let body = String(buffer: response.body) + #expect(body.contains("\"my-model\"")) + #expect(body.contains("\"object\":\"list\"")) + } + } + + @Test("non-streaming chat completion returns a chat.completion") + func nonStreaming() async throws { + let app = makeApp(engine: StubEngine(modelID: "m", cannedText: "pong")) + try await app.test(.router) { client in + let body = ByteBuffer( + string: #"{"model":"m","messages":[{"role":"user","content":"ping"}]}"#) + let response = try await client.execute( + uri: "/v1/chat/completions", method: .post, headers: jsonHeaders, body: body) + #expect(response.status == .ok) + let text = String(buffer: response.body) + #expect(text.contains("\"object\":\"chat.completion\"")) + #expect(text.contains("pong")) + #expect(text.contains("\"finish_reason\":\"stop\"")) + } + } + + @Test("streaming chat completion emits SSE frames terminated by [DONE]") + func streaming() async throws { + let app = makeApp(engine: StubEngine(modelID: "m", cannedText: "streamed")) + try await app.test(.router) { client in + let body = ByteBuffer( + string: #"{"model":"m","stream":true,"messages":[{"role":"user","content":"go"}]}"#) + let response = try await client.execute( + uri: "/v1/chat/completions", method: .post, headers: jsonHeaders, body: body) + #expect(response.status == .ok) + #expect(response.headers[.contentType] == "text/event-stream") + let text = String(buffer: response.body) + #expect(text.contains("chat.completion.chunk")) + #expect(text.contains("data: [DONE]")) + } + } + + @Test("tool-call result reports finish_reason tool_calls") + func toolCalls() async throws { + let toolCall = ToolCallObject( + id: "call_1", + function: FunctionCall(name: "get_weather", arguments: #"{"city":"Paris"}"#), + index: 0) + let app = makeApp(engine: StubEngine(modelID: "m", cannedToolCall: toolCall)) + try await app.test(.router) { client in + let body = ByteBuffer( + string: #"{"model":"m","messages":[{"role":"user","content":"weather?"}]}"#) + let response = try await client.execute( + uri: "/v1/chat/completions", method: .post, headers: jsonHeaders, body: body) + let text = String(buffer: response.body) + #expect(text.contains("\"finish_reason\":\"tool_calls\"")) + #expect(text.contains("get_weather")) + } + } + + @Test("empty messages array is rejected with 400") + func emptyMessages() async throws { + let app = makeApp(engine: StubEngine()) + try await app.test(.router) { client in + let body = ByteBuffer(string: #"{"model":"m","messages":[]}"#) + let response = try await client.execute( + uri: "/v1/chat/completions", method: .post, headers: jsonHeaders, body: body) + #expect(response.status == .badRequest) + } + } +} diff --git a/Tests/MLXServerTests/SSETests.swift b/Tests/MLXServerTests/SSETests.swift new file mode 100644 index 0000000..222e3af --- /dev/null +++ b/Tests/MLXServerTests/SSETests.swift @@ -0,0 +1,17 @@ +import Testing + +@testable import MLXServerKit + +@Suite("SSE") +struct SSETests { + @Test("dataLine frames a payload as data: \\n\\n") + func dataLineFraming() { + let line = SSE.dataLine(#"{"hello":"world"}"#) + #expect(line == "data: {\"hello\":\"world\"}\n\n") + } + + @Test("doneLine is the OpenAI stream terminator") + func doneTerminator() { + #expect(SSE.doneLine == "data: [DONE]\n\n") + } +} diff --git a/Tests/MLXServerTests/ServerConfigTests.swift b/Tests/MLXServerTests/ServerConfigTests.swift new file mode 100644 index 0000000..7dec769 --- /dev/null +++ b/Tests/MLXServerTests/ServerConfigTests.swift @@ -0,0 +1,22 @@ +import Testing + +@testable import MLXServerKit + +@Suite("ServerConfig") +struct ServerConfigTests { + @Test("stores the values it is constructed with") + func storesValues() { + let config = ServerConfig( + model: "/models/Qwen3-4B-4bit", + host: "0.0.0.0", + port: 8080, + maxSlots: 4 + ) + + #expect(config.model == "/models/Qwen3-4B-4bit") + #expect(config.host == "0.0.0.0") + #expect(config.port == 8080) + #expect(config.maxSlots == 4) + #expect(config.toolCallFormat == nil) + } +}