@@ -230,6 +230,7 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
230230 const { copiedId, copy } = useCopyToClipboard ( ) ;
231231 const parentRef = useRef < HTMLDivElement > ( null ) ;
232232 const didScrollToEndRef = useRef ( false ) ;
233+ const loadedCheckpointRef = useRef < string | null > ( null ) ;
233234 const markFileRestored = useGlobalStore ( ( s ) => s . markFileRestored ) ;
234235 const setPendingInput = useGlobalStore ( ( s ) => s . setPendingInput ) ;
235236
@@ -341,7 +342,7 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
341342 ( i ) => i . type === 'message' && ( i as any ) . role === 'user'
342343 ) ;
343344 const userContent = userMsg && 'content' in userMsg ? ( userMsg as any ) . content : '' ;
344- const newSessionId = await forkThread ( threadId , lastItem . id ) ;
345+ const newSessionId = await forkThread ( threadId , Number ( turn . id ) ) ;
345346 if ( newSessionId ) {
346347 setCurrentThread ( newSessionId ) ;
347348 if ( userContent ) setPendingInput ( userContent ) ;
@@ -363,16 +364,23 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
363364 setPendingInput ,
364365 ] ) ;
365366
367+ const getItemKey = useCallback (
368+ ( index : number ) => renderEntries [ index ] ?. key ?? `empty-${ index } ` ,
369+ [ renderEntries ]
370+ ) ;
371+
372+ const getScrollElement = useCallback ( ( ) => parentRef . current , [ ] ) ;
373+
366374 const virtualizer = useVirtualizer ( {
367375 count : renderEntries . length ,
368- getScrollElement : ( ) => parentRef . current ,
369- estimateSize : ( ) => 60 ,
370- getItemKey : ( index : number ) => renderEntries [ index ] ?. key ?? `empty- ${ index } ` ,
376+ getScrollElement,
377+ estimateSize : useCallback ( ( ) => 60 , [ ] ) ,
378+ getItemKey,
371379 overscan : 5 ,
372380 anchorTo : 'end' ,
373381 followOnAppend : 'smooth' ,
374382 scrollEndThreshold : 80 ,
375- initialOffset : ( ) => Number . MAX_SAFE_INTEGER ,
383+ initialOffset : useCallback ( ( ) => Number . MAX_SAFE_INTEGER , [ ] ) ,
376384 } ) ;
377385
378386 useLayoutEffect ( ( ) => {
@@ -384,34 +392,31 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
384392
385393 const turnStatusKey = useMemo ( ( ) => turns . map ( ( t ) => `${ t . id } :${ t . status } ` ) . join ( ',' ) , [ turns ] ) ;
386394
387- const handleLoadDiff = useCallback (
388- async ( uiTurnId : string ) => {
389- const diff = await loadCheckpointDiff ( threadId ) ;
390- if ( diff . turnId > 0 ) {
391- const state = useGlobalStore . getState ( ) ;
392- const mapping = state . rollback . turnCheckpointMapping [ threadId ] ;
393- if ( mapping ?. [ diff . turnId ] !== uiTurnId ) {
394- state . setTurnCheckpointMapping ( threadId , diff . turnId , uiTurnId ) ;
395- }
396- }
397- } ,
398- [ threadId , loadCheckpointDiff ]
399- ) ;
395+ useEffect ( ( ) => {
396+ loadedCheckpointRef . current = null ;
397+ } , [ threadId ] ) ;
400398
401399 useEffect ( ( ) => {
402- for ( const turn of turns ) {
403- if ( turn . status !== 'completed' && turn . status !== 'error' ) continue ;
404- const ckKey = getCheckpointKey (
405- threadId ,
406- turn . id ,
407- useGlobalStore . getState ( ) . rollback . checkpointDiffByTurnId ,
408- useGlobalStore . getState ( ) . rollback . turnCheckpointMapping [ threadId ] ?? EMPTY_MAPPING
409- ) ;
410- if ( ! ckKey ) {
411- handleLoadDiff ( turn . id ) ;
412- }
413- }
414- } , [ turnStatusKey , threadId , handleLoadDiff ] ) ;
400+ const completedTurnIds = turns
401+ . filter ( ( t ) => t . status === 'completed' || t . status === 'error' )
402+ . map ( ( t ) => t . id ) ;
403+ if ( completedTurnIds . length === 0 ) return ;
404+
405+ const loadKey = `${ threadId } :${ completedTurnIds . join ( ',' ) } ` ;
406+ if ( loadedCheckpointRef . current === loadKey ) return ;
407+ loadedCheckpointRef . current = loadKey ;
408+
409+ const state = useGlobalStore . getState ( ) ;
410+ const existingMapping = state . rollback . turnCheckpointMapping [ threadId ] ?? EMPTY_MAPPING ;
411+ const existingDiffs = state . rollback . checkpointDiffByTurnId ;
412+
413+ const alreadyLoaded = completedTurnIds . some ( ( id ) =>
414+ getCheckpointKey ( threadId , id , existingDiffs , existingMapping ) !== null
415+ ) ;
416+ if ( alreadyLoaded ) return ;
417+
418+ loadCheckpointDiff ( threadId ) ;
419+ } , [ turnStatusKey , threadId , loadCheckpointDiff ] ) ;
415420
416421 const handleRevertFile = useCallback (
417422 async ( uiTurnId : string , file : string , isReverted : boolean ) => {
0 commit comments