Skip to content

Commit

Permalink
Visualization (#1)
Browse files Browse the repository at this point in the history
* new visualization

* update

* update CI

* fix test

* fix test
  • Loading branch information
GiggleLiu authored Apr 4, 2024
1 parent 7716448 commit bbf49a7
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 2 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: CI
on:
push:
branches:
- main
tags: '*'
pull_request:
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
with:
files: lcov.info
16 changes: 16 additions & 0 deletions .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: CompatHelper
on:
schedule:
- cron: 0 0 * * *
workflow_dispatch:
jobs:
CompatHelper:
runs-on: ubuntu-latest
steps:
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ version = "1.0.0-DEV"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"

[compat]
KrylovKit = "0.7"
LinearAlgebra = "1"
LinearOperators = "2"
LuxorGraphPlot = "0.3"
OMEinsum = "0.8"
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[targets]
test = ["Test"]
test = ["Test", "Random"]
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# SimpleTDVP

[![Build Status](https://github.com/GiggleLiu/SimpleTDVP.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/GiggleLiu/SimpleTDVP.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Build Status](https://github.com/CodingThrust/SimpleTDVP.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/CodingThrust/SimpleTDVP.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/CodingThrust/SimpleTDVP.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/CodingThrust/SimpleTDVP.jl)
4 changes: 4 additions & 0 deletions src/SimpleTDVP.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
module SimpleTDVP

using LinearAlgebra, OMEinsum, LinearOperators, KrylovKit
using LuxorGraphPlot: Node, Connection, Luxor
import LuxorGraphPlot

export MPS, nsite, nflavor, vec2mps, IndexStore, newindex!, code_mps2vec, rand_mps, vec
export left_move!, right_move!, canonical_move!, is_canonicalized, canonical_center, to_left_canonical!, to_right_canonical!
export mat, MPO, code_mpo2mat, mat2mpo, rand_mpo
export dot, sandwich, compress!, num_of_elements
export dmrg!, dmrg
export TensorLayout

include("utils.jl")
include("tensornetwork.jl")
include("mps.jl")
include("mpo.jl")
include("mpsormpo.jl")
include("dmrg.jl")
include("visualize.jl")

end
125 changes: 125 additions & 0 deletions src/visualize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
struct TensorLayout{LT}
nodes::Vector{Pair{Vector{LT}, Node}} # labels => nodes
mask::Vector{Bool}
ignored_labels::Vector{LT}
end
function TensorLayout(ixsv::AbstractVector, locs; ignored_labels=Int[], kwargs...)
@assert length(ixsv) == length(locs)
tvg = TensorLayout{Int}(Pair{Vector{Int}, Node}[], Bool[], ignored_labels)
for (ix, loc) in zip(ixsv, locs)
addnode!(tvg, loc..., ix; kwargs...)
end
return tvg
end
get_width(tvg::TensorLayout) = maximum([node.loc[1] for (_, node) in tvg.nodes]) + 50
get_height(tvg::TensorLayout) = maximum([node.loc[2] for (_, node) in tvg.nodes]) + 50
function addnode!(tvg::TensorLayout, x::Real, y::Real, ix::AbstractVector; kwargs...)
loc = (x, y)
node = if length(ix) == 0
Node(:dot, loc; kwargs...)
elseif length(ix) == 1
Node(:circle, loc; radius=7, kwargs...)
elseif length(ix) == 2
Node(:polygon, loc; relpath=[[-1, 0], [0, -1], [1, 0], [0, 1]], kwargs...)
elseif length(ix) == 3
Node(:circle, loc; radius=12, kwargs...)
elseif length(ix) == 4
Node(:box, loc; width=20, height=20, kwargs...)
else
Node(:circle, loc; radius=12, kwargs...)
end
push!(tvg.nodes, ix => node)
push!(tvg.mask, true)
tvg
end

function draw_einsum(gl::TensorLayout)
Luxor.@layer begin
for (_, node) in gl.nodes[gl.mask]
Luxor.sethue("white")
LuxorGraphPlot.fill(node)
Luxor.sethue("black")
LuxorGraphPlot.stroke(node)
end
Luxor.sethue("black")
for i in 1:length(gl.nodes), j in i+1:length(gl.nodes)
ix1, node1 = gl.nodes[i]
ix2, node2 = gl.nodes[j]
if (gl.mask[i] || gl.mask[j]) && !isempty(setdiff(ix1 ix2, gl.ignored_labels))
LuxorGraphPlot.stroke(Connection(node1, node2))
end
end
end
end

function default_draw(f, gl::TensorLayout)
return Luxor.@drawsvg begin
Luxor.origin(0, 0)
Luxor.background("white")
draw_einsum(gl)
f(gl)
Luxor.finish()
end get_width(gl) get_height(gl)
end
default_draw(gl::TensorLayout) = default_draw(gl->nothing, gl)
function default_draw(mps::MPS)
n = nsite(mps)
code = code_mps2vec(mps)
ixs = OMEinsum.getixsv(OMEinsum.flatten(code))
phantoms = [[ix[2]] for ix in ixs]
layout = TensorLayout([ixs..., phantoms...],
[[(i*50, 90) for i in 1:n]... # real tensors
[(i*50, 50) for i in 1:n]...]; # phantoms
ignored_labels=[ixs[1][1]] # ignore the long-ancilla label
)
layout.mask[n+1:end] .= false
return default_draw(layout) do layout
Luxor.fontsize(14)
for i = 1:n
if i < n
node = LuxorGraphPlot.midpoint(layout.nodes[i].second, layout.nodes[i+1].second)
Luxor.text("$(size(mps.tensors[i], 3))", node.loc + Luxor.Point(0, -10); valign=:middle, halign=:center)
end
if canonical_center(mps) == i
Luxor.sethue("gray")
LuxorGraphPlot.fill(layout.nodes[i].second)
end
end
end
end

function default_draw(mpo::MPO)
n = nsite(mpo)
code = code_mpo2mat(mpo)
ixs = OMEinsum.getixsv(OMEinsum.flatten(code))
phantoms = [[[ix[2]] for ix in ixs]..., [[ix[3]] for ix in ixs]...]
layout = TensorLayout([ixs..., phantoms...],
[[(i*50, 90) for i in 1:n]...
[(i*50, 50) for i in 1:n]...
[(i*50, 130) for i in 1:n]...];
ignored_labels = [ixs[1][1]] # ignore the long-ancilla label
)
layout.mask[n+1:end] .= false
return default_draw(layout) do layout
Luxor.fontsize(14)
for i = 1:n
if i < n
node = LuxorGraphPlot.midpoint(layout.nodes[i].second, layout.nodes[i+1].second)
Luxor.text("$(size(mpo.tensors[i], 4))", node.loc + Luxor.Point(0, -10); valign=:middle, halign=:center)
end
if canonical_center(mpo) == i
Luxor.sethue("gray")
LuxorGraphPlot.fill(layout.nodes[i+n].second)
end
end
end
end

for FORMAT in [MIME"text/html", MIME"image/svg+xml"]
for OBJ in [MPS, MPO, TensorLayout]
@eval function Base.show(io::IO, ::$FORMAT, obj::$OBJ)
d = default_draw(obj)
show(io, $FORMAT(), d)
end
end
end
9 changes: 9 additions & 0 deletions test/visualize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using SimpleTDVP, Test
using LuxorGraphPlot.Luxor

@testset "visualize" begin
mps = SimpleTDVP.rand_mps([1,2,4,6,6,4,2,1])
@test SimpleTDVP.default_draw(mps) isa Luxor.Drawing
mpo = SimpleTDVP.rand_mpo([1,4,6,4,1])
@test SimpleTDVP.default_draw(mpo) isa Luxor.Drawing
end

0 comments on commit bbf49a7

Please sign in to comment.