@@ -165,9 +165,9 @@ pub async fn run(chat: Arc<Mutex<Chat<'_>>>, renderer: &mut TerminalRenderer<'_>
165165 let user_messages = vec ! [ ChatMessage {
166166 text: line. to_string( ) ,
167167 sender: SenderType :: User ,
168- tools : vec! [ ] ,
168+ .. Default :: default ( )
169169 } ] ;
170- if !process_message ( chat. clone ( ) , renderer, user_messages, vec ! [ ] ) . await ? {
170+ if !process_message ( chat. clone ( ) , renderer, user_messages) . await ? {
171171 save_history_on_exit ( & mut rl) ?;
172172 return Ok ( ( ) ) ;
173173 }
@@ -206,8 +206,7 @@ fn clean_value(value: &mut serde_json::Value) {
206206async fn process_message (
207207 chat : Arc < Mutex < Chat < ' _ > > > ,
208208 renderer : & mut TerminalRenderer < ' _ > ,
209- user_messages : Vec < ChatMessage > ,
210- tool_messages : Vec < ChatMessage > ,
209+ messages : Vec < ChatMessage > , // User input or tool responses
211210) -> Result < bool > {
212211 let mut metrics = CompletionMetrics :: default ( ) ;
213212 let mut finish_reason: Option < String > = None ;
@@ -222,16 +221,19 @@ async fn process_message(
222221 // Clone for async block
223222 let chat_clone = chat. clone ( ) ;
224223
225- // Child tool messages are created if LLM requires a set of tools to be invoked for responding
226- // to a user message.
227- let mut child_tool_messages: Vec < ChatMessage > = vec ! [ ] ;
224+ // Store tool call responses if LLM requires a set of tools to be invoked for responding to a
225+ // user message.
226+ let mut assistant_message_text = String :: new ( ) ;
227+ let mut assistant_message_tools: Vec < ToolCall > = vec ! [ ] ;
228+ let mut assistant_tool_responses: Vec < ChatMessage > = vec ! [ ] ;
229+
228230 let mut stream_error = false ;
229231 let was_cancelled = {
230232 // Get stream response
231233 let mut chat_guard = chat_clone. lock ( ) . await ;
232234 let available_tools = chat_guard. available_tools . clone ( ) ;
233235 let mut stream = {
234- chat_guard. add_messages ( user_messages , tool_messages ) . await ;
236+ chat_guard. add_messages ( messages ) . await ;
235237 chat_guard. stream_response ( cancel_token. clone ( ) ) . await ?
236238 } ;
237239
@@ -268,6 +270,7 @@ async fn process_message(
268270 match response {
269271 Ok ( Completion :: Response ( chunk) ) => {
270272 if !& chunk. text. is_empty( ) {
273+ assistant_message_text. push_str( & chunk. text) ;
271274 renderer. render_markdown( & chunk. text) ?;
272275 }
273276
@@ -277,8 +280,8 @@ async fn process_message(
277280
278281 // Tool messages can come in chunks, we collate all
279282 if let Some ( tools) = & chunk. tool_calls {
280- child_tool_messages =
281- process_tools( & available_tools, tools) . await ?;
283+ assistant_message_tools . extend ( tools . clone ( ) ) ;
284+ assistant_tool_responses = process_tools( & available_tools, tools) . await ?;
282285 }
283286 }
284287 Ok ( Completion :: Metrics ( m) ) => {
@@ -309,16 +312,28 @@ async fn process_message(
309312 return Ok ( true ) ;
310313 }
311314
315+ {
316+ // Add the assistant message to the chat history
317+ let mut chat_guard = chat. lock ( ) . await ;
318+ chat_guard
319+ . add_messages ( vec ! [ ChatMessage {
320+ sender: SenderType :: Assistant ,
321+ text: assistant_message_text,
322+ tools: Some ( assistant_message_tools) ,
323+ metrics: Some ( metrics. clone( ) ) ,
324+ } ] )
325+ . await ;
326+ }
327+
312328 // After a successful stream, flush any remaining partial lines from the renderer.
313329 renderer. render_markdown ( "\n " ) ?;
314330
315331 // If the model produced tool calls, recursively call this function to process them.
316- if !child_tool_messages . is_empty ( ) {
332+ if !assistant_tool_responses . is_empty ( ) {
317333 return Box :: pin ( process_message (
318334 chat_clone,
319335 renderer,
320- vec ! [ ] ,
321- child_tool_messages,
336+ assistant_tool_responses,
322337 ) )
323338 . await ;
324339 }
@@ -448,7 +463,7 @@ async fn process_tools(
448463 tool_messages. push ( ChatMessage {
449464 text : serde_json:: to_string ( & result) ?,
450465 sender : SenderType :: Tool ,
451- tools : vec ! [ ] ,
466+ .. Default :: default ( )
452467 } ) ;
453468 }
454469
@@ -647,13 +662,14 @@ mod tests {
647662 let user_message = ChatMessage {
648663 sender : SenderType :: User ,
649664 text : "Hi" . to_string ( ) ,
650- tools : vec ! [ ] ,
665+ .. Default :: default ( )
651666 } ;
652- process_message ( chat_session, & mut renderer, vec ! [ user_message] , vec ! [ ] ) . await ?;
667+ process_message ( chat_session. clone ( ) , & mut renderer, vec ! [ user_message] ) . await ?;
653668
654669 // 3. Assert rendered output
655670 let output = String :: from_utf8 ( buffer) . unwrap ( ) ;
656671 assert ! ( output. contains( "Hello world" ) ) ;
672+ assert_eq ! ( 2 , chat_session. lock( ) . await . get_all_messages( ) . len( ) ) ;
657673 Ok ( ( ) )
658674 }
659675
@@ -678,12 +694,13 @@ mod tests {
678694 let user_message = ChatMessage {
679695 sender : SenderType :: User ,
680696 text : "Hi" . to_string ( ) ,
681- tools : vec ! [ ] ,
697+ .. Default :: default ( )
682698 } ;
683- process_message ( chat_session, & mut renderer, vec ! [ user_message] , vec ! [ ] ) . await ?;
699+ process_message ( chat_session. clone ( ) , & mut renderer, vec ! [ user_message] ) . await ?;
684700
685701 let output = String :: from_utf8 ( buffer) . unwrap ( ) ;
686702 assert ! ( output. contains( "Tool output is mock tool output" ) ) ;
703+ assert_eq ! ( 4 , chat_session. lock( ) . await . get_all_messages( ) . len( ) ) ;
687704 Ok ( ( ) )
688705 }
689706
@@ -699,9 +716,9 @@ mod tests {
699716 let user_message = ChatMessage {
700717 sender : SenderType :: User ,
701718 text : "Hi" . to_string ( ) ,
702- tools : vec ! [ ] ,
719+ .. Default :: default ( )
703720 } ;
704- process_message ( chat_session, & mut renderer, vec ! [ user_message] , vec ! [ ] ) . await ?;
721+ process_message ( chat_session. clone ( ) , & mut renderer, vec ! [ user_message] ) . await ?;
705722
706723 // Expect no output to renderer, error is printed to stderr
707724 let output = String :: from_utf8 ( buffer) . unwrap ( ) ;
@@ -710,6 +727,7 @@ mod tests {
710727 "Output should be empty. Output: {}" ,
711728 output
712729 ) ;
730+ assert_eq ! ( 1 , chat_session. lock( ) . await . get_all_messages( ) . len( ) ) ;
713731 Ok ( ( ) )
714732 }
715733
0 commit comments