Skip to content

Commit 18aa158

Browse files
committed
Ensure blocktype is correctly inferred for CuArray
1 parent c53e762 commit 18aa158

2 files changed

Lines changed: 2 additions & 5 deletions

File tree

src/tensors/abstracttensor.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,6 @@ end
313313
#------------------------------------------------------------
314314
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
315315

316-
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
317-
318316
numout(t::AbstractTensorMap) = numout(typeof(t))
319317
numin(t::AbstractTensorMap) = numin(typeof(t))
320318
numind(t::AbstractTensorMap) = numind(typeof(t))
@@ -441,6 +439,7 @@ See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasb
441439
442440
Return the type of the matrix blocks of a tensor.
443441
""" blocktype
442+
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
444443
function blocktype(::Type{T}) where {T <: AbstractTensorMap}
445444
return Core.Compiler.return_type(block, Tuple{T, sectortype(T)})
446445
end

src/tensors/tensor.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,7 @@ block(t::TensorMap, c::Sector) = blocks(t)[c]
455455

456456
blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)
457457

458-
function blocktype(::Type{TT}) where {TT <: TensorMap}
459-
A = storagetype(TT)
460-
T = eltype(A)
458+
function blocktype(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}}
461459
return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}
462460
end
463461

0 commit comments

Comments
 (0)