-
Notifications
You must be signed in to change notification settings - Fork 59
Allow BraidingTensor to have a custom storage type #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0190c09
d2aaf55
fc0020f
b7339ec
b0aff17
3653185
fc3edc2
26e6959
b8c9d00
6d0049a
6577957
0fd01cf
99ddd05
a268464
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,3 +168,23 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth) | |
| return tf | ||
| end | ||
| end | ||
|
|
||
|
|
||
| function TensorKit.add_kernel_nonthreaded!( | ||
| ::TensorKit.FusionStyle, | ||
| tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend... | ||
| ) | ||
| # preallocate buffers | ||
| buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer) | ||
|
|
||
| for subtransformer in transformer.data | ||
| # Special case without intermediate buffers whenever there is only a single block | ||
| if length(subtransformer[1]) == 1 | ||
| TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) | ||
| else | ||
| cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...) | ||
| TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...) | ||
| end | ||
| end | ||
| return nothing | ||
| end | ||
|
Comment on lines
+171
to
+190
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably is something separate from this PR right? (Also, might no longer be necessary of the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this was needed for a test to work, let me look up which one in a moment... |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,73 +2,85 @@ | |
| # special (2,2) tensor that implements a standard braiding operation | ||
| #====================================================================# | ||
| """ | ||
| struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2} | ||
| BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace} | ||
| struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2} | ||
| BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool=false) where {S<:IndexSpace, A <: DenseVector{<:Number}} | ||
|
|
||
| Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that | ||
| braids the first input over the second input; its inverse can be obtained as the adjoint. | ||
|
|
||
| It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and | ||
| `codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. | ||
| `codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA` | ||
| controls the array type of the braiding tensor used when indexing | ||
| and multiplying with other tensors. | ||
| """ | ||
| struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2} | ||
| struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2} | ||
| V1::S | ||
| V2::S | ||
| adjoint::Bool | ||
| function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace} | ||
| for a in sectors(V1) | ||
| for b in sectors(V2) | ||
| for c in (a ⊗ b) | ||
| Nsymbol(a, b, c) == Nsymbol(b, a, c) || | ||
| throw(ArgumentError("Cannot define a braiding between $a and $b")) | ||
| end | ||
| end | ||
| function BraidingTensor{T, S, A}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}} | ||
| for a in sectors(V1), b in sectors(V2), c in (a ⊗ b) | ||
| Nsymbol(a, b, c) == Nsymbol(b, a, c) || | ||
| throw(ArgumentError("Cannot define a braiding between $a and $b")) | ||
| end | ||
| return new{T, S}(V1, V2, adjoint) | ||
| return new{T, S, A}(V1, V2, adjoint) | ||
| # partial construction: only construct rowr and colr when needed | ||
| end | ||
| end | ||
| function BraidingTensor{T, S}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'm not completely convinced by this constructor syntax, it feels like at this point we might as well just call |
||
| return BraidingTensor{T, S, A}(V1, V2, A, adjoint) | ||
| end | ||
| function BraidingTensor{T}(V1::S, V2::S, A, adjoint::Bool = false) where {T, S <: IndexSpace} | ||
| return BraidingTensor{T, S}(V1, V2, A, adjoint) | ||
| end | ||
| function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace} | ||
| return BraidingTensor{T, S}(V1, V2, adjoint) | ||
| return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint) | ||
| end | ||
| function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, A, adjoint::Bool = false) where {T} | ||
| return BraidingTensor{T}(promote(V1, V2)..., A, adjoint) | ||
| end | ||
| function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T} | ||
| return BraidingTensor{T}(promote(V1, V2)..., adjoint) | ||
| return BraidingTensor{T}(V1, V2, Vector{T}, adjoint) | ||
| end | ||
| function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{A}, adjoint::Bool = false) where {T, A <: DenseVector{T}} | ||
| return BraidingTensor{T}(promote(V1, V2)..., A, adjoint) | ||
| end | ||
| function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{T}, adjoint::Bool = false) where {T} | ||
| return BraidingTensor{T}(promote(V1, V2)..., Vector{T}, adjoint) | ||
| end | ||
| function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) | ||
| return BraidingTensor(promote(V1, V2)..., adjoint) | ||
| end | ||
| function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace} | ||
| T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64 | ||
| return BraidingTensor{T, S}(V1, V2, adjoint) | ||
| return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint) | ||
| end | ||
| function BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {S <: IndexSpace, A <: AbstractArray} | ||
| T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64 | ||
| A′ = similarstoragetype(A, T) | ||
| return BraidingTensor{T, S}(V1, V2, A′, adjoint) | ||
| end | ||
| function BraidingTensor(V::HomSpace, adjoint::Bool = false) | ||
| domain(V) == reverse(codomain(V)) || | ||
| throw(SpaceMismatch("Cannot define a braiding on $V")) | ||
| return BraidingTensor(V[2], V[1], adjoint) | ||
| end | ||
| function BraidingTensor(V::HomSpace, ::Type{A}, adjoint::Bool = false) where {A} | ||
| domain(V) == reverse(codomain(V)) || | ||
| throw(SpaceMismatch("Cannot define a braiding on $V")) | ||
| return BraidingTensor(V[2], V[1], A, adjoint) | ||
| end | ||
| function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T} | ||
| domain(V) == reverse(codomain(V)) || | ||
| throw(SpaceMismatch("Cannot define a braiding on $V")) | ||
| return BraidingTensor{T}(V[2], V[1], adjoint) | ||
| end | ||
| function Base.adjoint(b::BraidingTensor{T, S}) where {T, S} | ||
| return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint) | ||
| function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A} | ||
| return BraidingTensor{T, S, A}(b.V1, b.V2, A, !b.adjoint) | ||
| end | ||
|
|
||
| storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A | ||
| space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2 | ||
|
|
||
| # specializations to ignore the storagetype of BraidingTensor | ||
| promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B) | ||
| promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A) | ||
| promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A) | ||
|
|
||
| promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} = | ||
| similarstoragetype(B, T) | ||
| promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} = | ||
| similarstoragetype(A, T) | ||
| promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} = | ||
| similarstoragetype(A, T) | ||
|
|
||
| function Base.getindex(b::BraidingTensor) | ||
| sectortype(b) === Trivial || throw(SectorMismatch()) | ||
| (V1, V2) = domain(b) | ||
|
|
@@ -99,6 +111,12 @@ function _braiding_factor(f₁, f₂, inv::Bool = false) | |
| return r | ||
| end | ||
|
|
||
| function _set_subblock!(data, val) | ||
| f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val | ||
| data .= f.(CartesianIndices(data)) | ||
| end | ||
|
kshyatt marked this conversation as resolved.
|
||
|
|
||
|
|
||
| @inline function subblock( | ||
| b::BraidingTensor, (f₁, f₂)::Tuple{FusionTree{I, 2}, FusionTree{I, 2}} | ||
| ) where {I <: Sector} | ||
|
|
@@ -115,15 +133,12 @@ end | |
| d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...) | ||
| n1 = d[1] * d[2] | ||
| n2 = d[3] * d[4] | ||
| data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d) | ||
| data_parent = storagetype(b)(undef, prod(d)) | ||
| data = sreshape(StridedView(data_parent), d) | ||
| fill!(data, zero(eltype(b))) | ||
|
|
||
| r = _braiding_factor(f₁, f₂, b.adjoint) | ||
| if !isnothing(r) | ||
| @inbounds for i in axes(data, 1), j in axes(data, 2) | ||
| data[i, j, j, i] = r | ||
| end | ||
| end | ||
| !isnothing(r) && _set_subblock!(data, r) | ||
| return data | ||
| end | ||
|
|
||
|
|
@@ -134,33 +149,20 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b) | |
| Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b) | ||
|
|
||
| Base.complex(b::BraidingTensor{<:Complex}) = b | ||
| function Base.complex(b::BraidingTensor) | ||
| return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint) | ||
| function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A} | ||
| Ac = similarstoragetype(A, complex(T)) | ||
| return BraidingTensor(space(b), Ac, b.adjoint) | ||
| end | ||
|
|
||
| function block(b::BraidingTensor, s::Sector) | ||
| I = sectortype(b) | ||
| I == typeof(s) || throw(SectorMismatch()) | ||
|
|
||
| # TODO: probably always square? | ||
| m = blockdim(codomain(b), s) | ||
| n = blockdim(domain(b), s) | ||
| data = Matrix{eltype(b)}(undef, (m, n)) | ||
|
kshyatt marked this conversation as resolved.
|
||
|
|
||
| length(data) == 0 && return data # s ∉ blocksectors(b) | ||
|
|
||
| data = fill!(data, zero(eltype(b))) | ||
|
|
||
| function _trivial_subblock!(data, b::BraidingTensor) | ||
| V1, V2 = codomain(b) | ||
| if sectortype(b) === Trivial | ||
| d1, d2 = dim(V1), dim(V2) | ||
| subblock = sreshape(StridedView(data), (d1, d2, d2, d1)) | ||
| @inbounds for i in axes(subblock, 1), j in axes(subblock, 2) | ||
| subblock[i, j, j, i] = one(eltype(b)) | ||
| end | ||
| return data | ||
| end | ||
| d1, d2 = dim(V1), dim(V2) | ||
| subblock = sreshape(StridedView(data), (d1, d2, d2, d1)) | ||
| _set_subblock!(subblock, one(eltype(b))) | ||
| return data | ||
| end | ||
|
|
||
| function _nontrivial_subblock!(data, b::BraidingTensor, s::Sector) | ||
| base_offset = first(blockstructure(b)[s][2]) - 1 | ||
|
|
||
| for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b))) | ||
|
|
@@ -169,14 +171,31 @@ function block(b::BraidingTensor, s::Sector) | |
| isnothing(r) && continue | ||
| # change offset to account for single block | ||
| subblock = StridedView(data, sz, str, off - base_offset) | ||
| @inbounds for i in axes(subblock, 1), j in axes(subblock, 2) | ||
| subblock[i, j, j, i] = r | ||
| end | ||
| _set_subblock!(subblock, r) | ||
| end | ||
|
|
||
| return data | ||
| end | ||
|
|
||
| function block(b::BraidingTensor, s::Sector) | ||
| I = sectortype(b) | ||
| I == typeof(s) || throw(SectorMismatch()) | ||
|
|
||
| # TODO: probably always square? | ||
| m = blockdim(codomain(b), s) | ||
| n = blockdim(domain(b), s) | ||
|
|
||
| data = reshape(storagetype(b)(undef, m * n), (m, n)) | ||
|
|
||
| m * n == 0 && return data # s ∉ blocksectors(b) | ||
| fill!(data, zero(eltype(b))) | ||
|
|
||
| if sectortype(b) === Trivial | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mostly out of curiosity, is this an optimization or was this required?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, what's "this"?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad, I meant the split between the trivial and non-trivial implementation.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh it was so that both could dispatch to |
||
| return _trivial_subblock!(data, b) | ||
| else | ||
| return _nontrivial_subblock!(data, b, s) | ||
| end | ||
| end | ||
|
|
||
| # Index manipulations | ||
| # ------------------- | ||
| has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a merge error, and probably belongs in a different project.toml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah no this was here because of a GPUArrays bug that's now fixed