Skip to content

Commit 1f99e1d

Browse files
committed
remove type instabilities
1 parent 8e946db commit 1f99e1d

8 files changed

Lines changed: 27 additions & 8 deletions

File tree

FFTWOperators/src/IRDFT.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,7 @@ is_invertible(L::IRDFT) = true
8282
is_full_row_rank(L::IRDFT) = true
8383

8484
has_fast_opnorm(::IRDFT) = true
85-
LinearAlgebra.opnorm(L::IRDFT{T}) where {T} = sqrt(prod(L.dim_out) / 2)
85+
function LinearAlgebra.opnorm(L::IRDFT{T}) where {T}
86+
@assert length(L.dim_out) > 0
87+
return sqrt(prod(L.dim_out) / 2)
88+
end

NFFTOperators/src/NFFTOp.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,19 @@ function create_plan(trajectory, image_size, threaded; kwargs...)
169169
end
170170
end
171171

172+
# Helper to create matched forward/backward FFT plans so that JET can track
173+
# that both plans have the same element type T and dimension D.
174+
struct _MatchedFFTPlans{T, D}
175+
forward::FFTW.cFFTWPlan{Complex{T}, -1, true, D, UnitRange{Int64}}
176+
backward::FFTW.cFFTWPlan{Complex{T}, 1, true, D, UnitRange{Int64}}
177+
end
178+
179+
function _make_matched_fft_plans(tmpVec::Array{Complex{T}, D}, dims_; kwargs...) where {T, D}
180+
FP = FFTW.plan_fft!(tmpVec, dims_; kwargs...)::FFTW.cFFTWPlan{Complex{T}, -1, true, D, UnitRange{Int64}}
181+
BP = FFTW.plan_bfft!(tmpVec, dims_; kwargs...)::FFTW.cFFTWPlan{Complex{T}, 1, true, D, UnitRange{Int64}}
182+
return _MatchedFFTPlans{T, D}(FP, BP)
183+
end
184+
172185
function NFFTPlan(
173186
k::Matrix{T},
174187
N::NTuple{D, Int};
@@ -187,8 +200,9 @@ function NFFTPlan(
187200
tmpVec = Array{Complex{T}, D}(undef, Ñ)
188201

189202
fftflags_ = (fftflags !== nothing) ? (flags = fftflags,) : NamedTuple()
190-
FP = FFTW.plan_fft!(tmpVec, dims_; num_threads = FFTW.get_num_threads(), fftflags_...)
191-
BP = FFTW.plan_bfft!(tmpVec, dims_; num_threads = FFTW.get_num_threads(), fftflags_...)
203+
plans = _make_matched_fft_plans(tmpVec, dims_; num_threads = FFTW.get_num_threads(), fftflags_...)
204+
FP = plans.forward
205+
BP = plans.backward
192206

193207
calcBlocks =
194208
(

src/calculus/Ax_mul_Bx.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct Ax_mul_Bx{
3333
bufB::C
3434
bufC::C
3535
bufD::D
36-
function Ax_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1, L2, C, D}
36+
function Ax_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1 <: AbstractOperator, L2 <: AbstractOperator, C <: AbstractArray, D <: AbstractArray}
3737
if ndims(A, 1) != 2 || size(A, 2) != size(B, 2) || size(A, 1)[2] != size(B, 1)[1]
3838
throw(DimensionMismatch("Cannot compose operators"))
3939
end

src/calculus/Ax_mul_Bxt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct Ax_mul_Bxt{
3333
bufB::C
3434
bufC::C
3535
bufD::D
36-
function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1, L2, C, D}
36+
function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1 <: AbstractOperator, L2 <: AbstractOperator, C <: AbstractArray, D <: AbstractArray}
3737
if ndims(A, 1) == 1
3838
if size(A) != size(B)
3939
throw(DimensionMismatch("Cannot compose operators"))

src/calculus/Axt_mul_Bx.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct Axt_mul_Bx{
3333
bufB::C
3434
bufC::C
3535
bufD::D
36-
function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1, L2, C, D}
36+
function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1 <: AbstractOperator, L2 <: AbstractOperator, C <: AbstractArray, D <: AbstractArray}
3737
if ndims(A, 1) == 1
3838
if size(A) != size(B)
3939
throw(DimensionMismatch("Cannot compose operators"))

src/calculus/Compose.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ struct Compose{N, M, L <: Tuple, T <: Tuple} <: AbstractOperator
3131
end
3232
# check for adjacent operators that can be combined
3333
i = 1
34+
local new_op::AbstractOperator = A[1] # placeholder; set below when should_be_combined is true
3435
while i < length(A)
3536
should_be_combined = false
3637
triple_combination = false
@@ -292,6 +293,7 @@ end
292293
# utils
293294
function permute(C::Compose, p::AbstractVector{Int})
294295
i = findfirst(x -> ndoms(x, 2) > 1, C.A)
296+
i === nothing && throw(ArgumentError("No operator with multiple domain dimensions found in Compose"))
295297
P = permute(C.A[i], p)
296298
AA = (C.A[1:(i - 1)]..., P, C.A[(i + 1):end]...)
297299
return Compose(AA, C.buf)

src/calculus/HCAT.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function HCAT(A::Vararg{AbstractOperator})
7575
end
7676
end
7777
# use buffer from HCAT in A
78-
buf = A[findfirst((<:).(typeof.(A), HCAT))].buf
78+
buf = A[findfirst((<:).(typeof.(A), HCAT))::Int].buf
7979
else
8080
AA = A
8181
# generate buffer

src/calculus/VCAT.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function VCAT(A::Vararg{AbstractOperator})
6565
end
6666
end
6767
# use buffer from VCAT in A
68-
buf = A[findfirst((<:).(typeof.(A), VCAT))].buf
68+
buf = A[findfirst((<:).(typeof.(A), VCAT))::Int].buf
6969
else
7070
AA = A
7171
# generate buffer

0 commit comments

Comments
 (0)