Skip to content

Commit 644f46f

Browse files
committed
add fortran tests
1 parent 044afcb commit 644f46f

6 files changed

Lines changed: 104 additions & 66 deletions

File tree

.github/workflows/ci.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,26 @@ on:
44
push:
55
paths:
66
- "**/*.py"
7+
- "**/*.f90"
78
pull_request:
89
paths:
910
- "**/*.py"
11+
- "**/*.f90"
1012

1113
jobs:
1214

15+
fortran:
16+
runs-on: ubuntu-latest
17+
steps:
18+
- uses: actions/checkout@v1
19+
- uses: actions/setup-python@v1
20+
with:
21+
python-version: '3.x'
22+
- run: sudo apt install -yq --no-install-recommends gfortran ninja-build
23+
- run: pip install meson
24+
- run: meson setup build
25+
- run: meson test -C build
26+
1327
linux:
1428
runs-on: ubuntu-latest
1529
steps:

airtools/logmart.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def logmart(A: np.ndarray, b: np.ndarray,
5252
raise ValueError('b must be all non-negative')
5353

5454
b = b.copy() # needed to avoid modifying outside this function!
55+
# %% make sure there are no 0's in b
56+
b[b <= 1e-8] = 1e-8
5557
# %% set defaults
5658
if x0 is None: # backproject
5759
x = A.T @ b / A.sum()
@@ -62,8 +64,7 @@ def logmart(A: np.ndarray, b: np.ndarray,
6264
x = x0
6365
if not x.size == A.shape[1]:
6466
raise ValueError('x0 must be scalar or match Ncolumns of A')
65-
# %% make sure there are no 0's in b
66-
b[b <= 1e-8] = 1e-8
67+
6768
x[x < 1e-8] = 1e-8
6869
# W=sigma;
6970
# W=linspace(1,0,size(A,1))';

fortran/logmart.f90

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module art
22

3+
use iso_fortran_env, only: wp=>real64
4+
35
implicit none
46

57
contains
68

7-
pure subroutine logmart(A,b,relax,x0,sigma,max_iter,x)
9+
pure subroutine logmart(A,b,relax,x0,sigma,max_iter, x)
810
! delta Chisquare.
911
! stopped if Chisquare increases.
1012
!
@@ -30,45 +32,38 @@ pure subroutine logmart(A,b,relax,x0,sigma,max_iter,x)
3032
! x = [1,2,3]
3133
! b = A*x
3234

33-
use iso_fortran_env, only: wp=>real64
34-
35-
3635
! --- parameter check
3736
real(wp), intent(in) :: A(:,:), b(:)
38-
real(wp), optional, value :: relax
39-
real(wp), intent(in),optional :: x0(:), sigma(:)
37+
real(wp), optional, value :: relax, sigma
38+
real(wp), intent(in),optional :: x0(:)
4039
integer, optional, value :: max_iter
4140
real(wp), intent(out) :: x(:)
4241

43-
real(wp), dimension(size(b)) :: xA, op_sigma, W, op_b, arg,xold,c
42+
real(wp), dimension(size(b)) :: W(size(b)), x_prev,c, op_b
4443
integer :: i
45-
logical :: done
4644
real(wp) :: t,chi2,chiold
4745

48-
op_b = b
4946

5047
if (.not.size(A,1) == size(b)) error stop 'A and b row numbers must match'
48+
if (any(A<0)) error stop 'A must be non-negative'
49+
if (any(b<0)) error stop 'b must be non-negative'
50+
op_b = b
51+
! --- make sure there are no 0's in b
52+
where(op_b <= 1e-8) op_b = 1e-8_wp
5153

5254
! --- set defaults
53-
if (.not.present(relax)) relax = 1._wp
55+
if (.not.present(relax)) relax = 1
5456
if (.not.present(max_iter)) max_iter = 200
57+
if (.not.present(sigma)) sigma = 1
5558

5659
if (.not.present(x0)) then
57-
x = matmul(transpose(A), b) / sum(A)
58-
xA = matmul(A, x)
59-
x = x * maxval(b) / maxval(xA)
60+
x = matmul(transpose(A), op_b) / sum(A)
61+
x = x * maxval(op_b) / maxval(matmul(A, x))
6062
else
6163
x = x0
6264
endif
6365

64-
if (.not.present(sigma)) then
65-
op_sigma = 1._wp
66-
else
67-
op_sigma = sigma
68-
endif
6966

70-
! --- make sure there are no 0's in b
71-
where(op_b<=1e-8) op_b = 1e-8_wp
7267

7368
! W=sigma;
7469
! W=linspace(1,0,size(A,1))';
@@ -77,27 +72,28 @@ pure subroutine logmart(A,b,relax,x0,sigma,max_iter,x)
7772
W = W / sum(W)
7873

7974
! --- iterate solution
80-
i=0
81-
done=.false.
82-
arg= ((matmul(A,x) - op_b) / op_sigma)**2
83-
chi2 = sqrt(sum(arg))
84-
85-
do while (.not.done)
86-
i = i+1
87-
xold = x
88-
xA = matmul(A,x)
89-
t = minval(1/xA)
90-
C = relax*t*(1-(xA/b))
75+
chi2 = chi_squared(A, op_b, x, sigma)
76+
77+
do i = 1, max_iter
78+
x_prev = x
79+
t = minval(1/matmul(A,x))
80+
C = relax*t*(1-(matmul(A,x)/op_b))
9181
x = x / (1-x*matmul(transpose(A),W*C))
9282
! monitor solution
9383
chiold = chi2
94-
chi2 = sqrt( sum(((xA - b)/op_sigma)**2) )
95-
! dchi2=(chi2-chiold)
96-
done = ((chi2>chiold) .and. (i>2)) .or. (i==max_iter) .or. (chi2<0.7)
84+
chi2 = chi_squared(A, op_b, x, sigma)
85+
if (chi2 > chiold .and. i > 2) exit
9786
enddo
9887

99-
x = xold
88+
x = x_prev
10089

10190
end subroutine logmart
10291

92+
93+
pure real(wp) function chi_squared(A, b, x, sigma)
94+
real(wp), intent(in) :: A(:,:), b(:), x(:), sigma
95+
chi_squared = sqrt(sum(((matmul(A,x) - b) / sigma)**2))
96+
97+
end function chi_squared
98+
10399
end module art

fortran/random_utils.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ module random_utils
88

99
contains
1010

11-
subroutine randn(noise)
11+
impure elemental subroutine randn(noise)
1212
! implements Box-Muller Transform
1313
! https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
1414
!
1515
! Output:
1616
! noise: Gaussian noise vector
1717

18-
real(wp),intent(out) :: noise(:)
19-
real(wp),dimension(size(noise)) :: u1, u2
18+
real(wp),intent(out) :: noise
19+
real(wp) :: u1, u2
2020

2121
call random_number(u1)
2222
call random_number(u2)

fortran/test_logmart.f90

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,51 @@
88

99
implicit none
1010

11-
integer, parameter :: N=3
12-
real(wp) :: A(N,N)
13-
real(wp), parameter :: x_true(N)=[5,5,5]
14-
real(wp), parameter :: errtol=0.05_wp
15-
real(wp), dimension(N) :: x, noise, b,bias
16-
logical :: add_bias, add_noise
17-
18-
add_bias = .false.
19-
add_noise = .true.
11+
integer, parameter :: N=4, M=4
12+
real(wp), dimension(N,M) :: A1, A2
13+
real(wp) :: x(M)
14+
logical :: ok=.true., add_noise = .true.
2015

2116
call init_random_seed()
2217

23-
A = reshape([1,0,0, &
24-
0,1,0, &
25-
0,0,1], shape(A), order=[2,1])
18+
A1 = reshape([5,0,0,0, &
19+
0,5,0,0, &
20+
0,0,5,0, &
21+
0,0,0,5], shape(A1), order=[2,1])
22+
23+
A2 = reshape([0,1,2,3, &
24+
1,0,1,2, &
25+
2,1,0,1, &
26+
3,2,1,0], shape(A2), order=[2,1])
27+
28+
x = [1._wp, 3._wp, 0.5_wp, 2._wp]
29+
30+
if (.not. run_test(A1, x, 20,add_noise)) then
31+
ok = .false.
32+
write(stderr,*) 'failed on identity test'
33+
endif
34+
35+
if (.not. run_test(A2, x, 2000, add_noise)) then
36+
ok = .false.
37+
write(stderr,*) 'failed on Fiedler test'
38+
endif
39+
40+
if (.not. ok) error stop
41+
42+
print *, 'OK: logmart'
43+
44+
contains
45+
46+
logical function run_test(A, x, max_iter, add_noise)
47+
48+
real(wp), intent(in) :: A(:,:), x(:)
49+
logical, intent(in) :: add_noise
50+
integer, intent(in) :: max_iter
51+
52+
real(wp), parameter :: errtol=0.05_wp
53+
real(wp), dimension(size(A,1)) :: noise, x_est, b
54+
55+
run_test = .true.
2656

2757
block
2858
integer :: i
@@ -33,30 +63,25 @@
3363
end block
3464

3565
! ---- noisy observation
36-
if (add_bias) then
37-
call randn(bias)
38-
bias = 0.01_wp * bias
39-
print '(/,A,3F10.3)','bias',bias
40-
A = A * spread(bias,2,N)
41-
endif
66+
b = matmul(A,x)
4267

4368
if (add_noise) then
4469
call randn(noise)
4570
noise = 0.01_wp * noise
4671
print '(/,A,3F10.3)', 'noise',noise
47-
b = matmul(A,x_true) + noise
72+
b = b + noise
4873
endif
4974

5075
! ---- inversion
51-
call logmart(A,b,x=x)
76+
call logmart(A,b, max_iter=max_iter, x=x_est)
5277

5378
! --- check estimate
54-
if (any(abs(x-x_true) > errtol*maxval(x_true))) then
55-
print *,x
79+
if (any(abs(x_est-x) > errtol*maxval(x))) then
80+
print *,x_est
5681
write (stderr,*) 'larger than',errtol*100,' % error'
57-
stop 1
82+
run_test = .false.
5883
endif
5984

60-
print '(/,A)','OK: logmart'
85+
end function run_test
6186

6287
end program

matlab/logmart.m

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
%% set defaults
2727
if (nargin<6), max_iter=200.; end
2828
if (nargin<5), sigma=1.; end
29+
%% make sure there are no 0's in y
30+
y(y<=1e-8)=1e-8;
31+
2932
if (nargin<4) || isempty(x0)
3033
x=(A'*y)./sum(A(:));
3134
xA=A*x;
@@ -40,8 +43,7 @@
4043
validateattributes(relax, {'numeric'}, {'scalar', 'positive'})
4144
validateattributes(sigma, {'numeric'}, {'scalar', 'positive'})
4245
validateattributes(max_iter, {'numeric'}, {'scalar', 'positive'})
43-
%% make sure there are no 0's in y
44-
y(y<=1e-8)=1e-8;
46+
4547

4648
% W=sigma;
4749
% W=linspace(1,0,size(A,1))';

0 commit comments

Comments
 (0)