Skip to content

Commit 5b04ca9

Browse files
committed
Apply Lukas' suggestions
1 parent a158e8c commit 5b04ca9

2 files changed

Lines changed: 2 additions & 29 deletions

File tree

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,30 +105,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
105105
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
106106
end
107107

108-
function Base.convert(
109-
TT::Type{TensorMap{T, S, N₁, N₂, A}},
110-
t::TensorMap{T, S, N₁, N₂, AA}
111-
) where {T, S, N₁, N₂, A <: CuArray{T}, AA}
112-
if typeof(t) === TT
113-
return t
114-
else
115-
tnew = TT(undef, space(t))
116-
return copy!(tnew, t)
117-
end
118-
end
119-
120-
function Base.convert(
121-
TT::Type{TensorMap{T, S, N₁, N₂, A}},
122-
t::AdjointTensorMap
123-
) where {T, S, N₁, N₂, A <: CuArray{T}}
124-
if typeof(t) === TT
125-
return t
126-
else
127-
tnew = TT(undef, space(t))
128-
return copy!(tnew, t)
129-
end
130-
end
131-
132108
function LinearAlgebra.isposdef(t::CuTensorMap)
133109
domain(t) == codomain(t) ||
134110
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
@@ -154,11 +130,8 @@ function Base.promote_rule(
154130
return CuTensorMap{T, S, N₁, N₂}
155131
end
156132

157-
TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
133+
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
158134
CuArray{T, N, CUDA.default_memory}
159-
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{CuArray{T, N}}) where {T, N} =
160-
CuArray{T, N, CUDA.default_memory}
161-
162135

163136
# CuTensorMap exponentation:
164137
function TensorKit.exp!(t::CuTensorMap)

src/tensors/braidingtensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ function add_transform!(
176176
fusiontreetransform,
177177
α::Number, β::Number, backend::AbstractBackend...
178178
) where {T, S}
179-
tsrc_map = TensorMapWithStorage{scalartype(tdst), storagetype(tdst)}(undef, (tsrc.V2 tsrc.V1) (tsrc.V1 tsrc.V2))
179+
tsrc_map = similar(tdst, storagetype(tdst), space(tsrc))
180180
copy!(tsrc_map, tsrc)
181181
return add_transform!(
182182
tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β,

0 commit comments

Comments
 (0)