Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust ZX coords #61

Merged
merged 4 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/qcir/qcir_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace qsyn::qcir {

std::function<bool(size_t const&)> valid_qcir_id(QCirMgr const& qcir_mgr) {
return [&](size_t const& id) {
if (qcir_mgr.is_id(id)) return true;
if (qcir_mgr.get() && qcir_mgr.is_id(id)) return true;
spdlog::error("QCir {} does not exist!!", id);
return false;
};
Expand All @@ -45,7 +45,7 @@ std::function<bool(size_t const&)> valid_qcir_id(QCirMgr const& qcir_mgr) {
std::function<bool(size_t const&)> valid_qcir_gate_id(QCirMgr const& qcir_mgr) {
return [&](size_t const& id) {
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false;
if (qcir_mgr.get()->get_gate(id) != nullptr) return true;
if (qcir_mgr.get() && qcir_mgr.get()->get_gate(id) != nullptr) return true;
spdlog::error("Gate ID {} does not exist!!", id);
return false;
};
Expand All @@ -54,7 +54,7 @@ std::function<bool(size_t const&)> valid_qcir_gate_id(QCirMgr const& qcir_mgr) {
std::function<bool(QubitIdType const&)> valid_qcir_qubit_id(QCirMgr const& qcir_mgr) {
return [&](QubitIdType const& id) {
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false;
if (qcir_mgr.get()->get_qubit(id) != nullptr) return true;
if (qcir_mgr.get() && qcir_mgr.get()->get_qubit(id) != nullptr) return true;
spdlog::error("Qubit ID {} does not exist!!", id);
return false;
};
Expand Down
2 changes: 1 addition & 1 deletion src/util/data_structure_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DataStructureManager { // NOLINT(hicpp-special-member-functions, cppcoreg

size_t get_next_id() const { return _next_id; }

T* get() const { return _list.at(_focused_id).get(); }
T* get() const { return size() ? _list.at(_focused_id).get() : nullptr; }

void set_by_id(size_t id, std::unique_ptr<T> t) {
if (_list.contains(id)) {
Expand Down
36 changes: 36 additions & 0 deletions src/zx/simplifier/rules/pivot_boundary_rule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
****************************************************************************/

#include "./zx_rules_template.hpp"
#include "zx/zxgraph.hpp"

using namespace qsyn::zx;

Expand Down Expand Up @@ -83,3 +84,38 @@ void PivotBoundaryRule::apply(ZXGraph& graph, std::vector<MatchType> const& matc

PivotRuleInterface::apply(graph, matches);
}

bool PivotBoundaryRule::is_candidate(ZXGraph& graph, ZXVertex* vb, ZXVertex* vn) {
if (!graph.is_graph_like()) {
spdlog::error("The graph is not graph like!");
return false;
}
if (!vb->is_z()) {
spdlog::error("Vertex {} is not a Z vertex", vb->get_id());
return false;
}
bool has_boundary = false;
for (const auto& [nb, etype] : graph.get_neighbors(vb)) {
if (nb->is_boundary()) {
has_boundary = true;
break;
}
}
if (!has_boundary) {
spdlog::error("Vertex {} is not connected to a boundary", vb->get_id());
return false;
}
if (!vn->has_n_pi_phase()) {
spdlog::error("Vertex {} is not a Z vertex with phase n π", vn->get_id());
return false;
}
if (!graph.is_neighbor(vb, vn)) {
spdlog::error("Vertices {} and {} are not connected", vb->get_id(), vn->get_id());
return false;
}
// if (graph.has_dangling_neighbors(vn)) {
// spdlog::error("Vertex {} is the axel of a phase gadget", vn->get_id());
// return false;
// }
return true;
}
1 change: 1 addition & 0 deletions src/zx/simplifier/rules/zx_rules_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class PivotBoundaryRule : public PivotRuleInterface {

std::vector<MatchType> find_matches(ZXGraph const& graph) const override;
void apply(ZXGraph& graph, std::vector<MatchType> const& matches) const override;
bool is_candidate(ZXGraph& graph, ZXVertex* v0, ZXVertex* v1);
};

class SpiderFusionRule : public ZXRuleTemplate<std::pair<ZXVertex*, ZXVertex*>> {
Expand Down
41 changes: 41 additions & 0 deletions src/zx/simplifier/simp_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

#include "./simp_cmd.hpp"

#include <fmt/core.h>

#include <cstddef>
#include <string>
#include <vector>

#include "./simplify.hpp"
#include "argparse/arg_parser.hpp"
Expand Down Expand Up @@ -177,4 +180,42 @@ Command zxgraph_rule_cmd(zx::ZXGraphMgr &zxgraph_mgr) {
}};
}

// REVIEW - Logic of check function is not completed
Command zxgraph_manual_apply_cmd(zx::ZXGraphMgr &zxgraph_mgr) {
return Command{
"manual",
[&](ArgumentParser &parser) {
parser.description("apply simplification rules on specific candidates");

auto mutex = parser.add_mutually_exclusive_group().required(true);
mutex.add_argument<bool>("--pivot")
.action(store_true)
.help("applies pivot rules to vertex pairs with phase 0 or π");
mutex.add_argument<bool>("--pivot-boundary")
.action(store_true)
.help("applies pivot rules to vertex pairs connected to the boundary");
mutex.add_argument<bool>("--pivot-gadget")
.action(store_true)
.help("unfuses the phase and applies pivot rules to form gadgets");

parser.add_argument<size_t>("vertices")
.nargs(2)
.constraint(valid_zxvertex_id(zxgraph_mgr))
.help("the vertices on which the rule applies");
},
[&](ArgumentParser const &parser) {
if (!dvlab::utils::mgr_has_data(zxgraph_mgr)) return dvlab::CmdExecResult::error;
auto vertices = parser.get<std::vector<size_t>>("vertices");
ZXVertex *bound = zxgraph_mgr.get()->find_vertex_by_id(vertices[0]);
ZXVertex *vert = zxgraph_mgr.get()->find_vertex_by_id(vertices[1]);

const bool is_cand = PivotBoundaryRule().is_candidate(*zxgraph_mgr.get(), bound, vert);
if (!is_cand) return CmdExecResult::error;

std::vector<std::pair<ZXVertex *, ZXVertex *>> match;
match.emplace_back(bound, vert);
PivotBoundaryRule().apply(*zxgraph_mgr.get(), match);
return CmdExecResult::done;
}};
}
} // namespace qsyn::zx
1 change: 1 addition & 0 deletions src/zx/simplifier/simp_cmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ namespace qsyn::zx {

dvlab::Command zxgraph_optimize_cmd(ZXGraphMgr &zxgraph_mgr);
dvlab::Command zxgraph_rule_cmd(ZXGraphMgr &zxgraph_mgr);
dvlab::Command zxgraph_manual_apply_cmd(ZXGraphMgr &zxgraph_mgr);

} // namespace qsyn::zx
3 changes: 2 additions & 1 deletion src/zx/simplifier/simplify.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class Simplifier {
hadamard_rule_simp();
}
~Simplifier() {
_simp_graph->adjustVertexCoordinates();
// REVIEW - Whether to adjust
// _simp_graph->adjust_vertex_coordinates();
}
Simplifier(Simplifier const& other) = default;
Simplifier(Simplifier&& other) = default;
Expand Down
16 changes: 15 additions & 1 deletion src/zx/zx_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@ namespace qsyn::zx {

std::function<bool(size_t const&)> valid_zxvertex_id(ZXGraphMgr const& zxgraph_mgr) {
return [&](size_t const& id) {
if (zxgraph_mgr.get()->is_v_id(id)) return true;
if (zxgraph_mgr.get() && zxgraph_mgr.get()->is_v_id(id)) return true;
spdlog::error("Cannot find vertex with ID {} in the ZXGraph!!", id);
return false;
};
}

std::function<bool(size_t const&)> zxgraph_id_not_exist(ZXGraphMgr const& zxgraph_mgr) {
return [&](size_t const& id) {
if (!zxgraph_mgr.get()) {
spdlog::error("ZXGraphMgr does not exist!!");
return true;
}
if (!zxgraph_mgr.is_id(id)) return true;
spdlog::error("ZXGraph {} already exists!!", id);
spdlog::info("Use `-Replace` if you want to overwrite it.");
Expand All @@ -43,6 +47,10 @@ std::function<bool(size_t const&)> zxgraph_id_not_exist(ZXGraphMgr const& zxgrap

std::function<bool(int const&)> zxgraph_input_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr) {
return [&](int const& qid) {
if (!zxgraph_mgr.get()) {
spdlog::error("ZXGraphMgr does not exist!!");
return true;
}
if (!zxgraph_mgr.get()->is_input_qubit(qid)) return true;
spdlog::error("This qubit's input already exists!!");
return false;
Expand All @@ -51,6 +59,10 @@ std::function<bool(int const&)> zxgraph_input_qubit_not_exist(ZXGraphMgr const&

std::function<bool(int const&)> zxgraph_output_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr) {
return [&](int const& qid) {
if (!zxgraph_mgr.get()) {
spdlog::error("ZXGraphMgr does not exist!!");
return true;
}
if (!zxgraph_mgr.get()->is_output_qubit(qid)) return true;
spdlog::error("This qubit's output already exists!!");
return false;
Expand Down Expand Up @@ -242,6 +254,7 @@ Command zxgraph_draw_cmd(ZXGraphMgr const& zxgraph_mgr) {
[&](ArgumentParser const& parser) {
if (!dvlab::utils::mgr_has_data(zxgraph_mgr)) return CmdExecResult::error;
if (parser.parsed("filepath")) {
zxgraph_mgr.get()->adjust_vertex_coordinates();
if (!zxgraph_mgr.get()->write_pdf(parser.get<std::string>("filepath"))) return CmdExecResult::error;
}
if (parser.parsed("-cli")) {
Expand Down Expand Up @@ -661,6 +674,7 @@ Command zxgraph_cmd(ZXGraphMgr& zxgraph_mgr) {
cmd.add_subcommand(zxgraph_gflow_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_optimize_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_rule_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_manual_apply_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_vertex_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_edge_cmd(zxgraph_mgr));
return cmd;
Expand Down
4 changes: 4 additions & 0 deletions src/zx/zx_cmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

namespace qsyn::zx {

std::function<bool(size_t const&)> valid_zxvertex_id(ZXGraphMgr const& zxgraph_mgr);
std::function<bool(size_t const&)> zxgraph_id_not_exist(ZXGraphMgr const& zxgraph_mgr);
std::function<bool(int const&)> zxgraph_input_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr);
std::function<bool(int const&)> zxgraph_output_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr);
bool add_zx_cmds(dvlab::CommandLineInterface& cli, qsyn::zx::ZXGraphMgr& zxgraph_mgr);

} // namespace qsyn::zx
2 changes: 1 addition & 1 deletion src/zx/zxgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class ZXGraph { // NOLINT(cppcoreguidelines-special-member-functions) : copy-sw
void add_gadget(Phase p, std::vector<ZXVertex*> const& vertices);
void remove_gadget(ZXVertex* v);
std::unordered_map<size_t, ZXVertex*> create_id_to_vertex_map() const;
void adjustVertexCoordinates();
void adjust_vertex_coordinates();

// Print functions (zxGraphPrint.cpp)
void print_graph(spdlog::level::level_enum lvl = spdlog::level::level_enum::off) const;
Expand Down
61 changes: 41 additions & 20 deletions src/zx/zxgraph_action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Copyright [ Copyright(c) 2023 DVLab, GIEE, NTU, Taiwan ]
****************************************************************************/

#include <fmt/core.h>

#include <cstddef>
#include <gsl/narrow>
#include <queue>
Expand Down Expand Up @@ -237,46 +239,65 @@ std::unordered_map<size_t, ZXVertex*> ZXGraph::create_id_to_vertex_map() const {
* @brief Rearrange vertices on each qubit so that each vertex can be seperated in the printed graph.
*
*/
void ZXGraph::adjustVertexCoordinates() {
void ZXGraph::adjust_vertex_coordinates() {
// FIXME - QubitId -> RowId
std::unordered_map<QubitIdType, std::vector<ZXVertex*>> qubit_id_to_vertices_map;
std::unordered_set<QubitIdType> visited_qubit_ids;
std::queue<ZXVertex*> vertex_queue;
std::vector<ZXVertex*> vertex_queue;
// NOTE - Check Gadgets
// FIXME - When replacing QubitId with RowId, add 0.5 on it
for (auto const& i : _vertices) {
if (i->get_qubit() == -2 && get_num_neighbors(i) > 1) {
std::unordered_map<QubitIdType, size_t> num_neighbor_qubits;
for (auto const& [nb, _] : get_neighbors(i)) {
if (num_neighbor_qubits.contains(nb->get_qubit())) {
num_neighbor_qubits[nb->get_qubit()]++;
fmt::println("add qb: {}", nb->get_qubit());
} else
num_neighbor_qubits[nb->get_qubit()] = 1;
}
fmt::println("move to {}", (*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
i->set_qubit((*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
}
}

// REVIEW - Whether to move the vertex from row -2 when it is no longer a gadget
// for (auto const& i : _vertices) {
// if (i->get_qubit() == -2 && get_num_neighbors(i) > 1) {
// std::unordered_map<QubitIdType, size_t> num_neighbor_qubits;
// for (auto const& [nb, _] : get_neighbors(i)) {
// if (num_neighbor_qubits.contains(nb->get_qubit())) {
// num_neighbor_qubits[nb->get_qubit()]++;
// } else
// num_neighbor_qubits[nb->get_qubit()] = 1;
// }
// // fmt::println("move to {}", (*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
// i->set_qubit((*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
// }
// }

for (auto const& i : _inputs) {
vertex_queue.push(i);
vertex_queue.emplace_back(i);
visited_qubit_ids.insert(gsl::narrow<QubitIdType>(i->get_id()));
}
while (!vertex_queue.empty()) {
ZXVertex* v = vertex_queue.front();
vertex_queue.pop();
vertex_queue.erase(vertex_queue.begin());
qubit_id_to_vertices_map[v->get_qubit()].emplace_back(v);
for (auto const& nb : get_neighbors(v) | std::views::keys) {
if (visited_qubit_ids.find(gsl::narrow<QubitIdType>(nb->get_id())) == visited_qubit_ids.end()) {
vertex_queue.push(nb);
vertex_queue.emplace_back(nb);
visited_qubit_ids.insert(gsl::narrow<QubitIdType>(nb->get_id()));
}
}
}
std::vector<ZXVertex*> gadgets;
double non_gadget = 0;
for (size_t i = 0; i < qubit_id_to_vertices_map[-2].size(); i++) {
if (get_num_neighbors(qubit_id_to_vertices_map[-2][i]) == 1) { // Not Gadgets
gadgets.emplace_back(qubit_id_to_vertices_map[-2][i]);
} else
non_gadget++;
}
auto end_it = std::remove_if(
qubit_id_to_vertices_map[-2].begin(),
qubit_id_to_vertices_map[-2].end(),
[this](ZXVertex* v) {
return this->get_num_neighbors(v) == 1;
});
qubit_id_to_vertices_map[-2].erase(end_it, qubit_id_to_vertices_map[-2].end());

qubit_id_to_vertices_map[-2].insert(qubit_id_to_vertices_map[-2].end(), gadgets.begin(), gadgets.end());
double max_col = 0.0;
for (auto& i : qubit_id_to_vertices_map) {
double col = i.first < 0 ? 0.5 : 0.0;
double col = i.first == -2 ? 0.5 : i.first == -1 ? 0.5 + non_gadget
: 0.0;
for (auto& v : i.second) {
v->set_col(col);
col++;
Expand Down
Loading