import Mathlib
/-!
# A Lean 4 formalization of the LemmaScript "pi" context-compaction cut-point selector
This file ports the Dafny-verified compaction cut-point selector from the
`midspiral/pi-lemmascript` case study to Lean 4, reproducing the three methods
of the selector and proving their `ensures` (post-condition) clauses.
When pi's context window fills, history is compacted: a cut point is chosen and
everything before it is discarded. A provider API rejects a retained message
list whose kept prefix starts with an *orphaned* `toolResult` (a tool result
whose preceding tool-use turn was dropped). The selector must therefore never
cut so that the kept suffix begins with a tool result, and never split a
tool-use / tool-result run. These are exactly the properties proven below.
## Modeling choices
* A session history `seq<SessionTreeEntry>` is modeled as a *total* function
`entries : ℕ → SessionTreeEntry`. The Dafny code only ever accesses indices
inside `[startIndex, endIndex)` and carries `endIndex ≤ |entries|` as a
precondition purely to keep those accesses in bounds; with a total index
function the bounds obligations vanish while the mathematical content is
unchanged. Consequently the `\result.turnStartIndex < entries.length`
conjunct of the Dafny `findCutPoint` post-condition (a pure in-bounds fact)
has no counterpart here.
* `SessionTreeEntry` keeps all eleven Dafny constructors, but their payloads —
which the selector never inspects beyond a `message`'s `role` — are dropped.
* The opaque `estimateTokens` becomes an arbitrary parameter `estTok`; as in the
Dafny proof, none of the safety properties depend on its value.
* The `-1` integer sentinel for "no turn start" becomes `Option.none`.
-/
open scoped BigOperators
namespace LemmaScript
/-- Message roles. Mirrors the Dafny `Role` enum. -/
inductive Role
| bashExecution
| custom
| branchSummary
| compactionSummary
| user
| assistant
| toolResult
deriving DecidableEq
/-- A chat message, carrying only the field the selector reads. -/
structure AgentMessage where
role : Role
deriving DecidableEq
/-- Session-tree entries. All Dafny constructors are kept; irrelevant payloads
are dropped, except a `message` carries its `AgentMessage`. -/
inductive SessionTreeEntry
| message (msg : AgentMessage)
| thinkingLevelChange
| modelChange
| activeToolsChange
| compaction
| branchSummary
| custom
| customMessage
| label
| sessionInfo
| leaf
/-- The result of `findCutPoint`. Mirrors the Dafny `CutPointResult`, with the
`-1` turn-start sentinel replaced by `Option.none`. -/
structure CutPointResult where
firstKeptEntryIndex : ℕ
turnStartIndex : Option ℕ
isSplitTurn : Bool
/-- An entry is a tool-result message. -/
def isToolResultMessage : SessionTreeEntry → Bool
| .message ⟨.toolResult⟩ => true
| _ => false
/-- An entry starts a turn: a branch summary, a custom message, or a
`user`/`bashExecution` message. -/
def isTurnStarter : SessionTreeEntry → Bool
| .branchSummary => true
| .customMessage => true
| .message ⟨.user⟩ => true
| .message ⟨.bashExecution⟩ => true
| _ => false
/-- An entry is a `message`. -/
def isMessage : SessionTreeEntry → Bool
| .message _ => true
| _ => false
/-- An entry is a *valid cut point*: a non-`toolResult` message, a branch
summary, or a custom message. (This is exactly the set of indices pushed by the
`findValidCutPoints` switch + trailing `if`.) -/
def isValidCutEntry : SessionTreeEntry → Bool
| .message ⟨.toolResult⟩ => false
| .message _ => true
| .branchSummary => true
| .customMessage => true
| _ => false
/-
A valid cut entry is never a tool-result message.
-/
theorem not_toolResult_of_validCut {e : SessionTreeEntry}
(h : isValidCutEntry e = true) : isToolResultMessage e = false := by
cases e <;> simp_all +decide [ isValidCutEntry ];
rename_i msg; rcases msg with ⟨ _ | _ | _ | _ | _ | _ | _ ⟩ <;> tauto;
/-! ## `findValidCutPoints` -/
/-- Scan `entries[startIndex, endIndex)` and return the indices that are safe
places to start a retained suffix (in increasing order). -/
def findValidCutPoints (entries : ℕ → SessionTreeEntry) (startIndex endIndex : ℕ) : List ℕ :=
(List.range' startIndex (endIndex - startIndex)).filter (fun i => isValidCutEntry (entries i))
/-
**In range.** Every returned index lies in `[startIndex, endIndex)`.
-/
theorem findValidCutPoints_mem_range (entries : ℕ → SessionTreeEntry) (startIndex endIndex : ℕ)
{x : ℕ} (hx : x ∈ findValidCutPoints entries startIndex endIndex) :
startIndex ≤ x ∧ x < endIndex := by
unfold findValidCutPoints at hx;
grind
/-
**No orphan at the cut.** No returned index points at a tool-result message,
so the retained suffix can never *begin* with an orphaned tool result.
-/
theorem findValidCutPoints_not_toolResult (entries : ℕ → SessionTreeEntry)
(startIndex endIndex : ℕ) {x : ℕ}
(hx : x ∈ findValidCutPoints entries startIndex endIndex) :
isToolResultMessage (entries x) = false := by
exact not_toolResult_of_validCut ( List.mem_filter.mp hx |>.2 )
/-! ## `findTurnStartIndex` -/
/-- Walk back from `entryIndex` to `startIndex` and return the first turn-start
entry, or `none` if there is none. -/
def findTurnStartIndex (entries : ℕ → SessionTreeEntry) (entryIndex startIndex : ℕ) : Option ℕ :=
((List.range' startIndex (entryIndex + 1 - startIndex)).reverse).find?
(fun i => isTurnStarter (entries i))
/-
Post-condition of `findTurnStartIndex`: a returned index is in range and is a
genuine turn starter.
-/
theorem findTurnStartIndex_spec (entries : ℕ → SessionTreeEntry) (entryIndex startIndex : ℕ)
{r : ℕ} (hr : findTurnStartIndex entries entryIndex startIndex = some r) :
startIndex ≤ r ∧ r ≤ entryIndex ∧ isTurnStarter (entries r) = true := by
have := List.find?_some hr;
have := List.mem_of_find?_eq_some hr; simp_all +decide [ List.mem_reverse, List.mem_range' ] ;
omega
/-! ## `findCutPoint` -/
/-- The first element of `cutPoints` that is `≥ i`, falling back to `dflt`.
Models the inner `for c` loop of `findCutPoint`. -/
def firstCutGe (cutPoints : List ℕ) (i dflt : ℕ) : ℕ :=
(cutPoints.find? (fun c => decide (i ≤ c))).getD dflt
/-
`firstCutGe` returns either a member of `cutPoints` or the default.
-/
theorem firstCutGe_mem_or (cutPoints : List ℕ) (i dflt : ℕ) :
firstCutGe cutPoints i dflt ∈ cutPoints ∨ firstCutGe cutPoints i dflt = dflt := by
unfold firstCutGe;
grind +suggestions
/-- One step of the backward token-accumulation scan of `findCutPoint`.
State is `(found cut index?, accumulated tokens)`. -/
def fwdStep (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ) (cutPoints : List ℕ)
(dflt keepRecentTokens : ℕ) : (Option ℕ × ℕ) → ℕ → (Option ℕ × ℕ)
| (some r, acc), _ => (some r, acc)
| (none, acc), i =>
match entries i with
| .message msg =>
let acc' := acc + estTok msg
if keepRecentTokens ≤ acc' then (some (firstCutGe cutPoints i dflt), acc')
else (none, acc')
| _ => (none, acc)
/-- The forward token-accumulation scan: walk indices from `endIndex - 1` down to
`startIndex`, accumulating message tokens, and once the budget is reached snap to
the first cut point at or after the current index (else keep `cutPoints[0]`). -/
def forwardScan (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
(cutPoints : List ℕ) (startIndex endIndex keepRecentTokens : ℕ) : ℕ :=
let dflt := cutPoints.headD startIndex
let indices := (List.range' startIndex (endIndex - startIndex)).reverse
((indices.foldl (fwdStep entries estTok cutPoints dflt keepRecentTokens) (none, 0)).1).getD dflt
/-
The chosen index of the forward scan is always one of the valid cut points.
-/
theorem forwardScan_mem (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
(cutPoints : List ℕ) (startIndex endIndex keepRecentTokens : ℕ) (h : cutPoints ≠ []) :
forwardScan entries estTok cutPoints startIndex endIndex keepRecentTokens ∈ cutPoints := by
unfold forwardScan; simp +decide ;
induction' ( List.range' startIndex ( endIndex - startIndex ) ) with x xs ih <;> simp +decide [ * ] at *;
· cases cutPoints <;> aesop;
· cases h : List.foldr ( fun x y => fwdStep entries estTok cutPoints ( cutPoints.head?.getD startIndex ) keepRecentTokens y x ) ( none, 0 ) xs ; simp +decide [ h ] at *;
cases ‹Option ℕ› <;> simp_all +decide [ fwdStep ];
cases h : entries x <;> simp_all +decide;
exact Classical.not_not.1 fun h' => by have := firstCutGe_mem_or cutPoints x ( cutPoints.head?.getD startIndex ) ; aesop;
/-- The backward metadata snap: move the cut earlier while the preceding entry is
neither a `compaction` nor a `message`. -/
def snapBackward (entries : ℕ → SessionTreeEntry) (startIndex : ℕ) (cutIndex : ℕ) : ℕ :=
if startIndex < cutIndex then
match entries (cutIndex - 1) with
| .compaction => cutIndex
| .message _ => cutIndex
| _ => snapBackward entries startIndex (cutIndex - 1)
else cutIndex
termination_by cutIndex
decreasing_by omega
theorem snapBackward_le (entries : ℕ → SessionTreeEntry) (startIndex cutIndex : ℕ) :
snapBackward entries startIndex cutIndex ≤ cutIndex := by
induction' cutIndex using Nat.strong_induction_on with cutIndex ih;
unfold snapBackward;
rcases h : entries ( cutIndex - 1 ) with ( _ | _ | _ | _ | _ | _ | _ | _ | _ | _ | _ ) <;> simp_all +decide;
all_goals split_ifs <;> [ exact le_trans ( ih _ ( Nat.sub_lt ( Nat.pos_of_ne_zero ( by aesop_cat ) ) zero_lt_one ) ) ( Nat.sub_le _ _ ) ; exact le_rfl ] ;
theorem snapBackward_ge (entries : ℕ → SessionTreeEntry) (startIndex cutIndex : ℕ)
(h : startIndex ≤ cutIndex) :
startIndex ≤ snapBackward entries startIndex cutIndex := by
induction' n : cutIndex - startIndex using Nat.strong_induction_on with n ih generalizing startIndex cutIndex;
unfold snapBackward;
split_ifs;
· rcases h : entries ( cutIndex - 1 ) with ( _ | _ | _ | _ | _ | _ | _ | _ | _ | _ | _ ) <;> simp_all +decide;
all_goals exact ih _ ( by omega ) _ _ ( Nat.le_sub_one_of_lt ‹_› ) rfl;
· linarith
/-
The backward snap can never land on a tool-result message: it only steps over
non-`message` entries, so it preserves the non-`toolResult` status of the cut.
-/
theorem snapBackward_not_toolResult (entries : ℕ → SessionTreeEntry) (startIndex cutIndex : ℕ)
(h : isToolResultMessage (entries cutIndex) = false) :
isToolResultMessage (entries (snapBackward entries startIndex cutIndex)) = false := by
induction' n : cutIndex - startIndex using Nat.strong_induction_on with n ih generalizing startIndex cutIndex;
unfold snapBackward;
split_ifs;
· rcases h : entries ( cutIndex - 1 ) with ( _ | _ | _ | _ | _ | _ | _ | _ | _ | _ | _ ) <;> simp_all +decide;
all_goals exact ih _ ( by omega ) _ _ ( by aesop ) rfl;
· assumption
/-- Whether the entry at index `i` is a `user` message. -/
def isUserMsg (entries : ℕ → SessionTreeEntry) (i : ℕ) : Bool :=
match entries i with
| .message ⟨.user⟩ => true
| _ => false
/-- Walk back from the end accumulating an (opaque) token estimate until the
budget is reached, snap to the nearest valid cut point, then snap backward over
metadata entries; report the turn start when the cut splits a turn. -/
def findCutPoint (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
(startIndex endIndex keepRecentTokens : ℕ) : CutPointResult :=
let cutPoints := findValidCutPoints entries startIndex endIndex
if cutPoints = [] then
⟨startIndex, none, false⟩
else
let c0 := forwardScan entries estTok cutPoints startIndex endIndex keepRecentTokens
let cutIndex := snapBackward entries startIndex c0
let isUser := isUserMsg entries cutIndex
let tsi := findTurnStartIndex entries cutIndex startIndex
let turnStartIndex := if isUser then none else tsi
⟨cutIndex, turnStartIndex, (!isUser) && turnStartIndex.isSome⟩
/-
**The snap can't undo safety.** `firstKeptEntryIndex` is in `[startIndex,
endIndex)` and is not a tool-result message — or, in the degenerate no-cut-point
case, equals `startIndex`.
-/
theorem findCutPoint_firstKept_spec (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
(startIndex endIndex keepRecentTokens : ℕ) :
∀ r, r = findCutPoint entries estTok startIndex endIndex keepRecentTokens →
(startIndex ≤ r.firstKeptEntryIndex ∧ r.firstKeptEntryIndex < endIndex ∧
isToolResultMessage (entries r.firstKeptEntryIndex) = false)
∨ r.firstKeptEntryIndex = startIndex := by
by_cases h : findValidCutPoints entries startIndex endIndex = [] <;> simp_all +decide [ findCutPoint ];
refine' Or.inl ⟨ _, _, _ ⟩;
· apply snapBackward_ge;
have := forwardScan_mem entries estTok ( findValidCutPoints entries startIndex endIndex ) startIndex endIndex keepRecentTokens h; exact findValidCutPoints_mem_range entries startIndex endIndex this |>.1;
· refine' lt_of_le_of_lt ( snapBackward_le _ _ _ ) _;
exact findValidCutPoints_mem_range _ _ _ ( forwardScan_mem _ _ _ _ _ _ h ) |>.2;
· apply snapBackward_not_toolResult;
exact findValidCutPoints_not_toolResult _ _ _ ( forwardScan_mem _ _ _ _ _ _ h )
/-
**No retained `toolResult` is orphaned.** Under the ordering precondition
(within range, a tool result is immediately preceded by a message), every tool
result in the kept suffix has its preceding tool-use turn retained too.
-/
theorem findCutPoint_noOrphan_spec (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
(startIndex endIndex keepRecentTokens : ℕ)
(hpre : ∀ j, startIndex < j → j < endIndex → isToolResultMessage (entries j) = true →
isMessage (entries (j - 1)) = true) :
∀ r, r = findCutPoint entries estTok startIndex endIndex keepRecentTokens →
(∀ j, r.firstKeptEntryIndex ≤ j → j < endIndex → isToolResultMessage (entries j) = true →
r.firstKeptEntryIndex ≤ j - 1 ∧ isMessage (entries (j - 1)) = true)
∨ r.firstKeptEntryIndex = startIndex := by
intro r hr;
-- If the cut index is not `startIndex`, then it must be in the valid cut points list.
by_cases h_cut : r.firstKeptEntryIndex = startIndex;
· exact Or.inr h_cut;
· have h_cut_range : startIndex ≤ r.firstKeptEntryIndex ∧ r.firstKeptEntryIndex < endIndex ∧ isToolResultMessage (entries r.firstKeptEntryIndex) = false := by
exact Or.resolve_right ( findCutPoint_firstKept_spec entries estTok startIndex endIndex keepRecentTokens r hr ) h_cut;
grind
/-
**A split turn names a real boundary.** When the cut falls mid-turn, the
reported `turnStartIndex` is an in-range turn boundary at or before the cut.
-/
theorem findCutPoint_splitTurn_spec (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
(startIndex endIndex keepRecentTokens : ℕ) :
∀ r, r = findCutPoint entries estTok startIndex endIndex keepRecentTokens →
r.isSplitTurn = true →
∃ t, r.turnStartIndex = some t ∧ startIndex ≤ t ∧ t ≤ r.firstKeptEntryIndex ∧
isTurnStarter (entries t) = true := by
by_cases h : findValidCutPoints entries startIndex endIndex = [] <;> simp_all +decide [ findCutPoint ];
exact fun h1 h2 => by obtain ⟨ t, ht ⟩ := Option.isSome_iff_exists.mp h2; exact ⟨ t, ht, findTurnStartIndex_spec _ _ _ ht |>.1, findTurnStartIndex_spec _ _ _ ht |>.2.1, findTurnStartIndex_spec _ _ _ ht |>.2.2 ⟩ ;
end LemmaScript
Here is a draft lean4 stub for lemmascript
Formalized the LemmaScript "pi" context-compaction cut-point selector (the Dafny-verified case study in the uploaded
midspiral-pi-lemmascriptrepo) in Lean 4, in a new fileRequestProject/Compaction.lean. The full project builds with nosorry, no extra axioms (onlypropext,Classical.choice,Quot.sound), and no linter warnings.What was modeled and proved (namespace
LemmaScript):Data model mirroring the Dafny datatypes:
Role,AgentMessage,SessionTreeEntry(all 11 constructors kept),CutPointResult, and the predicatesisToolResultMessage,isTurnStarter,isMessage,isValidCutEntry. A session history is modeled as a total index functionentries : ℕ → SessionTreeEntry, and the opaqueestimateTokensas an arbitrary parameter — both choices, and the dropped< entries.lengthconjunct, are documented in the file header.The three selector methods were ported as Lean functions and their Dafny
ensuresclauses proved as theorems:findValidCutPoints+findValidCutPoints_mem_range(every returned index is in[startIndex, endIndex)) andfindValidCutPoints_not_toolResult(no returned index is a tool-result message — the "no orphan at the cut" property).findTurnStartIndex+findTurnStartIndex_spec(a returned index is in range and is a genuine turn starter).findCutPoint(token-accumulation forward scanforwardScan, metadata backward snapsnapBackward) with its three post-conditions:findCutPoint_firstKept_spec— the chosenfirstKeptEntryIndexis in range and not a tool-result message (the snap can't undo safety), or equalsstartIndexin the degenerate no-cut-point case;findCutPoint_noOrphan_spec— under the explicit ordering precondition, no retained tool result is orphaned;findCutPoint_splitTurn_spec— a split turn always names a real in-range turn boundary at or before the cut.Supporting lemmas (
forwardScan_mem,snapBackward_le/_ge/_not_toolResult,firstCutGe_mem_or,not_toolResult_of_validCut) were also proved.