diff --git a/README.md b/README.md index e305cfc..4a2e972 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,14 @@ xcodebuild -scheme mlx-server -destination 'platform=macOS,arch=arm64' \ ``` `--model` takes a local MLX model directory or a HuggingFace id. Other flags: -`--host`, `--port`, `--max-slots`, `--tool-call-format` (e.g. `xml_function` -for Qwen3.5 / Qwen3-Coder; auto-inferred when unset). + +- `--host`, `--port`, `--max-slots` +- `--tool-call-format` — e.g. `xml_function` for Qwen3.5 / Qwen3-Coder; + auto-inferred when unset +- `--reasoning` — how thinking output is split into `reasoning_content` vs + `content`: `auto` (default; splits on a literal ``/``), + `prefilled` (output starts mid-thought — use for Qwen3.5 / Qwen3.6), or + `off` ## Roadmap diff --git a/Sources/MLXServer/MLXServerCommand.swift b/Sources/MLXServer/MLXServerCommand.swift index 48a0699..47cde87 100644 --- a/Sources/MLXServer/MLXServerCommand.swift +++ b/Sources/MLXServer/MLXServerCommand.swift @@ -23,17 +23,24 @@ struct MLXServerCommand: AsyncParsableCommand { @Option(name: .long, help: "Tool-call format override (e.g. xml_function, json). Auto-inferred when unset.") var toolCallFormat: String? + @Option(name: .long, help: "Reasoning split mode: auto, prefilled, or off. Use 'prefilled' for Qwen3.5 / Qwen3.6.") + var reasoning: String = "auto" + func run() async throws { guard let model else { throw ValidationError("--model is required (HuggingFace ID or local directory path).") } + guard let reasoningMode = ReasoningMode(rawValue: reasoning) else { + throw ValidationError("--reasoning must be one of: auto, prefilled, off") + } let config = ServerConfig( model: model, host: host, port: port, maxSlots: maxSlots, - toolCallFormat: toolCallFormat + toolCallFormat: toolCallFormat, + reasoningMode: reasoningMode ) try await MLXServerKit.run(config: config) diff --git a/Sources/MLXServerKit/ChatCompletion.swift b/Sources/MLXServerKit/ChatCompletion.swift index df3886d..a840fb7 100644 --- a/Sources/MLXServerKit/ChatCompletion.swift +++ b/Sources/MLXServerKit/ChatCompletion.swift @@ -7,7 +7,9 @@ extension InferenceEngine { let input = try await prepareInput(for: request) let parameters = ChatMapping.resolveGenerateParameters(request) - var text = "" + var splitter = ReasoningSplitter(mode: reasoningMode) + var content = "" + var reasoning = "" var toolCalls: [ToolCallObject] = [] var info: GenerateCompletionInfo? @@ -16,7 +18,9 @@ extension InferenceEngine { for await generation in stream { switch generation { case .chunk(let chunk): - text += chunk + let split = splitter.push(chunk) + content += split.content + reasoning += split.reasoning case .toolCall(let call): toolCalls.append(Self.toolCallObject(call, index: toolCalls.count)) case .info(let completionInfo): @@ -26,10 +30,14 @@ extension InferenceEngine { } catch { throw ServerError.inferenceFailed(String(describing: error)) } + let tail = splitter.flush() + content += tail.content + reasoning += tail.reasoning let hasToolCalls = !toolCalls.isEmpty let message = ChatCompletionResponse.ResponseMessage( - content: hasToolCalls ? (text.isEmpty ? nil : text) : text, + content: hasToolCalls && content.isEmpty ? nil : content, + reasoningContent: reasoning.isEmpty ? nil : reasoning, toolCalls: hasToolCalls ? toolCalls : nil ) return ChatCompletionResponse( @@ -86,6 +94,7 @@ extension InferenceEngine { /// Semantic events emitted by a streaming completion, before OpenAI framing. enum StreamEvent: Sendable { case textDelta(String) + case reasoningDelta(String) case toolCall(ToolCallObject) case finished(reason: String, usage: Usage) } @@ -117,6 +126,7 @@ extension InferenceEngine { let input = try await prepareInput(for: request) let parameters = ChatMapping.resolveGenerateParameters(request) + var splitter = ReasoningSplitter(mode: reasoningMode) var toolCallCount = 0 var info: GenerateCompletionInfo? do { @@ -124,7 +134,13 @@ extension InferenceEngine { for await generation in generations { switch generation { case .chunk(let chunk): - continuation.yield(.textDelta(chunk)) + let split = splitter.push(chunk) + if !split.reasoning.isEmpty { + continuation.yield(.reasoningDelta(split.reasoning)) + } + if !split.content.isEmpty { + continuation.yield(.textDelta(split.content)) + } case .toolCall(let call): continuation.yield(.toolCall(Self.toolCallObject(call, index: toolCallCount))) toolCallCount += 1 @@ -136,6 +152,10 @@ extension InferenceEngine { throw ServerError.inferenceFailed(String(describing: error)) } + let tail = splitter.flush() + if !tail.reasoning.isEmpty { continuation.yield(.reasoningDelta(tail.reasoning)) } + if !tail.content.isEmpty { continuation.yield(.textDelta(tail.content)) } + continuation.yield( .finished( reason: toolCallCount > 0 ? "tool_calls" : "stop", diff --git a/Sources/MLXServerKit/ChatCompletionsHandler.swift b/Sources/MLXServerKit/ChatCompletionsHandler.swift index e3c5e02..372180f 100644 --- a/Sources/MLXServerKit/ChatCompletionsHandler.swift +++ b/Sources/MLXServerKit/ChatCompletionsHandler.swift @@ -64,6 +64,13 @@ enum ChatCompletionsHandler { roleSent = true try await writer.write( SSE.event(chunk(id, created, model, delta, finishReason: nil))) + case .reasoningDelta(let text): + let delta = ChatCompletionChunk.Delta( + role: roleSent ? nil : "assistant", content: nil, + reasoningContent: text, toolCalls: nil) + roleSent = true + try await writer.write( + SSE.event(chunk(id, created, model, delta, finishReason: nil))) case .toolCall(let call): let delta = ChatCompletionChunk.Delta( role: roleSent ? nil : "assistant", content: nil, toolCalls: [call]) diff --git a/Sources/MLXServerKit/InferenceEngine.swift b/Sources/MLXServerKit/InferenceEngine.swift index 37a2077..30d47e5 100644 --- a/Sources/MLXServerKit/InferenceEngine.swift +++ b/Sources/MLXServerKit/InferenceEngine.swift @@ -17,17 +17,21 @@ public actor InferenceEngine { public let modelID: String /// Resolved tool-call format, or `nil` to let the model config decide. let toolCallFormat: ToolCallFormat? + /// How thinking output is split into reasoning vs answer text. + let reasoningMode: ReasoningMode let logger: Logger private init( container: ModelContainer, modelID: String, toolCallFormat: ToolCallFormat?, + reasoningMode: ReasoningMode, logger: Logger ) { self.container = container self.modelID = modelID self.toolCallFormat = toolCallFormat + self.reasoningMode = reasoningMode self.logger = logger } @@ -63,6 +67,7 @@ public actor InferenceEngine { container: container, modelID: config.model, toolCallFormat: format, + reasoningMode: config.reasoningMode, logger: logger) } diff --git a/Sources/MLXServerKit/OpenAITypes.swift b/Sources/MLXServerKit/OpenAITypes.swift index 8f7be2c..a7f6db0 100644 --- a/Sources/MLXServerKit/OpenAITypes.swift +++ b/Sources/MLXServerKit/OpenAITypes.swift @@ -212,10 +212,12 @@ public struct ChatCompletionResponse: Encodable, Sendable { public struct ResponseMessage: Encodable, Sendable { public var role: String = "assistant" public var content: String? + public var reasoningContent: String? public var toolCalls: [ToolCallObject]? enum CodingKeys: String, CodingKey { case role, content + case reasoningContent = "reasoning_content" case toolCalls = "tool_calls" } } @@ -262,10 +264,12 @@ public struct ChatCompletionChunk: Encodable, Sendable { public struct Delta: Encodable, Sendable { public var role: String? public var content: String? + public var reasoningContent: String? public var toolCalls: [ToolCallObject]? enum CodingKeys: String, CodingKey { case role, content + case reasoningContent = "reasoning_content" case toolCalls = "tool_calls" } } diff --git a/Sources/MLXServerKit/ReasoningSplitter.swift b/Sources/MLXServerKit/ReasoningSplitter.swift new file mode 100644 index 0000000..0507380 --- /dev/null +++ b/Sources/MLXServerKit/ReasoningSplitter.swift @@ -0,0 +1,125 @@ +import Foundation + +/// How the server separates a model's thinking output from its answer. +public enum ReasoningMode: String, Sendable { + /// Start in the answer; a literal `` opens a reasoning block and + /// `` closes it. Safe for non-reasoning models (no markers ever + /// appear, so everything stays in the answer). + case auto + /// Start already inside reasoning — the chat template prefilled the + /// opening ``, so generated output begins mid-thought. `` + /// switches to the answer. Use for Qwen3.5 / Qwen3.6. + case prefilled + /// No splitting; all output is the answer. + case off +} + +/// Streaming splitter that classifies model output into reasoning vs. answer +/// text by tracking `` / `` markers. +/// +/// Marker-safe across chunk boundaries: a partial marker at a chunk edge is +/// held back until the next chunk completes (or fails to complete) it. +struct ReasoningSplitter { + /// Text classified out of a `push` or `flush` call. + struct Split: Equatable { + var reasoning = "" + var content = "" + } + + private static let openMarker = "" + private static let closeMarker = "" + + private enum Phase { case reasoning, content } + private var phase: Phase + /// True while a `` opener could still appear. + private var watchingForOpen: Bool + private let mode: ReasoningMode + /// Holds a possible partial marker straddling a chunk boundary. + private var pending = "" + + init(mode: ReasoningMode) { + self.mode = mode + switch mode { + case .auto: + phase = .content + watchingForOpen = true + case .prefilled: + phase = .reasoning + watchingForOpen = false + case .off: + phase = .content + watchingForOpen = false + } + } + + /// The marker currently being scanned for, or `nil` when none applies. + private var activeMarker: String? { + if watchingForOpen { return Self.openMarker } + if phase == .reasoning { return Self.closeMarker } + return nil + } + + /// Feed a chunk of model output; returns the text split by phase. + mutating func push(_ text: String) -> Split { + var split = Split() + guard mode != .off else { + split.content = text + return split + } + + var work = pending + text + pending = "" + + while let marker = activeMarker { + if let range = work.range(of: marker) { + emit(String(work[work.startIndex.. Split { + var split = Split() + emit(pending, into: &split) + pending = "" + return split + } + + private func emit(_ text: String, into split: inout Split) { + guard !text.isEmpty else { return } + switch phase { + case .reasoning: split.reasoning += text + case .content: split.content += text + } + } + + private mutating func advancePhase(after marker: String) { + phase = (marker == Self.openMarker) ? .reasoning : .content + watchingForOpen = false + } + + /// Length of the longest suffix of `text` that is a proper prefix of `marker`. + private func partialMarkerSuffixLength(of text: String, marker: String) -> Int { + var length = min(text.count, marker.count - 1) + while length > 0 { + if marker.hasPrefix(text.suffix(length)) { + return length + } + length -= 1 + } + return 0 + } +} diff --git a/Sources/MLXServerKit/ServerConfig.swift b/Sources/MLXServerKit/ServerConfig.swift index c6aea25..c55280d 100644 --- a/Sources/MLXServerKit/ServerConfig.swift +++ b/Sources/MLXServerKit/ServerConfig.swift @@ -15,18 +15,22 @@ public struct ServerConfig: Sendable { /// Optional tool-call format override (e.g. `xml_function`, `json`). /// When `nil` the format is inferred from the model's `config.json`. public var toolCallFormat: String? + /// How thinking-model output is split into `reasoning_content` vs `content`. + public var reasoningMode: ReasoningMode public init( model: String, host: String, port: Int, maxSlots: Int, - toolCallFormat: String? = nil + toolCallFormat: String? = nil, + reasoningMode: ReasoningMode = .auto ) { self.model = model self.host = host self.port = port self.maxSlots = maxSlots self.toolCallFormat = toolCallFormat + self.reasoningMode = reasoningMode } } diff --git a/Tests/MLXServerTests/ReasoningSplitterTests.swift b/Tests/MLXServerTests/ReasoningSplitterTests.swift new file mode 100644 index 0000000..7e99727 --- /dev/null +++ b/Tests/MLXServerTests/ReasoningSplitterTests.swift @@ -0,0 +1,78 @@ +import Testing + +@testable import MLXServerKit + +@Suite("ReasoningSplitter") +struct ReasoningSplitterTests { + /// Feed all chunks through a splitter, then flush; return the accumulated split. + private func run(mode: ReasoningMode, _ chunks: [String]) -> ReasoningSplitter.Split { + var splitter = ReasoningSplitter(mode: mode) + var accumulated = ReasoningSplitter.Split() + for chunk in chunks { + let split = splitter.push(chunk) + accumulated.reasoning += split.reasoning + accumulated.content += split.content + } + let tail = splitter.flush() + accumulated.reasoning += tail.reasoning + accumulated.content += tail.content + return accumulated + } + + @Test("auto mode splits a block from the answer") + func autoSplit() { + let result = run(mode: .auto, ["reasoning herethe answer"]) + #expect(result.reasoning == "reasoning here") + #expect(result.content == "the answer") + } + + @Test("auto mode passes plain output through as content") + func autoNoMarkers() { + let result = run(mode: .auto, ["just a plain answer"]) + #expect(result.reasoning == "") + #expect(result.content == "just a plain answer") + } + + @Test("prefilled mode treats leading text as reasoning until ") + func prefilledSplit() { + let result = run(mode: .prefilled, ["thinking out loudfinal answer"]) + #expect(result.reasoning == "thinking out loud") + #expect(result.content == "final answer") + } + + @Test("off mode keeps everything as content") + func offMode() { + let result = run(mode: .off, ["xy"]) + #expect(result.reasoning == "") + #expect(result.content == "xy") + } + + @Test("a closing marker split across chunk boundaries is reassembled") + func partialCloseMarker() { + let result = run(mode: .prefilled, ["abcdef"]) + #expect(result.reasoning == "abc") + #expect(result.content == "def") + } + + @Test("an opening marker split across chunk boundaries is reassembled") + func partialOpenMarker() { + let result = run(mode: .auto, ["rc"]) + #expect(result.reasoning == "r") + #expect(result.content == "c") + } + + @Test("token-by-token streaming classifies each phase correctly") + func tokenByToken() { + let tokens = ["<", "think", ">", "deep ", "thoughts", "", "the ", "answer"] + let result = run(mode: .auto, tokens) + #expect(result.reasoning == "deep thoughts") + #expect(result.content == "the answer") + } + + @Test("flush emits a held-back partial marker that never completed") + func flushIncompletePartial() { + let result = run(mode: .prefilled, ["reasoning