Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
version = "0.17.0"
authors = ["Jutho Haegeman, Lukas Devos"]

[workspace]
projects = ["test", "docs"]

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
Expand Down Expand Up @@ -39,12 +42,9 @@ TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitGPUArraysExt = "GPUArrays"
TensorKitMooncakeExt = "Mooncake"

[workspace]
projects = ["test", "docs"]

[compat]
Adapt = "4"
AMDGPU = "2"
Adapt = "4"
CUDA = "6"
ChainRulesCore = "1"
Dictionaries = "0.4"
Expand All @@ -62,7 +62,7 @@ Random = "1"
ScopedValues = "1.3.0"
Strided = "2.6.1"
TensorKitSectors = "0.3.7"
TensorOperations = "5.5.2"
TensorOperations = "5.5.2, 5.6"
TupleTools = "1.5"
VectorInterface = "0.6"
julia = "1.10"
7 changes: 6 additions & 1 deletion test/enzyme-linalg/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ fRTs = is_ci ? (Duplicated,) : (Const, Duplicated)
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T), TC $TC" for V in spacelist, T in eltypes, TC in (Const, Duplicated)
atol = default_tol(T)
rtol = default_tol(T)
C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
# see https://github.com/QuantumKitHub/TensorKit.jl/issues/457
@static if VERSION < v"1.11.0-rc"
C = randn(T, V[1] ⊗ V[2] ← (V[4] ⊗ V[5])')
else
C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
end
for RT in rRTs
EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol)
EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol)
Expand Down
51 changes: 51 additions & 0 deletions test/enzyme-vectorinterface/add.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using Test, TestExtras
using TensorKit, Enzyme, EnzymeTestUtils
using TensorOperations
using Random

#spacelist = ad_spacelist(fast_tests)
spacelist = [ad_spacelist(fast_tests)[1]]
eltypes = (Float64, ComplexF64)

@testset "Enzyme - VectorInterface (add!) $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
atol = default_tol(T)
rtol = default_tol(T)

α = randn(T)
β = randn(T)

# see https://github.com/QuantumKitHub/TensorKit.jl/issues/457
if VERSION < v"1.11.0-rc" && sectortype(eltype(V)) == Trivial
CV = V[1] ⊗ V[2] ← V[4] ⊗ V[5]
else
CV = V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]
end
C = randn(T, CV)
A = randn(T, CV)
for TC in (Duplicated,), TA in (Duplicated,)
C = randn(T, CV)
A = randn(T, CV)
EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA); atol, rtol, testset_name = "add! reverse TC $TC TA $TA no α no β")
EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA); atol, rtol, testset_name = "add! forward TC $TC TA $TA no α no β")
for Tα in (Active, Const)
C = randn(T, CV)
A = randn(T, CV)
EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol, testset_name = "add! reverse TC $TC TA $TA Tα $Tα no β")
for Tβ in (Active, Const)
C = randn(T, CV)
A = randn(T, CV)
EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add! reverse TC $TC TA $TA Tα $Tα Tβ $Tβ")
end
end
for Tα in (Duplicated, Const)
C = randn(T, CV)
A = randn(T, CV)
EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol, testset_name = "add! forward TC $TC TA $TA Tα $Tα no β")
for Tβ in (Duplicated, Const)
C = randn(T, CV)
A = randn(T, CV)
EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add! forward TC $TC TA $TA Tα $Tα Tβ $Tβ")
end
end
end
end
31 changes: 31 additions & 0 deletions test/enzyme-vectorinterface/inner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using Enzyme, EnzymeTestUtils
using Random, FiniteDifferences

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@testset "Enzyme - VectorInterface" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
@testset for TC in (Duplicated,), TA in (Duplicated,), f in (identity, adjoint)
atol = default_tol(T)
rtol = default_tol(T)
# see https://github.com/QuantumKitHub/TensorKit.jl/issues/457
if VERSION < v"1.11.0-rc" && sectortype(eltype(V)) == Trivial
CV = V[1] ⊗ V[2] ← V[4] ⊗ V[5]
else
CV = V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]
end
C = randn(T, CV)
A = randn(T, CV)
for RT in (Active, Const)
EnzymeTestUtils.test_reverse(inner, RT, (f(C), TC), (f(A), TA); atol, rtol)
end
for RT in (Duplicated, Const)
EnzymeTestUtils.test_forward(inner, RT, (f(C), TC), (f(A), TA); atol, rtol)
end
end
end
end
46 changes: 46 additions & 0 deletions test/enzyme-vectorinterface/scale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@testset "Enzyme - VectorInterface (scale!)" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
atol = default_tol(T)
rtol = default_tol(T)
α = randn(T)
# see https://github.com/QuantumKitHub/TensorKit.jl/issues/457
if VERSION < v"1.11.0-rc" && sectortype(eltype(V)) == Trivial
CV = V[1] ⊗ V[2] ← V[4] ⊗ V[5]
else
CV = V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]
end
@testset for TC in (Duplicated,)
for Tα in (Active, Const)
C = randn(T, CV)
EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (α, Tα); atol, rtol)
C = randn(T, CV)
EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (α, Tα); atol, rtol)
@testset for TA in (Duplicated,), (fc, fa) in ((identity, identity), (adjoint, adjoint))
C = randn(T, CV)
A = randn(T, CV)
EnzymeTestUtils.test_reverse(scale!, TC, (fc(C), TC), (fa(A), TA), (α, Tα); atol, rtol)
end
end
for Tα in (Duplicated, Const)
C = randn(T, CV)
EnzymeTestUtils.test_forward(scale!, TC, (C, TC), (α, Tα); atol, rtol)
C = randn(T, CV)
EnzymeTestUtils.test_forward(scale!, TC, (C', TC), (α, Tα); atol, rtol)
@testset for TA in (Duplicated,), (fc, fa) in ((identity, identity), (adjoint, adjoint))
C = randn(T, CV)
A = randn(T, CV)
EnzymeTestUtils.test_forward(scale!, TC, (fc(C), TC), (fa(A), TA), (α, Tα); atol, rtol)
end
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
if (Sys.isapple() && get(ENV, "CI", "false") == "true") || !isempty(VERSION.prerelease)
filter!(!startswith("chainrules") ∘ first, testsuite)
filter!(!startswith("mooncake") ∘ first, testsuite)
filter!(!startswith("enzyme") ∘ first, testsuite)
end

args = parse_args(ARGS; custom = ["fast"])
Expand Down
Loading