diff --git a/src/indexnotation/binarytree.jl b/src/indexnotation/binarytree.jl new file mode 100644 index 00000000..9e6d2526 --- /dev/null +++ b/src/indexnotation/binarytree.jl @@ -0,0 +1,25 @@ +struct BinaryTreeNode + left + right +end + +Base.show(io::IO, blk::BinaryTreeNode) = show(io, "plain/text", blk) +function Base.show(io::IO, ::MIME"plain/text", blk::BinaryTreeNode) + print(io, "Contraction Tree: ") + print_tree(io, blk) +end + +function print_tree(io::IO, blk::BinaryTreeNode, print_level=0) + print(io, "(") + print_tree(io, blk.left, print_level+1) + print(io, " ↔ ") + print_tree(io, blk.right, print_level+1) + print(io, ")") +end + +function print_tree(io::IO, blk, print_level=0) + print(io, blk) +end + +Base.getindex(t::BinaryTreeNode, i::Int) = i==1 ? t.left : (i==2 ? t.right : throw(BoundsError(t, i))) +Base.iterate(t::BinaryTreeNode, args...) = iterate((t.left, t.right), args...) diff --git a/src/indexnotation/optimaltree.jl b/src/indexnotation/optimaltree.jl index 5992d400..82a20aa3 100644 --- a/src/indexnotation/optimaltree.jl +++ b/src/indexnotation/optimaltree.jl @@ -1,3 +1,5 @@ +include("binarytree.jl") + function optimaltree(network, optdata::Dict; verbose::Bool = false) numtensors = length(network) allindices = unique(vcat(network...)) @@ -165,7 +167,7 @@ function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initi if cost <= get(costdict[n], s, currentcost) costdict[n][s] = cost indexdict[n][s] = _setdiff(_union(ind1,ind2), cind) - treedict[n][s] = (treedict[k][s1], treedict[n-k][s2]) + treedict[n][s] = BinaryTreeNode(treedict[k][s1], treedict[n-k][s2]) elseif currentcost < cost < nextcost nextcost = cost end @@ -192,7 +194,7 @@ function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initi if cost <= get(costdict[n], s, currentcost) costdict[n][s] = cost indexdict[n][s] = _setdiff(_union(ind1,ind2), cind) - treedict[n][s] = (treedict[k][s1], treedict[k][s2]) + treedict[n][s] = BinaryTreeNode(treedict[k][s1], treedict[k][s2]) elseif currentcost < cost < nextcost nextcost = cost end @@ -226,7 +228,7 @@ function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initi cost = costlist[p[1]] ind = indexlist[p[1]] for c = 2:numcomponent - tree = (tree, treelist[p[c]]) + tree = BinaryTreeNode(tree, treelist[p[c]]) cost = addcost(cost, costlist[p[c]], computecost(allcosts, ind, indexlist[p[c]])) ind = _union(ind, indexlist[p[c]]) end diff --git a/test/tensoropt.jl b/test/tensoropt.jl index 2a4e5833..6cadd0e1 100644 --- a/test/tensoropt.jl +++ b/test/tensoropt.jl @@ -1,3 +1,6 @@ +using TensorOperations, Test +using TensorOperations: BinaryTreeNode + @testset "Optimal contraction order" begin _,cost = @optimalcontractiontree A[-1,1,2,3]*B[2,4,5,6]*C[1,5,7,-3]*D[3,8,4,9]* E[6,9,7,10]*F[-2,8,11,12]*G[10,11,12,-4] @@ -22,3 +25,10 @@ aa[89,90,45,46,51,52,57,58,59,-1,64,65,66,67,68,69,70,71] @test cost == TensorOperations.Poly{:χ,Int}([0,0,0,0,0,0,0,4,4,0,0,0,1,1,1,0,1,0,0,0,3,0,3,2,0,2,4]) end + +@testset "BinaryTree" begin + t = BinaryTreeNode(BinaryTreeNode(3, 2), 4) + l, r = t + @test l == BinaryTreeNode(3, 2) == t[1] + @test r == 4 == t[2] +end