Skip to content

a stab at lean4 #145

Description

@jmikedupont2

Here is a draft lean4 stub for lemmascript
Formalized the LemmaScript "pi" context-compaction cut-point selector (the Dafny-verified case study in the uploaded midspiral-pi-lemmascript repo) in Lean 4, in a new file RequestProject/Compaction.lean. The full project builds with no sorry, no extra axioms (only propext, Classical.choice, Quot.sound), and no linter warnings.
What was modeled and proved (namespace LemmaScript):
Data model mirroring the Dafny datatypes: Role, AgentMessage, SessionTreeEntry (all 11 constructors kept), CutPointResult, and the predicates isToolResultMessage, isTurnStarter, isMessage, isValidCutEntry. A session history is modeled as a total index function entries : ℕ → SessionTreeEntry, and the opaque estimateTokens as an arbitrary parameter — both choices, and the dropped < entries.length conjunct, are documented in the file header.
The three selector methods were ported as Lean functions and their Dafny ensures clauses proved as theorems:

  • findValidCutPoints + findValidCutPoints_mem_range (every returned index is in [startIndex, endIndex)) and findValidCutPoints_not_toolResult (no returned index is a tool-result message — the "no orphan at the cut" property).
  • findTurnStartIndex + findTurnStartIndex_spec (a returned index is in range and is a genuine turn starter).
  • findCutPoint (token-accumulation forward scan forwardScan, metadata backward snap snapBackward) with its three post-conditions:
    • findCutPoint_firstKept_spec — the chosen firstKeptEntryIndex is in range and not a tool-result message (the snap can't undo safety), or equals startIndex in the degenerate no-cut-point case;
    • findCutPoint_noOrphan_spec — under the explicit ordering precondition, no retained tool result is orphaned;
    • findCutPoint_splitTurn_spec — a split turn always names a real in-range turn boundary at or before the cut.
      Supporting lemmas (forwardScan_mem, snapBackward_le/_ge/_not_toolResult, firstCutGe_mem_or, not_toolResult_of_validCut) were also proved.
import Mathlib
/-!
# A Lean 4 formalization of the LemmaScript "pi" context-compaction cut-point selector
This file ports the Dafny-verified compaction cut-point selector from the
`midspiral/pi-lemmascript` case study to Lean 4, reproducing the three methods
of the selector and proving their `ensures` (post-condition) clauses.
When pi's context window fills, history is compacted: a cut point is chosen and
everything before it is discarded.  A provider API rejects a retained message
list whose kept prefix starts with an *orphaned* `toolResult` (a tool result
whose preceding tool-use turn was dropped).  The selector must therefore never
cut so that the kept suffix begins with a tool result, and never split a
tool-use / tool-result run.  These are exactly the properties proven below.
## Modeling choices
* A session history `seq<SessionTreeEntry>` is modeled as a *total* function
  `entries : ℕ → SessionTreeEntry`.  The Dafny code only ever accesses indices
  inside `[startIndex, endIndex)` and carries `endIndex ≤ |entries|` as a
  precondition purely to keep those accesses in bounds; with a total index
  function the bounds obligations vanish while the mathematical content is
  unchanged.  Consequently the `\result.turnStartIndex < entries.length`
  conjunct of the Dafny `findCutPoint` post-condition (a pure in-bounds fact)
  has no counterpart here.
* `SessionTreeEntry` keeps all eleven Dafny constructors, but their payloads —
  which the selector never inspects beyond a `message`'s `role` — are dropped.
* The opaque `estimateTokens` becomes an arbitrary parameter `estTok`; as in the
  Dafny proof, none of the safety properties depend on its value.
* The `-1` integer sentinel for "no turn start" becomes `Option.none`.
-/
open scoped BigOperators
namespace LemmaScript
/-- Message roles. Mirrors the Dafny `Role` enum. -/
inductive Role
  | bashExecution
  | custom
  | branchSummary
  | compactionSummary
  | user
  | assistant
  | toolResult
deriving DecidableEq
/-- A chat message, carrying only the field the selector reads. -/
structure AgentMessage where
  role : Role
deriving DecidableEq
/-- Session-tree entries.  All Dafny constructors are kept; irrelevant payloads
are dropped, except a `message` carries its `AgentMessage`. -/
inductive SessionTreeEntry
  | message (msg : AgentMessage)
  | thinkingLevelChange
  | modelChange
  | activeToolsChange
  | compaction
  | branchSummary
  | custom
  | customMessage
  | label
  | sessionInfo
  | leaf
/-- The result of `findCutPoint`. Mirrors the Dafny `CutPointResult`, with the
`-1` turn-start sentinel replaced by `Option.none`. -/
structure CutPointResult where
  firstKeptEntryIndex : ℕ
  turnStartIndex : Option ℕ
  isSplitTurn : Bool
/-- An entry is a tool-result message. -/
def isToolResultMessage : SessionTreeEntry → Bool
  | .message ⟨.toolResult⟩ => true
  | _ => false
/-- An entry starts a turn: a branch summary, a custom message, or a
`user`/`bashExecution` message. -/
def isTurnStarter : SessionTreeEntry → Bool
  | .branchSummary => true
  | .customMessage => true
  | .message ⟨.user⟩ => true
  | .message ⟨.bashExecution⟩ => true
  | _ => false
/-- An entry is a `message`. -/
def isMessage : SessionTreeEntry → Bool
  | .message _ => true
  | _ => false
/-- An entry is a *valid cut point*: a non-`toolResult` message, a branch
summary, or a custom message.  (This is exactly the set of indices pushed by the
`findValidCutPoints` switch + trailing `if`.) -/
def isValidCutEntry : SessionTreeEntry → Bool
  | .message ⟨.toolResult⟩ => false
  | .message _ => true
  | .branchSummary => true
  | .customMessage => true
  | _ => false
/-
A valid cut entry is never a tool-result message.
-/
theorem not_toolResult_of_validCut {e : SessionTreeEntry}
    (h : isValidCutEntry e = true) : isToolResultMessage e = false := by
  cases e <;> simp_all +decide [ isValidCutEntry ];
  rename_i msg; rcases msg with ⟨ _ | _ | _ | _ | _ | _ | _ ⟩ <;> tauto;
/-! ## `findValidCutPoints` -/
/-- Scan `entries[startIndex, endIndex)` and return the indices that are safe
places to start a retained suffix (in increasing order). -/
def findValidCutPoints (entries : ℕ → SessionTreeEntry) (startIndex endIndex : ℕ) : List ℕ :=
  (List.range' startIndex (endIndex - startIndex)).filter (fun i => isValidCutEntry (entries i))
/-
**In range.** Every returned index lies in `[startIndex, endIndex)`.
-/
theorem findValidCutPoints_mem_range (entries : ℕ → SessionTreeEntry) (startIndex endIndex : ℕ)
    {x : ℕ} (hx : x ∈ findValidCutPoints entries startIndex endIndex) :
    startIndex ≤ x ∧ x < endIndex := by
  unfold findValidCutPoints at hx;
  grind
/-
**No orphan at the cut.** No returned index points at a tool-result message,
so the retained suffix can never *begin* with an orphaned tool result.
-/
theorem findValidCutPoints_not_toolResult (entries : ℕ → SessionTreeEntry)
    (startIndex endIndex : ℕ) {x : ℕ}
    (hx : x ∈ findValidCutPoints entries startIndex endIndex) :
    isToolResultMessage (entries x) = false := by
  exact not_toolResult_of_validCut ( List.mem_filter.mp hx |>.2 )
/-! ## `findTurnStartIndex` -/
/-- Walk back from `entryIndex` to `startIndex` and return the first turn-start
entry, or `none` if there is none. -/
def findTurnStartIndex (entries : ℕ → SessionTreeEntry) (entryIndex startIndex : ℕ) : Option ℕ :=
  ((List.range' startIndex (entryIndex + 1 - startIndex)).reverse).find?
    (fun i => isTurnStarter (entries i))
/-
Post-condition of `findTurnStartIndex`: a returned index is in range and is a
genuine turn starter.
-/
theorem findTurnStartIndex_spec (entries : ℕ → SessionTreeEntry) (entryIndex startIndex : ℕ)
    {r : ℕ} (hr : findTurnStartIndex entries entryIndex startIndex = some r) :
    startIndex ≤ r ∧ r ≤ entryIndex ∧ isTurnStarter (entries r) = true := by
  have := List.find?_some hr;
  have := List.mem_of_find?_eq_some hr; simp_all +decide [ List.mem_reverse, List.mem_range' ] ;
  omega
/-! ## `findCutPoint` -/
/-- The first element of `cutPoints` that is `≥ i`, falling back to `dflt`.
Models the inner `for c` loop of `findCutPoint`. -/
def firstCutGe (cutPoints : List ℕ) (i dflt : ℕ) : ℕ :=
  (cutPoints.find? (fun c => decide (i ≤ c))).getD dflt
/-
`firstCutGe` returns either a member of `cutPoints` or the default.
-/
theorem firstCutGe_mem_or (cutPoints : List ℕ) (i dflt : ℕ) :
    firstCutGe cutPoints i dflt ∈ cutPoints ∨ firstCutGe cutPoints i dflt = dflt := by
  unfold firstCutGe;
  grind +suggestions
/-- One step of the backward token-accumulation scan of `findCutPoint`.
State is `(found cut index?, accumulated tokens)`. -/
def fwdStep (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ) (cutPoints : List ℕ)
    (dflt keepRecentTokens : ℕ) : (Option ℕ × ℕ) → ℕ → (Option ℕ × ℕ)
  | (some r, acc), _ => (some r, acc)
  | (none, acc), i =>
      match entries i with
      | .message msg =>
          let acc' := acc + estTok msg
          if keepRecentTokens ≤ acc' then (some (firstCutGe cutPoints i dflt), acc')
          else (none, acc')
      | _ => (none, acc)
/-- The forward token-accumulation scan: walk indices from `endIndex - 1` down to
`startIndex`, accumulating message tokens, and once the budget is reached snap to
the first cut point at or after the current index (else keep `cutPoints[0]`). -/
def forwardScan (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
    (cutPoints : List ℕ) (startIndex endIndex keepRecentTokens : ℕ) : ℕ :=
  let dflt := cutPoints.headD startIndex
  let indices := (List.range' startIndex (endIndex - startIndex)).reverse
  ((indices.foldl (fwdStep entries estTok cutPoints dflt keepRecentTokens) (none, 0)).1).getD dflt
/-
The chosen index of the forward scan is always one of the valid cut points.
-/
theorem forwardScan_mem (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
    (cutPoints : List ℕ) (startIndex endIndex keepRecentTokens : ℕ) (h : cutPoints ≠ []) :
    forwardScan entries estTok cutPoints startIndex endIndex keepRecentTokens ∈ cutPoints := by
  unfold forwardScan; simp +decide ;
  induction' ( List.range' startIndex ( endIndex - startIndex ) ) with x xs ih <;> simp +decide [ * ] at *;
  · cases cutPoints <;> aesop;
  · cases h : List.foldr ( fun x y => fwdStep entries estTok cutPoints ( cutPoints.head?.getD startIndex ) keepRecentTokens y x ) ( none, 0 ) xs ; simp +decide [ h ] at *;
    cases ‹Option ℕ› <;> simp_all +decide [ fwdStep ];
    cases h : entries x <;> simp_all +decide;
    exact Classical.not_not.1 fun h' => by have := firstCutGe_mem_or cutPoints x ( cutPoints.head?.getD startIndex ) ; aesop;
/-- The backward metadata snap: move the cut earlier while the preceding entry is
neither a `compaction` nor a `message`. -/
def snapBackward (entries : ℕ → SessionTreeEntry) (startIndex : ℕ) (cutIndex : ℕ) : ℕ :=
  if startIndex < cutIndex then
    match entries (cutIndex - 1) with
    | .compaction => cutIndex
    | .message _ => cutIndex
    | _ => snapBackward entries startIndex (cutIndex - 1)
  else cutIndex
termination_by cutIndex
decreasing_by omega
theorem snapBackward_le (entries : ℕ → SessionTreeEntry) (startIndex cutIndex : ℕ) :
    snapBackward entries startIndex cutIndex ≤ cutIndex := by
  induction' cutIndex using Nat.strong_induction_on with cutIndex ih;
  unfold snapBackward;
  rcases h : entries ( cutIndex - 1 ) with ( _ | _ | _ | _ | _ | _ | _ | _ | _ | _ | _ ) <;> simp_all +decide;
  all_goals split_ifs <;> [ exact le_trans ( ih _ ( Nat.sub_lt ( Nat.pos_of_ne_zero ( by aesop_cat ) ) zero_lt_one ) ) ( Nat.sub_le _ _ ) ; exact le_rfl ] ;
theorem snapBackward_ge (entries : ℕ → SessionTreeEntry) (startIndex cutIndex : ℕ)
    (h : startIndex ≤ cutIndex) :
    startIndex ≤ snapBackward entries startIndex cutIndex := by
  induction' n : cutIndex - startIndex using Nat.strong_induction_on with n ih generalizing startIndex cutIndex;
  unfold snapBackward;
  split_ifs;
  · rcases h : entries ( cutIndex - 1 ) with ( _ | _ | _ | _ | _ | _ | _ | _ | _ | _ | _ ) <;> simp_all +decide;
    all_goals exact ih _ ( by omega ) _ _ ( Nat.le_sub_one_of_lt ‹_› ) rfl;
  · linarith
/-
The backward snap can never land on a tool-result message: it only steps over
non-`message` entries, so it preserves the non-`toolResult` status of the cut.
-/
theorem snapBackward_not_toolResult (entries : ℕ → SessionTreeEntry) (startIndex cutIndex : ℕ)
    (h : isToolResultMessage (entries cutIndex) = false) :
    isToolResultMessage (entries (snapBackward entries startIndex cutIndex)) = false := by
  induction' n : cutIndex - startIndex using Nat.strong_induction_on with n ih generalizing startIndex cutIndex;
  unfold snapBackward;
  split_ifs;
  · rcases h : entries ( cutIndex - 1 ) with ( _ | _ | _ | _ | _ | _ | _ | _ | _ | _ | _ ) <;> simp_all +decide;
    all_goals exact ih _ ( by omega ) _ _ ( by aesop ) rfl;
  · assumption
/-- Whether the entry at index `i` is a `user` message. -/
def isUserMsg (entries : ℕ → SessionTreeEntry) (i : ℕ) : Bool :=
  match entries i with
  | .message ⟨.user⟩ => true
  | _ => false
/-- Walk back from the end accumulating an (opaque) token estimate until the
budget is reached, snap to the nearest valid cut point, then snap backward over
metadata entries; report the turn start when the cut splits a turn. -/
def findCutPoint (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
    (startIndex endIndex keepRecentTokens : ℕ) : CutPointResult :=
  let cutPoints := findValidCutPoints entries startIndex endIndex
  if cutPoints = [] then
    ⟨startIndex, none, false⟩
  else
    let c0 := forwardScan entries estTok cutPoints startIndex endIndex keepRecentTokens
    let cutIndex := snapBackward entries startIndex c0
    let isUser := isUserMsg entries cutIndex
    let tsi := findTurnStartIndex entries cutIndex startIndex
    let turnStartIndex := if isUser then none else tsi
    ⟨cutIndex, turnStartIndex, (!isUser) && turnStartIndex.isSome⟩
/-
**The snap can't undo safety.** `firstKeptEntryIndex` is in `[startIndex,
endIndex)` and is not a tool-result message — or, in the degenerate no-cut-point
case, equals `startIndex`.
-/
theorem findCutPoint_firstKept_spec (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
    (startIndex endIndex keepRecentTokens : ℕ) :
    ∀ r, r = findCutPoint entries estTok startIndex endIndex keepRecentTokens →
      (startIndex ≤ r.firstKeptEntryIndex ∧ r.firstKeptEntryIndex < endIndex ∧
        isToolResultMessage (entries r.firstKeptEntryIndex) = false)
      ∨ r.firstKeptEntryIndex = startIndex := by
  by_cases h : findValidCutPoints entries startIndex endIndex = [] <;> simp_all +decide [ findCutPoint ];
  refine' Or.inl ⟨ _, _, _ ⟩;
  · apply snapBackward_ge;
    have := forwardScan_mem entries estTok ( findValidCutPoints entries startIndex endIndex ) startIndex endIndex keepRecentTokens h; exact findValidCutPoints_mem_range entries startIndex endIndex this |>.1;
  · refine' lt_of_le_of_lt ( snapBackward_le _ _ _ ) _;
    exact findValidCutPoints_mem_range _ _ _ ( forwardScan_mem _ _ _ _ _ _ h ) |>.2;
  · apply snapBackward_not_toolResult;
    exact findValidCutPoints_not_toolResult _ _ _ ( forwardScan_mem _ _ _ _ _ _ h )
/-
**No retained `toolResult` is orphaned.** Under the ordering precondition
(within range, a tool result is immediately preceded by a message), every tool
result in the kept suffix has its preceding tool-use turn retained too.
-/
theorem findCutPoint_noOrphan_spec (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
    (startIndex endIndex keepRecentTokens : ℕ)
    (hpre : ∀ j, startIndex < j → j < endIndex → isToolResultMessage (entries j) = true →
      isMessage (entries (j - 1)) = true) :
    ∀ r, r = findCutPoint entries estTok startIndex endIndex keepRecentTokens →
      (∀ j, r.firstKeptEntryIndex ≤ j → j < endIndex → isToolResultMessage (entries j) = true →
        r.firstKeptEntryIndex ≤ j - 1 ∧ isMessage (entries (j - 1)) = true)
      ∨ r.firstKeptEntryIndex = startIndex := by
  intro r hr;
  -- If the cut index is not `startIndex`, then it must be in the valid cut points list.
  by_cases h_cut : r.firstKeptEntryIndex = startIndex;
  · exact Or.inr h_cut;
  · have h_cut_range : startIndex ≤ r.firstKeptEntryIndex ∧ r.firstKeptEntryIndex < endIndex ∧ isToolResultMessage (entries r.firstKeptEntryIndex) = false := by
      exact Or.resolve_right ( findCutPoint_firstKept_spec entries estTok startIndex endIndex keepRecentTokens r hr ) h_cut;
    grind
/-
**A split turn names a real boundary.** When the cut falls mid-turn, the
reported `turnStartIndex` is an in-range turn boundary at or before the cut.
-/
theorem findCutPoint_splitTurn_spec (entries : ℕ → SessionTreeEntry) (estTok : AgentMessage → ℕ)
    (startIndex endIndex keepRecentTokens : ℕ) :
    ∀ r, r = findCutPoint entries estTok startIndex endIndex keepRecentTokens →
      r.isSplitTurn = true →
        ∃ t, r.turnStartIndex = some t ∧ startIndex ≤ t ∧ t ≤ r.firstKeptEntryIndex ∧
          isTurnStarter (entries t) = true := by
  by_cases h : findValidCutPoints entries startIndex endIndex = [] <;> simp_all +decide [ findCutPoint ];
  exact fun h1 h2 => by obtain ⟨ t, ht ⟩ := Option.isSome_iff_exists.mp h2; exact ⟨ t, ht, findTurnStartIndex_spec _ _ _ ht |>.1, findTurnStartIndex_spec _ _ _ ht |>.2.1, findTurnStartIndex_spec _ _ _ ht |>.2.2 ⟩ ;
end LemmaScript

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions