Skip to content

Commit 1823327

Browse files
author
Katharine Hyatt
committed
Some CUDA fixes
1 parent bf7a376 commit 1823327

1 file changed

Lines changed: 7 additions & 11 deletions

File tree

test/cuda/tensors.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -532,16 +532,10 @@ for V in spacelist
532532
# TODO
533533
@timedtestset "Tensor product: test via norm preservation" begin
534534
for T in (ComplexF64,) # Float32 case broken because of cuTENSOR
535-
@time "Construction" begin
536-
t1 = CUDA.rand(T, V1, V5')
537-
t2 = CUDA.rand(T, V2 V3, V4')
538-
end
539-
@time "Product" begin
540-
t = @constinferred (t1 t2)
541-
end
542-
@time "Norm" begin
543-
@test norm(t) norm(t1) * norm(t2)
544-
end
535+
t1 = CUDA.rand(T, V1, V5')
536+
t2 = CUDA.rand(T, V2 V3, V4')
537+
t = @constinferred (t1 t2)
538+
@test norm(t) norm(t1) * norm(t2)
545539
end
546540
end
547541
symmetricbraiding && @timedtestset "Tensor product: test via conversion" begin
@@ -562,7 +556,9 @@ for V in spacelist
562556
t1 = CUDA.rand(T, V1, V5')
563557
t2 = CUDA.rand(T, V2 V3, V4')
564558
t = @constinferred (t1 t2)
565-
@tensor t′[1 2 3; 4 5] := t1[1; 4] * t2[2 3; 5]
559+
CUDA.@allowscalar begin
560+
@tensor t′[1 2 3; 4 5] := t1[1; 4] * t2[2 3; 5]
561+
end
566562
@test t t′ # This should really not be broken
567563
end
568564
end

0 commit comments

Comments
 (0)