Skip to content

Commit eabfce9

Browse files
committed
Even more small tweaks
1 parent 8665c4a commit eabfce9

2 files changed

Lines changed: 9 additions & 17 deletions

File tree

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,6 @@ function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂,
77
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
88
end
99

10-
#=function TensorKit.TensorMap{T, S₁, N₁, N₂, A}(
11-
::UndefInitializer, space::TensorMapSpace{S₂, N₁, N₂}
12-
) where {T, S₁, S₂ <: TensorKit.ElementarySpace, N₁, N₂, A <: CuVector{T}}
13-
d = TensorKit.fusionblockstructure(space).totaldim
14-
data = A(undef, d)
15-
if !isbitstype(T)
16-
zerovector!(data)
17-
end
18-
return TensorKit.TensorMap{T, S₂, A}(data, space)
19-
end=#
20-
2110
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
2211
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
2312
h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V)
@@ -29,7 +18,7 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr
2918
end
3019

3120
function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S}
32-
return SubArray{T, 1, CuVector{T, CUDA.DeviceMemory}, Tuple{UnitRange{Int}}, true}
21+
return CuMatrix{T, CUDA.DeviceMemory}
3322
end
3423

3524
for (fname, felt) in ((:zeros, :zero), (:ones, :one))

src/tensors/abstracttensor.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t))
5353
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
5454
if T isa Union
5555
# attempt to be slightly more specific by promoting unions
56-
Ma = storagetype(T.a)
57-
Mb = storagetype(T.b)
58-
return promote_storagetype(Ma, Mb)
56+
return promote_storagetype(T.a, T.b)
57+
elseif storagetype(T) isa Union
58+
# attempt to be slightly more specific by promoting unions
59+
TU = storagetype(T)
60+
return promote_storagetype(TU.a, TU.b)
5961
else
6062
# fallback definition by using scalartype
6163
return similarstoragetype(scalartype(T))
@@ -103,8 +105,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
103105

104106
# implement on tensors
105107
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
106-
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
107-
similarstoragetype(storagetype(TT), T)
108+
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
109+
return similarstoragetype(storagetype(TT), T)
110+
end
108111

109112
# implement on arrays
110113
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A

0 commit comments

Comments
 (0)