From bbf49a7c532fa9fc3caa6cbc6503303c094edc71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jinguo=20Liu=20=28=E5=88=98=E9=87=91=E5=9B=BD=29?= Date: Fri, 5 Apr 2024 00:43:17 +0800 Subject: [PATCH] Visualization (#1) * new visualization * update * update CI * fix test * fix test --- .github/workflows/CI.yml | 38 +++++++++ .github/workflows/CompatHelper.yml | 16 ++++ Project.toml | 5 +- README.md | 3 +- src/SimpleTDVP.jl | 4 + src/visualize.jl | 125 +++++++++++++++++++++++++++++ test/visualize.jl | 9 +++ 7 files changed, 198 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/CI.yml create mode 100644 .github/workflows/CompatHelper.yml create mode 100644 src/visualize.jl create mode 100644 test/visualize.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 0000000..d84626b --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,38 @@ +name: CI +on: + push: + branches: + - main + tags: '*' + pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000..cba9134 --- /dev/null +++ b/.github/workflows/CompatHelper.yml @@ -0,0 +1,16 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/Project.toml b/Project.toml index 6f20da9..f83a931 100644 --- a/Project.toml +++ b/Project.toml @@ -7,17 +7,20 @@ version = "1.0.0-DEV" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" +LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" [compat] KrylovKit = "0.7" LinearAlgebra = "1" LinearOperators = "2" +LuxorGraphPlot = "0.3" OMEinsum = "0.8" julia = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [targets] -test = ["Test"] +test = ["Test", "Random"] diff --git a/README.md b/README.md index 3203519..4d34509 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ # SimpleTDVP -[![Build Status](https://github.com/GiggleLiu/SimpleTDVP.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/GiggleLiu/SimpleTDVP.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Build Status](https://github.com/CodingThrust/SimpleTDVP.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/CodingThrust/SimpleTDVP.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/CodingThrust/SimpleTDVP.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/CodingThrust/SimpleTDVP.jl) diff --git a/src/SimpleTDVP.jl b/src/SimpleTDVP.jl index df93428..511f7f5 100644 --- a/src/SimpleTDVP.jl +++ b/src/SimpleTDVP.jl @@ -1,12 +1,15 @@ module SimpleTDVP using LinearAlgebra, OMEinsum, LinearOperators, KrylovKit +using LuxorGraphPlot: Node, Connection, Luxor +import LuxorGraphPlot export MPS, nsite, nflavor, vec2mps, IndexStore, newindex!, code_mps2vec, rand_mps, vec export left_move!, right_move!, canonical_move!, is_canonicalized, canonical_center, to_left_canonical!, to_right_canonical! export mat, MPO, code_mpo2mat, mat2mpo, rand_mpo export dot, sandwich, compress!, num_of_elements export dmrg!, dmrg +export TensorLayout include("utils.jl") include("tensornetwork.jl") @@ -14,5 +17,6 @@ include("mps.jl") include("mpo.jl") include("mpsormpo.jl") include("dmrg.jl") +include("visualize.jl") end diff --git a/src/visualize.jl b/src/visualize.jl new file mode 100644 index 0000000..9415523 --- /dev/null +++ b/src/visualize.jl @@ -0,0 +1,125 @@ +struct TensorLayout{LT} + nodes::Vector{Pair{Vector{LT}, Node}} # labels => nodes + mask::Vector{Bool} + ignored_labels::Vector{LT} +end +function TensorLayout(ixsv::AbstractVector, locs; ignored_labels=Int[], kwargs...) + @assert length(ixsv) == length(locs) + tvg = TensorLayout{Int}(Pair{Vector{Int}, Node}[], Bool[], ignored_labels) + for (ix, loc) in zip(ixsv, locs) + addnode!(tvg, loc..., ix; kwargs...) + end + return tvg +end +get_width(tvg::TensorLayout) = maximum([node.loc[1] for (_, node) in tvg.nodes]) + 50 +get_height(tvg::TensorLayout) = maximum([node.loc[2] for (_, node) in tvg.nodes]) + 50 +function addnode!(tvg::TensorLayout, x::Real, y::Real, ix::AbstractVector; kwargs...) + loc = (x, y) + node = if length(ix) == 0 + Node(:dot, loc; kwargs...) + elseif length(ix) == 1 + Node(:circle, loc; radius=7, kwargs...) + elseif length(ix) == 2 + Node(:polygon, loc; relpath=[[-1, 0], [0, -1], [1, 0], [0, 1]], kwargs...) + elseif length(ix) == 3 + Node(:circle, loc; radius=12, kwargs...) + elseif length(ix) == 4 + Node(:box, loc; width=20, height=20, kwargs...) + else + Node(:circle, loc; radius=12, kwargs...) + end + push!(tvg.nodes, ix => node) + push!(tvg.mask, true) + tvg +end + +function draw_einsum(gl::TensorLayout) + Luxor.@layer begin + for (_, node) in gl.nodes[gl.mask] + Luxor.sethue("white") + LuxorGraphPlot.fill(node) + Luxor.sethue("black") + LuxorGraphPlot.stroke(node) + end + Luxor.sethue("black") + for i in 1:length(gl.nodes), j in i+1:length(gl.nodes) + ix1, node1 = gl.nodes[i] + ix2, node2 = gl.nodes[j] + if (gl.mask[i] || gl.mask[j]) && !isempty(setdiff(ix1 ∩ ix2, gl.ignored_labels)) + LuxorGraphPlot.stroke(Connection(node1, node2)) + end + end + end +end + +function default_draw(f, gl::TensorLayout) + return Luxor.@drawsvg begin + Luxor.origin(0, 0) + Luxor.background("white") + draw_einsum(gl) + f(gl) + Luxor.finish() + end get_width(gl) get_height(gl) +end +default_draw(gl::TensorLayout) = default_draw(gl->nothing, gl) +function default_draw(mps::MPS) + n = nsite(mps) + code = code_mps2vec(mps) + ixs = OMEinsum.getixsv(OMEinsum.flatten(code)) + phantoms = [[ix[2]] for ix in ixs] + layout = TensorLayout([ixs..., phantoms...], + [[(i*50, 90) for i in 1:n]... # real tensors + [(i*50, 50) for i in 1:n]...]; # phantoms + ignored_labels=[ixs[1][1]] # ignore the long-ancilla label + ) + layout.mask[n+1:end] .= false + return default_draw(layout) do layout + Luxor.fontsize(14) + for i = 1:n + if i < n + node = LuxorGraphPlot.midpoint(layout.nodes[i].second, layout.nodes[i+1].second) + Luxor.text("$(size(mps.tensors[i], 3))", node.loc + Luxor.Point(0, -10); valign=:middle, halign=:center) + end + if canonical_center(mps) == i + Luxor.sethue("gray") + LuxorGraphPlot.fill(layout.nodes[i].second) + end + end + end +end + +function default_draw(mpo::MPO) + n = nsite(mpo) + code = code_mpo2mat(mpo) + ixs = OMEinsum.getixsv(OMEinsum.flatten(code)) + phantoms = [[[ix[2]] for ix in ixs]..., [[ix[3]] for ix in ixs]...] + layout = TensorLayout([ixs..., phantoms...], + [[(i*50, 90) for i in 1:n]... + [(i*50, 50) for i in 1:n]... + [(i*50, 130) for i in 1:n]...]; + ignored_labels = [ixs[1][1]] # ignore the long-ancilla label + ) + layout.mask[n+1:end] .= false + return default_draw(layout) do layout + Luxor.fontsize(14) + for i = 1:n + if i < n + node = LuxorGraphPlot.midpoint(layout.nodes[i].second, layout.nodes[i+1].second) + Luxor.text("$(size(mpo.tensors[i], 4))", node.loc + Luxor.Point(0, -10); valign=:middle, halign=:center) + end + if canonical_center(mpo) == i + Luxor.sethue("gray") + LuxorGraphPlot.fill(layout.nodes[i+n].second) + end + end + end +end + +for FORMAT in [MIME"text/html", MIME"image/svg+xml"] + for OBJ in [MPS, MPO, TensorLayout] + @eval function Base.show(io::IO, ::$FORMAT, obj::$OBJ) + d = default_draw(obj) + show(io, $FORMAT(), d) + end + end +end \ No newline at end of file diff --git a/test/visualize.jl b/test/visualize.jl new file mode 100644 index 0000000..89ff917 --- /dev/null +++ b/test/visualize.jl @@ -0,0 +1,9 @@ +using SimpleTDVP, Test +using LuxorGraphPlot.Luxor + +@testset "visualize" begin + mps = SimpleTDVP.rand_mps([1,2,4,6,6,4,2,1]) + @test SimpleTDVP.default_draw(mps) isa Luxor.Drawing + mpo = SimpleTDVP.rand_mpo([1,4,6,4,1]) + @test SimpleTDVP.default_draw(mpo) isa Luxor.Drawing +end \ No newline at end of file