Skip to content

Commit 0ea3cf0

Browse files
authored
Merge pull request #17 from DurieuxPol/fix/qr
Fixed pivot in QR decomposition
2 parents 4762f4c + 14f332c commit 0ea3cf0

2 files changed

Lines changed: 138 additions & 130 deletions

File tree

src/Math-Matrix-Tests/PMQRTest.class.st

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,6 @@ PMQRTest >> assert: inverse isMoorePenroseInverseOf: aMatrix [
2828
closeTo: aMatrix mpInverse transpose.
2929
]
3030

31-
{ #category : 'tests' }
32-
PMQRTest >> testDecompositionOfMatrixCausingErraticFailure [
33-
34-
| a qrDecomposition matricesAndPivot q r expectedMatrix pivot |
35-
a := PMSymmetricMatrix rows:
36-
#( #( 0.41929313699681925 0.05975350554089691
37-
0.2771676258543356 0.35628773381760703 )
38-
#( 0.05975350554089691 0.12794227252152854
39-
0.3257742693302102 0.28814463284245906 )
40-
#( 0.2771676258543356 0.3257742693302102 0.8468441832097453
41-
0.9101872061892353 )
42-
#( 0.35628773381760703 0.28814463284245906
43-
0.9101872061892353 0.5163744224777326 ) ).
44-
45-
qrDecomposition := PMQRDecomposition of: a.
46-
matricesAndPivot := qrDecomposition decomposeWithPivot.
47-
48-
expectedMatrix := PMMatrix rows:
49-
#( #( 0.2771676258543356 0.35628773381760703
50-
0.41929313699681925 0.05975350554089691 )
51-
#( 0.3257742693302102 0.28814463284245906
52-
0.05975350554089691 0.12794227252152854 )
53-
#( 0.8468441832097453 0.9101872061892353
54-
0.2771676258543356 0.3257742693302102 )
55-
#( 0.9101872061892353 0.5163744224777326
56-
0.35628773381760703 0.28814463284245906 ) ).
57-
q := matricesAndPivot at: 1.
58-
r := matricesAndPivot at: 2.
59-
pivot := matricesAndPivot at: 3.
60-
self assert: q * r closeTo: expectedMatrix.
61-
self assert: pivot equals: #( 3 4 3 nil )
62-
]
63-
6431
{ #category : 'tests' }
6532
PMQRTest >> testHorizontalRectangularMatrixCannotBeDecomposed [
6633

@@ -151,6 +118,51 @@ PMQRTest >> testOrthogonalize [
151118
i < 10 ] whileTrue
152119
]
153120

121+
{ #category : 'tests' }
122+
PMQRTest >> testQRDecompositionOnRankDeficientMatrix [
123+
124+
| a qrDecomposition reconstruction |
125+
a := PMMatrix rows: {
126+
{ 1. 2. 3 }.
127+
{ 4. 5. 6 }.
128+
{ 7. 8. 9 } }.
129+
130+
qrDecomposition := PMQRDecomposition of: a.
131+
qrDecomposition decompose.
132+
133+
self assert: qrDecomposition q rank equals: a rank.
134+
self assert: qrDecomposition r rank equals: a rank.
135+
136+
reconstruction := qrDecomposition q * qrDecomposition r.
137+
self assert: reconstruction closeTo: a
138+
]
139+
140+
{ #category : 'tests' }
141+
PMQRTest >> testQRDecompositionWithPivotOnRankDeficientMatrix [
142+
143+
| a qrDecomposition expectedQR expectedPivot reconstruction |
144+
a := PMMatrix rows: {
145+
{ 1. 2. 3 }.
146+
{ 4. 5. 6 }.
147+
{ 7. 8. 9 } }.
148+
expectedQR := PMMatrix rows: {
149+
{ 3. 1. 2 }.
150+
{ 6. 4. 5 }.
151+
{ 9. 7. 8 } }.
152+
expectedPivot := #( 3 1 2 ).
153+
154+
qrDecomposition := PMQRDecomposition of: a.
155+
qrDecomposition decomposeWithPivot.
156+
157+
self assert: qrDecomposition q rank equals: a rank.
158+
self assert: qrDecomposition r rank equals: a rank.
159+
self assert: qrDecomposition q * qrDecomposition r closeTo: expectedQR.
160+
self assert: qrDecomposition pivot equals: expectedPivot.
161+
162+
reconstruction := qrDecomposition q * qrDecomposition r * qrDecomposition permutationMatrixFromPivot inverse.
163+
self assert: reconstruction closeTo: a
164+
]
165+
154166
{ #category : 'tests' }
155167
PMQRTest >> testQRFactorization [
156168

@@ -225,24 +237,25 @@ PMQRTest >> testSimpleQRDecomposition [
225237
{ #category : 'tests' }
226238
PMQRTest >> testSimpleQRDecompositionWithPivot [
227239

228-
| a qrDecomposition decomposition expected |
229-
a := PMMatrix rows: {
240+
| a qrDecomposition expectedQR expectedPivot reconstruction |
241+
a := PMMatrix rows: {
230242
{ 12. -51. 4 }.
231243
{ 6. 167. -68 }.
232244
{ -4. 24. -41 } }.
245+
expectedQR := PMMatrix rows: {
246+
{ -51. 4. 12 }.
247+
{ 167. -68. 6 }.
248+
{ 24. -41. -4 } }.
249+
expectedPivot := #( 2 3 1 ).
233250

234251
qrDecomposition := PMQRDecomposition of: a.
252+
qrDecomposition decomposeWithPivot.
235253

236-
decomposition := qrDecomposition decomposeWithPivot.
237-
decomposition first * decomposition second.
254+
self assert: qrDecomposition q * qrDecomposition r closeTo: expectedQR.
255+
self assert: qrDecomposition pivot equals: expectedPivot.
238256

239-
expected := PMMatrix rows: {
240-
{ -51. 4. 12 }.
241-
{ 167. -68. 6 }.
242-
{ 24. -41. -4 } }.
243-
self
244-
assert: decomposition first * decomposition second
245-
closeTo: expected
257+
reconstruction := qrDecomposition q * qrDecomposition r * qrDecomposition permutationMatrixFromPivot inverse.
258+
self assert: reconstruction closeTo: a
246259
]
247260

248261
{ #category : 'tests' }

src/Math-Matrix/PMQRDecomposition.class.st

Lines changed: 81 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Class {
99
'colSize',
1010
'r',
1111
'q',
12-
'comparisonPrecision'
12+
'pivot'
1313
],
1414
#category : 'Math-Matrix',
1515
#package : 'Math-Matrix'
@@ -22,18 +22,6 @@ PMQRDecomposition class >> of: matrix [
2222
^ self new of: matrix
2323
]
2424

25-
{ #category : 'constants' }
26-
PMQRDecomposition >> comparisonPrecision [
27-
28-
^ comparisonPrecision ifNil: [ self defaultComparisonPrecision ]
29-
]
30-
31-
{ #category : 'accessing' }
32-
PMQRDecomposition >> comparisonPrecision: anObject [
33-
34-
comparisonPrecision := anObject
35-
]
36-
3725
{ #category : 'arithmetic' }
3826
PMQRDecomposition >> decompose [
3927
"
@@ -73,91 +61,70 @@ https://en.wikipedia.org/wiki/QR_decomposition#Using_Householder_reflections
7361

7462
{ #category : 'arithmetic' }
7563
PMQRDecomposition >> decomposeWithPivot [
64+
"Variant of the decompose method that introduces a pivot. At the beginning of each step it takes the largest remaining column, thus introducing a pivot. For more info, look at https://en.wikipedia.org/wiki/QR_decomposition#Column_pivoting.
65+
Here the pivot is stored as an array containing the new order of the columns of the input matrix. It can be used to generate the proper permutation matrix with the permutationMatrixFromPivot method"
7666

77-
| i vectorOfNormSquareds rank positionOfMaximum pivot matrixOfMinor |
78-
vectorOfNormSquareds := matrixToDecompose columnsCollect: [
79-
:columnVector | columnVector * columnVector ].
67+
| i vectorOfNormSquareds rank positionOfMaximum matrixOfMinor |
68+
vectorOfNormSquareds := matrixToDecompose columnsCollect: [ :columnVector | columnVector * columnVector ].
8069
positionOfMaximum := vectorOfNormSquareds indexOf: vectorOfNormSquareds max.
81-
pivot := Array new: vectorOfNormSquareds size.
70+
pivot := (1 to: vectorOfNormSquareds size) asArray.
8271
rank := 0.
83-
[
84-
| householderReflection householderMatrix householderVector columnVectorFromRMatrix |
85-
rank := rank + 1.
86-
pivot at: rank put: positionOfMaximum.
87-
r swapColumn: rank withColumn: positionOfMaximum.
88-
vectorOfNormSquareds swap: rank with: positionOfMaximum.
89-
columnVectorFromRMatrix := r columnVectorAt: rank size: colSize.
90-
householderReflection := self
91-
householderReflectionOf:
92-
columnVectorFromRMatrix
93-
atColumnNumber: rank.
94-
householderVector := householderReflection at: 1.
95-
householderMatrix := householderReflection at: 2.
96-
q := q * householderMatrix.
97-
matrixOfMinor := r minor: rank - 1 and: rank - 1.
98-
matrixOfMinor := matrixOfMinor
99-
- ((householderVector at: 2) tensorProduct:
100-
(householderVector at: 1)
101-
* (householderVector at: 2) * matrixOfMinor).
102-
matrixOfMinor rowsWithIndexDo: [ :aRow :index |
103-
aRow withIndexDo: [ :element :column |
104-
| rowNumber columnNumber |
105-
rowNumber := rank + index - 1.
106-
columnNumber := rank + column - 1.
107-
r
108-
rowAt: rowNumber
109-
columnAt: columnNumber
110-
put: ((element closeTo: 0)
111-
ifTrue: [ 0 ]
112-
ifFalse: [ element ]) ] ].
113-
rank + 1 to: vectorOfNormSquareds size do: [ :ind |
114-
vectorOfNormSquareds
115-
at: ind
116-
put:
117-
(vectorOfNormSquareds at: ind)
118-
- (r rowAt: rank columnAt: ind) squared ].
119-
rank < vectorOfNormSquareds size
120-
ifTrue: [
121-
positionOfMaximum := (vectorOfNormSquareds
122-
copyFrom: rank + 1
123-
to: vectorOfNormSquareds size) max.
124-
(positionOfMaximum closeTo: 0 precision: self comparisonPrecision) ifTrue: [ positionOfMaximum := 0 ].
125-
positionOfMaximum := positionOfMaximum > 0
126-
ifTrue: [
127-
vectorOfNormSquareds indexOf: positionOfMaximum startingAt: rank + 1 ]
128-
ifFalse: [ 0 ] ]
129-
ifFalse: [ positionOfMaximum := 0 ].
130-
positionOfMaximum > 0 ] whileTrue.
72+
[
73+
| temp householderReflection householderMatrix householderVector columnVectorFromRMatrix |
74+
rank := rank + 1.
75+
temp := pivot at: rank.
76+
pivot at: rank put: (pivot at: positionOfMaximum).
77+
pivot at: positionOfMaximum put: temp.
78+
79+
r swapColumn: rank withColumn: positionOfMaximum.
80+
vectorOfNormSquareds swap: rank with: positionOfMaximum.
81+
columnVectorFromRMatrix := r columnVectorAt: rank size: colSize.
82+
householderReflection := self householderReflectionOf: columnVectorFromRMatrix atColumnNumber: rank.
83+
householderVector := householderReflection first.
84+
householderMatrix := householderReflection second.
85+
q := q * householderMatrix.
86+
matrixOfMinor := r minor: rank - 1 and: rank - 1.
87+
matrixOfMinor := matrixOfMinor
88+
- (householderVector second tensorProduct: householderVector first * householderVector second * matrixOfMinor).
89+
matrixOfMinor rowsWithIndexDo: [ :aRow :index |
90+
aRow withIndexDo: [ :element :column |
91+
| rowNumber columnNumber |
92+
rowNumber := rank + index - 1.
93+
columnNumber := rank + column - 1.
94+
r rowAt: rowNumber columnAt: columnNumber put: ((element closeTo: 0)
95+
ifTrue: [ 0 ]
96+
ifFalse: [ element ]) ] ].
97+
rank + 1 to: vectorOfNormSquareds size do: [ :ind |
98+
vectorOfNormSquareds at: ind put: (vectorOfNormSquareds at: ind) - (r rowAt: rank columnAt: ind) squared ].
99+
rank < vectorOfNormSquareds size
100+
ifTrue: [
101+
positionOfMaximum := (vectorOfNormSquareds copyFrom: rank + 1 to: vectorOfNormSquareds size) max.
102+
(positionOfMaximum closeTo: 0) ifTrue: [ positionOfMaximum := 0 ].
103+
positionOfMaximum := positionOfMaximum > 0
104+
ifTrue: [ vectorOfNormSquareds indexOf: positionOfMaximum startingAt: rank + 1 ]
105+
ifFalse: [ 0 ] ]
106+
ifFalse: [ positionOfMaximum := 0 ].
107+
positionOfMaximum > 0 ] whileTrue.
131108
i := 0.
132-
[ (r rowAt: colSize) isZero ] whileTrue: [
133-
i := i + 1.
134-
colSize := colSize - 1 ].
135-
i > 0 ifTrue: [
136-
r := self upperTriangularPartOf: r With: colSize.
137-
i := q numberOfColumns - i.
138-
pivot := pivot copyFrom: 1 to: i.
139-
q := PMMatrix rows:
140-
(q rowsCollect: [ :row | row copyFrom: 1 to: i ]) ].
109+
[ (r rowAt: colSize) isZero ] whileTrue: [
110+
i := i + 1.
111+
colSize := colSize - 1 ].
112+
i > 0 ifTrue: [
113+
r := self upperTriangularPartOf: r With: colSize.
114+
i := q numberOfColumns - i.
115+
q := PMMatrix rows: (q rowsCollect: [ :row | row copyFrom: 1 to: i ]) ].
141116
^ Array with: q with: r with: pivot
142117
]
143118

144-
{ #category : 'constants' }
145-
PMQRDecomposition >> defaultComparisonPrecision [
146-
147-
^ 0.0001
148-
]
149-
150119
{ #category : 'private' }
151120
PMQRDecomposition >> householderReflectionOf: columnVector atColumnNumber: columnNumber [
152121

153122
| householderVector v identityMatrix householderMatrix |
154123
householderVector := columnVector householder.
155124
v := (PMVector zeros: columnNumber - 1) , (householderVector at: 2).
156125
identityMatrix := PMSymmetricMatrix identity: colSize.
157-
householderMatrix := identityMatrix
158-
-
159-
((householderVector at: 1) * v tensorProduct: v).
160-
^ Array with: householderVector with: householderMatrix .
126+
householderMatrix := identityMatrix - (householderVector first * v tensorProduct: v).
127+
^ Array with: householderVector with: householderMatrix
161128
]
162129

163130
{ #category : 'private' }
@@ -183,8 +150,36 @@ PMQRDecomposition >> of: matrix [
183150

184151
matrixToDecompose := matrix.
185152
colSize := matrixToDecompose numberOfRows.
186-
r := self initialRMatrix.
187-
q := self initialQMatrix.
153+
r := self initialRMatrix.
154+
q := self initialQMatrix
155+
]
156+
157+
{ #category : 'accessing' }
158+
PMQRDecomposition >> permutationMatrixFromPivot [
159+
160+
| matrix |
161+
matrix := PMMatrix zerosRows: matrixToDecompose numberOfRows cols: matrixToDecompose numberOfColumns.
162+
pivot withIndexCollect: [ :column :index | matrix at: column at: index put: 1 ].
163+
164+
^ matrix
165+
]
166+
167+
{ #category : 'accessing' }
168+
PMQRDecomposition >> pivot [
169+
170+
^ pivot
171+
]
172+
173+
{ #category : 'accessing' }
174+
PMQRDecomposition >> q [
175+
176+
^ q
177+
]
178+
179+
{ #category : 'accessing' }
180+
PMQRDecomposition >> r [
181+
182+
^ r
188183
]
189184

190185
{ #category : 'private' }

0 commit comments

Comments
 (0)