From 55f6f7ac086ec089acf228fe208b3dd1c65bd71d Mon Sep 17 00:00:00 2001 From: Erick Date: Sat, 20 Jun 2026 19:27:33 -0700 Subject: [PATCH] feat(memos): chunk batch reflection scoring --- .../core/capture/ALGORITHMS.md | 13 +- .../core/capture/batch-scorer.ts | 21 +- .../core/capture/capture.ts | 75 +++++-- .../tests/helpers/fake-llm.ts | 4 +- .../tests/unit/capture/capture-batch.test.ts | 183 +++++++++++++++--- 5 files changed, 238 insertions(+), 58 deletions(-) diff --git a/apps/memos-local-plugin/core/capture/ALGORITHMS.md b/apps/memos-local-plugin/core/capture/ALGORITHMS.md index 15a15ac64..49c98a105 100644 --- a/apps/memos-local-plugin/core/capture/ALGORITHMS.md +++ b/apps/memos-local-plugin/core/capture/ALGORITHMS.md @@ -78,7 +78,8 @@ priority once reward arrives. ## V7 §3.2 batched variant — `batch-scorer.ts` The per-step path (`reflection-synth.ts` + `alpha-scorer.ts`) issues 2N -LLM calls per N-step episode. `batch-scorer.ts` collapses them into ONE: +LLM calls per N-step episode. `batch-scorer.ts` collapses up to +`batchThreshold` steps into one call: ``` inputs = [{idx, state, action, outcome, reflection, synth_allowed}, …] @@ -91,8 +92,8 @@ Dispatch (in `capture.ts`): | `cfg.batchMode` | `cfg.batchThreshold` | behavior | |-------------------|----------------------|----------| | `per_step` | (ignored) | legacy: 2N calls | -| `per_episode` | (ignored) | always batch | -| `auto` (default) | `12` | batch when `N ≤ 12`; else per-step | +| `per_episode` | chunk size | batch when `N ≤ threshold`; else chunk-batch | +| `auto` (default) | `12` | batch when `N ≤ 12`; else chunk-batch | The dispatcher also refuses to batch when no LLM is wired — same fallback path as missing-LLM in per-step mode. @@ -107,15 +108,15 @@ Failure handling: - LLM throws / facade gives up after `malformedRetries=1` → capture catches in `runBatchScoring`, surfaces a `{stage: "batch"}` warning, - and the per-step path runs as a fallback. + and the per-step path runs as a fallback for that chunk. - Validator rejects on length mismatch, missing/non-numeric `alpha`, non-boolean `usable`, non-string `reflection_text`. Same fallback. Bookkeeping (`CaptureResult.llmCalls`): -- `batchedReflection`: 0 or 1 per episode (1 on a successful batch). +- `batchedReflection`: number of successful batch/chunk calls. - `reflectionSynth` / `alphaScoring`: only nonzero when the per-step path - ran (either selected directly, or as fallback after a batch failure). + ran (either selected directly, or as fallback after a chunk failure). Stable prompt fingerprint: diff --git a/apps/memos-local-plugin/core/capture/batch-scorer.ts b/apps/memos-local-plugin/core/capture/batch-scorer.ts index e7b8ab50f..da434c3b2 100644 --- a/apps/memos-local-plugin/core/capture/batch-scorer.ts +++ b/apps/memos-local-plugin/core/capture/batch-scorer.ts @@ -16,11 +16,12 @@ * `transferability` axes benefit directly. * * Trade-offs (encoded in capture.ts dispatch): - * - Prompt grows linearly with N steps. Capped via `batchThreshold`; - * long episodes degrade to the per-step path automatically. - * - One bad output value forces a single batched retry instead of N - * isolated retries — but the facade already does `malformedRetries` - * for us, and on hard failure capture.ts falls back to per-step. + * - Prompt grows linearly with N steps. Each call is capped at + * `batchThreshold`; long episodes run as several bounded chunks. + * - One bad chunk forces a single batched retry for that chunk instead + * of N isolated retries — but the facade already does + * `malformedRetries` for us, and on hard failure capture.ts falls + * back to per-step for that chunk only. * * Wire format ↔ prompt: * Send `{ host_context?, task_context?, steps: [{idx, state, action, outcome, reflection, synth_allowed}] }`. @@ -170,6 +171,7 @@ export async function batchScoreReflections( validate: (v) => validateBatchPayload(v, inputs.length), malformedRetries: 1, temperature: 0, + maxTokens: batchMaxTokens(inputs.length), }, ); @@ -321,6 +323,15 @@ function validateBatchPayload(v: unknown, expected: number): void { } } +function batchMaxTokens(stepCount: number): number { + // Batch output scales with step count; keep a per-step budget but cap below + // the 16k range that triggered avoidable reasoning spend on mimo replay. + const perStepOutputBudget = 512; + const baseBudget = 768; + const ceiling = 8_192; + return Math.min(ceiling, baseBudget + Math.max(1, stepCount) * perStepOutputBudget); +} + function lastToolOutcome(step: NormalizedStep, max: number): string { const last = step.toolCalls[step.toolCalls.length - 1]; if (!last) return "(assistant-only step)"; diff --git a/apps/memos-local-plugin/core/capture/capture.ts b/apps/memos-local-plugin/core/capture/capture.ts index 9d52f749e..8c26276ac 100644 --- a/apps/memos-local-plugin/core/capture/capture.ts +++ b/apps/memos-local-plugin/core/capture/capture.ts @@ -435,14 +435,14 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { } // Batch reflection + α across every step of the now-closed - // episode. Falls back to per-step scoring when over the threshold - // or when batching fails / no LLM is wired. The reflect pass uses + // episode. Long episodes are chunk-batched at `batchThreshold`; + // failed chunks fall back to per-step scoring. The reflect pass uses // `reflectLlm` (skill-evolver model when configured) for higher // quality reflections; per-turn lite capture still uses `llm`. const reflectStart = now(); const rLlm = deps.reflectLlm ?? deps.llm; - const useBatch = shouldBatch(deps.cfg, normalized.length, rLlm !== null); - const contextEnabled = contextModeFor(deps.cfg, useBatch, normalized.length); + const scoringPlan = planScoring(deps.cfg, normalized.length, rLlm !== null); + const contextEnabled = contextModeFor(deps.cfg, scoringPlan, normalized.length); const taskSummary = contextEnabled.includeTask ? buildTaskReflectionSummary(input.episode, normalized, deps.cfg.taskContextMaxChars) : null; @@ -453,7 +453,10 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { episodeId: input.episode.id, sessionId: input.episode.sessionId, steps: normalized.length, - mode: useBatch ? "batch" : contextEnabled.includeDownstream ? "per_step_downstream" : "per_step", + mode: scoringPlan === "per_step" && contextEnabled.includeDownstream ? "per_step_downstream" : scoringPlan, + chunks: scoringPlan === "chunk_batch" + ? Math.ceil(normalized.length / Math.max(1, deps.cfg.batchThreshold)) + : undefined, reflectionContextMode: deps.cfg.reflectionContextMode, downstreamPreview: contextEnabled.includeDownstream, provider: rLlm?.provider ?? "none", @@ -461,10 +464,13 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { taskSummary: taskSummary ? taskSummary.slice(0, 240) : null, }); let scored: ScoredStep[] = []; - if (useBatch) { + if (scoringPlan === "batch") { scored = await runBatchScoring(normalized, rLlm!, deps, warnings, llmCalls, input.episode.id, taskSummary); } - if (!useBatch || scored.length === 0) { + if (scoringPlan === "chunk_batch") { + scored = await runChunkedBatchScoring(normalized, rLlm!, deps, warnings, llmCalls, input.episode.id, taskSummary); + } + if (scoringPlan === "per_step" || scored.length === 0) { scored = await runPerStepScoring( normalized, rLlm, @@ -1018,30 +1024,30 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { // ─── helpers ──────────────────────────────────────────────────────────────── /** - * Decide whether to use the batched reflection+α path. + * Decide which reflection+α path to use. * * `per_step` → never (legacy path). - * `per_episode` → always, when an LLM is available. - * `auto` → batch when step count fits inside `batchThreshold`. + * `per_episode` → batch up to threshold, then chunk-batch. + * `auto` → batch up to threshold, then chunk-batch. */ -function shouldBatch(cfg: CaptureConfig, stepCount: number, hasLlm: boolean): boolean { - if (!hasLlm) return false; - if (stepCount === 0) return false; - if (cfg.batchMode === "per_step") return false; - if (cfg.batchMode === "per_episode") return true; - // "auto" - return stepCount <= cfg.batchThreshold; +type ScoringPlan = "per_step" | "batch" | "chunk_batch"; + +function planScoring(cfg: CaptureConfig, stepCount: number, hasLlm: boolean): ScoringPlan { + if (!hasLlm) return "per_step"; + if (stepCount === 0) return "per_step"; + if (cfg.batchMode === "per_step") return "per_step"; + return stepCount <= Math.max(1, cfg.batchThreshold) ? "batch" : "chunk_batch"; } function contextModeFor( cfg: CaptureConfig, - useBatch: boolean, + scoringPlan: ScoringPlan, stepCount: number, ): { includeTask: boolean; includeDownstream: boolean } { const mode = cfg.reflectionContextMode; const includeTask = mode === "task" || mode === "task_downstream"; const wantsDownstream = mode === "downstream" || mode === "task_downstream"; - const longPerStep = !useBatch && stepCount > cfg.batchThreshold; + const longPerStep = scoringPlan === "per_step" && stepCount > cfg.batchThreshold; const includeDownstream = wantsDownstream && cfg.longEpisodeReflectMode === "per_step_downstream" && @@ -1101,6 +1107,37 @@ async function runBatchScoring( } } +async function runChunkedBatchScoring( + normalized: NormalizedStep[], + llm: LlmClient, + deps: CaptureDeps, + warnings: CaptureResult["warnings"], + llmCalls: { reflectionSynth: number; alphaScoring: number; batchedReflection: number }, + episodeId: string, + taskSummary: string | null, +): Promise { + const chunkSize = Math.max(1, deps.cfg.batchThreshold); + const chunks: NormalizedStep[][] = []; + for (let start = 0; start < normalized.length; start += chunkSize) { + chunks.push(normalized.slice(start, start + chunkSize)); + } + const concurrency = Math.max(1, deps.cfg.llmConcurrency); + const scoredChunks = await runConcurrently(chunks, concurrency, async (chunk): Promise => { + const scored = await runBatchScoring(chunk, llm, deps, warnings, llmCalls, episodeId, taskSummary); + if (scored.length > 0) return scored; + return runPerStepScoring( + chunk, + llm, + deps, + warnings, + llmCalls, + episodeId, + buildReflectionContexts(chunk, taskSummary, chunk.map(() => [])), + ); + }); + return scoredChunks.flat(); +} + async function runPerStepScoring( normalized: NormalizedStep[], llm: LlmClient | null, diff --git a/apps/memos-local-plugin/tests/helpers/fake-llm.ts b/apps/memos-local-plugin/tests/helpers/fake-llm.ts index 22d9fc1a3..ec2c00261 100644 --- a/apps/memos-local-plugin/tests/helpers/fake-llm.ts +++ b/apps/memos-local-plugin/tests/helpers/fake-llm.ts @@ -20,7 +20,7 @@ export interface FakeLlmScript { complete?: Record string | Promise)>; completeJson?: Record< string, - unknown | ((input: unknown) => unknown | Promise) + unknown | ((input: unknown, opts?: unknown) => unknown | Promise) >; /** Override the served-by identifier. */ servedBy?: LlmProviderName | "host_fallback"; @@ -64,7 +64,7 @@ export function fakeLlm(script: FakeLlmScript = {}): LlmClient { throw new Error(`fakeLlm: no completeJson mock for op="${op}"`); } const value = (typeof entry === "function" - ? await (entry as (x: unknown) => unknown)(input) + ? await (entry as (x: unknown, o?: unknown) => unknown)(input, opts) : entry) as T; if (o?.validate) o.validate(value); return { diff --git a/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts b/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts index d86290517..bc4b76f28 100644 --- a/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts +++ b/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts @@ -7,14 +7,15 @@ * 2. existing reflections are preserved verbatim; * 3. synth-disabled steps stay at α=0 even when the LLM tries to write * one for them; - * 4. `auto` mode falls back to per-step when stepCount > batchThreshold; - * 5. a malformed batched response degrades into the per-step path - * instead of crashing capture. + * 4. `auto` mode chunk-batches when stepCount > batchThreshold; + * 5. a malformed chunk degrades only that chunk into the per-step path + * instead of dropping the whole episode to per-step. */ import { afterEach, beforeAll, beforeEach, describe, expect, it } from "vitest"; import { createCaptureRunner, type CaptureRunner } from "../../../core/capture/capture.js"; +import { batchScoreReflections } from "../../../core/capture/batch-scorer.js"; import { createCaptureEventBus } from "../../../core/capture/events.js"; import { BATCH_REFLECTION_PROMPT, @@ -312,20 +313,31 @@ describe("capture/pipeline (batched ρ+α path)", () => { expect(t.alpha).toBe(0); // V7 disabledScore semantics }); - it("auto mode falls back to per-step when stepCount > batchThreshold", async () => { + it("auto mode chunk-batches when stepCount > batchThreshold", async () => { + const batchStates: string[][] = []; const llm = fakeLlm({ completeJson: { - // ONLY per-step alpha mock; if batched gets called, the test fails - // with "no completeJson mock for op=...batch...". - [alphaOp]: { alpha: 0.5, usable: true, reason: "ok" }, - }, - complete: { - "capture.reflection.synth": "I made this decision deliberately.", + [batchOp]: (input) => { + const messages = input as Array<{ role: string; content: string }>; + const payload = JSON.parse(messages[messages.length - 1]!.content) as { + steps: Array<{ idx: number; state: string }>; + }; + batchStates.push(payload.steps.map((s) => s.state)); + return { + scores: payload.steps.map((step) => ({ + idx: step.idx, + reflection_text: `reflection ${step.state}`, + alpha: step.idx === 0 ? 0.2 : 0.4, + usable: true, + reason: "ok", + })), + }; + }, }, }); const runner = buildRunner({ batchMode: "auto", batchThreshold: 2 }, llm); - // 3 steps → above threshold → per-step path. + // 3 steps → above threshold → two bounded batch chunks. const ep = episodeSnapshot({ id: "ep_1", sessionId: "se_1", @@ -341,10 +353,17 @@ describe("capture/pipeline (batched ρ+α path)", () => { const result = await runCapture(runner, ep); expect(result.traceIds).toHaveLength(3); - expect(result.llmCalls.batchedReflection).toBe(0); - // 3 synth + 3 alpha calls in per-step mode. - expect(result.llmCalls.reflectionSynth).toBe(3); - expect(result.llmCalls.alphaScoring).toBe(3); + expect(batchStates).toEqual([["a", "b"], ["c"]]); + expect(result.llmCalls.batchedReflection).toBe(2); + expect(result.llmCalls.reflectionSynth).toBe(0); + expect(result.llmCalls.alphaScoring).toBe(0); + + const rows = result.traceIds.map((id) => tmp.repos.traces.getById(id)!); + expect(rows.map((row) => row.reflection)).toEqual([ + "reflection a", + "reflection b", + "reflection c", + ]); }); it("long per-step downstream mode injects up to three following steps", async () => { @@ -368,7 +387,7 @@ describe("capture/pipeline (batched ρ+α path)", () => { }); const runner = buildRunner( { - batchMode: "auto", + batchMode: "per_step", batchThreshold: 2, reflectionContextMode: "task_downstream", longEpisodeReflectMode: "per_step_downstream", @@ -415,16 +434,27 @@ describe("capture/pipeline (batched ρ+α path)", () => { expect(step3Prompt).not.toContain("[step+3]"); }); - it("per_episode mode batches even when step count is large", async () => { - const scores = Array.from({ length: 5 }, (_, i) => ({ - idx: i, - reflection_text: `reflection #${i}`, - alpha: 0.4, - usable: true, - reason: "ok", - })); + it("per_episode mode chunk-batches when step count is large", async () => { + const chunkSizes: number[] = []; const llm = fakeLlm({ - completeJson: { [batchOp]: { scores } }, + completeJson: { + [batchOp]: (input) => { + const messages = input as Array<{ role: string; content: string }>; + const payload = JSON.parse(messages[messages.length - 1]!.content) as { + steps: Array<{ idx: number; state: string }>; + }; + chunkSizes.push(payload.steps.length); + return { + scores: payload.steps.map((step) => ({ + idx: step.idx, + reflection_text: `reflection ${step.state}`, + alpha: 0.4, + usable: true, + reason: "ok", + })), + }; + }, + }, }); const runner = buildRunner({ batchMode: "per_episode", batchThreshold: 2 }, llm); @@ -436,10 +466,111 @@ describe("capture/pipeline (batched ρ+α path)", () => { const ep = episodeSnapshot({ id: "ep_1", sessionId: "se_1", turns }); const result = await runCapture(runner, ep); expect(result.traceIds).toHaveLength(5); - expect(result.llmCalls.batchedReflection).toBe(1); + expect(chunkSizes).toEqual([2, 2, 1]); + expect(result.llmCalls.batchedReflection).toBe(3); expect(result.llmCalls.alphaScoring).toBe(0); }); + it("chunk-batch falls back to per-step only for the failed chunk", async () => { + const llm = fakeLlm({ + completeJson: { + [batchOp]: (input) => { + const messages = input as Array<{ role: string; content: string }>; + const payload = JSON.parse(messages[messages.length - 1]!.content) as { + steps: Array<{ idx: number; state: string }>; + }; + if (payload.steps[0]?.state === "q2") { + throw new Error("chunk failed"); + } + return { + scores: payload.steps.map((step) => ({ + idx: step.idx, + reflection_text: `batch ${step.state}`, + alpha: step.state === "q4" ? 0.5 : 0.2, + usable: true, + reason: "ok", + })), + }; + }, + [alphaOp]: { alpha: 0.9, usable: true, reason: "fallback" }, + }, + complete: { + "capture.reflection.synth": "per-step fallback reflection", + }, + }); + const runner = buildRunner({ batchMode: "auto", batchThreshold: 2 }, llm); + + const turns: EpisodeTurn[] = []; + for (let i = 0; i < 5; i++) { + turns.push(turn("user", `q${i}`, 1_000 + i * 100)); + turns.push(turn("assistant", `a${i}`, 1_050 + i * 100)); + } + const ep = episodeSnapshot({ id: "ep_1", sessionId: "se_1", turns }); + + const result = await runCapture(runner, ep); + expect(result.traceIds).toHaveLength(5); + expect(result.llmCalls.batchedReflection).toBe(2); + expect(result.llmCalls.reflectionSynth).toBe(2); + expect(result.llmCalls.alphaScoring).toBe(2); + expect(result.warnings.filter((w) => w.stage === "batch")).toHaveLength(1); + + const rows = result.traceIds.map((id) => tmp.repos.traces.getById(id)!); + expect(rows.map((row) => row.reflection)).toEqual([ + "batch q0", + "batch q1", + "per-step fallback reflection", + "per-step fallback reflection", + "batch q4", + ]); + expect(rows.map((row) => row.alpha)).toEqual([0.2, 0.2, 0.9, 0.9, 0.5]); + }); + + it("batch scorer passes an explicit maxTokens budget", async () => { + let seenMaxTokens: number | undefined; + const llm = fakeLlm({ + completeJson: { + [batchOp]: (_input, opts) => { + seenMaxTokens = (opts as { maxTokens?: number }).maxTokens; + return { + scores: [ + { + idx: 0, + reflection_text: "I made a useful choice.", + alpha: 0.5, + usable: true, + reason: "ok", + }, + ], + }; + }, + }, + }); + + await batchScoreReflections( + llm, + [ + { + step: { + key: "s1", + ts: 1_000 as EpochMs, + type: "text", + userText: "q", + agentText: "a", + agentThinking: null, + toolCalls: [], + rawReflection: null, + meta: {}, + }, + existingReflection: null, + }, + ], + { synthReflections: true }, + ); + + expect(seenMaxTokens).toBeGreaterThan(0); + expect(seenMaxTokens).toBeLessThan(16_384); + }); + it("malformed batched response → falls back to per-step + emits warning", async () => { const llm = fakeLlm({ completeJson: {