Skip to content

Commit 8fddacd

Browse files
committed
fix missing BroadCast adjoint reshaping
1 parent bf97079 commit 8fddacd

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/calculus/BroadCast.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ function mul!(y::AbstractArray, A::AdjointOperator{<:OperatorBroadCast{T, N, M,
149149
R = A.A
150150
for idx in R.idxs
151151
b_slice = get_input_slice(R, idx, b)
152+
if size(b_slice) != size(R.A, 1)
153+
b_slice = reshape(b_slice, size(R.A, 1))
154+
end
152155
mul!(R.bufD, R.A', b_slice)
153156
if idx == first(R.idxs)
154157
y .= R.bufD
@@ -176,6 +179,9 @@ function _threaded_broadcast_adjoint!(y, R::OperatorBroadCast, b)
176179
idx_end = min(t * chunk, length(R.idxs))
177180
for i in idx_start:idx_end
178181
b_slice = get_input_slice(R, R.idxs[i], b)
182+
if size(b_slice) != size(R.A[t], 1)
183+
b_slice = reshape(b_slice, size(R.A[t], 1))
184+
end
179185
mul!(R.bufD[t], R.A[t]', b_slice)
180186
@lock lock y .+= R.bufD[t]
181187
end

0 commit comments

Comments
 (0)