Skip to content

Commit c136a58

Browse files
committed
change blocktype of TensorMap to StridedView
1 parent 2631631 commit c136a58

2 files changed

Lines changed: 19 additions & 12 deletions

File tree

src/spaces/homspace.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int}
299299

300300
struct FusionBlockStructure{I, N, F₁, F₂}
301301
totaldim::Int
302-
blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}
302+
blockstructure::SectorDict{I, StridedStructure{2}}
303303
fusiontreelist::Vector{Tuple{F₁, F₂}}
304304
fusiontreestructure::Vector{StridedStructure{N}}
305305
fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int}
@@ -325,9 +325,9 @@ end
325325
F₂ = fusiontreetype(I, N₂)
326326

327327
# output structure
328-
blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() # size, range
328+
blockstructure = SectorDict{I, StridedStructure{2}}() # size, strides, offset
329329
fusiontreelist = Vector{Tuple{F₁, F₂}}()
330-
fusiontreestructure = Vector{Tuple{NTuple{N₁ + N₂, Int}, NTuple{N₁ + N₂, Int}, Int}}() # size, strides, offset
330+
fusiontreestructure = Vector{StridedStructure{N₁ + N₂}}() # size, strides, offset
331331

332332
# temporary data structures
333333
splittingtrees = Vector{F₁}()
@@ -367,8 +367,8 @@ end
367367
blocksize = (blockdim₁, blockdim₂)
368368
blocklength = blockdim₁ * blockdim₂
369369
blockrange = (blockoffset + 1):(blockoffset + blocklength)
370+
blockstructure[c] = (blocksize, strides, blockoffset)
370371
blockoffset = last(blockrange)
371-
blockstructure[c] = (blocksize, blockrange)
372372
end
373373

374374
fusiontreeindices = sizehint!(

src/tensors/tensor.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -453,28 +453,35 @@ blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)
453453
function blocktype(::Type{TT}) where {TT <: TensorMap}
454454
A = storagetype(TT)
455455
T = eltype(A)
456-
return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}
456+
@static if isdefined(Core, :Memory) # StridedViews normalizes parent types!
457+
if A <: Vector{T}
458+
A = GenericMemory{T}
459+
end
460+
end
461+
return StridedView{T, 2, A, typeof(identity)}
457462
end
458463

459464
function Base.iterate(iter::BlockIterator{<:TensorMap}, state...)
460465
next = iterate(iter.structure, state...)
461466
isnothing(next) && return next
462-
(c, (sz, r)), newstate = next
463-
return c => reshape(view(iter.t.data, r), sz), newstate
467+
(c, (sz, str, offset)), newstate = next
468+
return c => StridedView(iter.t.data, sz, str, offset), newstate
464469
end
465470

466471
function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector)
467472
sectortype(iter.t) === typeof(c) || throw(SectorMismatch())
468-
(d₁, d₂), r = get(iter.structure, c) do
469-
# is s is not a key, at least one of the two dimensions will be zero:
473+
(d₁, d₂), (s₁, s₂), offset = get(iter.structure, c) do
474+
# is c is not a key, at least one of the two dimensions will be zero:
470475
# it then does not matter where exactly we construct a view in `t.data`,
471476
# as it will have length zero anyway
472477
d₁′ = blockdim(codomain(iter.t), c)
473478
d₂′ = blockdim(domain(iter.t), c)
474-
l = d₁′ * d₂′
475-
return (d₁′, d₂′), 1:l
479+
s₁ = 1
480+
s₂ = 0
481+
offset = 0
482+
return (d₁′, d₂′), (s₁, s₂), offset
476483
end
477-
return reshape(view(iter.t.data, r), (d₁, d₂))
484+
return StridedView(iter.t.data, (d₁, d₂), (s₁, s₂), offset)
478485
end
479486

480487
# Getting and setting the data at the subblock level

0 commit comments

Comments
 (0)