@@ -86,3 +86,48 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM
8686
8787 return Ainv_ΔAinv, inv_pullback
8888end
89+
90+ # single-output projections: project_hermitian!, project_antihermitian!
91+ for (f!, f, adj) in (
92+ (:project_hermitian! , :project_hermitian , :project_hermitian_adjoint ),
93+ (:project_antihermitian! , :project_antihermitian , :project_antihermitian_adjoint ),
94+ )
95+ @eval begin
96+ function Mooncake. rrule!! (f_df:: CoDual{typeof($f!)} , A_dA:: CoDual{<:AbstractTensorMap} , arg_darg:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
97+ A, dA = arrayify (A_dA)
98+ arg, darg = A_dA === arg_darg ? (A, dA) : arrayify (arg_darg)
99+
100+ # don't need to copy/restore A since projections don't mutate input
101+ argc = copy (arg)
102+ arg = $ f! (A, arg, Mooncake. primal (alg_dalg))
103+
104+ function $adj (:: NoRData )
105+ $ f! (darg)
106+ if dA != = darg
107+ add! (dA, darg)
108+ MatrixAlgebraKit. zero! (darg)
109+ end
110+ copy! (arg, argc)
111+ return ntuple (Returns (NoRData ()), 4 )
112+ end
113+
114+ return arg_darg, $ adj
115+ end
116+
117+ function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual{<:AbstractTensorMap} , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
118+ A, dA = arrayify (A_dA)
119+ output = $ f (A, Mooncake. primal (alg_dalg))
120+ output_doutput = Mooncake. zero_fcodual (output)
121+
122+ doutput = last (arrayify (output_doutput))
123+ function $adj (:: NoRData )
124+ # TODO : need accumulating projection to avoid intermediate here
125+ add! (dA, $ f (doutput))
126+ MatrixAlgebraKit. zero! (doutput)
127+ return ntuple (Returns (NoRData ()), 3 )
128+ end
129+
130+ return output_doutput, $ adj
131+ end
132+ end
133+ end
0 commit comments