Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/codingcode/src/client/direct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ export async function createDirectClient(llm: LLMClient, rt: AppRuntime): Promis
return clients.sessions.getRollbackState({ sessionId: currentSessionId, cwd: cwd() });
},

async forkSession(atUuid?: string) {
async forkSession(atTurnId?: number) {
if (!currentSessionId) return '';
const result = await clients.sessions.forkSession({
sessionId: currentSessionId,
cwd: cwd(),
atUuid,
atTurnId,
});
return result.sessionId;
},
Expand Down
6 changes: 3 additions & 3 deletions packages/codingcode/src/client/direct/sessions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export interface SessionClient {
forkSession(input: {
sessionId: string;
cwd: string;
atUuid?: string;
atTurnId?: number;
}): Promise<{ sessionId: string; turns: SessionEvent[] }>;
}

Expand Down Expand Up @@ -221,13 +221,13 @@ export function createDirectSessionClient(rt: AppRuntime): SessionClient {
code: { canUndoLast: false, lastEntry: null, revertedFiles: [], lastEntryId: null },
};
},
async forkSession({ sessionId, atUuid }) {
async forkSession({ sessionId, atTurnId }) {
const cwd = await getWorkspaceCwd(rt);
const newSessionId = await rt.runPromise(
Effect.gen(function* () {
const session = yield* SessionService;
const state = yield* session.create(cwd, 'unknown', sessionId);
return yield* session.forkSession(state, atUuid ?? '');
return yield* session.forkSession(state, atTurnId ?? 0);
})
);
return { sessionId: newSessionId, turns: [] as SessionEvent[] };
Expand Down
2 changes: 1 addition & 1 deletion packages/codingcode/src/client/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ export async function createHttpClient(serverUrl: string): Promise<AgentClient>
code: { canUndoLast: false, lastEntry: null, revertedFiles: [], lastEntryId: null },
};
},
async forkSession() {
async forkSession(_atTurnId?: number) {
return '';
},

Expand Down
6 changes: 3 additions & 3 deletions packages/codingcode/src/client/http/sessions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export interface SessionClient {
forkSession(input: {
sessionId: string;
cwd: string;
atUuid?: string;
atTurnId?: number;
}): Promise<{ sessionId: string; turns: SessionEvent[] }>;
}

Expand Down Expand Up @@ -140,8 +140,8 @@ export function createHttpSessionClient(
return apiGet(`/api/sessions/${sessionId}/rollback-state?cwd=${encodeURIComponent(cwd)}`);
},

async forkSession({ sessionId, cwd, atUuid }) {
return apiPost(`/api/sessions/${sessionId}/fork`, { cwd, atUuid });
async forkSession({ sessionId, cwd, atTurnId }) {
return apiPost(`/api/sessions/${sessionId}/fork`, { cwd, atTurnId });
},
};
}
2 changes: 1 addition & 1 deletion packages/codingcode/src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export interface AgentClient {
}>;
undoLastCodeRollback(force?: boolean, files?: string[]): Promise<CodeRollbackUndoResult>;
getRollbackState(): Promise<RollbackState>;
forkSession(atUuid?: string): Promise<string>;
forkSession(atTurnId?: number): Promise<string>;
compact(): Promise<void>;
getMemoryEnabled(): Promise<boolean>;
setMemoryEnabled(enabled: boolean): Promise<void>;
Expand Down
5 changes: 3 additions & 2 deletions packages/codingcode/src/server/routes/sessions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -454,18 +454,19 @@ export function createSessionsRouter(rt: ManagedRt): Hono {

router.post('/:id/fork', async (c) => {
const sessionId = c.req.param('id');
const body = (await c.req.json()) as { cwd: string; atUuid?: string };
const body = (await c.req.json()) as { cwd: string; atTurnId?: number };
const cwd = await rt.runPromise(
Effect.gen(function* () {
const ws = yield* WorkspaceService;
return ws.resolveWorkspaceCwd(body.cwd);
})
);
const atTurnId = body.atTurnId ?? 0;
const result = await runWithLayer(
Effect.gen(function* () {
const session = yield* SessionService;
const state = yield* session.create(cwd, 'unknown', sessionId);
const newSessionId = yield* session.forkSession(state, body.atUuid ?? '');
const newSessionId = yield* session.forkSession(state, atTurnId);
const turns = readUIHistory(newSessionId);
return { sessionId: newSessionId, turns };
}) as any
Expand Down
10 changes: 6 additions & 4 deletions packages/codingcode/src/session/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,10 @@ export class SessionService extends Effect.Service<SessionService>()('Session',

const forkSession = (
state: SessionStoreState,
atUuid: string
atTurnId: number
): Effect.Effect<string, AgentError> =>
Effect.sync(() => {
return forkSessionImpl(state.sessionId, state.transcriptPath, atUuid);
return forkSessionImpl(state.sessionId, state.transcriptPath, atTurnId);
});

const renameSession = (
Expand Down Expand Up @@ -485,9 +485,11 @@ function initState(cwd: string, sessionId?: string, parentSessionId?: string): S
};
}

function forkSessionImpl(sourceSessionId: string, sourceJsonlPath: string, atUuid: string): string {
function forkSessionImpl(sourceSessionId: string, sourceJsonlPath: string, atTurnId: number): string {
const events = readHistory(sourceJsonlPath);
const atIdx = atUuid ? events.findIndex((e) => 'uuid' in e && (e as any).uuid === atUuid) : -1;
const atIdx = events.findIndex(
(e) => e.type === 'user' && (e as any).turnId === atTurnId
);

const chain = atIdx >= 0 ? events.slice(0, atIdx + 1) : events;
const newSessionId = randomUUID();
Expand Down
12 changes: 6 additions & 6 deletions packages/codingcode/test/session/fork.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function run<T>(eff: Effect.Effect<T, any, any>): Promise<T> {
}

describe('forkSession', () => {
it('fork copies events from root to atUuid', async () => {
it('fork copies events from root to atTurnId', async () => {
const sessionId = randomUUID();
const slug = randomUUID();
const fx = makeFixture(sessionId, slug);
Expand All @@ -122,11 +122,11 @@ describe('forkSession', () => {
memorySnapshot: '',
};

// Fork at u2 (turn 2 start)
// Fork at turn 2 (user message "second")
const newSessionId = await run(
Effect.gen(function* () {
const svc = yield* SessionService;
return yield* svc.forkSession(state, 'u2');
return yield* svc.forkSession(state, 2);
})
);

Expand Down Expand Up @@ -169,7 +169,7 @@ describe('forkSession', () => {
const newSessionId = await run(
Effect.gen(function* () {
const svc = yield* SessionService;
return yield* svc.forkSession(state, 'u2');
return yield* svc.forkSession(state, 2);
})
);

Expand Down Expand Up @@ -214,7 +214,7 @@ describe('forkSession', () => {
const newSessionId = await run(
Effect.gen(function* () {
const svc = yield* SessionService;
return yield* svc.forkSession(state, 'u2');
return yield* svc.forkSession(state, 2);
})
);

Expand Down Expand Up @@ -277,7 +277,7 @@ describe('forkSession', () => {
const newSessionId = await run(
Effect.gen(function* () {
const svc = yield* SessionService;
return yield* svc.forkSession(state, 'a1');
return yield* svc.forkSession(state, 1);
})
);

Expand Down
4 changes: 2 additions & 2 deletions packages/codingcode/test/session/prompt-estimate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ describe('promptEstimate', () => {
const newSessionId = await run(
Effect.gen(function* () {
const svc = yield* SessionService;
return yield* svc.forkSession(state, 'a1');
return yield* svc.forkSession(state, 2);
})
);
const newIndexPath = join(fx.dir, `${newSessionId}.index.json`);
Expand Down Expand Up @@ -253,7 +253,7 @@ describe('promptEstimate', () => {
const newSessionId = await run(
Effect.gen(function* () {
const svc = yield* SessionService;
return yield* svc.forkSession(state, 'u2');
return yield* svc.forkSession(state, 2);
})
);
const newIndexPath = join(fx.dir, `${newSessionId}.index.json`);
Expand Down
67 changes: 36 additions & 31 deletions packages/desktop/src/agent/MessageStream.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
const { copiedId, copy } = useCopyToClipboard();
const parentRef = useRef<HTMLDivElement>(null);
const didScrollToEndRef = useRef(false);
const loadedCheckpointRef = useRef<string | null>(null);
const markFileRestored = useGlobalStore((s) => s.markFileRestored);
const setPendingInput = useGlobalStore((s) => s.setPendingInput);

Expand Down Expand Up @@ -341,7 +342,7 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
(i) => i.type === 'message' && (i as any).role === 'user'
);
const userContent = userMsg && 'content' in userMsg ? (userMsg as any).content : '';
const newSessionId = await forkThread(threadId, lastItem.id);
const newSessionId = await forkThread(threadId, Number(turn.id));
if (newSessionId) {
setCurrentThread(newSessionId);
if (userContent) setPendingInput(userContent);
Expand All @@ -363,16 +364,23 @@ export default function MessageStream({ threadId }: MessageStreamProps) {
setPendingInput,
]);

const getItemKey = useCallback(
(index: number) => renderEntries[index]?.key ?? `empty-${index}`,
[renderEntries]
);

const getScrollElement = useCallback(() => parentRef.current, []);

const virtualizer = useVirtualizer({
count: renderEntries.length,
getScrollElement: () => parentRef.current,
estimateSize: () => 60,
getItemKey: (index: number) => renderEntries[index]?.key ?? `empty-${index}`,
getScrollElement,
estimateSize: useCallback(() => 60, []),
getItemKey,
overscan: 5,
anchorTo: 'end',
followOnAppend: 'smooth',
scrollEndThreshold: 80,
initialOffset: () => Number.MAX_SAFE_INTEGER,
initialOffset: useCallback(() => Number.MAX_SAFE_INTEGER, []),
});

useLayoutEffect(() => {
Expand All @@ -384,34 +392,31 @@ export default function MessageStream({ threadId }: MessageStreamProps) {

const turnStatusKey = useMemo(() => turns.map((t) => `${t.id}:${t.status}`).join(','), [turns]);

const handleLoadDiff = useCallback(
async (uiTurnId: string) => {
const diff = await loadCheckpointDiff(threadId);
if (diff.turnId > 0) {
const state = useGlobalStore.getState();
const mapping = state.rollback.turnCheckpointMapping[threadId];
if (mapping?.[diff.turnId] !== uiTurnId) {
state.setTurnCheckpointMapping(threadId, diff.turnId, uiTurnId);
}
}
},
[threadId, loadCheckpointDiff]
);
useEffect(() => {
loadedCheckpointRef.current = null;
}, [threadId]);

useEffect(() => {
for (const turn of turns) {
if (turn.status !== 'completed' && turn.status !== 'error') continue;
const ckKey = getCheckpointKey(
threadId,
turn.id,
useGlobalStore.getState().rollback.checkpointDiffByTurnId,
useGlobalStore.getState().rollback.turnCheckpointMapping[threadId] ?? EMPTY_MAPPING
);
if (!ckKey) {
handleLoadDiff(turn.id);
}
}
}, [turnStatusKey, threadId, handleLoadDiff]);
const completedTurnIds = turns
.filter((t) => t.status === 'completed' || t.status === 'error')
.map((t) => t.id);
if (completedTurnIds.length === 0) return;

const loadKey = `${threadId}:${completedTurnIds.join(',')}`;
if (loadedCheckpointRef.current === loadKey) return;
loadedCheckpointRef.current = loadKey;

const state = useGlobalStore.getState();
const existingMapping = state.rollback.turnCheckpointMapping[threadId] ?? EMPTY_MAPPING;
const existingDiffs = state.rollback.checkpointDiffByTurnId;

const alreadyLoaded = completedTurnIds.some((id) =>
getCheckpointKey(threadId, id, existingDiffs, existingMapping) !== null
);
if (alreadyLoaded) return;

loadCheckpointDiff(threadId);
}, [turnStatusKey, threadId, loadCheckpointDiff]);

const handleRevertFile = useCallback(
async (uiTurnId: string, file: string, isReverted: boolean) => {
Expand Down
9 changes: 3 additions & 6 deletions packages/desktop/src/hooks/useAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import type {
CheckpointDiff,
CodeRollbackResult,
CodeRollbackUndoResult,
RollbackPreviewDiff,
SessionRollbackState,
} from '../lib/core-api';
import type { Item, Turn, Project } from '@shared/types';
Expand Down Expand Up @@ -356,7 +355,6 @@ export function useAgentRollback() {
const revertedFilesByTurnId = useGlobalStore((s) => s.rollback.revertedFilesByTurnId);
const setRollbackState = useGlobalStore((s) => s.setRollbackState);
const setCheckpointDiff = useGlobalStore((s) => s.setCheckpointDiff);
const setRollbackPreview = useGlobalStore((s) => s.setRollbackPreview);
const markFileReverted = useGlobalStore((s) => s.markFileReverted);
const markFileRestored = useGlobalStore((s) => s.markFileRestored);
const setTurnCheckpointMapping = useGlobalStore((s) => s.setTurnCheckpointMapping);
Expand Down Expand Up @@ -422,10 +420,9 @@ export function useAgentRollback() {
async (threadId: string, throughTurnId: number) => {
const cwd = useGlobalStore.getState().agent.threads[threadId]?.cwd ?? workspace.rootPath;
const preview = await previewRollbackDiff(threadId, cwd, throughTurnId);
setRollbackPreview(threadId, preview);
return preview;
},
[workspace.rootPath, setRollbackPreview]
[workspace.rootPath]
);

const rollbackCode = useCallback(
Expand Down Expand Up @@ -495,9 +492,9 @@ export function useAgentRollback() {
);

const forkThread = useCallback(
async (threadId: string, atUuid?: string) => {
async (threadId: string, atTurnId?: number) => {
const cwd = useGlobalStore.getState().agent.threads[threadId]?.cwd ?? workspace.rootPath;
const res = await forkSession(threadId, cwd, atUuid);
const res = await forkSession(threadId, cwd, atTurnId);
return res.sessionId;
},
[workspace.rootPath]
Expand Down
4 changes: 2 additions & 2 deletions packages/desktop/src/lib/core-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,9 @@ export function getRollbackState(sessionId: string, cwd: string): Promise<Sessio
export function forkSession(
sessionId: string,
cwd: string,
atUuid?: string
atTurnId?: number
): Promise<{ sessionId: string; turns: any[] }> {
return clients.sessions.forkSession({ sessionId, cwd, atUuid });
return clients.sessions.forkSession({ sessionId, cwd, atTurnId });
}

// ---- Automations ----
Expand Down
Loading
Loading