diff --git a/CompPoly/Univariate/NTT/Domain.lean b/CompPoly/Univariate/NTT/Domain.lean index f00e952e..59bd5842 100644 --- a/CompPoly/Univariate/NTT/Domain.lean +++ b/CompPoly/Univariate/NTT/Domain.lean @@ -59,11 +59,31 @@ def inverse (D : Domain R) : Domain R where section RawHelpers -variable [BEq R] [LawfulBEq R] +variable [BEq R] /-- Required convolution length for multiplying `p` and `q`. -/ def requiredLength (p q : CPolynomial.Raw R) : Nat := - p.trim.size + q.trim.size - 1 + if p.trim.size = 0 ∨ q.trim.size = 0 then + 0 + else + p.trim.size + q.trim.size - 1 + +@[simp] theorem requiredLength_eq_zero_of_left_trim_size_zero + (p q : CPolynomial.Raw R) (hp : p.trim.size = 0) : + requiredLength p q = 0 := by + simp [requiredLength, hp] + +@[simp] theorem requiredLength_eq_zero_of_right_trim_size_zero + (p q : CPolynomial.Raw R) (hq : q.trim.size = 0) : + requiredLength p q = 0 := by + simp [requiredLength, hq] + +theorem requiredLength_eq_of_trim_size_pos + (p q : CPolynomial.Raw R) (hp : 0 < p.trim.size) (hq : 0 < q.trim.size) : + requiredLength p q = p.trim.size + q.trim.size - 1 := by + have hp0 : p.trim.size ≠ 0 := Nat.ne_of_gt hp + have hq0 : q.trim.size ≠ 0 := Nat.ne_of_gt hq + simp [requiredLength, hp0, hq0] /-- Whether domain `D` is large enough for multiplying `p` and `q`. -/ def fits (D : Domain R) (p q : CPolynomial.Raw R) : Prop := diff --git a/CompPoly/Univariate/NTT/FastMul.lean b/CompPoly/Univariate/NTT/FastMul.lean index fe879bd3..1b0a92ac 100644 --- a/CompPoly/Univariate/NTT/FastMul.lean +++ b/CompPoly/Univariate/NTT/FastMul.lean @@ -206,8 +206,11 @@ private theorem coeff_truncate (m : Nat) (a : CPolynomial.Raw R) (i : Nat) : · simp [hi] private theorem mul_coeff_eq_zero_of_requiredLength_le - (p q : CPolynomial.Raw R) {i : Nat} (hi : Domain.requiredLength p q ≤ i) : + (p q : CPolynomial.Raw R) (hppos : 0 < p.trim.size) (hqpos : 0 < q.trim.size) + {i : Nat} (hi : Domain.requiredLength p q ≤ i) : (p * q).coeff i = 0 := by + have hreq : p.trim.size + q.trim.size - 1 ≤ i := by + simpa [Domain.requiredLength_eq_of_trim_size_pos p q hppos hqpos] using hi rw [CPolynomial.Raw.mul_coeff] apply Finset.sum_eq_zero intro x hx @@ -216,7 +219,6 @@ private theorem mul_coeff_eq_zero_of_requiredLength_le simp [hp0] · have hxlt : x < p.trim.size := Nat.lt_of_not_ge hpx have hqle : q.trim.size ≤ i - x := by - simp [Domain.requiredLength] at hi omega have hq0 : q.coeff (i - x) = 0 := coeff_zero_of_trim_size_le q hqle simp [hq0] @@ -239,20 +241,6 @@ private theorem mul_coeff_eq_zero_of_right_trim_size_zero have hq0 : q.coeff (i - x) = 0 := coeff_zero_of_trim_size_le q (by omega) simp [hq0] -private theorem natDegree_toPoly_lt_of_trim_size_le - (D : Domain R) (a : CPolynomial.Raw R) (ha : a.trim.size ≤ D.n) : - a.toPoly.natDegree < D.n := by - by_cases hzero : a.toPoly = 0 - · rw [hzero] - exact D.n_pos - · have hround := CPolynomial.Raw.toImpl_toPoly (R := R) a - have hsize : a.toPoly.toImpl.size = a.trim.size := congrArg Array.size hround - rcases CPolynomial.Raw.toImpl_elim a.toPoly with ⟨hz, _himpl⟩ | ⟨_hnz, himpl⟩ - · exact (hzero hz).elim - · have himpl_size : a.toPoly.toImpl.size = a.toPoly.natDegree + 1 := by - simp [himpl] - omega - private theorem natDegree_toPoly_lt_trim_size_of_pos (a : CPolynomial.Raw R) (ha : 0 < a.trim.size) : a.toPoly.natDegree < a.trim.size := by @@ -295,10 +283,7 @@ private theorem raw_eval_mul (x : R) (p q : CPolynomial.Raw R) : rw [← CPolynomial.Raw.eval_toPoly_eq_eval x (p * q)] rw [← CPolynomial.Raw.eval_toPoly_eq_eval x p] rw [← CPolynomial.Raw.eval_toPoly_eq_eval x q] - have hpoly : (p * q).toPoly = p.toPoly * q.toPoly := by - ext i - exact CPolynomial.Raw.toPoly_mul_coeff p q i - rw [hpoly] + rw [CPolynomial.Raw.toPoly_mul p q] simp private theorem pointwise_forwardSpec_eq_forwardSpec_mul_of_natDegree_lt @@ -314,7 +299,8 @@ private theorem pointwise_forwardSpec_eq_forwardSpec_mul_of_natDegree_lt let k : D.Idx := ⟨i, hiD⟩ have hpsize : i < (Forward.forwardSpec D p).size := by simpa [Forward.forwardSpec] using hiD have hqsize : i < (Forward.forwardSpec D q).size := by simpa [Forward.forwardSpec] using hiD - have hpqsize : i < (Forward.forwardSpec D (p * q)).size := by simpa [Forward.forwardSpec] using hiD + have hpqsize : i < (Forward.forwardSpec D (p * q)).size := by + simpa [Forward.forwardSpec] using hiD have hpget : (Forward.forwardSpec D p)[i]'hpsize = p.eval (D.node k) := by simpa [k] using forwardSpec_get_eq_eval_of_natDegree_lt D p hpdeg k have hqget : (Forward.forwardSpec D q)[i]'hqsize = q.eval (D.node k) := by @@ -324,71 +310,6 @@ private theorem pointwise_forwardSpec_eq_forwardSpec_mul_of_natDegree_lt simp [pointwiseMul] rw [hpget, hqget, hpqget, raw_eval_mul] -private theorem forwardSpec_getD_eq_zero_of_trim_size_zero - (D : Domain R) (p : CPolynomial.Raw R) (hp : p.trim.size = 0) (i : Nat) : - (Forward.forwardSpec D p).getD i 0 = 0 := by - by_cases hi : i < (Forward.forwardSpec D p).size - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi] - simp only [Option.getD_some] - have hiD : i < D.n := by simpa [Forward.forwardSpec] using hi - let k : D.Idx := ⟨i, hiD⟩ - change (Forward.forwardSpec D p)[k.1] = 0 - simp [Forward.forwardSpec, Forward.nttAt] - apply Finset.sum_eq_zero - intro j _ - have hp0 : p.coeff j.1 = 0 := coeff_zero_of_trim_size_le p (by omega) - simp [CPolynomial.Raw.coeff] at hp0 - simp [hp0] - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)] - simp - -omit [BEq R] [LawfulBEq R] in -private theorem inverseSpec_getD_eq_zero_of_getD_zero - (D : Domain R) (a : Array R) (ha : ∀ i : Nat, a.getD i 0 = 0) (i : Nat) : - (Inverse.inverseSpec D a).getD i 0 = 0 := by - by_cases hi : i < (Inverse.inverseSpec D a).size - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi] - simp only [Option.getD_some] - simp only [Inverse.inverseSpec, Inverse.inttAt, Array.getElem_ofFn] - have hsum : (∑ j : D.Idx, a.getD j.1 0 * D.omegaInv ^ (i * j.1)) = 0 := by - apply Finset.sum_eq_zero - intro j _ - simp [ha j.1] - rw [hsum] - simp - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)] - simp - -private theorem pointwise_getD_eq_zero_of_left_trim_size_zero - (D : Domain R) (p q : CPolynomial.Raw R) (hp : p.trim.size = 0) (i : Nat) : - (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).getD i 0 = 0 := by - by_cases hi : i < (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).size - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi] - simp [pointwiseMul] - left - have hpsize : i < (Forward.forwardSpec D p).size := by - simpa [Forward.forwardSpec, pointwiseMul] using hi - have hpget := forwardSpec_getD_eq_zero_of_trim_size_zero D p hp i - rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hpsize] at hpget - simpa using hpget - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)] - simp - -private theorem pointwise_getD_eq_zero_of_right_trim_size_zero - (D : Domain R) (p q : CPolynomial.Raw R) (hq : q.trim.size = 0) (i : Nat) : - (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).getD i 0 = 0 := by - by_cases hi : i < (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).size - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi] - simp [pointwiseMul] - right - have hqsize : i < (Forward.forwardSpec D q).size := by - simpa [Forward.forwardSpec, pointwiseMul] using hi - have hqget := forwardSpec_getD_eq_zero_of_trim_size_zero D q hq i - rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hqsize] at hqget - simpa using hqget - · rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)] - simp - /-- Spec pipeline for NTT-based multiplication. -/ @[inline] def fastMulSpec (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R := let pHat := Forward.forwardSpec D p @@ -403,13 +324,8 @@ private theorem fastMulSpec_coeff_eq_zero_of_left_trim_size_zero rw [fastMulSpec] rw [CPolynomial.Raw.Trim.coeff_eq_coeff] rw [coeff_truncate] - by_cases hi : i < Domain.requiredLength p q - · rw [if_pos hi] - rw [CPolynomial.Raw.coeff] - exact inverseSpec_getD_eq_zero_of_getD_zero D - (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)) - (pointwise_getD_eq_zero_of_left_trim_size_zero D p q hp) i - · rw [if_neg hi] + rw [Domain.requiredLength_eq_zero_of_left_trim_size_zero p q hp] + simp private theorem fastMulSpec_coeff_eq_zero_of_right_trim_size_zero (D : Domain R) (p q : CPolynomial.Raw R) (hq : q.trim.size = 0) (i : Nat) : @@ -417,15 +333,15 @@ private theorem fastMulSpec_coeff_eq_zero_of_right_trim_size_zero rw [fastMulSpec] rw [CPolynomial.Raw.Trim.coeff_eq_coeff] rw [coeff_truncate] - by_cases hi : i < Domain.requiredLength p q - · rw [if_pos hi] - rw [CPolynomial.Raw.coeff] - exact inverseSpec_getD_eq_zero_of_getD_zero D - (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)) - (pointwise_getD_eq_zero_of_right_trim_size_zero D p q hq) i - · rw [if_neg hi] - -/-- Implementation pipeline for NTT-based multiplication. -/ + rw [Domain.requiredLength_eq_zero_of_right_trim_size_zero p q hq] + simp + +/-- +Implementation pipeline for NTT-based multiplication. + +Callers must provide a domain satisfying `Domain.fits D p q`; otherwise the +result is truncated to the domain-supported convolution length. +-/ @[inline] def fastMulImpl (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R := let pHat := Forward.forwardImpl D p let qHat := Forward.forwardImpl D q @@ -452,7 +368,7 @@ theorem fastMulSpec_coeff (D : Domain R) (p q : CPolynomial.Raw R) have hfit' : Domain.requiredLength p q ≤ D.n := by simpa [Domain.fits] using hfit have hfitLen : p.trim.size + q.trim.size - 1 ≤ D.n := by - simpa [Domain.requiredLength] using hfit' + simpa [Domain.requiredLength_eq_of_trim_size_pos p q hppos hqpos] using hfit' have hpdeg_lt_trim := natDegree_toPoly_lt_trim_size_of_pos p hppos have hqdeg_lt_trim := natDegree_toPoly_lt_trim_size_of_pos q hqpos have hpdeg : p.toPoly.natDegree < D.n := by @@ -460,10 +376,7 @@ theorem fastMulSpec_coeff (D : Domain R) (p q : CPolynomial.Raw R) have hqdeg : q.toPoly.natDegree < D.n := by omega have hpqdeg : (p * q).toPoly.natDegree < D.n := by - have hpoly : (p * q).toPoly = p.toPoly * q.toPoly := by - ext j - exact CPolynomial.Raw.toPoly_mul_coeff p q j - rw [hpoly] + rw [CPolynomial.Raw.toPoly_mul p q] refine lt_of_le_of_lt Polynomial.natDegree_mul_le ?_ omega rw [fastMulSpec] @@ -478,7 +391,8 @@ theorem fastMulSpec_coeff (D : Domain R) (p q : CPolynomial.Raw R) rw [hpoint] exact inverse_forwardSpec_coeff_of_lt D (p * q) hiD · rw [if_neg hi] - exact (mul_coeff_eq_zero_of_requiredLength_le p q (Nat.le_of_not_lt hi)).symm + exact (mul_coeff_eq_zero_of_requiredLength_le p q hppos hqpos + (Nat.le_of_not_lt hi)).symm theorem fastMulSpec_eq_mul (D : Domain R) (p q : CPolynomial.Raw R) (hfit : Domain.fits D p q) : fastMulSpec D p q = p * q := by diff --git a/CompPoly/Univariate/NTT/Forward.lean b/CompPoly/Univariate/NTT/Forward.lean index 64718a57..8741d9c5 100644 --- a/CompPoly/Univariate/NTT/Forward.lean +++ b/CompPoly/Univariate/NTT/Forward.lean @@ -136,7 +136,8 @@ private def forwardMathPairsSpec simp [forwardStageSpec, bitRevPermute] | succ completed ih => rw [forwardStageSpec_succ] - exact size_butterflyStage D completed (forwardStageSpec D completed a) ih + rw [size_butterflyStage] + exact ih @[simp] theorem size_forwardStagePureSpec (D : Domain R) (completed : Nat) (a : Array R) : (forwardStagePureSpec D completed a).size = D.n := by @@ -145,7 +146,8 @@ private def forwardMathPairsSpec simp [forwardStagePureSpec, bitRevPermute] | succ completed ih => rw [forwardStagePureSpec_succ] - exact size_butterflyStageSpec D completed (forwardStagePureSpec D completed a) ih + rw [size_butterflyStageSpec] + exact ih /-- The algorithmic stage recursion agrees with the pure stage recursion. @@ -162,7 +164,6 @@ theorem forwardStageSpec_eq_forwardStagePureSpec (D : Domain R) (a : Array R) : | succ completed ih => rw [forwardStageSpec_succ, forwardStagePureSpec_succ, ih] rw [butterflyStage_eq_butterflyStageSpec D completed (forwardStagePureSpec D completed a)] - exact size_forwardStagePureSpec D completed a /-- Base case of the mathematical stage invariant: before any butterflies, the @@ -371,12 +372,6 @@ private theorem omega_pow_domain_half_eq_neg_one exact IsPrimitiveRoot.pow D.n_pos D.primitive hprod exact IsPrimitiveRoot.eq_neg_one_of_two_right hprim2 -private theorem omega_pow_stage_stride_eq_neg_one - (D : Domain R) (stage : Nat) (hstage : stage < D.logN) : - D.omega ^ (2 ^ stage * 2 ^ (D.logN - (stage + 1))) = -1 := by - rw [stage_stride_half_eq_domain_half D stage hstage] - exact omega_pow_domain_half_eq_neg_one D (by omega) - private theorem forwardMathPairsSpec_get_lower_current (D : Domain R) (stage block j : Nat) (a : Array R) (hj : j < 2 ^ stage) (hi : block * 2 ^ (stage + 1) + j < (forwardMathPairsSpec D stage block j a).size) : @@ -540,7 +535,7 @@ private theorem forwardMathValueAt_succ_upper ring private theorem eq_lower_or_upper_of_block_pair - (stage block j i : Nat) (_hj : j < 2 ^ stage) + (stage block j i : Nat) (hblock : i / 2 ^ (stage + 1) = block) (hpair : i % 2 ^ stage = j) : i = block * 2 ^ (stage + 1) + j ∨ i = block * 2 ^ (stage + 1) + j + 2 ^ stage := by @@ -583,7 +578,7 @@ private theorem eq_lower_or_upper_of_block_pair omega private theorem forwardMathPairsSpec_get_unchanged - (D : Domain R) (stage block j : Nat) (a : Array R) (hj : j < 2 ^ stage) + (D : Domain R) (stage block j : Nat) (a : Array R) {i : Nat} (hiOld : i < (forwardMathPairsSpec D stage block j a).size) (hiNew : i < (forwardMathPairsSpec D stage block (j + 1) a).size) @@ -603,7 +598,7 @@ private theorem forwardMathPairsSpec_get_unchanged rw [if_neg hltPair] by_cases hltPairNext : i % 2 ^ stage < j + 1 · have hpair : i % 2 ^ stage = j := by omega - rcases eq_lower_or_upper_of_block_pair stage block j i hj hEqBlock hpair with h | h + rcases eq_lower_or_upper_of_block_pair stage block j i hEqBlock hpair with h | h · exact (hneLower h.symm).elim · exact (hneUpper h.symm).elim · rw [if_neg hltPairNext] @@ -708,7 +703,7 @@ private theorem butterflyInnerStep_forwardMathPairsSpec_succ rw [forwardMathPairsSpec_get_lower_next D stage block donePairs a hdonePairs hi₂] exact (forwardMathValueAt_succ_lower D stage block donePairs a hstage hdonePairs).symm · rw [if_neg hLower] - exact forwardMathPairsSpec_get_unchanged D stage block donePairs a hdonePairs + exact forwardMathPairsSpec_get_unchanged D stage block donePairs a hi₁ hi₂ hLower hUpper · rw [pow_succ] diff --git a/CompPoly/Univariate/NTT/Transform.lean b/CompPoly/Univariate/NTT/Transform.lean index 9690b232..2493099e 100644 --- a/CompPoly/Univariate/NTT/Transform.lean +++ b/CompPoly/Univariate/NTT/Transform.lean @@ -132,9 +132,8 @@ specification. This is where the local `set!` bookkeeping for one stage belongs. -/ theorem butterflyStage_eq_butterflyStageSpec - (D : Domain R) (stage : Nat) (a : Array R) (ha : a.size = D.n) : + (D : Domain R) (stage : Nat) (a : Array R) : butterflyStage D stage a = butterflyStageSpec D stage a := by - have _ := ha let blockSize : Nat := 2 ^ (stage + 1) let half : Nat := 2 ^ stage let wm := D.omega ^ (D.n / blockSize) @@ -209,19 +208,14 @@ private theorem size_butterflyStageSpec_aux | n + 1, acc => by simp [size_butterflyStageSpec_aux blockSize half wm n acc] -theorem size_butterflyStageSpec (D : Domain R) (stage : Nat) (a : Array R) (ha : a.size = D.n) : - (butterflyStageSpec D stage a).size = D.n := by - let blockSize : Nat := 2 ^ (stage + 1) - let half : Nat := 2 ^ stage - let wm := D.omega ^ (D.n / blockSize) - rw [show (butterflyStageSpec D stage a).size = a.size by - simp [butterflyStageSpec, size_butterflyStageSpec_aux]] - exact ha - -theorem size_butterflyStage (D : Domain R) (stage : Nat) (a : Array R) (ha : a.size = D.n) : - (butterflyStage D stage a).size = D.n := by - rw [butterflyStage_eq_butterflyStageSpec D stage a ha] - exact size_butterflyStageSpec D stage a ha +theorem size_butterflyStageSpec (D : Domain R) (stage : Nat) (a : Array R) : + (butterflyStageSpec D stage a).size = a.size := by + simp [butterflyStageSpec, size_butterflyStageSpec_aux] + +theorem size_butterflyStage (D : Domain R) (stage : Nat) (a : Array R) : + (butterflyStage D stage a).size = a.size := by + rw [butterflyStage_eq_butterflyStageSpec D stage a] + exact size_butterflyStageSpec D stage a /-- Run all radix-2 butterfly stages (complexity: `O(n log n)`). -/ def runStages (D : Domain R) (a : Array R) : Array R := Id.run do diff --git a/CompPoly/Univariate/ToPoly/Equiv.lean b/CompPoly/Univariate/ToPoly/Equiv.lean index 61f8aa8a..1ebe5212 100644 --- a/CompPoly/Univariate/ToPoly/Equiv.lean +++ b/CompPoly/Univariate/ToPoly/Equiv.lean @@ -61,6 +61,12 @@ lemma Raw.toPoly_mul_coeff [LawfulBEq R] (p q : CPolynomial.Raw R) (i : ℕ) : rcases h_coeff (i - x) with ⟨_, hq⟩ simp [hp, hq] +@[grind =] +lemma Raw.toPoly_mul [LawfulBEq R] (p q : CPolynomial.Raw R) : + (p * q).toPoly = p.toPoly * q.toPoly := by + ext i + exact Raw.toPoly_mul_coeff p q i + @[grind =] lemma toPoly_mul_coeffC [LawfulBEq R] (p q : CPolynomial R) (i : ℕ) : (p.val * q.val).toPoly.coeff i = (p.val.toPoly * q.val.toPoly).coeff i := by diff --git a/tests/CompPolyTests/Univariate/NTT/FastMul.lean b/tests/CompPolyTests/Univariate/NTT/FastMul.lean index b4797b6f..542c9e6d 100644 --- a/tests/CompPolyTests/Univariate/NTT/FastMul.lean +++ b/tests/CompPolyTests/Univariate/NTT/FastMul.lean @@ -9,7 +9,7 @@ import CompPolyTests.Univariate.NTT.Common /-! # Univariate NTT FastMul Tests - Concrete executable checks for the temporary spec-backed NTT multiplication path. + Concrete executable checks for the iterative butterfly NTT multiplication path. -/ namespace CompPoly diff --git a/tests/CompPolyTests/Univariate/NTT/Forward.lean b/tests/CompPolyTests/Univariate/NTT/Forward.lean index 997e07c6..892e0ca2 100644 --- a/tests/CompPolyTests/Univariate/NTT/Forward.lean +++ b/tests/CompPolyTests/Univariate/NTT/Forward.lean @@ -9,7 +9,7 @@ import CompPolyTests.Univariate.NTT.Common /-! # Univariate NTT Forward Tests - Concrete executable checks for the temporary spec-backed forward NTT path. + Concrete executable checks for the iterative butterfly forward NTT path. -/ namespace CompPoly diff --git a/tests/CompPolyTests/Univariate/NTT/Inverse.lean b/tests/CompPolyTests/Univariate/NTT/Inverse.lean index 7d3d247e..07e62a3a 100644 --- a/tests/CompPolyTests/Univariate/NTT/Inverse.lean +++ b/tests/CompPolyTests/Univariate/NTT/Inverse.lean @@ -10,7 +10,7 @@ import CompPolyTests.Univariate.NTT.Common /-! # Univariate NTT Inverse Tests - Concrete executable checks for the temporary spec-backed inverse NTT path. + Concrete executable checks for the iterative butterfly inverse NTT path. -/ namespace CompPoly