Skip to content

Commit 3ec69fa

Browse files
authored
Merge pull request #96 from phantom5099/memory
refactor: Refactor the conversation branching feature, using turnId i…
2 parents aee7273 + f0d73e9 commit 3ec69fa

18 files changed

Lines changed: 176 additions & 196 deletions

packages/codingcode/src/client/direct.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,12 +336,12 @@ export async function createDirectClient(llm: LLMClient, rt: AppRuntime): Promis
336336
return clients.sessions.getRollbackState({ sessionId: currentSessionId, cwd: cwd() });
337337
},
338338

339-
async forkSession(atUuid?: string) {
339+
async forkSession(atTurnId?: number) {
340340
if (!currentSessionId) return '';
341341
const result = await clients.sessions.forkSession({
342342
sessionId: currentSessionId,
343343
cwd: cwd(),
344-
atUuid,
344+
atTurnId,
345345
});
346346
return result.sessionId;
347347
},

packages/codingcode/src/client/direct/sessions.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ export interface SessionClient {
6565
forkSession(input: {
6666
sessionId: string;
6767
cwd: string;
68-
atUuid?: string;
68+
atTurnId?: number;
6969
}): Promise<{ sessionId: string; turns: SessionEvent[] }>;
7070
}
7171

@@ -221,13 +221,13 @@ export function createDirectSessionClient(rt: AppRuntime): SessionClient {
221221
code: { canUndoLast: false, lastEntry: null, revertedFiles: [], lastEntryId: null },
222222
};
223223
},
224-
async forkSession({ sessionId, atUuid }) {
224+
async forkSession({ sessionId, atTurnId }) {
225225
const cwd = await getWorkspaceCwd(rt);
226226
const newSessionId = await rt.runPromise(
227227
Effect.gen(function* () {
228228
const session = yield* SessionService;
229229
const state = yield* session.create(cwd, 'unknown', sessionId);
230-
return yield* session.forkSession(state, atUuid ?? '');
230+
return yield* session.forkSession(state, atTurnId ?? 0);
231231
})
232232
);
233233
return { sessionId: newSessionId, turns: [] as SessionEvent[] };

packages/codingcode/src/client/http.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ export async function createHttpClient(serverUrl: string): Promise<AgentClient>
210210
code: { canUndoLast: false, lastEntry: null, revertedFiles: [], lastEntryId: null },
211211
};
212212
},
213-
async forkSession() {
213+
async forkSession(_atTurnId?: number) {
214214
return '';
215215
},
216216

packages/codingcode/src/client/http/sessions.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export interface SessionClient {
6161
forkSession(input: {
6262
sessionId: string;
6363
cwd: string;
64-
atUuid?: string;
64+
atTurnId?: number;
6565
}): Promise<{ sessionId: string; turns: SessionEvent[] }>;
6666
}
6767

@@ -140,8 +140,8 @@ export function createHttpSessionClient(
140140
return apiGet(`/api/sessions/${sessionId}/rollback-state?cwd=${encodeURIComponent(cwd)}`);
141141
},
142142

143-
async forkSession({ sessionId, cwd, atUuid }) {
144-
return apiPost(`/api/sessions/${sessionId}/fork`, { cwd, atUuid });
143+
async forkSession({ sessionId, cwd, atTurnId }) {
144+
return apiPost(`/api/sessions/${sessionId}/fork`, { cwd, atTurnId });
145145
},
146146
};
147147
}

packages/codingcode/src/client/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export interface AgentClient {
5050
}>;
5151
undoLastCodeRollback(force?: boolean, files?: string[]): Promise<CodeRollbackUndoResult>;
5252
getRollbackState(): Promise<RollbackState>;
53-
forkSession(atUuid?: string): Promise<string>;
53+
forkSession(atTurnId?: number): Promise<string>;
5454
compact(): Promise<void>;
5555
getMemoryEnabled(): Promise<boolean>;
5656
setMemoryEnabled(enabled: boolean): Promise<void>;

packages/codingcode/src/server/routes/sessions.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,18 +454,19 @@ export function createSessionsRouter(rt: ManagedRt): Hono {
454454

455455
router.post('/:id/fork', async (c) => {
456456
const sessionId = c.req.param('id');
457-
const body = (await c.req.json()) as { cwd: string; atUuid?: string };
457+
const body = (await c.req.json()) as { cwd: string; atTurnId?: number };
458458
const cwd = await rt.runPromise(
459459
Effect.gen(function* () {
460460
const ws = yield* WorkspaceService;
461461
return ws.resolveWorkspaceCwd(body.cwd);
462462
})
463463
);
464+
const atTurnId = body.atTurnId ?? 0;
464465
const result = await runWithLayer(
465466
Effect.gen(function* () {
466467
const session = yield* SessionService;
467468
const state = yield* session.create(cwd, 'unknown', sessionId);
468-
const newSessionId = yield* session.forkSession(state, body.atUuid ?? '');
469+
const newSessionId = yield* session.forkSession(state, atTurnId);
469470
const turns = readUIHistory(newSessionId);
470471
return { sessionId: newSessionId, turns };
471472
}) as any

packages/codingcode/src/session/store.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,10 @@ export class SessionService extends Effect.Service<SessionService>()('Session',
350350

351351
const forkSession = (
352352
state: SessionStoreState,
353-
atUuid: string
353+
atTurnId: number
354354
): Effect.Effect<string, AgentError> =>
355355
Effect.sync(() => {
356-
return forkSessionImpl(state.sessionId, state.transcriptPath, atUuid);
356+
return forkSessionImpl(state.sessionId, state.transcriptPath, atTurnId);
357357
});
358358

359359
const renameSession = (
@@ -485,9 +485,11 @@ function initState(cwd: string, sessionId?: string, parentSessionId?: string): S
485485
};
486486
}
487487

488-
function forkSessionImpl(sourceSessionId: string, sourceJsonlPath: string, atUuid: string): string {
488+
function forkSessionImpl(sourceSessionId: string, sourceJsonlPath: string, atTurnId: number): string {
489489
const events = readHistory(sourceJsonlPath);
490-
const atIdx = atUuid ? events.findIndex((e) => 'uuid' in e && (e as any).uuid === atUuid) : -1;
490+
const atIdx = events.findIndex(
491+
(e) => e.type === 'user' && (e as any).turnId === atTurnId
492+
);
491493

492494
const chain = atIdx >= 0 ? events.slice(0, atIdx + 1) : events;
493495
const newSessionId = randomUUID();

packages/codingcode/test/session/fork.test.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function run<T>(eff: Effect.Effect<T, any, any>): Promise<T> {
102102
}
103103

104104
describe('forkSession', () => {
105-
it('fork copies events from root to atUuid', async () => {
105+
it('fork copies events from root to atTurnId', async () => {
106106
const sessionId = randomUUID();
107107
const slug = randomUUID();
108108
const fx = makeFixture(sessionId, slug);
@@ -122,11 +122,11 @@ describe('forkSession', () => {
122122
memorySnapshot: '',
123123
};
124124

125-
// Fork at u2 (turn 2 start)
125+
// Fork at turn 2 (user message "second")
126126
const newSessionId = await run(
127127
Effect.gen(function* () {
128128
const svc = yield* SessionService;
129-
return yield* svc.forkSession(state, 'u2');
129+
return yield* svc.forkSession(state, 2);
130130
})
131131
);
132132

@@ -169,7 +169,7 @@ describe('forkSession', () => {
169169
const newSessionId = await run(
170170
Effect.gen(function* () {
171171
const svc = yield* SessionService;
172-
return yield* svc.forkSession(state, 'u2');
172+
return yield* svc.forkSession(state, 2);
173173
})
174174
);
175175

@@ -214,7 +214,7 @@ describe('forkSession', () => {
214214
const newSessionId = await run(
215215
Effect.gen(function* () {
216216
const svc = yield* SessionService;
217-
return yield* svc.forkSession(state, 'u2');
217+
return yield* svc.forkSession(state, 2);
218218
})
219219
);
220220

@@ -277,7 +277,7 @@ describe('forkSession', () => {
277277
const newSessionId = await run(
278278
Effect.gen(function* () {
279279
const svc = yield* SessionService;
280-
return yield* svc.forkSession(state, 'a1');
280+
return yield* svc.forkSession(state, 1);
281281
})
282282
);
283283

packages/codingcode/test/session/prompt-estimate.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ describe('promptEstimate', () => {
219219
const newSessionId = await run(
220220
Effect.gen(function* () {
221221
const svc = yield* SessionService;
222-
return yield* svc.forkSession(state, 'a1');
222+
return yield* svc.forkSession(state, 2);
223223
})
224224
);
225225
const newIndexPath = join(fx.dir, `${newSessionId}.index.json`);
@@ -253,7 +253,7 @@ describe('promptEstimate', () => {
253253
const newSessionId = await run(
254254
Effect.gen(function* () {
255255
const svc = yield* SessionService;
256-
return yield* svc.forkSession(state, 'u2');
256+
return yield* svc.forkSession(state, 2);
257257
})
258258
);
259259
const newIndexPath = join(fx.dir, `${newSessionId}.index.json`);

packages/desktop/src/agent/MessageStream.tsx

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
230230
const { copiedId, copy } = useCopyToClipboard();
231231
const parentRef = useRef<HTMLDivElement>(null);
232232
const didScrollToEndRef = useRef(false);
233+
const loadedCheckpointRef = useRef<string | null>(null);
233234
const markFileRestored = useGlobalStore((s) => s.markFileRestored);
234235
const setPendingInput = useGlobalStore((s) => s.setPendingInput);
235236

@@ -341,7 +342,7 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
341342
(i) => i.type === 'message' && (i as any).role === 'user'
342343
);
343344
const userContent = userMsg && 'content' in userMsg ? (userMsg as any).content : '';
344-
const newSessionId = await forkThread(threadId, lastItem.id);
345+
const newSessionId = await forkThread(threadId, Number(turn.id));
345346
if (newSessionId) {
346347
setCurrentThread(newSessionId);
347348
if (userContent) setPendingInput(userContent);
@@ -363,16 +364,23 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
363364
setPendingInput,
364365
]);
365366

367+
const getItemKey = useCallback(
368+
(index: number) => renderEntries[index]?.key ?? `empty-${index}`,
369+
[renderEntries]
370+
);
371+
372+
const getScrollElement = useCallback(() => parentRef.current, []);
373+
366374
const virtualizer = useVirtualizer({
367375
count: renderEntries.length,
368-
getScrollElement: () => parentRef.current,
369-
estimateSize: () => 60,
370-
getItemKey: (index: number) => renderEntries[index]?.key ?? `empty-${index}`,
376+
getScrollElement,
377+
estimateSize: useCallback(() => 60, []),
378+
getItemKey,
371379
overscan: 5,
372380
anchorTo: 'end',
373381
followOnAppend: 'smooth',
374382
scrollEndThreshold: 80,
375-
initialOffset: () => Number.MAX_SAFE_INTEGER,
383+
initialOffset: useCallback(() => Number.MAX_SAFE_INTEGER, []),
376384
});
377385

378386
useLayoutEffect(() => {
@@ -384,34 +392,31 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
384392

385393
const turnStatusKey = useMemo(() => turns.map((t) => `${t.id}:${t.status}`).join(','), [turns]);
386394

387-
const handleLoadDiff = useCallback(
388-
async (uiTurnId: string) => {
389-
const diff = await loadCheckpointDiff(threadId);
390-
if (diff.turnId > 0) {
391-
const state = useGlobalStore.getState();
392-
const mapping = state.rollback.turnCheckpointMapping[threadId];
393-
if (mapping?.[diff.turnId] !== uiTurnId) {
394-
state.setTurnCheckpointMapping(threadId, diff.turnId, uiTurnId);
395-
}
396-
}
397-
},
398-
[threadId, loadCheckpointDiff]
399-
);
395+
useEffect(() => {
396+
loadedCheckpointRef.current = null;
397+
}, [threadId]);
400398

401399
useEffect(() => {
402-
for (const turn of turns) {
403-
if (turn.status !== 'completed' && turn.status !== 'error') continue;
404-
const ckKey = getCheckpointKey(
405-
threadId,
406-
turn.id,
407-
useGlobalStore.getState().rollback.checkpointDiffByTurnId,
408-
useGlobalStore.getState().rollback.turnCheckpointMapping[threadId] ?? EMPTY_MAPPING
409-
);
410-
if (!ckKey) {
411-
handleLoadDiff(turn.id);
412-
}
413-
}
414-
}, [turnStatusKey, threadId, handleLoadDiff]);
400+
const completedTurnIds = turns
401+
.filter((t) => t.status === 'completed' || t.status === 'error')
402+
.map((t) => t.id);
403+
if (completedTurnIds.length === 0) return;
404+
405+
const loadKey = `${threadId}:${completedTurnIds.join(',')}`;
406+
if (loadedCheckpointRef.current === loadKey) return;
407+
loadedCheckpointRef.current = loadKey;
408+
409+
const state = useGlobalStore.getState();
410+
const existingMapping = state.rollback.turnCheckpointMapping[threadId] ?? EMPTY_MAPPING;
411+
const existingDiffs = state.rollback.checkpointDiffByTurnId;
412+
413+
const alreadyLoaded = completedTurnIds.some((id) =>
414+
getCheckpointKey(threadId, id, existingDiffs, existingMapping) !== null
415+
);
416+
if (alreadyLoaded) return;
417+
418+
loadCheckpointDiff(threadId);
419+
}, [turnStatusKey, threadId, loadCheckpointDiff]);
415420

416421
const handleRevertFile = useCallback(
417422
async (uiTurnId: string, file: string, isReverted: boolean) => {

0 commit comments

Comments
 (0)