Skip to content

Commit 35e31ae

Browse files
committed
Optimize cython variant
1 parent 3178f69 commit 35e31ae

3 files changed

Lines changed: 107 additions & 39 deletions

File tree

.github/workflows/test_cython.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,4 @@ jobs:
3030
--executable='uv run main.py' \
3131
--cwd='../cython' \
3232
--strictness=4 \
33-
--shuffle=42 \
34-
--filter='o:{"lines": "0"}'
33+
--shuffle=42

cython/calculate.pyx

Lines changed: 105 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
# cython: initializedcheck=False
66
# cython: cdivision=True
77

8-
98
from time import time
10-
from itertools import count
9+
10+
from libc.math cimport fabs
11+
12+
import numpy as np
13+
cimport numpy as np
1114

1215
from partdiff_common.parse_args import (
1316
Options,
@@ -20,45 +23,111 @@ from partdiff_common import (
2023
CalculationResults,
2124
)
2225

26+
cdef int METH_GAUSS_SEIDEL = 1
27+
cdef int METH_JACOBI = 2
28+
cdef int TERM_ACC = 1
29+
cdef int TERM_ITER = 2
2330

24-
def calculate(arguments: CalculationArguments, options: Options) -> CalculationResults:
25-
start_time = time()
26-
n = arguments.n
27-
tensor = arguments.tensor
28-
perturbation_matrix = arguments.perturbation_matrix
29-
stat_iteration = 0
30-
stat_accuracy = None
31-
matrix_out = tensor[0, :, :]
32-
matrix_in = matrix_out
33-
if options.method == CalculationMethod.JACOBI:
34-
matrix_in = tensor[1, :, :]
35-
for stat_iteration in count(start=1):
31+
cdef inline double tensor_get(double* tensor, int N, int m, int i, int j) noexcept nogil:
32+
return tensor[m * (N+1) * (N+1) + i * (N+1) + j]
33+
34+
cdef inline void tensor_set(double* tensor, int N, int m, int i, int j, double value) noexcept nogil:
35+
tensor[m * (N+1) * (N+1) + i * (N+1) + j] = value
36+
37+
cdef inline double matrix_get(double* matrix, int N, int i, int j) noexcept nogil:
38+
return matrix[i * (N+1) + j]
39+
40+
cdef tuple calculate_inner(
41+
int N,
42+
int term_iteration,
43+
double* tensor,
44+
double* perturbation_matrix,
45+
int method,
46+
int termination,
47+
double term_accuracy
48+
):
49+
cdef int m1, m2
50+
cdef int i, j
51+
cdef double star, residuum, maxresiduum
52+
cdef int stat_iteration = 0
53+
cdef double stat_accuracy = 0.0
54+
cdef int temp
55+
if method == METH_JACOBI:
56+
m1 = 0
57+
m2 = 1
58+
else:
59+
m1 = 0
60+
m2 = 0
61+
while term_iteration > 0:
3662
maxresiduum = 0.0
37-
for i in range(1, n):
38-
for j in range(1, n):
63+
for i in range(1, N):
64+
for j in range(1, N):
3965
star = 0.25 * (
40-
matrix_in[i - 1, j]
41-
+ matrix_in[i, j - 1]
42-
+ matrix_in[i, j + 1]
43-
+ matrix_in[i + 1, j]
66+
tensor_get(tensor, N, m2, i-1, j) +
67+
tensor_get(tensor, N, m2, i, j-1) +
68+
tensor_get(tensor, N, m2, i, j+1) +
69+
tensor_get(tensor, N, m2, i+1, j)
4470
)
45-
star += perturbation_matrix[i, j]
46-
if (
47-
options.termination == TerminationCondition.ACCURACY
48-
or stat_iteration == options.term_iteration
49-
):
50-
residuum = abs(matrix_in[i, j] - star)
51-
maxresiduum = max(maxresiduum, residuum)
52-
matrix_out[i, j] = star
71+
star += matrix_get(perturbation_matrix, N, i, j)
72+
if termination == TERM_ACC or term_iteration == 1:
73+
residuum = tensor_get(tensor, N, m2, i, j) - star
74+
residuum = fabs(residuum)
75+
if residuum > maxresiduum:
76+
maxresiduum = residuum
77+
tensor_set(tensor, N, m1, i, j, star)
78+
stat_iteration += 1
5379
stat_accuracy = maxresiduum
54-
matrix_in, matrix_out = matrix_out, matrix_in
55-
if options.termination == TerminationCondition.ACCURACY:
56-
if maxresiduum < options.term_accuracy:
57-
break
58-
else:
59-
if stat_iteration == options.term_iteration:
60-
break
80+
temp = m1
81+
m1 = m2
82+
m2 = temp
83+
if termination == TERM_ACC:
84+
if maxresiduum < term_accuracy:
85+
term_iteration = 0
86+
elif termination == TERM_ITER:
87+
term_iteration -= 1
88+
return m2, stat_iteration, stat_accuracy
89+
90+
def calculate_np(
91+
int N,
92+
int term_iteration,
93+
np.ndarray[np.float64_t, ndim=3, mode="c"] tensor,
94+
np.ndarray[np.float64_t, ndim=2, mode="c"] perturbation_matrix,
95+
int method,
96+
int termination,
97+
double term_accuracy
98+
):
99+
if not (1 <= tensor.shape[0] <= 2) or tensor.shape[1] != N+1 or tensor.shape[2] != N+1:
100+
raise ValueError("tensor must have shape (2, N+1, N+1)")
101+
if perturbation_matrix.shape[0] != N+1 or perturbation_matrix.shape[1] != N+1:
102+
raise ValueError("perturbation_matrix must have shape (N+1, N+1)")
103+
cdef double* tensor_ptr = <double*> tensor.data
104+
cdef double* matrix_ptr = <double*> perturbation_matrix.data
105+
cdef int m
106+
cdef int stat_iteration
107+
cdef double stat_accuracy
108+
m, stat_iteration, stat_accuracy = calculate_inner(
109+
N,
110+
term_iteration,
111+
tensor_ptr,
112+
matrix_ptr,
113+
method,
114+
termination,
115+
term_accuracy
116+
)
117+
return m, stat_iteration, stat_accuracy
118+
119+
def calculate(arguments: CalculationArguments, options: Options) -> CalculationResults:
120+
start_time = time()
121+
m, stat_iteration, stat_accuracy = calculate_np(
122+
arguments.n,
123+
options.term_iteration,
124+
arguments.tensor,
125+
arguments.perturbation_matrix,
126+
options.method.value,
127+
options.termination.value,
128+
options.term_accuracy,
129+
)
61130
end_time = time()
62131
duration = end_time - start_time
63-
final_matrix = matrix_in
132+
final_matrix = arguments.tensor[m, :, :]
64133
return CalculationResults(final_matrix, stat_iteration, stat_accuracy, duration)

cython/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"calculate",
88
["calculate.pyx"],
99
include_dirs=[np.get_include()],
10-
extra_compile_args=["-O3", "-march=native", "-ffast-math"],
10+
extra_compile_args=["-O3", "-march=native"],
1111
)
1212
]
1313

0 commit comments

Comments
 (0)