diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl index 7f448f9e3..bacf256c3 100644 --- a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -12,5 +12,6 @@ using Random: AbstractRNG include("utility.jl") include("linalg.jl") +include("tensoroperations.jl") end diff --git a/ext/TensorKitEnzymeExt/tensoroperations.jl b/ext/TensorKitEnzymeExt/tensoroperations.jl new file mode 100644 index 000000000..e7170b6c5 --- /dev/null +++ b/ext/TensorKitEnzymeExt/tensoroperations.jl @@ -0,0 +1,208 @@ +# tensorcontract! +# --------------- +# TODO: it might be beneficial to compare here if it would make sense to simply compute the +# rrule of permute-permute-gemm-permute, rather than using the contractions directly. +# This could possibly out save some permutations being carried out twice, at the cost of having +# to store some more intermediate objects. +# For example, the combination `ΔC, pΔC, false` appears in the pullback for ΔA and ΔB, so effectively +# this permutation is done multiple times. + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + Ccache = isa(β, Const) ? nothing : copy(C.val) + A_needs_cache = EnzymeRules.overwritten(config)[3] && !(typeof(B) <: Const) && !(typeof(C) <: Const) + Acache = A_needs_cache ? copy(A.val) : nothing + B_needs_cache = EnzymeRules.overwritten(config)[5] && !(typeof(A) <: Const) && !(typeof(C) <: Const) + Bcache = B_needs_cache ? copy(B.val) : nothing + AB = if !isa(α, Const) + AB = TO.tensorcontract(A.val, pA.val, false, B.val, pB.val, false, pAB.val, One(), backend.val, allocator.val) + add!(C.val, AB, α.val, β.val) + AB + else + TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (Ccache, Acache, Bcache, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + cacheC, cacheA, cacheB, AB = cache + Cval = cacheC + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + Δα = pullback_dα(α, C, AB) + Δβ = pullback_dβ(β, C, Cval) + + if !isa(A, Const) + TensorKit.blas_contract_pullback_ΔA!( + A.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + if !isa(B, Const) + TensorKit.blas_contract_pullback_ΔB!( + B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) # this typically returns nothing + return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Annotation{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Annotation{<:Index2Tuple}, + pAB::Annotation{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + if !isa(C, Const) + if isa(β, Const) + scale!(C.dval, β.val) + else + add!(C.dval, C.val, β.dval, β.val) + end + !isa(α, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.val, pB.val, pAB.val, α.dval, One(), backend.val, allocator.val) + !isa(A, Const) && TensorKit.blas_contract!(C.dval, A.dval, pA.val, B.val, pB.val, pAB.val, α.val, One(), backend.val, allocator.val) + !isa(B, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.dval, pB.val, pAB.val, α.val, One(), backend.val, allocator.val) + end + TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end + +# tensortrace! +# ------------ + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache = !isa(β, Const) ? copy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + At = if !isa(α, Const) + At = TO.tensortrace(A.val, p.val, q.val, false, One(), backend.val) + add!(C.val, At, α.val, β.val) + At + else + TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (C_cache, A_cache, At) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache, A_cache, At = cache + Aval = something(A_cache, A.val) + Cval = something(C_cache, C.val) + !isa(A, Const) && !isa(C, Const) && TensorKit.trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val) + Δαr = pullback_dα(α, C, At) + Δβr = pullback_dβ(β, C, Cval) + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Annotation{<:Index2Tuple}, + q::Annotation{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + if !isa(C, Const) + if isa(β, Const) + scale!(C.dval, β.val) + else + add!(C.dval, C.val, β.dval, β.val) + end + !isa(α, Const) && TensorKit.trace_permute!(C.dval, A.val, p.val, q.val, α.dval, One(), backend.val) + !isa(A, Const) && TensorKit.trace_permute!(C.dval, A.dval, p.val, q.val, α.val, One(), backend.val) + end + TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 0e0b1cdf3..342dedef3 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -48,19 +48,19 @@ function Mooncake.rrule!!( function blas_contract_pullback(::NoRData) copy!(C, C_cache) - ΔAr = blas_contract_pullback_ΔA!( + ΔAr = TensorKit.blas_contract_pullback_ΔA!( ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) # this typically returns NoRData() - ΔBr = blas_contract_pullback_ΔB!( + ΔBr = TensorKit.blas_contract_pullback_ΔB!( ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) # this typically returns NoRData() Δαr = pullback_dα(α, ΔC, AB) Δβr = pullback_dβ(β, ΔC, C) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() - return NoRData(), ΔCr, - ΔAr, NoRData(), - ΔBr, NoRData(), + return NoRData(), NoRData(), + NoRData(), NoRData(), + NoRData(), NoRData(), NoRData(), Δαr, Δβr, NoRData(), NoRData() @@ -99,56 +99,6 @@ function Mooncake.frule!!( return C_ΔC end -function blas_contract_pullback_ΔA!( - ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator - ) - ipAB = invperm(linearize(pAB)) - pΔC = TO.repartition(ipAB, pA) - ipA = TO.repartition(invperm(linearize(pA)), numout(A)) - - tB = twist( - B, - TupleTools.vcat( - filter(x -> !isdual(space(B, x)), pB[1]), - filter(x -> isdual(space(B, x)), pB[2]) - ); copy = false - ) - - TK.project_contract!( - ΔA, - ΔC, pΔC, false, - tB, reverse(pB), true, - ipA, conj(α), backend, allocator - ) - - return NoRData() -end - -function blas_contract_pullback_ΔB!( - ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator - ) - ipAB = invperm(linearize(pAB)) - pΔC = TO.repartition(ipAB, pA) - ipB = TO.repartition(invperm(linearize(pB)), numout(B)) - - tA = twist( - A, - TupleTools.vcat( - filter(x -> isdual(space(A, x)), pA[1]), - filter(x -> !isdual(space(A, x)), pA[2]) - ); copy = false - ) - - TK.project_contract!( - ΔB, - tA, reverse(pA), true, - ΔC, pΔC, false, - ipB, conj(α), backend, allocator - ) - - return NoRData() -end - # tensortrace! # ------------ @is_primitive( @@ -191,14 +141,14 @@ function Mooncake.rrule!!( function trace_permute_pullback(::NoRData) copy!(C, C_cache) - ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) # this typically returns NoRData() + ΔAr = TensorKit.trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) # this typically returns NoRData() Δαr = pullback_dα(α, ΔC, At) Δβr = pullback_dβ(β, ΔC, C) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), - ΔCr, ΔAr, NoRData(), NoRData(), + NoRData(), NoRData(), NoRData(), NoRData(), Δαr, Δβr, NoRData() end @@ -236,21 +186,6 @@ function Mooncake.frule!!( return C_ΔC end -function trace_permute_pullback_ΔA!( - ΔA, ΔC, A, p, q, α, backend - ) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - pdA = TO.repartition(ip, numout(A)) - E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) - twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - pE = ((), TO.trivialpermutation(TO.numind(q))) - pΔC = (TO.trivialpermutation(TO.numind(p)), ()) - TO.tensorproduct!( - ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend - ) - return NoRData() -end - @is_primitive( DefaultCtx, Tuple{ diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 0622d6e6a..b40ee5163 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -284,5 +284,6 @@ include("planar/planaroperations.jl") # once all types have been declared # ------------------------ include("auxiliary/ad.jl") +include("pullbacks/tensoroperations.jl") end diff --git a/src/auxiliary/ad.jl b/src/auxiliary/ad.jl index 14879768d..158f0f8de 100644 --- a/src/auxiliary/ad.jl +++ b/src/auxiliary/ad.jl @@ -9,8 +9,8 @@ function project_mul!(C, A, B, α, β = One()) end end function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) - TA = TensorKit.promote_permute(A) - TB = TensorKit.promote_permute(B) + TA = promote_permute(A) + TB = promote_permute(B) TC = TO.promote_contract(TA, TB, scalartype(α)) return if scalartype(C) <: Real && !(TC <: Real) diff --git a/src/pullbacks/tensoroperations.jl b/src/pullbacks/tensoroperations.jl new file mode 100644 index 000000000..64507861f --- /dev/null +++ b/src/pullbacks/tensoroperations.jl @@ -0,0 +1,63 @@ +function blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = TO.repartition(ipAB, TO.numout(pA)) + ipA = TO.repartition(invperm(linearize(pA)), numout(A)) + tB = twist( + B, + vcat( + [i for i in pB[1] if !isdual(space(B, i))], + [i for i in pB[2] if isdual(space(B, i))] + ); copy = false + ) + + project_contract!( + ΔA, + ΔC, pΔC, false, + tB, reverse(pB), true, + ipA, conj(α), backend, allocator + ) + + return nothing +end + +function blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = TO.repartition(ipAB, TO.numout(pA)) + ipB = TO.repartition(invperm(linearize(pB)), numout(B)) + + tA = twist( + A, + vcat( + [i for i in pA[1] if isdual(space(A, i))], + [i for i in pA[2] if !isdual(space(A, i))] + ); copy = false + ) + + project_contract!( + ΔB, + tA, reverse(pA), true, + ΔC, pΔC, false, + ipB, conj(α), backend, allocator + ) + + return nothing +end + +function trace_permute_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = TO.repartition(ip, numout(A)) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), TO.trivialpermutation(TO.numind(q))) + pΔC = (TO.trivialpermutation(TO.numind(p)), ()) + TO.tensorproduct!( + ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend + ) + return nothing +end diff --git a/test/enzyme-tensoroperations/contract.jl b/test/enzyme-tensoroperations/contract.jl new file mode 100644 index 000000000..77947c194 --- /dev/null +++ b/test/enzyme-tensoroperations/contract.jl @@ -0,0 +1,139 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Enzyme, EnzymeTestUtils + +is_ci = get(ENV, "CI", "false") == "true" + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - TensorOperations" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + println(TensorKit.type_repr(sectortype(eltype(V)))) # just some printing for test purposes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + symmetricbraiding && @timedtestset "tensorcontract!" begin + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + + αβs = is_ci ? (((α, Active), (β, Active)),) : Iterators.product(((One(), Const), (α, Const), (α, Active)), ((Zero(), Const), (β, Const), (β, Active))) + for (α_, β_) in αβs + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + α_, β_, + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse α $α_ β $β_", + ) + end + αβs = is_ci ? (((α, Duplicated), (β, Duplicated)),) : Iterators.product(((One(), Const), (α, Const), (α, Duplicated)), ((Zero(), Const), (β, Const), (β, Duplicated))) + for (α_, β_) in αβs + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + α_, β_, + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward α $α_ β $β_", + ) + end + if !(T <: Real) && !is_ci + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse real(α) real(β)", + ) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (real(A), Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse real(A) real(α) real(β)", + ) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (real(B), Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! reverse real(B) real(α) real(β)", + ) + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward real(α) real(β)", + ) + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (real(A), Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward real(A) real(α) real(β)", + ) + EnzymeTestUtils.test_forward( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (real(B), Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! forward real(B) real(α) real(β)", + ) + end + end + end +end diff --git a/test/enzyme-tensoroperations/trace.jl b/test/enzyme-tensoroperations/trace.jl new file mode 100644 index 000000000..3ff8fad43 --- /dev/null +++ b/test/enzyme-tensoroperations/trace.jl @@ -0,0 +1,61 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Enzyme, EnzymeTestUtils + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" +rTαs = is_ci ? (Active,) : (Const, Active) +rTβs = is_ci ? (Active,) : (Const, Active) +fTαs = is_ci ? (Duplicated,) : (Const, Duplicated) +fTβs = is_ci ? (Duplicated,) : (Const, Duplicated) +TCs = is_ci ? (Duplicated,) : (Const, Duplicated) +TAs = is_ci ? (Duplicated,) : (Const, Duplicated) + +@timedtestset "Enzyme - TensorOperations (trace)" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + println(TensorKit.type_repr(sectortype(eltype(V)))) # just some printing for test purposes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + symmetricbraiding && @timedtestset "trace_permute!" begin + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + for TC in TCs, TA in TAs + for Tα in rTαs, Tβ in rTβs + EnzymeTestUtils.test_reverse( + TensorKit.trace_permute!, TC, + (copy(C), TC), (A, TA), (p, Const), (q, Const), + (α, Tα), (β, Tβ), (TensorOperations.DefaultBackend(), Const); + atol, rtol, + testset_name = "trace_permute! reverse TC $TC TA $TA Tα $Tα Tβ $Tβ", + ) + end + for Tα in fTαs, Tβ in fTβs + EnzymeTestUtils.test_forward( + TensorKit.trace_permute!, TC, + (copy(C), TC), (A, TA), (p, Const), (q, Const), + (α, Tα), (β, Tβ), (TensorOperations.DefaultBackend(), Const); + atol, rtol, + testset_name = "trace_permute! forward TC $TC TA $TA Tα $Tα Tβ $Tβ", + ) + end + end + end + end +end