Skip to content

Commit b7608ef

Browse files
committed
Even more fixes
1 parent 1b03607 commit b7608ef

4 files changed

Lines changed: 19 additions & 17 deletions

File tree

.github/workflows/CI.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ jobs:
5252
- os: macos-latest
5353
group: enzyme/vectorinterface
5454
- os: macos-latest
55-
group: enzyme/indexmanipulations
55+
group: enzyme/indexmanipulations/add
56+
- os: macos-latest
57+
group: enzyme/indexmanipulations/twist_flip_unit
5658
- os: windows-latest
5759
group: enzyme/factorizations
5860
- os: windows-latest
@@ -62,7 +64,9 @@ jobs:
6264
- os: windows-latest
6365
group: enzyme/vectorinterface
6466
- os: windows-latest
65-
group: enzyme/indexmanipulations
67+
group: enzyme/indexmanipulations/add
68+
- os: windows-latest
69+
group: enzyme/indexmanipulations/twist_flip_unit
6670
uses: "QuantumKitHub/QuantumKitHubActions/.github/workflows/Tests.yml@main"
6771
with:
6872
group: "${{ matrix.group }}"

ext/TensorKitEnzymeExt/vectorinterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function EnzymeRules.reverse(
5454
Aval = something(cache, A.val)
5555
Δα = pullback_dα(α, C, Aval)
5656
!isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj.val))
57-
!isa(C, Const) && zerovector!(C.dval)
57+
!isa(C, Const) && make_zero!(C.dval)
5858
return (nothing, nothing, Δα)
5959
end
6060

test/enzyme/vectorinterface/inner.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,12 @@ eltypes = (Float64, ComplexF64)
3535

3636
@testset "Enzyme - VectorInterface" begin
3737
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
38-
atol = default_tol(T)
39-
rtol = default_tol(T)
40-
C = randn(T, V[1] V[2] V[3] V[4] V[5])
41-
A = randn(T, V[1] V[2] V[3] V[4] V[5])
42-
@testset for RT in (Active,), TC in (Duplicated, Const), TA in (Duplicated, Const)
43-
EnzymeTestUtils.test_reverse(inner, RT, (C, TC), (A, TA); atol, rtol)
44-
EnzymeTestUtils.test_reverse(inner, RT, (C', TC), (A', TA); atol, rtol)
38+
@testset for RT in (Active,), TC in (Duplicated, Const), TA in (Duplicated, Const), f in (identity, adjoint)
39+
atol = default_tol(T)
40+
rtol = default_tol(T)
41+
C = randn(T, V[1] V[2] V[3] V[4] V[5])
42+
A = randn(T, V[1] V[2] V[3] V[4] V[5])
43+
EnzymeTestUtils.test_reverse(inner, RT, (f(C), TC), (f(A), TA); atol, rtol)
4544
end
4645
end
4746
end

test/enzyme/vectorinterface/scale.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,16 @@ eltypes = (Float64, ComplexF64)
3737
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
3838
atol = default_tol(T)
3939
rtol = default_tol(T)
40-
C = randn(T, V[1] V[2] V[3] V[4] V[5])
41-
A = randn(T, V[1] V[2] V[3] V[4] V[5])
4240
α = randn(T)
4341
@testset for TC in (Duplicated,), Tα in (Active, Const)
42+
C = randn(T, V[1] V[2] V[3] V[4] V[5])
4443
EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (α, Tα); atol, rtol)
44+
C = randn(T, V[1] V[2] V[3] V[4] V[5])
4545
EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (α, Tα); atol, rtol)
46-
@testset for TA in (Duplicated,)
47-
EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol)
48-
EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (A', TA), (α, Tα); atol, rtol)
49-
EnzymeTestUtils.test_reverse(scale!, TC, (copy(C'), TC), (A', TA), (α, Tα); atol, rtol)
50-
EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (copy(A'), TA), (α, Tα); atol, rtol)
46+
@testset for TA in (Duplicated,), (fc, fa) in ((identity, identity), (adjoint, adjoint))
47+
C = randn(T, V[1] V[2] V[3] V[4] V[5])
48+
A = randn(T, V[1] V[2] V[3] V[4] V[5])
49+
EnzymeTestUtils.test_reverse(scale!, TC, (fc(C), TC), (fa(A), TA), (α, Tα); atol, rtol)
5150
end
5251
end
5352
end

0 commit comments

Comments
 (0)