Skip to content

Commit c53e762

Browse files
committed
add projection mooncake rules
1 parent ab9e777 commit c53e762

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

ext/TensorKitMooncakeExt/linalg.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,48 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM
8686

8787
return Ainv_ΔAinv, inv_pullback
8888
end
89+
90+
# single-output projections: project_hermitian!, project_antihermitian!
91+
for (f!, f, adj) in (
92+
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
93+
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
94+
)
95+
@eval begin
96+
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
97+
A, dA = arrayify(A_dA)
98+
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
99+
100+
# don't need to copy/restore A since projections don't mutate input
101+
argc = copy(arg)
102+
arg = $f!(A, arg, Mooncake.primal(alg_dalg))
103+
104+
function $adj(::NoRData)
105+
$f!(darg)
106+
if dA !== darg
107+
add!(dA, darg)
108+
MatrixAlgebraKit.zero!(darg)
109+
end
110+
copy!(arg, argc)
111+
return ntuple(Returns(NoRData()), 4)
112+
end
113+
114+
return arg_darg, $adj
115+
end
116+
117+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
118+
A, dA = arrayify(A_dA)
119+
output = $f(A, Mooncake.primal(alg_dalg))
120+
output_doutput = Mooncake.zero_fcodual(output)
121+
122+
doutput = last(arrayify(output_doutput))
123+
function $adj(::NoRData)
124+
# TODO: need accumulating projection to avoid intermediate here
125+
add!(dA, $f(doutput))
126+
MatrixAlgebraKit.zero!(doutput)
127+
return ntuple(Returns(NoRData()), 3)
128+
end
129+
130+
return output_doutput, $adj
131+
end
132+
end
133+
end

0 commit comments

Comments
 (0)