Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions Sources/MLXServerKit/ChatCompletion.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ extension InferenceEngine {
.init(
index: 0,
message: message,
finishReason: hasToolCalls ? "tool_calls" : "stop"
finishReason: Self.finishReason(
hasToolCalls: hasToolCalls,
generatedTokens: info?.generationTokenCount ?? 0,
maxTokens: parameters.maxTokens)
)
],
usage: Usage(
Expand Down Expand Up @@ -89,6 +92,17 @@ extension InferenceEngine {
static func unixNow() -> Int {
Int(Date().timeIntervalSince1970)
}

/// OpenAI `finish_reason` for a completed generation: `length` when the
/// output was truncated at the token limit, `tool_calls` when the model
/// emitted tool calls, otherwise `stop`. Truncation is inferred by
/// comparing the generated token count against the requested limit,
/// since the generator does not surface a stop reason directly.
static func finishReason(hasToolCalls: Bool, generatedTokens: Int, maxTokens: Int?) -> String {
if hasToolCalls { return "tool_calls" }
if let maxTokens, generatedTokens >= maxTokens { return "length" }
return "stop"
}
}

/// Semantic events emitted by a streaming completion, before OpenAI framing.
Expand Down Expand Up @@ -158,7 +172,10 @@ extension InferenceEngine {

continuation.yield(
.finished(
reason: toolCallCount > 0 ? "tool_calls" : "stop",
reason: Self.finishReason(
hasToolCalls: toolCallCount > 0,
generatedTokens: info?.generationTokenCount ?? 0,
maxTokens: parameters.maxTokens),
usage: Usage(
promptTokens: info?.promptTokenCount ?? 0,
completionTokens: info?.generationTokenCount ?? 0)))
Expand Down
33 changes: 33 additions & 0 deletions Tests/MLXServerTests/FinishReasonTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import Testing

@testable import MLXServerKit

@Suite("finishReason")
struct FinishReasonTests {
@Test("tool calls take precedence over everything else")
func toolCalls() {
#expect(
InferenceEngine.finishReason(
hasToolCalls: true, generatedTokens: 999, maxTokens: 10) == "tool_calls")
}

@Test("output truncated at the token limit reports length")
func length() {
#expect(
InferenceEngine.finishReason(
hasToolCalls: false, generatedTokens: 512, maxTokens: 512) == "length")
#expect(
InferenceEngine.finishReason(
hasToolCalls: false, generatedTokens: 600, maxTokens: 512) == "length")
}

@Test("a natural stop reports stop")
func stop() {
#expect(
InferenceEngine.finishReason(
hasToolCalls: false, generatedTokens: 40, maxTokens: 512) == "stop")
#expect(
InferenceEngine.finishReason(
hasToolCalls: false, generatedTokens: 40, maxTokens: nil) == "stop")
}
}