Skip to content

Commit 122d6c8

Browse files
authored
fix: capture assistant message in conversation history (#79)
Add assistant messages to the conversation history. Add heuristic to trim the tool responses if context size exceeds. Always preserve the last user message block.
1 parent 3359476 commit 122d6c8

12 files changed

Lines changed: 337 additions & 279 deletions

File tree

crates/arey/src/cli/chat/commands.rs

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,11 @@ fn format_message_block(messages: &[arey_core::completion::ChatMessage]) -> Resu
431431
out.push_str(&format!("{} {}\n", sender_tag, content));
432432

433433
// Show tool calls if any
434-
if !msg.tools.is_empty() {
434+
if let Some(tools) = msg.tools.as_ref()
435+
&& !tools.is_empty()
436+
{
435437
out.push_str(" Tools:\n");
436-
for tool in &msg.tools {
438+
for tool in tools {
437439
out.push_str(&format!(" - {}: {}\n", tool.name, tool.arguments));
438440
}
439441
}
@@ -511,7 +513,7 @@ mod tests {
511513
let messages = vec![ChatMessage {
512514
sender: SenderType::User,
513515
text: "Test".to_string(),
514-
tools: vec![],
516+
..Default::default()
515517
}];
516518
let result = format_message_block(&messages).unwrap();
517519
let expected = r#"
@@ -528,22 +530,24 @@ USER: Test
528530
ChatMessage {
529531
sender: SenderType::User,
530532
text: "First".to_string(),
531-
tools: vec![],
533+
..Default::default()
532534
},
533535
ChatMessage {
534536
sender: SenderType::Assistant,
535537
text: "First Response".to_string(),
536-
tools: vec![],
538+
metrics: Some(Default::default()),
539+
..Default::default()
537540
},
538541
ChatMessage {
539542
sender: SenderType::User,
540543
text: "Second".to_string(),
541-
tools: vec![],
544+
..Default::default()
542545
},
543546
ChatMessage {
544547
sender: SenderType::Assistant,
545548
text: "Second Response".to_string(),
546-
tools: vec![],
549+
metrics: Some(Default::default()),
550+
..Default::default()
547551
},
548552
];
549553
let result = format_message_block(&messages).unwrap();
@@ -563,7 +567,7 @@ ASSISTANT: Second Response
563567
let messages = vec![ChatMessage {
564568
sender: SenderType::User,
565569
text: long_text,
566-
tools: vec![],
570+
..Default::default()
567571
}];
568572
let result = format_message_block(&messages).unwrap();
569573
let truncated_part = "a".repeat(500) + "\n... [truncated]";
@@ -576,11 +580,12 @@ ASSISTANT: Second Response
576580
let messages = vec![ChatMessage {
577581
sender: SenderType::User,
578582
text: "Run tool".to_string(),
579-
tools: vec![ToolCall {
583+
tools: Some(vec![ToolCall {
580584
id: "id1".to_string(),
581585
name: "tool1".to_string(),
582586
arguments: "{\"arg\":1}".to_string(),
583-
}],
587+
}]),
588+
..Default::default()
584589
}];
585590
let result = format_message_block(&messages).unwrap();
586591
let expected = r#"
@@ -797,14 +802,11 @@ USER: Run tool
797802
chat_session
798803
.lock()
799804
.await
800-
.add_messages(
801-
vec![ChatMessage {
802-
sender: SenderType::User,
803-
text: "hello".to_string(),
804-
tools: vec![],
805-
}],
806-
vec![],
807-
)
805+
.add_messages(vec![ChatMessage {
806+
sender: SenderType::User,
807+
text: "hello".to_string(),
808+
..Default::default()
809+
}])
808810
.await;
809811
assert!(!chat_session.lock().await.get_all_messages().is_empty());
810812

@@ -826,14 +828,11 @@ USER: Run tool
826828
chat_session
827829
.lock()
828830
.await
829-
.add_messages(
830-
vec![ChatMessage {
831-
sender: SenderType::User,
832-
text: "log this".to_string(),
833-
tools: vec![],
834-
}],
835-
vec![],
836-
)
831+
.add_messages(vec![ChatMessage {
832+
sender: SenderType::User,
833+
text: "log this".to_string(),
834+
..Default::default()
835+
}])
837836
.await;
838837

839838
// Test log command

crates/arey/src/cli/chat/repl.rs

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ pub async fn run(chat: Arc<Mutex<Chat<'_>>>, renderer: &mut TerminalRenderer<'_>
165165
let user_messages = vec![ChatMessage {
166166
text: line.to_string(),
167167
sender: SenderType::User,
168-
tools: vec![],
168+
..Default::default()
169169
}];
170-
if !process_message(chat.clone(), renderer, user_messages, vec![]).await? {
170+
if !process_message(chat.clone(), renderer, user_messages).await? {
171171
save_history_on_exit(&mut rl)?;
172172
return Ok(());
173173
}
@@ -206,8 +206,7 @@ fn clean_value(value: &mut serde_json::Value) {
206206
async fn process_message(
207207
chat: Arc<Mutex<Chat<'_>>>,
208208
renderer: &mut TerminalRenderer<'_>,
209-
user_messages: Vec<ChatMessage>,
210-
tool_messages: Vec<ChatMessage>,
209+
messages: Vec<ChatMessage>, // User input or tool responses
211210
) -> Result<bool> {
212211
let mut metrics = CompletionMetrics::default();
213212
let mut finish_reason: Option<String> = None;
@@ -222,16 +221,19 @@ async fn process_message(
222221
// Clone for async block
223222
let chat_clone = chat.clone();
224223

225-
// Child tool messages are created if LLM requires a set of tools to be invoked for responding
226-
// to a user message.
227-
let mut child_tool_messages: Vec<ChatMessage> = vec![];
224+
// Store tool call responses if LLM requires a set of tools to be invoked for responding to a
225+
// user message.
226+
let mut assistant_message_text = String::new();
227+
let mut assistant_message_tools: Vec<ToolCall> = vec![];
228+
let mut assistant_tool_responses: Vec<ChatMessage> = vec![];
229+
228230
let mut stream_error = false;
229231
let was_cancelled = {
230232
// Get stream response
231233
let mut chat_guard = chat_clone.lock().await;
232234
let available_tools = chat_guard.available_tools.clone();
233235
let mut stream = {
234-
chat_guard.add_messages(user_messages, tool_messages).await;
236+
chat_guard.add_messages(messages).await;
235237
chat_guard.stream_response(cancel_token.clone()).await?
236238
};
237239

@@ -268,6 +270,7 @@ async fn process_message(
268270
match response {
269271
Ok(Completion::Response(chunk)) => {
270272
if !&chunk.text.is_empty() {
273+
assistant_message_text.push_str(&chunk.text);
271274
renderer.render_markdown(&chunk.text)?;
272275
}
273276

@@ -277,8 +280,8 @@ async fn process_message(
277280

278281
// Tool messages can come in chunks, we collate all
279282
if let Some(tools) = &chunk.tool_calls {
280-
child_tool_messages =
281-
process_tools(&available_tools, tools).await?;
283+
assistant_message_tools.extend(tools.clone());
284+
assistant_tool_responses = process_tools(&available_tools, tools).await?;
282285
}
283286
}
284287
Ok(Completion::Metrics(m)) => {
@@ -309,16 +312,28 @@ async fn process_message(
309312
return Ok(true);
310313
}
311314

315+
{
316+
// Add the assistant message to the chat history
317+
let mut chat_guard = chat.lock().await;
318+
chat_guard
319+
.add_messages(vec![ChatMessage {
320+
sender: SenderType::Assistant,
321+
text: assistant_message_text,
322+
tools: Some(assistant_message_tools),
323+
metrics: Some(metrics.clone()),
324+
}])
325+
.await;
326+
}
327+
312328
// After a successful stream, flush any remaining partial lines from the renderer.
313329
renderer.render_markdown("\n")?;
314330

315331
// If the model produced tool calls, recursively call this function to process them.
316-
if !child_tool_messages.is_empty() {
332+
if !assistant_tool_responses.is_empty() {
317333
return Box::pin(process_message(
318334
chat_clone,
319335
renderer,
320-
vec![],
321-
child_tool_messages,
336+
assistant_tool_responses,
322337
))
323338
.await;
324339
}
@@ -448,7 +463,7 @@ async fn process_tools(
448463
tool_messages.push(ChatMessage {
449464
text: serde_json::to_string(&result)?,
450465
sender: SenderType::Tool,
451-
tools: vec![],
466+
..Default::default()
452467
});
453468
}
454469

@@ -647,13 +662,14 @@ mod tests {
647662
let user_message = ChatMessage {
648663
sender: SenderType::User,
649664
text: "Hi".to_string(),
650-
tools: vec![],
665+
..Default::default()
651666
};
652-
process_message(chat_session, &mut renderer, vec![user_message], vec![]).await?;
667+
process_message(chat_session.clone(), &mut renderer, vec![user_message]).await?;
653668

654669
// 3. Assert rendered output
655670
let output = String::from_utf8(buffer).unwrap();
656671
assert!(output.contains("Hello world"));
672+
assert_eq!(2, chat_session.lock().await.get_all_messages().len());
657673
Ok(())
658674
}
659675

@@ -678,12 +694,13 @@ mod tests {
678694
let user_message = ChatMessage {
679695
sender: SenderType::User,
680696
text: "Hi".to_string(),
681-
tools: vec![],
697+
..Default::default()
682698
};
683-
process_message(chat_session, &mut renderer, vec![user_message], vec![]).await?;
699+
process_message(chat_session.clone(), &mut renderer, vec![user_message]).await?;
684700

685701
let output = String::from_utf8(buffer).unwrap();
686702
assert!(output.contains("Tool output is mock tool output"));
703+
assert_eq!(4, chat_session.lock().await.get_all_messages().len());
687704
Ok(())
688705
}
689706

@@ -699,9 +716,9 @@ mod tests {
699716
let user_message = ChatMessage {
700717
sender: SenderType::User,
701718
text: "Hi".to_string(),
702-
tools: vec![],
719+
..Default::default()
703720
};
704-
process_message(chat_session, &mut renderer, vec![user_message], vec![]).await?;
721+
process_message(chat_session.clone(), &mut renderer, vec![user_message]).await?;
705722

706723
// Expect no output to renderer, error is printed to stderr
707724
let output = String::from_utf8(buffer).unwrap();
@@ -710,6 +727,7 @@ mod tests {
710727
"Output should be empty. Output: {}",
711728
output
712729
);
730+
assert_eq!(1, chat_session.lock().await.get_all_messages().len());
713731
Ok(())
714732
}
715733

crates/arey/src/cli/mod.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ enum Commands {
5656

5757
/// Runs the main CLI application.
5858
pub async fn run() -> Result<()> {
59-
let cli = Cli::parse();
60-
// let cli = Cli {
61-
// command: Commands::Chat { model: None },
62-
// verbose: true,
63-
// };
59+
// let cli = Cli::parse();
60+
let cli = Cli {
61+
command: Commands::Chat { model: None },
62+
verbose: true,
63+
};
6464

6565
if cli.verbose {
6666
setup_logging().context("Failed to set up logging")?;

0 commit comments

Comments
 (0)