-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* new visualization * update * update CI * fix test * fix test
- Loading branch information
Showing
7 changed files
with
198 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: CI | ||
on: | ||
push: | ||
branches: | ||
- main | ||
tags: '*' | ||
pull_request: | ||
concurrency: | ||
# Skip intermediate builds: always. | ||
# Cancel intermediate builds: only if it is a pull request build. | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} | ||
jobs: | ||
test: | ||
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
version: | ||
- '1' | ||
os: | ||
- ubuntu-latest | ||
arch: | ||
- x64 | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: julia-actions/setup-julia@v1 | ||
with: | ||
version: ${{ matrix.version }} | ||
arch: ${{ matrix.arch }} | ||
- uses: julia-actions/cache@v1 | ||
- uses: julia-actions/julia-buildpkg@v1 | ||
- uses: julia-actions/julia-runtest@v1 | ||
- uses: julia-actions/julia-processcoverage@v1 | ||
- uses: codecov/codecov-action@v3 | ||
with: | ||
files: lcov.info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
name: CompatHelper | ||
on: | ||
schedule: | ||
- cron: 0 0 * * * | ||
workflow_dispatch: | ||
jobs: | ||
CompatHelper: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Pkg.add("CompatHelper") | ||
run: julia -e 'using Pkg; Pkg.add("CompatHelper")' | ||
- name: CompatHelper.main() | ||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} | ||
run: julia -e 'using CompatHelper; CompatHelper.main()' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# SimpleTDVP | ||
|
||
[data:image/s3,"s3://crabby-images/0512e/0512eb83b0e793ea5d383e8834e280f67f235c9d" alt="Build Status"](https://github.com/GiggleLiu/SimpleTDVP.jl/actions/workflows/CI.yml?query=branch%3Amain) | ||
[data:image/s3,"s3://crabby-images/b0a02/b0a02f04d02b3511116e17adc31ee8d5df8a67e4" alt="Build Status"](https://github.com/CodingThrust/SimpleTDVP.jl/actions/workflows/CI.yml?query=branch%3Amain) | ||
[data:image/s3,"s3://crabby-images/25142/25142b8a7e10ab5c78848f18664fa9fb417e786b" alt="Coverage"](https://codecov.io/gh/CodingThrust/SimpleTDVP.jl) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,22 @@ | ||
module SimpleTDVP | ||
|
||
using LinearAlgebra, OMEinsum, LinearOperators, KrylovKit | ||
using LuxorGraphPlot: Node, Connection, Luxor | ||
import LuxorGraphPlot | ||
|
||
export MPS, nsite, nflavor, vec2mps, IndexStore, newindex!, code_mps2vec, rand_mps, vec | ||
export left_move!, right_move!, canonical_move!, is_canonicalized, canonical_center, to_left_canonical!, to_right_canonical! | ||
export mat, MPO, code_mpo2mat, mat2mpo, rand_mpo | ||
export dot, sandwich, compress!, num_of_elements | ||
export dmrg!, dmrg | ||
export TensorLayout | ||
|
||
include("utils.jl") | ||
include("tensornetwork.jl") | ||
include("mps.jl") | ||
include("mpo.jl") | ||
include("mpsormpo.jl") | ||
include("dmrg.jl") | ||
include("visualize.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
struct TensorLayout{LT} | ||
nodes::Vector{Pair{Vector{LT}, Node}} # labels => nodes | ||
mask::Vector{Bool} | ||
ignored_labels::Vector{LT} | ||
end | ||
function TensorLayout(ixsv::AbstractVector, locs; ignored_labels=Int[], kwargs...) | ||
@assert length(ixsv) == length(locs) | ||
tvg = TensorLayout{Int}(Pair{Vector{Int}, Node}[], Bool[], ignored_labels) | ||
for (ix, loc) in zip(ixsv, locs) | ||
addnode!(tvg, loc..., ix; kwargs...) | ||
end | ||
return tvg | ||
end | ||
get_width(tvg::TensorLayout) = maximum([node.loc[1] for (_, node) in tvg.nodes]) + 50 | ||
get_height(tvg::TensorLayout) = maximum([node.loc[2] for (_, node) in tvg.nodes]) + 50 | ||
function addnode!(tvg::TensorLayout, x::Real, y::Real, ix::AbstractVector; kwargs...) | ||
loc = (x, y) | ||
node = if length(ix) == 0 | ||
Node(:dot, loc; kwargs...) | ||
elseif length(ix) == 1 | ||
Node(:circle, loc; radius=7, kwargs...) | ||
elseif length(ix) == 2 | ||
Node(:polygon, loc; relpath=[[-1, 0], [0, -1], [1, 0], [0, 1]], kwargs...) | ||
elseif length(ix) == 3 | ||
Node(:circle, loc; radius=12, kwargs...) | ||
elseif length(ix) == 4 | ||
Node(:box, loc; width=20, height=20, kwargs...) | ||
else | ||
Node(:circle, loc; radius=12, kwargs...) | ||
end | ||
push!(tvg.nodes, ix => node) | ||
push!(tvg.mask, true) | ||
tvg | ||
end | ||
|
||
function draw_einsum(gl::TensorLayout) | ||
Luxor.@layer begin | ||
for (_, node) in gl.nodes[gl.mask] | ||
Luxor.sethue("white") | ||
LuxorGraphPlot.fill(node) | ||
Luxor.sethue("black") | ||
LuxorGraphPlot.stroke(node) | ||
end | ||
Luxor.sethue("black") | ||
for i in 1:length(gl.nodes), j in i+1:length(gl.nodes) | ||
ix1, node1 = gl.nodes[i] | ||
ix2, node2 = gl.nodes[j] | ||
if (gl.mask[i] || gl.mask[j]) && !isempty(setdiff(ix1 ∩ ix2, gl.ignored_labels)) | ||
LuxorGraphPlot.stroke(Connection(node1, node2)) | ||
end | ||
end | ||
end | ||
end | ||
|
||
function default_draw(f, gl::TensorLayout) | ||
return Luxor.@drawsvg begin | ||
Luxor.origin(0, 0) | ||
Luxor.background("white") | ||
draw_einsum(gl) | ||
f(gl) | ||
Luxor.finish() | ||
end get_width(gl) get_height(gl) | ||
end | ||
default_draw(gl::TensorLayout) = default_draw(gl->nothing, gl) | ||
function default_draw(mps::MPS) | ||
n = nsite(mps) | ||
code = code_mps2vec(mps) | ||
ixs = OMEinsum.getixsv(OMEinsum.flatten(code)) | ||
phantoms = [[ix[2]] for ix in ixs] | ||
layout = TensorLayout([ixs..., phantoms...], | ||
[[(i*50, 90) for i in 1:n]... # real tensors | ||
[(i*50, 50) for i in 1:n]...]; # phantoms | ||
ignored_labels=[ixs[1][1]] # ignore the long-ancilla label | ||
) | ||
layout.mask[n+1:end] .= false | ||
return default_draw(layout) do layout | ||
Luxor.fontsize(14) | ||
for i = 1:n | ||
if i < n | ||
node = LuxorGraphPlot.midpoint(layout.nodes[i].second, layout.nodes[i+1].second) | ||
Luxor.text("$(size(mps.tensors[i], 3))", node.loc + Luxor.Point(0, -10); valign=:middle, halign=:center) | ||
end | ||
if canonical_center(mps) == i | ||
Luxor.sethue("gray") | ||
LuxorGraphPlot.fill(layout.nodes[i].second) | ||
end | ||
end | ||
end | ||
end | ||
|
||
function default_draw(mpo::MPO) | ||
n = nsite(mpo) | ||
code = code_mpo2mat(mpo) | ||
ixs = OMEinsum.getixsv(OMEinsum.flatten(code)) | ||
phantoms = [[[ix[2]] for ix in ixs]..., [[ix[3]] for ix in ixs]...] | ||
layout = TensorLayout([ixs..., phantoms...], | ||
[[(i*50, 90) for i in 1:n]... | ||
[(i*50, 50) for i in 1:n]... | ||
[(i*50, 130) for i in 1:n]...]; | ||
ignored_labels = [ixs[1][1]] # ignore the long-ancilla label | ||
) | ||
layout.mask[n+1:end] .= false | ||
return default_draw(layout) do layout | ||
Luxor.fontsize(14) | ||
for i = 1:n | ||
if i < n | ||
node = LuxorGraphPlot.midpoint(layout.nodes[i].second, layout.nodes[i+1].second) | ||
Luxor.text("$(size(mpo.tensors[i], 4))", node.loc + Luxor.Point(0, -10); valign=:middle, halign=:center) | ||
end | ||
if canonical_center(mpo) == i | ||
Luxor.sethue("gray") | ||
LuxorGraphPlot.fill(layout.nodes[i+n].second) | ||
end | ||
end | ||
end | ||
end | ||
|
||
for FORMAT in [MIME"text/html", MIME"image/svg+xml"] | ||
for OBJ in [MPS, MPO, TensorLayout] | ||
@eval function Base.show(io::IO, ::$FORMAT, obj::$OBJ) | ||
d = default_draw(obj) | ||
show(io, $FORMAT(), d) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
using SimpleTDVP, Test | ||
using LuxorGraphPlot.Luxor | ||
|
||
@testset "visualize" begin | ||
mps = SimpleTDVP.rand_mps([1,2,4,6,6,4,2,1]) | ||
@test SimpleTDVP.default_draw(mps) isa Luxor.Drawing | ||
mpo = SimpleTDVP.rand_mpo([1,4,6,4,1]) | ||
@test SimpleTDVP.default_draw(mpo) isa Luxor.Drawing | ||
end |