Skip to content

Commit

Permalink
Add partial zx optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
lklxx committed Feb 29, 2024
1 parent 99dd6c0 commit 915db10
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/qcir/optimizer/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Optimizer {
void _cancel_double_gate(QCir& qcir, QCirGate* prev_gate, QCirGate* gate);
void _fuse_z_phase(QCir& qcir, QCirGate* prev_gate, QCirGate* gate);
void _fuse_x_phase(QCir& qcir, QCirGate* prev_gate, QCirGate* gate);
void _partial_zx_optimization(QCir& qcir);
};

} // namespace qsyn::qcir
166 changes: 166 additions & 0 deletions src/qcir/optimizer/trivial_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

#include <cassert>

#include "../../convert/qcir_to_zxgraph.hpp"
#include "../qcir.hpp"
#include "../qcir_gate.hpp"
#include "./optimizer.hpp"
#include "extractor/extract.hpp"

extern bool stop_requested();

Expand All @@ -29,6 +31,7 @@ std::optional<QCir> Optimizer::trivial_optimization(QCir const& qcir) {
QCir result{qcir.get_num_qubits()};
result.set_filename(qcir.get_filename());
result.add_procedures(qcir.get_procedures());
result.set_gate_set(qcir.get_gate_set());

auto const gate_list = qcir.get_topologically_ordered_gates();
for (auto gate : gate_list) {
Expand Down Expand Up @@ -61,6 +64,11 @@ std::optional<QCir> Optimizer::trivial_optimization(QCir const& qcir) {
Optimizer::_add_gate_to_circuit(result, gate, false);
}
}

if (!result.get_gate_set().empty()) {
_partial_zx_optimization(result);
}

spdlog::info("Finished trivial optimization");
return result;
}
Expand Down Expand Up @@ -176,4 +184,162 @@ void Optimizer::_cancel_double_gate(QCir& qcir, QCirGate* prev_gate, QCirGate* g
Optimizer::_add_gate_to_circuit(qcir, gate, false);
}

static size_t _match_gate_sequence(std::vector<std::string> const& type_seq,
std::vector<std::string> const& target_seq) {
if (type_seq.size() < target_seq.size()) {
return type_seq.size();
}

for (size_t i = 0; i < type_seq.size() - target_seq.size() + 1; i++) {
bool match = true;
for (size_t j = 0; j < target_seq.size(); j++) {
if (type_seq[i + j] != target_seq[j]) {
match = false;
break;
}
}
if (match) {
return i;
}
}
return type_seq.size();
}

static QCir _replace_gate_sequence(QCir& qcir, QubitIdType qubit, size_t gate_num,
size_t seq_len, std::vector<std::string> const& seq) {
QCir replaced;
replaced.add_procedures(qcir.get_procedures());
replaced.add_qubits(qcir.get_num_qubits());
replaced.set_gate_set(qcir.get_gate_set());

qcir.update_topological_order();
auto const gate_list = qcir.get_topologically_ordered_gates();
size_t replace_count = 0;

if (gate_num == 0) {
replace_count = seq_len;
}

for (auto gate : gate_list) {
auto bit_range = gate->get_qubits() |
std::views::transform([](QubitInfo const& qb) { return qb._qubit; });
if (gate->get_targets()._qubit != qubit) {
replaced.add_gate(gate->get_type_str(), {bit_range.begin(), bit_range.end()}, gate->get_phase(), true);
continue;
}

if (gate_num != 0) {
replaced.add_gate(gate->get_type_str(), {bit_range.begin(), bit_range.end()}, gate->get_phase(), true);
gate_num--;
if (gate_num == 0) {
replace_count = seq_len;
}
continue;
}

if (replace_count == 0) {
replaced.add_gate(gate->get_type_str(), {bit_range.begin(), bit_range.end()}, gate->get_phase(), true);
continue;
}

if (replace_count == seq_len) {
for (auto& type : seq) {
replaced.add_gate(type, {bit_range.begin(), bit_range.end()}, dvlab::Phase(0), true);
}
replace_count--;
continue;
}
replace_count--;
}

return replaced;
}

static std::vector<std::string> _zx_optimize(std::vector<std::string> partial) {
QCir qcir;
qcir.add_qubits(1);

for (std::string type : partial) {
auto gate_type = str_to_gate_type(type);
auto const& [category, num_qubits, gate_phase] = gate_type.value();
if (gate_phase.has_value())
qcir.add_gate(type, QubitIdList{0}, gate_phase.value(), true);
else
qcir.add_gate(type, QubitIdList{0}, dvlab::Phase(0), true);
}

auto zx = to_zxgraph(qcir, 3).value();
zx.add_procedure("QC2ZX");

extractor::Extractor ext(&zx, nullptr, std::nullopt);
QCir* result = ext.extract();

result->update_topological_order();
auto const gate_list = result->get_topologically_ordered_gates();
std::vector<std::string> opt_partial;
for (auto gate : gate_list) {
opt_partial.emplace_back(gate->get_type_str());
}

return opt_partial;
}

void Optimizer::_partial_zx_optimization(QCir& qcir) {
for (size_t i = 0; i < qcir.get_num_qubits(); i++) {
qcir.update_topological_order();
auto const gate_list = qcir.get_topologically_ordered_gates();
std::vector<std::string> type_seq;
for (auto gate : gate_list) {
if ((size_t)gate->get_targets()._qubit == i || (size_t)gate->get_control()._qubit == i) {
type_seq.emplace_back(gate->get_type_str());
}
}

std::vector<std::pair<std::vector<std::string>, std::vector<std::string>>> replace_rules;
while (type_seq.size()) {
std::vector<std::string> partial;
while (type_seq.size()) {
std::string type = type_seq[0];
type_seq.erase(type_seq.begin());
if (type == "cx" || type == "cz" || type == "ecr") {
break;
}
partial.emplace_back(type);
}

if (partial.size() >= 3) {
auto opt_partial = _zx_optimize(partial);
std::vector<std::string> replaced_h_opt_partial;
for (auto g : opt_partial) {
if (g == "h") {
replaced_h_opt_partial.emplace_back("s");
replaced_h_opt_partial.emplace_back("sx");
replaced_h_opt_partial.emplace_back("s");
} else {
replaced_h_opt_partial.emplace_back(g);
}
}
if (replaced_h_opt_partial.size() < partial.size()) {
replace_rules.emplace_back(std::make_pair(partial, replaced_h_opt_partial));
}
}
}

for (auto const& [lhs, rhs] : replace_rules) {
qcir.update_topological_order();
auto const updated_gate_list = qcir.get_topologically_ordered_gates();
std::vector<std::string> updated_type_seq;
for (auto gate : updated_gate_list) {
if ((size_t)gate->get_targets()._qubit == i || (size_t)gate->get_control()._qubit == i) {
updated_type_seq.emplace_back(gate->get_type_str());
}
}

size_t g = _match_gate_sequence(updated_type_seq, lhs);
QCir replaced = _replace_gate_sequence(qcir, i, g, lhs.size(), rhs);
qcir = replaced;
}
}
}

} // namespace qsyn::qcir

0 comments on commit 915db10

Please sign in to comment.