forked from leanprover/lean4
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSplit.lean
More file actions
346 lines (323 loc) · 16.7 KB
/
Split.lean
File metadata and controls
346 lines (323 loc) · 16.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Match.MatcherApp.Basic
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.SplitIf
import Lean.Meta.Tactic.Apply
import Lean.Meta.Tactic.Generalize
namespace Lean.Meta
namespace Split
def getSimpMatchContext : MetaM Simp.Context :=
return {
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := { Simp.neutralConfig with dsimp := false }
}
def simpMatch (e : Expr) : MetaM Simp.Result := do
let discharge? ← SplitIf.mkDischarge?
(·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre, discharge? })
where
pre (e : Expr) : SimpM Simp.Step := do
unless (← isMatcherApp e) do
return Simp.Step.continue
let matcherDeclName := e.getAppFn.constName!
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return Simp.Step.done { expr := e' }
| none => Simp.simpMatchCore matcherDeclName e
def simpMatchTarget (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← instantiateMVars (← mvarId.getType)
let r ← simpMatch target
applySimpResultToTarget mvarId target r
private def simpMatchCore (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : MetaM Simp.Result := do
let discharge? ← SplitIf.mkDischarge?
(·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre, discharge? })
where
pre (e : Expr) : SimpM Simp.Step := do
if e.isAppOf matchDeclName then
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return .done { expr := e' }
| none =>
-- Try lemma
let simpTheorem := {
origin := .decl matchEqDeclName
proof := mkConst matchEqDeclName
rfl := (← isRflTheorem matchEqDeclName)
}
match (← withReducible <| Simp.tryTheorem? e simpTheorem) with
| none => return .continue
| some r => return .done r
else
return .continue
private def simpMatchTargetCore (mvarId : MVarId) (matchDeclName : Name) (matchEqDeclName : Name) : MetaM MVarId := do
mvarId.withContext do
let target ← instantiateMVars (← mvarId.getType)
let r ← simpMatchCore matchDeclName matchEqDeclName target
match r.proof? with
| some proof => mvarId.replaceTargetEq r.expr proof
| none => mvarId.replaceTargetDefEq r.expr
private partial def withEqs (lhs rhs : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α := do
go 0 #[] #[]
where
go (i : Nat) (hs : Array Expr) (rfls : Array Expr) : MetaM α := do
if i < lhs.size then
withLocalDeclD (← mkFreshUserName `heq) (← mkEqHEq lhs[i]! rhs[i]!) fun h => do
let rfl ← if (← inferType h).isEq then mkEqRefl lhs[i]! else mkHEqRefl lhs[i]!
go (i+1) (hs.push h) (rfls.push rfl)
else
k hs rfls
/--
Internal exception for discriminant generalization failures due to type errors.
-/
builtin_initialize discrGenExId : InternalExceptionId ←
registerInternalExceptionId `discrGeneralizationFailure
def isDiscrGenException (ex : Exception) : Bool :=
match ex with
| .internal id => id == discrGenExId
| _ => false
/--
This method makes sure each discriminant is a free variable.
Return the tuple `(discrsNew, discrEqs, mvarId)`. `discrsNew` in an array representing the new discriminants, `discrEqs` is an array of auxiliary equality hypotheses
that connect the new discriminants to the original terms they represent.
Remark: `discrEqs.size ≤ discrsNew.size`
Remark:
We should only generalize `discrs` occurrences as `match`-expression discriminants.
For example, given the following goal.
```
x : Nat
⊢ (match g x with
| 0 => 1
| Nat.succ y => g x) =
2 * x + 1
```
we should not generalize the `g x` in the rhs of the second alternative, and the two resulting goals
for the `split` tactic should be
```
case h_1
x x✝ : Nat
h✝ : g x = 0
⊢ 1 = 2 * x + 1
case h_2
x x✝ y✝ : Nat
h✝ : g x = Nat.succ y✝
⊢ g x = 2 * x + 1
```
-/
private partial def generalizeMatchDiscrs (mvarId : MVarId) (matcherDeclName : Name) (motiveType : Expr) (discrs : Array Expr) : MetaM (Array FVarId × Array FVarId × MVarId) := mvarId.withContext do
if discrs.all (·.isFVar) then
return (discrs.map (·.fvarId!), #[], mvarId)
let some matcherInfo ← getMatcherInfo? matcherDeclName | unreachable!
let numDiscrEqs := matcherInfo.getNumDiscrEqs -- Number of `h : discr = pattern` equations
let (targetNew, rfls) ←
forallTelescope motiveType fun discrVars _ =>
withEqs discrs discrVars fun eqs rfls => do
let foundRef ← IO.mkRef false
let rec mkNewTarget (e : Expr) : MetaM Expr := do
let pre (e : Expr) : MetaM TransformStep := do
if !e.isAppOfArity matcherDeclName matcherInfo.arity then
return .continue
let some matcherApp ← matchMatcherApp? e | return .continue
for matcherDiscr in matcherApp.discrs, discr in discrs do
unless matcherDiscr == discr do
trace[split.debug] "discr mismatch {matcherDiscr} != {discr}"
return .continue
let matcherApp := { matcherApp with discrs := discrVars }
foundRef.set true
let mut altsNew := #[]
for i in [:matcherApp.alts.size] do
let alt := matcherApp.alts[i]!
let altNumParams := matcherApp.altNumParams[i]!
let altNew ← lambdaTelescope alt fun xs body => do
if xs.size < altNumParams || xs.size < numDiscrEqs then
throwError "internal error in `split` tactic: encountered an unexpected `match` expression alternative\nthis error typically occurs when the `match` expression has been constructed using meta-programming."
let body ← mkLambdaFVars xs[altNumParams:] (← mkNewTarget body)
let ys := xs[:altNumParams - numDiscrEqs]
if numDiscrEqs == 0 then
mkLambdaFVars ys body
else
let altEqs := xs[altNumParams - numDiscrEqs : altNumParams]
withNewAltEqs matcherInfo eqs altEqs fun altEqsNew subst => do
let body := body.replaceFVars altEqs subst
mkLambdaFVars (ys++altEqsNew) body
altsNew := altsNew.push altNew
return .done { matcherApp with alts := altsNew }.toExpr
transform (← instantiateMVars e) pre
let targetNew ← mkNewTarget (← mvarId.getType)
unless (← foundRef.get) do
throwError "internal error in `split` tactic: failed to find match-expression discriminants\nthis error typically occurs when the `split` tactic internal functions have been used in a new meta-program"
let targetNew ← mkForallFVars (discrVars ++ eqs) targetNew
unless (← isTypeCorrect targetNew) do
throw <| Exception.internal discrGenExId
return (targetNew, rfls)
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew (← mvarId.getTag)
trace[split.debug] "targetNew:\n{mvarNew.mvarId!}"
mvarId.assign (mkAppN (mkAppN mvarNew discrs) rfls)
let (discrs', mvarId') ← mvarNew.mvarId!.introNP discrs.size
let (discrEqs, mvarId') ← mvarId'.introNP discrs.size
return (discrs', discrEqs, mvarId')
where
/--
- `eqs` are free variables `h_eq : discr = discrVar`. `eqs.size == discrs.size`
- `altEqs` are free variables of the form `h_altEq : discr = pattern`. `altEqs.size = numDiscrEqs ≤ discrs.size`
This method executes `k altEqsNew subst` where
- `altEqsNew` are fresh free variables of the form `h_altEqNew : discrVar = pattern`
- `subst` are terms of the form `h_eq.trans h_altEqNew : discr = pattern`. We use `subst` later to replace occurrences of `h_altEq` with `h_eq.trans h_altEqNew`.
-/
withNewAltEqs (matcherInfo : MatcherInfo) (eqs : Array Expr) (altEqs : Array Expr) (k : Array Expr → Array Expr → MetaM Expr) : MetaM Expr := do
let eqs' := (eqs.zip matcherInfo.discrInfos).filterMap fun (eq, info) => if info.hName?.isNone then none else some eq
-- `eqs'.size == altEqs.size ≤ eqs.size`
let rec go (i : Nat) (altEqsNew : Array Expr) (subst : Array Expr) : MetaM Expr := do
if i < altEqs.size then
let altEqDecl ← getFVarLocalDecl altEqs[i]!
let eq := eqs'[i]!
let eqType ← inferType eq
let altEqType := altEqDecl.type
match eqType.eq?, altEqType.eq? with
| some (_, _, discrVar), some (_, _ /- discr -/, pattern) =>
withLocalDeclD altEqDecl.userName (← mkEq discrVar pattern) fun altEqNew => do
go (i+1) (altEqsNew.push altEqNew) (subst.push (← mkEqTrans eq altEqNew))
| _, _ =>
match eqType.heq?, altEqType.heq? with
| some (_, _, _, discrVar), some (_, _ /- discr -/, _, pattern) =>
withLocalDeclD altEqDecl.userName (← mkHEq discrVar pattern) fun altEqNew => do
go (i+1) (altEqsNew.push altEqNew) (subst.push (← mkHEqTrans eq altEqNew))
| _, _ =>
throwError "internal error in `split` tactic: encountered unexpected auxiliary equalities created to generalize `match`-expression discriminant\nthis error typically occurs when the `split` tactic internal functions have been used in a new meta-program"
else
k altEqsNew subst
go 0 #[] #[]
private def substDiscrEqs (mvarId : MVarId) (fvarSubst : FVarSubst) (discrEqs : Array FVarId) : MetaM MVarId := mvarId.withContext do
let mut mvarId := mvarId
let mut fvarSubst := fvarSubst
for fvarId in discrEqs do
if let .fvar fvarId := fvarSubst.apply (mkFVar fvarId) then
let (fvarId, mvarId') ← heqToEq mvarId fvarId
match (← substCore? mvarId' fvarId (symm := false) fvarSubst) with
| some (fvarSubst', mvarId') => mvarId := mvarId'; fvarSubst := fvarSubst'
| none =>
match (← substCore? mvarId' fvarId (symm := true) fvarSubst) with
| some (fvarSubst', mvarId') => mvarId := mvarId'; fvarSubst := fvarSubst'
| none => mvarId := mvarId'
return mvarId
def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Level) (params : Array Expr) (discrs : Array Expr) : MetaM (List MVarId) := do
let some info ← getMatcherInfo? matcherDeclName | throwError "internal error in `split` tactic: `{matcherDeclName}` is not an auxiliary declaration used to encode `match`-expressions\nthis error typically occurs when the `split` tactic internal functions have been used in a new meta-program"
let matchEqns ← Match.getEquationsFor matcherDeclName
-- splitterPre does not have the correct universe elimination level, but this is fine, we only use it to compute the `motiveType`,
-- and we only care about the `motiveType` arguments, and not the resulting `Sort u`.
let splitterPre := mkAppN (mkConst matchEqns.splitterName us.toList) params
let motiveType := (← whnfForall (← inferType splitterPre)).bindingDomain!
trace[split.debug] "applyMatchSplitter\n{mvarId}"
let (discrFVarIds, discrEqs, mvarId) ← generalizeMatchDiscrs mvarId matcherDeclName motiveType discrs
trace[split.debug] "after generalizeMatchDiscrs\n{mvarId}"
let mvarId ← generalizeTargetsEq mvarId motiveType (discrFVarIds.map mkFVar)
mvarId.withContext do trace[split.debug] "discrEqs after generalizeTargetsEq: {discrEqs.map mkFVar}"
trace[split.debug] "after generalize\n{mvarId}"
let numEqs := discrs.size
let (discrFVarIdsNew, mvarId) ← mvarId.introN discrs.size
trace[split.debug] "after introN\n{mvarId}"
let discrsNew := discrFVarIdsNew.map mkFVar
let mvarType ← mvarId.getType
let elimUniv ← mvarId.withContext <| getLevel mvarType
let us ← if let some uElimPos := info.uElimPos? then
pure <| us.set! uElimPos elimUniv
else
unless elimUniv.isZero do
throwError "`split` tactic failed to split a match-expression: the splitter auxiliary theorem `{matchEqns.splitterName}` can only eliminate into `Prop`"
pure us
let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) params
mvarId.withContext do
let motive ← mkLambdaFVars discrsNew mvarType
let splitter := mkAppN (mkApp splitter motive) discrsNew
check splitter
trace[split.debug] "after check splitter"
let mvarIds ← mvarId.apply splitter
unless mvarIds.length == matchEqns.size do
throwError "internal error in `split` tactic: unexpected number of goals created after applying splitter auxiliary theorem `{matchEqns.splitterName}` for `{matcherDeclName}`"
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
let numParams := matchEqns.splitterAltNumParams[i]!
let (_, mvarId) ← mvarId.introN numParams
trace[split.debug] "before unifyEqs\n{mvarId}"
match (← Cases.unifyEqs? (numEqs + info.getNumDiscrEqs) mvarId {}) with
| none => return (i+1, mvarIds) -- case was solved
| some (mvarId, fvarSubst) =>
trace[split.debug] "after unifyEqs\n{mvarId}"
let mvarId ← substDiscrEqs mvarId fvarSubst discrEqs
return (i+1, mvarId::mvarIds)
return mvarIds.reverse
def mkDiscrGenErrorMsg (e : Expr) : MessageData :=
m!"`split` tactic failed to generalize discriminant(s) at{indentExpr e}\nresulting expression was not type correct\npossible solution: generalize discriminant(s) manually before using `split`"
def throwDiscrGenError (e : Expr) : MetaM α :=
throwError (mkDiscrGenErrorMsg e)
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
let some app ← matchMatcherApp? e | throwError "internal error in `split` tactic: match application expected{indentExpr e}\nthis error typically occurs when the `split` tactic internal functions have been used in a new meta-program"
let matchEqns ← Match.getEquationsFor app.matcherName
let mvarIds ← applyMatchSplitter mvarId app.matcherName app.matcherLevels app.params app.discrs
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
let mvarId ← simpMatchTargetCore mvarId app.matcherName matchEqns.eqnNames[i]!
return (i+1, mvarId::mvarIds)
return mvarIds.reverse
end Split
open Split
partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do
let target ← instantiateMVars (← mvarId.getType)
let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
if let some e ← findSplit? target (if splitIte then .both else .match) badCases then
if e.isIte || e.isDIte then
return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
else
try
splitMatch mvarId e
catch ex =>
if isDiscrGenException ex then
trace[split.failure] mkDiscrGenErrorMsg e
else
trace[split.failure] "`split` tactic failed at{indentExpr e}\n{ex.toMessageData}"
go (badCases.insert e)
else
trace[split.debug] "did not find term to split\n{MessageData.ofGoal mvarId}"
return none
go {}
def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
mvarId.withContext do
if let some e ← findSplit? (← instantiateMVars (← inferType (mkFVar fvarId))) then
if e.isIte || e.isDIte then
return (← splitIfLocalDecl? mvarId fvarId).map fun (mvarId₁, mvarId₂) => [mvarId₁, mvarId₂]
else
let result? ← commitWhenSome? do try
let (fvarIds, mvarId) ← mvarId.revert #[fvarId]
let num := fvarIds.size
let mvarIds ← splitMatch mvarId e
let mvarIds ← mvarIds.mapM fun mvarId => return (← mvarId.introNP num).2
return some mvarIds
catch ex =>
if isDiscrGenException ex then
return none
else
throw ex
if result?.isSome then
return result?
-- Generalization failed, if `fvarId` is a let-decl or has forward dependencies, we try to `assert` a copy and try again
let localDecl ← fvarId.getDecl
if (← pure localDecl.isLet <||> exprDependsOn (← mvarId.getType) fvarId <||> fvarId.hasForwardDeps) then
try
let mvarId ← mvarId.assert localDecl.userName localDecl.type localDecl.toExpr
let mvarIds ← splitMatch mvarId e
let mvarIds ← mvarIds.mapM fun mvarId => return (← mvarId.intro1P).2
return some mvarIds
catch ex =>
if isDiscrGenException ex then
throwDiscrGenError e
else
throw ex
throwDiscrGenError e
else
return none
builtin_initialize
registerTraceClass `split.debug
registerTraceClass `split.failure
end Lean.Meta