Skip to content

Commit f677d84

Browse files
committed
fix HCAT+Sum domain problem
1 parent 884bfb6 commit f677d84

2 files changed

Lines changed: 23 additions & 2 deletions

File tree

src/calculus/HCAT.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,23 @@ function HCAT(A::Vararg{AbstractOperator})
8585
return HCAT(AA, buf)
8686
end
8787

88-
# compile-time domain ndoms for HCAT's sub-operators
89-
_ndoms_from_type(::Type{<:HCAT{N}}, dim::Int) where {N} = dim == 2 ? N : 1
88+
# Count actual domain slots from an HCAT's P (idxs) type.
89+
# Each entry in P is either an Int (1 slot) or a NTuple{n,Int} (n slots).
90+
_count_hcat_ndoms(::Type{<:Tuple{}}) = 0
91+
@generated function _count_hcat_ndoms(::Type{P}) where {P <: Tuple}
92+
K = 0
93+
for i in 1:fieldcount(P)
94+
Pi = fieldtype(P, i)
95+
K += Pi <: Integer ? 1 : fieldcount(Pi)
96+
end
97+
return :($K)
98+
end
99+
100+
# compile-time domain ndoms for HCAT's sub-operators:
101+
# use the index-tuple type P (not N which only counts sub-operators) so that
102+
# sub-operators with multi-component domains are accounted for correctly.
103+
_ndoms_from_type(::Type{<:HCAT{N, L, P}}, dim::Int) where {N, L, P} =
104+
dim == 2 ? _count_hcat_ndoms(P) : 1
90105

91106
@generated function HCAT(AA::NTuple{N, AbstractOperator}, buf::C) where {N, C}
92107
N == 1 && return :(AA[1])

src/calculus/Sum.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ codomain_storage_type(S::Sum{K, C, D, L}) where {K, C <: Tuple, D, L} = codomain
162162
fun_domain(S::Sum) = fun_domain(S.A[1])
163163
fun_codomain(S::Sum) = fun_codomain(S.A[1])
164164

165+
# A Sum of multi-domain operators (e.g. Sum of HCATs) has the same domain arity
166+
# as its first constituent. This extends the HCAT machinery so that HCAT can
167+
# correctly assign indices when a Sum appears as one of its sub-operators.
168+
_ndoms_from_type(::Type{<:Sum{K, C, D, L}}, dim::Int) where {K, C, D, L} =
169+
_ndoms_from_type(fieldtype(L, 1), dim)
170+
165171
fun_name(S::Sum) = length(S.A) == 2 ? fun_name(S.A[1]) * "+" * fun_name(S.A[2]) : "Σ"
166172

167173
is_linear(L::Sum) = all(is_linear.(L.A))

0 commit comments

Comments
 (0)