Skip to content

Commit 4c113bd

Browse files
authored
Produce Thought/ThoughtChunk for Fireworks reasoning_content (tensorzero#3006)
* Produce Thought/ThoughtChunk for Fireworks `reasoning_content` * Fix clippy
1 parent 7ca3dee commit 4c113bd

4 files changed

Lines changed: 130 additions & 2 deletions

File tree

tensorzero-core/src/providers/fireworks/mod.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{borrow::Cow, sync::OnceLock};
22

3-
use crate::tool::Tool;
3+
use crate::{providers::helpers_thinking_block::THINK_CHUNK_ID, tool::Tool};
44
use futures::StreamExt;
55
use lazy_static::lazy_static;
66
use reqwest_eventsource::{Event, EventSource};
@@ -470,6 +470,8 @@ struct FireworksResponseMessage {
470470
#[serde(skip_serializing_if = "Option::is_none")]
471471
content: Option<String>,
472472
#[serde(skip_serializing_if = "Option::is_none")]
473+
reasoning_content: Option<String>,
474+
#[serde(skip_serializing_if = "Option::is_none")]
473475
tool_calls: Option<Vec<FireworksResponseToolCall>>,
474476
}
475477

@@ -531,6 +533,8 @@ struct FireworksDelta {
531533
#[serde(skip_serializing_if = "Option::is_none")]
532534
content: Option<String>,
533535
#[serde(skip_serializing_if = "Option::is_none")]
536+
reasoning_content: Option<String>,
537+
#[serde(skip_serializing_if = "Option::is_none")]
534538
tool_calls: Option<Vec<FireworksToolCallChunk>>,
535539
}
536540

@@ -630,6 +634,14 @@ fn fireworks_to_tensorzero_chunk(
630634
if let Some(reason) = choice.finish_reason {
631635
finish_reason = Some(reason.into());
632636
}
637+
if let Some(reasoning) = choice.delta.reasoning_content {
638+
content.push(ContentBlockChunk::Thought(ThoughtChunk {
639+
text: Some(reasoning),
640+
signature: None,
641+
id: THINK_CHUNK_ID.to_string(),
642+
provider_type: Some(PROVIDER_TYPE.to_string()),
643+
}));
644+
}
633645
if let Some(text) = choice.delta.content {
634646
if parse_think_blocks {
635647
if !thinking_state.update(&text, PROVIDER_TYPE)? {
@@ -745,6 +757,13 @@ impl<'a> TryFrom<FireworksResponseWithMetadata<'a>> for ProviderInferenceRespons
745757
}
746758
))?;
747759
let mut content: Vec<ContentBlockOutput> = Vec::new();
760+
if let Some(reasoning) = message.reasoning_content {
761+
content.push(ContentBlockOutput::Thought(Thought {
762+
text: Some(reasoning),
763+
signature: None,
764+
provider_type: Some(PROVIDER_TYPE.to_string()),
765+
}));
766+
}
748767
if let Some(raw_text) = message.content {
749768
let (clean_text, extracted_reasoning) =
750769
process_think_blocks(&raw_text, parse_think_blocks, PROVIDER_TYPE)?;
@@ -826,6 +845,7 @@ mod tests {
826845
index: 0,
827846
finish_reason: Some(FireworksFinishReason::Stop),
828847
message: FireworksResponseMessage {
848+
reasoning_content: None,
829849
content: Some(test_response_with_thinking.to_string()),
830850
tool_calls: None,
831851
},
@@ -1008,6 +1028,7 @@ mod tests {
10081028
index: 0,
10091029
finish_reason: Some(FireworksFinishReason::Stop),
10101030
message: FireworksResponseMessage {
1031+
reasoning_content: None,
10111032
content: Some("Hello, world!".to_string()),
10121033
tool_calls: None,
10131034
},
@@ -1077,6 +1098,7 @@ mod tests {
10771098
choices: vec![FireworksChatChunkChoice {
10781099
delta: FireworksDelta {
10791100
content: Some("Hello".to_string()),
1101+
reasoning_content: None,
10801102
tool_calls: None,
10811103
},
10821104
finish_reason: Some(FireworksFinishReason::Stop),
@@ -1109,6 +1131,7 @@ mod tests {
11091131
choices: vec![FireworksChatChunkChoice {
11101132
delta: FireworksDelta {
11111133
content: None,
1134+
reasoning_content: None,
11121135
tool_calls: Some(vec![FireworksToolCallChunk {
11131136
index: 0,
11141137
id: None,
@@ -1178,6 +1201,7 @@ mod tests {
11781201
choices: vec![FireworksChatChunkChoice {
11791202
delta: FireworksDelta {
11801203
content: Some("<think>".to_string()),
1204+
reasoning_content: None,
11811205
tool_calls: None,
11821206
},
11831207
finish_reason: None,
@@ -1209,6 +1233,7 @@ mod tests {
12091233
choices: vec![FireworksChatChunkChoice {
12101234
delta: FireworksDelta {
12111235
content: Some("reasoning".to_string()),
1236+
reasoning_content: None,
12121237
tool_calls: None,
12131238
},
12141239
finish_reason: None,
@@ -1241,6 +1266,7 @@ mod tests {
12411266
choices: vec![FireworksChatChunkChoice {
12421267
delta: FireworksDelta {
12431268
content: Some("</think>".to_string()),
1269+
reasoning_content: None,
12441270
tool_calls: None,
12451271
},
12461272
finish_reason: None,
@@ -1268,6 +1294,7 @@ mod tests {
12681294
choices: vec![FireworksChatChunkChoice {
12691295
delta: FireworksDelta {
12701296
content: Some("Final answer".to_string()),
1297+
reasoning_content: None,
12711298
tool_calls: None,
12721299
},
12731300
finish_reason: None,
@@ -1302,6 +1329,7 @@ mod tests {
13021329
choices: vec![FireworksChatChunkChoice {
13031330
delta: FireworksDelta {
13041331
content: Some("Hello <think>should not parse</think>".to_string()),
1332+
reasoning_content: None,
13051333
tool_calls: None,
13061334
},
13071335
finish_reason: Some(FireworksFinishReason::Stop),
@@ -1335,6 +1363,7 @@ mod tests {
13351363
choices: vec![FireworksChatChunkChoice {
13361364
delta: FireworksDelta {
13371365
content: None,
1366+
reasoning_content: None,
13381367
tool_calls: Some(vec![FireworksToolCallChunk {
13391368
index: 0,
13401369
id: Some("new_id".to_string()),
@@ -1379,6 +1408,7 @@ mod tests {
13791408
choices: vec![FireworksChatChunkChoice {
13801409
delta: FireworksDelta {
13811410
content: None,
1411+
reasoning_content: None,
13821412
tool_calls: Some(vec![FireworksToolCallChunk {
13831413
index: 0,
13841414
id: None,

tensorzero-core/src/providers/helpers_thinking_block.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub enum ThinkingState {
5858
Finished,
5959
}
6060

61+
pub const THINK_CHUNK_ID: u64 = 1;
62+
6163
impl ThinkingState {
6264
pub fn get_id(&self) -> String {
6365
match self {

tensorzero-core/tests/e2e/providers/fireworks.rs

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
#![expect(clippy::print_stderr)]
12
use std::collections::HashMap;
23

3-
use crate::providers::common::{E2ETestProvider, E2ETestProviders};
4+
use http::StatusCode;
5+
use reqwest::Client;
6+
use reqwest_eventsource::{Event, RequestBuilderExt};
7+
use tokio_stream::StreamExt;
8+
9+
use crate::{
10+
common::get_gateway_endpoint,
11+
providers::common::{E2ETestProvider, E2ETestProviders},
12+
};
413

514
crate::generate_provider_tests!(get_providers);
615
crate::generate_batch_inference_tests!(get_providers);
@@ -117,3 +126,81 @@ async fn get_providers() -> E2ETestProviders {
117126
shorthand_inference: shorthand_providers,
118127
}
119128
}
129+
130+
#[tokio::test]
131+
async fn test_fireworks_reasoning_content_non_stream() {
132+
let response = Client::new()
133+
.post(get_gateway_endpoint("/inference"))
134+
.json(&serde_json::json!({
135+
"model_name": "gpt-oss-20b-fireworks",
136+
"input": {
137+
"messages": [
138+
{
139+
"role": "user",
140+
"content": "What is the capital of France? Think before responding."
141+
}
142+
]
143+
}
144+
}))
145+
.send()
146+
.await
147+
.unwrap();
148+
149+
assert_eq!(response.status(), StatusCode::OK);
150+
let response_body = response.json::<serde_json::Value>().await.unwrap();
151+
152+
eprintln!("API response: {response_body}");
153+
let content = response_body["content"].as_array().unwrap();
154+
// Check that the response contains a thought block
155+
let thought = content.iter().find(|c| c["type"] == "thought").unwrap();
156+
assert!(
157+
!thought["text"].as_str().unwrap().is_empty(),
158+
"Thought block was empty: {thought:?}",
159+
);
160+
}
161+
162+
#[tokio::test]
163+
async fn test_fireworks_reasoning_content_stream() {
164+
let mut event_source = Client::new()
165+
.post(get_gateway_endpoint("/inference"))
166+
.json(&serde_json::json!({
167+
"model_name": "gpt-oss-20b-fireworks",
168+
"input": {
169+
"messages": [
170+
{
171+
"role": "user",
172+
"content": "What is the capital of France? Think before responding."
173+
}
174+
]
175+
},
176+
"stream": true,
177+
}))
178+
.eventsource()
179+
.unwrap();
180+
181+
let mut chunks = vec![];
182+
while let Some(event) = event_source.next().await {
183+
let event = event.unwrap();
184+
match event {
185+
Event::Open => continue,
186+
Event::Message(message) => {
187+
eprintln!("API chunk: {message:?}");
188+
if message.data == "[DONE]" {
189+
break;
190+
}
191+
chunks.push(message.data.to_string());
192+
}
193+
}
194+
}
195+
196+
let mut found_thought = false;
197+
for chunk in chunks {
198+
let chunk_json: serde_json::Value = serde_json::from_str(&chunk).unwrap();
199+
eprintln!("Chunk: {chunk_json}");
200+
if chunk_json["content"][0]["type"] == "thought" {
201+
found_thought = true;
202+
break;
203+
}
204+
}
205+
assert!(found_thought, "Thought chunk not found");
206+
}

tensorzero-core/tests/e2e/tensorzero.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ extra_body = [
3939
{ pointer = "/frequency_penalty", value = 1.42 },
4040
]
4141

42+
[models."gpt-oss-20b-fireworks"]
43+
routing = ["fireworks"]
44+
45+
[models."gpt-oss-20b-fireworks".providers.fireworks]
46+
type = "fireworks"
47+
model_name = "accounts/fireworks/models/gpt-oss-20b"
48+
parse_think_blocks = false
49+
50+
4251
[models."gpt-4o-mini-2024-07-18-dynamic"]
4352
routing = ["openai"]
4453

0 commit comments

Comments
 (0)