Skip to content

Commit ed55d22

Browse files
feat(maths): add LU decomposition algorithm using Doolittle method
1 parent 4b8099c commit ed55d22

2 files changed

Lines changed: 235 additions & 0 deletions

File tree

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package com.thealgorithms.maths;
2+
3+
/**
4+
* @brief Implementation of LU Decomposition using the Doolittle algorithm
5+
* @details Decomposes a square matrix A into a lower triangular matrix L and
6+
* an upper triangular matrix U such that A = L * U. The diagonal of L contains
7+
* all ones (Doolittle convention). This decomposition is useful for solving
8+
* systems of linear equations, computing determinants, and finding inverses.
9+
* @see <a href="https://en.wikipedia.org/wiki/LU_decomposition">LU Decomposition</a>
10+
*/
11+
public final class LUDecomposition {
12+
13+
private LUDecomposition() {
14+
}
15+
16+
/**
17+
* @brief Performs LU decomposition on a square matrix using the Doolittle algorithm
18+
* @param matrix a square matrix
19+
* @return a 2D array where the lower triangle (excluding diagonal) contains L
20+
* elements (with implicit 1s on the diagonal) and the upper triangle
21+
* (including diagonal) contains U elements
22+
* @throws IllegalArgumentException if the matrix is not square
23+
* @throws ArithmeticException if a zero pivot is encountered
24+
*/
25+
public static double[][] decompose(double[][] matrix) {
26+
int n = matrix.length;
27+
for (double[] row : matrix) {
28+
if (row.length != n) {
29+
throw new IllegalArgumentException("Matrix must be square.");
30+
}
31+
}
32+
33+
double[][] lu = new double[n][n];
34+
for (int i = 0; i < n; i++) {
35+
for (int j = 0; j < n; j++) {
36+
lu[i][j] = matrix[i][j];
37+
}
38+
}
39+
40+
for (int k = 0; k < n; k++) {
41+
for (int j = k; j < n; j++) {
42+
double sum = 0;
43+
for (int s = 0; s < k; s++) {
44+
sum += lu[k][s] * lu[s][j];
45+
}
46+
lu[k][j] -= sum;
47+
}
48+
49+
if (lu[k][k] == 0) {
50+
throw new ArithmeticException("Zero pivot encountered. Matrix may be singular.");
51+
}
52+
53+
for (int i = k + 1; i < n; i++) {
54+
double sum = 0;
55+
for (int s = 0; s < k; s++) {
56+
sum += lu[i][s] * lu[s][k];
57+
}
58+
lu[i][k] = (lu[i][k] - sum) / lu[k][k];
59+
}
60+
}
61+
62+
return lu;
63+
}
64+
65+
/**
66+
* @brief Extracts the lower triangular matrix L from the combined LU matrix
67+
* @param lu the combined LU matrix from decompose()
68+
* @return the lower triangular matrix L with 1s on the diagonal
69+
*/
70+
public static double[][] getLowerMatrix(double[][] lu) {
71+
int n = lu.length;
72+
double[][] lower = new double[n][n];
73+
for (int i = 0; i < n; i++) {
74+
lower[i][i] = 1.0;
75+
for (int j = 0; j < i; j++) {
76+
lower[i][j] = lu[i][j];
77+
}
78+
}
79+
return lower;
80+
}
81+
82+
/**
83+
* @brief Extracts the upper triangular matrix U from the combined LU matrix
84+
* @param lu the combined LU matrix from decompose()
85+
* @return the upper triangular matrix U
86+
*/
87+
public static double[][] getUpperMatrix(double[][] lu) {
88+
int n = lu.length;
89+
double[][] upper = new double[n][n];
90+
for (int i = 0; i < n; i++) {
91+
for (int j = i; j < n; j++) {
92+
upper[i][j] = lu[i][j];
93+
}
94+
}
95+
return upper;
96+
}
97+
98+
/**
99+
* @brief Solves a system of linear equations Ax = b using LU decomposition
100+
* @param lu the combined LU matrix from decompose()
101+
* @param b the right-hand side vector
102+
* @return the solution vector x
103+
*/
104+
public static double[] solve(double[][] lu, double[] b) {
105+
int n = lu.length;
106+
double[] y = new double[n];
107+
double[] x = new double[n];
108+
109+
// Forward substitution: solve Ly = b
110+
for (int i = 0; i < n; i++) {
111+
double sum = 0;
112+
for (int j = 0; j < i; j++) {
113+
sum += lu[i][j] * y[j];
114+
}
115+
y[i] = b[i] - sum;
116+
}
117+
118+
// Back substitution: solve Ux = y
119+
for (int i = n - 1; i >= 0; i--) {
120+
double sum = 0;
121+
for (int j = i + 1; j < n; j++) {
122+
sum += lu[i][j] * x[j];
123+
}
124+
x[i] = (y[i] - sum) / lu[i][i];
125+
}
126+
127+
return x;
128+
}
129+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package com.thealgorithms.maths;
2+
3+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
public class LUDecompositionTest {
9+
10+
private static final double DELTA = 1e-9;
11+
12+
@Test
13+
public void testDecomposeSimpleMatrix() {
14+
double[][] matrix = {{2, 1, 1}, {4, 3, 3}, {8, 7, 9}};
15+
double[][] lu = LUDecomposition.decompose(matrix);
16+
17+
double[][] lower = LUDecomposition.getLowerMatrix(lu);
18+
double[][] upper = LUDecomposition.getUpperMatrix(lu);
19+
20+
assertArrayEquals(new double[] {1, 1, 1}, new double[] {lower[0][0], lower[1][1], lower[2][2]}, DELTA);
21+
22+
double[][] product = multiply(lower, upper);
23+
assertArrayEquals(new double[] {2, 1, 1, 4, 3, 3, 8, 7, 9}, flatten(product), DELTA);
24+
}
25+
26+
@Test
27+
public void testDecomposeTwoByTwo() {
28+
double[][] matrix = {{1, 2}, {3, 4}};
29+
double[][] lu = LUDecomposition.decompose(matrix);
30+
31+
double[][] lower = LUDecomposition.getLowerMatrix(lu);
32+
double[][] upper = LUDecomposition.getUpperMatrix(lu);
33+
34+
double[][] product = multiply(lower, upper);
35+
assertArrayEquals(new double[] {1, 2, 3, 4}, flatten(product), DELTA);
36+
}
37+
38+
@Test
39+
public void testDecomposeIdentityMatrix() {
40+
double[][] matrix = {{1, 0}, {0, 1}};
41+
double[][] lu = LUDecomposition.decompose(matrix);
42+
43+
double[][] lower = LUDecomposition.getLowerMatrix(lu);
44+
double[][] upper = LUDecomposition.getUpperMatrix(lu);
45+
46+
assertArrayEquals(new double[] {1, 0, 0, 1}, flatten(lower), DELTA);
47+
assertArrayEquals(new double[] {1, 0, 0, 1}, flatten(upper), DELTA);
48+
}
49+
50+
@Test
51+
public void testDecomposeNonSquareMatrixThrows() {
52+
double[][] matrix = {{1, 2, 3}, {4, 5, 6}};
53+
assertThrows(IllegalArgumentException.class, () -> LUDecomposition.decompose(matrix));
54+
}
55+
56+
@Test
57+
public void testDecomposeSingularMatrixThrows() {
58+
double[][] matrix = {{0, 1}, {1, 0}};
59+
assertThrows(ArithmeticException.class, () -> LUDecomposition.decompose(matrix));
60+
}
61+
62+
@Test
63+
public void testSolveLinearSystem() {
64+
double[][] matrix = {{2, 1, 1}, {4, 3, 3}, {8, 7, 9}};
65+
double[] b = {8, 20, 46};
66+
double[][] lu = LUDecomposition.decompose(matrix);
67+
double[] solution = LUDecomposition.solve(lu, b);
68+
69+
assertArrayEquals(new double[] {1, 3, 3}, solution, DELTA);
70+
}
71+
72+
@Test
73+
public void testSolveTwoByTwoSystem() {
74+
double[][] matrix = {{2, 1}, {1, 3}};
75+
double[] b = {5, 7};
76+
double[][] lu = LUDecomposition.decompose(matrix);
77+
double[] solution = LUDecomposition.solve(lu, b);
78+
79+
assertArrayEquals(new double[] {1.6, 1.8}, solution, DELTA);
80+
}
81+
82+
private static double[][] multiply(double[][] a, double[][] b) {
83+
int n = a.length;
84+
double[][] result = new double[n][n];
85+
for (int i = 0; i < n; i++) {
86+
for (int j = 0; j < n; j++) {
87+
for (int k = 0; k < n; k++) {
88+
result[i][j] += a[i][k] * b[k][j];
89+
}
90+
}
91+
}
92+
return result;
93+
}
94+
95+
private static double[] flatten(double[][] matrix) {
96+
int n = matrix.length;
97+
double[] result = new double[n * n];
98+
int idx = 0;
99+
for (double[] row : matrix) {
100+
for (double val : row) {
101+
result[idx++] = val;
102+
}
103+
}
104+
return result;
105+
}
106+
}

0 commit comments

Comments
 (0)