Skip to content
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ TensorKitMooncakeExt = "Mooncake"
[workspace]
projects = ["test"]

[extras]
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a merge error, and probably belongs in a different project.toml?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no this was here because of a GPUArrays bug that's now fixed


[compat]
Adapt = "4"
CUDA = "5.9"
ChainRulesCore = "1"
Dictionaries = "0.4"
FiniteDifferences = "0.12"
GPUArrays = "<11.5.0"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.5"
Expand Down
7 changes: 5 additions & 2 deletions ext/TensorKitAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
data′ = adapt(to, x.data)
return DiagonalTensorMap(data′, x.domain)
end
function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}}
return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint)
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
return BraidingTensor(space(x), TensorKit.similarstoragetype(A, T), x.adjoint)
end
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {TA <: DenseArray{<:Number}, T, S, A}
return BraidingTensor(space(x), TA, x.adjoint)
end

end
18 changes: 17 additions & 1 deletion ext/TensorKitCUDAExt/TensorKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ module TensorKitCUDAExt
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
using CUDA: @allowscalar
using cuTENSOR: cuTENSOR
using Strided: StridedViews
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!

using CUDA: KernelAbstractions
using CUDA.KernelAbstractions: @kernel, @index

using TensorKit
using TensorKit.Factorizations
using TensorKit.Strided
using TensorKit.Factorizations: AbstractAlgorithm
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
import TensorKit: randisometry, rand, randn
import TensorKit: randisometry, rand, randn, _set_subblock!

using TensorKit: MatrixAlgebraKit

Expand All @@ -19,4 +23,16 @@ using Random
include("cutensormap.jl")
include("truncation.jl")

function TensorKit._set_subblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
@kernel function fill_subblock_kernel!(subblock, val)
idx = @index(Global, Cartesian)
@inbounds subblock[idx[1], idx[2], idx[2], idx[1]] = val
end
kernel = fill_subblock_kernel!(KernelAbstractions.get_backend(data))
d1 = size(data, 1)
d2 = size(data, 2)
kernel(data, val; ndrange = (d1, d2))
return data
end

end
20 changes: 20 additions & 0 deletions ext/TensorKitCUDAExt/cutensormap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,23 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
return tf
end
end


function TensorKit.add_kernel_nonthreaded!(
::TensorKit.FusionStyle,
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
)
# preallocate buffers
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)

for subtransformer in transformer.data
# Special case without intermediate buffers whenever there is only a single block
if length(subtransformer[1]) == 1
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
else
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
end
end
return nothing
end
Comment on lines +171 to +190
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably is something separate from this PR right? (Also, might no longer be necessary of the mul! specializations of StridedViews are correctly handled in the later versions of Strided.jl

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this was needed for a test to work, let me look up which one in a moment...

145 changes: 82 additions & 63 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,85 @@
# special (2,2) tensor that implements a standard braiding operation
#====================================================================#
"""
struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2}
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool=false) where {S<:IndexSpace, A <: DenseVector{<:Number}}

Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
braids the first input over the second input; its inverse can be obtained as the adjoint.

It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA`
controls the array type of the braiding tensor used when indexing
and multiplying with other tensors.
"""
struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2}
struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
V1::S
V2::S
adjoint::Bool
function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
for a in sectors(V1)
for b in sectors(V2)
for c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
end
function BraidingTensor{T, S, A}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
for a in sectors(V1), b in sectors(V2), c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
return new{T, S}(V1, V2, adjoint)
return new{T, S, A}(V1, V2, adjoint)
# partial construction: only construct rowr and colr when needed
end
end
function BraidingTensor{T, S}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm not completely convinced by this constructor syntax, it feels like at this point we might as well just call BraidingTensor{T, S, A} directly, rather than including the Type{A} as an argument. This also just cuts down slightly on the number of constructors we need to have, under the less-is-more strategy :)

return BraidingTensor{T, S, A}(V1, V2, A, adjoint)
end
function BraidingTensor{T}(V1::S, V2::S, A, adjoint::Bool = false) where {T, S <: IndexSpace}
return BraidingTensor{T, S}(V1, V2, A, adjoint)
end
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
return BraidingTensor{T, S}(V1, V2, adjoint)
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, A, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
end
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., adjoint)
return BraidingTensor{T}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{A}, adjoint::Bool = false) where {T, A <: DenseVector{T}}
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{T}, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., Vector{T}, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
return BraidingTensor(promote(V1, V2)..., adjoint)
end
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
return BraidingTensor{T, S}(V1, V2, adjoint)
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {S <: IndexSpace, A <: AbstractArray}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
A′ = similarstoragetype(A, T)
return BraidingTensor{T, S}(V1, V2, A′, adjoint)
end
function BraidingTensor(V::HomSpace, adjoint::Bool = false)
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[2], V[1], adjoint)
end
function BraidingTensor(V::HomSpace, ::Type{A}, adjoint::Bool = false) where {A}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[2], V[1], A, adjoint)
end
function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor{T}(V[2], V[1], adjoint)
end
function Base.adjoint(b::BraidingTensor{T, S}) where {T, S}
return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint)
function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A}
return BraidingTensor{T, S, A}(b.V1, b.V2, A, !b.adjoint)
end

storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A
space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2

# specializations to ignore the storagetype of BraidingTensor
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A)

promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} =
similarstoragetype(B, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} =
similarstoragetype(A, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} =
similarstoragetype(A, T)

function Base.getindex(b::BraidingTensor)
sectortype(b) === Trivial || throw(SectorMismatch())
(V1, V2) = domain(b)
Expand Down Expand Up @@ -99,6 +111,12 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
return r
end

function _set_subblock!(data, val)
f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val
data .= f.(CartesianIndices(data))
end
Comment thread
kshyatt marked this conversation as resolved.


@inline function subblock(
b::BraidingTensor, (f₁, f₂)::Tuple{FusionTree{I, 2}, FusionTree{I, 2}}
) where {I <: Sector}
Expand All @@ -115,15 +133,12 @@ end
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
n1 = d[1] * d[2]
n2 = d[3] * d[4]
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
data_parent = storagetype(b)(undef, prod(d))
data = sreshape(StridedView(data_parent), d)
fill!(data, zero(eltype(b)))

r = _braiding_factor(f₁, f₂, b.adjoint)
if !isnothing(r)
@inbounds for i in axes(data, 1), j in axes(data, 2)
data[i, j, j, i] = r
end
end
!isnothing(r) && _set_subblock!(data, r)
return data
end

Expand All @@ -134,33 +149,20 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b)
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)

Base.complex(b::BraidingTensor{<:Complex}) = b
function Base.complex(b::BraidingTensor)
return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint)
function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A}
Ac = similarstoragetype(A, complex(T))
return BraidingTensor(space(b), Ac, b.adjoint)
end

function block(b::BraidingTensor, s::Sector)
I = sectortype(b)
I == typeof(s) || throw(SectorMismatch())

# TODO: probably always square?
m = blockdim(codomain(b), s)
n = blockdim(domain(b), s)
data = Matrix{eltype(b)}(undef, (m, n))
Comment thread
kshyatt marked this conversation as resolved.

length(data) == 0 && return data # s ∉ blocksectors(b)

data = fill!(data, zero(eltype(b)))

function _trivial_subblock!(data, b::BraidingTensor)
V1, V2 = codomain(b)
if sectortype(b) === Trivial
d1, d2 = dim(V1), dim(V2)
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
subblock[i, j, j, i] = one(eltype(b))
end
return data
end
d1, d2 = dim(V1), dim(V2)
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
_set_subblock!(subblock, one(eltype(b)))
return data
end

function _nontrivial_subblock!(data, b::BraidingTensor, s::Sector)
base_offset = first(blockstructure(b)[s][2]) - 1

for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b)))
Expand All @@ -169,14 +171,31 @@ function block(b::BraidingTensor, s::Sector)
isnothing(r) && continue
# change offset to account for single block
subblock = StridedView(data, sz, str, off - base_offset)
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
subblock[i, j, j, i] = r
end
_set_subblock!(subblock, r)
end

return data
end

function block(b::BraidingTensor, s::Sector)
I = sectortype(b)
I == typeof(s) || throw(SectorMismatch())

# TODO: probably always square?
m = blockdim(codomain(b), s)
n = blockdim(domain(b), s)

data = reshape(storagetype(b)(undef, m * n), (m, n))

m * n == 0 && return data # s ∉ blocksectors(b)
fill!(data, zero(eltype(b)))

if sectortype(b) === Trivial
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly out of curiosity, is this an optimization or was this required?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what's "this"?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I meant the split between the trivial and non-trivial implementation.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it was so that both could dispatch to set_sublock (which I could specialize) but do their own setup, and to cleanup a quite long function.

return _trivial_subblock!(data, b)
else
return _nontrivial_subblock!(data, b, s)
end
end

# Index manipulations
# -------------------
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
Expand Down
Loading
Loading