diff --git a/Sources/MLXServerKit/ChatCompletion.swift b/Sources/MLXServerKit/ChatCompletion.swift index a840fb7..f5320e9 100644 --- a/Sources/MLXServerKit/ChatCompletion.swift +++ b/Sources/MLXServerKit/ChatCompletion.swift @@ -48,7 +48,10 @@ extension InferenceEngine { .init( index: 0, message: message, - finishReason: hasToolCalls ? "tool_calls" : "stop" + finishReason: Self.finishReason( + hasToolCalls: hasToolCalls, + generatedTokens: info?.generationTokenCount ?? 0, + maxTokens: parameters.maxTokens) ) ], usage: Usage( @@ -89,6 +92,17 @@ extension InferenceEngine { static func unixNow() -> Int { Int(Date().timeIntervalSince1970) } + + /// OpenAI `finish_reason` for a completed generation: `length` when the + /// output was truncated at the token limit, `tool_calls` when the model + /// emitted tool calls, otherwise `stop`. Truncation is inferred by + /// comparing the generated token count against the requested limit, + /// since the generator does not surface a stop reason directly. + static func finishReason(hasToolCalls: Bool, generatedTokens: Int, maxTokens: Int?) -> String { + if hasToolCalls { return "tool_calls" } + if let maxTokens, generatedTokens >= maxTokens { return "length" } + return "stop" + } } /// Semantic events emitted by a streaming completion, before OpenAI framing. @@ -158,7 +172,10 @@ extension InferenceEngine { continuation.yield( .finished( - reason: toolCallCount > 0 ? "tool_calls" : "stop", + reason: Self.finishReason( + hasToolCalls: toolCallCount > 0, + generatedTokens: info?.generationTokenCount ?? 0, + maxTokens: parameters.maxTokens), usage: Usage( promptTokens: info?.promptTokenCount ?? 0, completionTokens: info?.generationTokenCount ?? 0))) diff --git a/Tests/MLXServerTests/FinishReasonTests.swift b/Tests/MLXServerTests/FinishReasonTests.swift new file mode 100644 index 0000000..6dd4683 --- /dev/null +++ b/Tests/MLXServerTests/FinishReasonTests.swift @@ -0,0 +1,33 @@ +import Testing + +@testable import MLXServerKit + +@Suite("finishReason") +struct FinishReasonTests { + @Test("tool calls take precedence over everything else") + func toolCalls() { + #expect( + InferenceEngine.finishReason( + hasToolCalls: true, generatedTokens: 999, maxTokens: 10) == "tool_calls") + } + + @Test("output truncated at the token limit reports length") + func length() { + #expect( + InferenceEngine.finishReason( + hasToolCalls: false, generatedTokens: 512, maxTokens: 512) == "length") + #expect( + InferenceEngine.finishReason( + hasToolCalls: false, generatedTokens: 600, maxTokens: 512) == "length") + } + + @Test("a natural stop reports stop") + func stop() { + #expect( + InferenceEngine.finishReason( + hasToolCalls: false, generatedTokens: 40, maxTokens: 512) == "stop") + #expect( + InferenceEngine.finishReason( + hasToolCalls: false, generatedTokens: 40, maxTokens: nil) == "stop") + } +}