Skip to content

Commit

Permalink
Uncertainty models using circuits (qiskit-community/qiskit-aqua#908)
Browse files Browse the repository at this point in the history
* make variational UMs accept circuits

* test circuit variants

* add deprecation warnings if varform is passed

* fix typo Uni -> Multi

* fix warning filters

* filter deprecation warnings on univ. var. dist

* only filter warnings locally
  • Loading branch information
Cryoris authored Apr 22, 2020
1 parent a5df837 commit 7aa44a1
Showing 1 changed file with 47 additions and 9 deletions.
56 changes: 47 additions & 9 deletions test/aqua/test_qgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# =============================================================================

""" Test QGAN """

import unittest
from test.aqua import QiskitAquaTestCase

import warnings
import unittest
from ddt import ddt, data
from qiskit import QuantumCircuit, QuantumRegister
from qiskit.circuit import ParameterVector
from qiskit.aqua.components.uncertainty_models import (UniformDistribution,
UnivariateVariationalDistribution)
from qiskit.aqua.components.variational_forms import RY
Expand All @@ -28,10 +31,13 @@
from qiskit import BasicAer


@ddt
class TestQGAN(QiskitAquaTestCase):
""" Test QGAN """

def setUp(self):
super().setUp()

self.seed = 7
aqua_globals.random_seed = self.seed
# Number training data samples
Expand Down Expand Up @@ -86,27 +92,59 @@ def setUp(self):
# Set generator's initial parameters
init_params = aqua_globals.random.rand(var_form._num_parameters) * 2 * 1e-2
# Set generator circuit
g_circuit = UnivariateVariationalDistribution(sum(num_qubits), var_form, init_params,
low=self._bounds[0],
high=self._bounds[1])
# Set quantum generator
self.qgan.set_generator(generator_circuit=g_circuit)
self.g_var_form = UnivariateVariationalDistribution(sum(num_qubits), var_form, init_params,
low=self._bounds[0],
high=self._bounds[1])

theta = ParameterVector('θ', var_form.num_parameters)
var_form = var_form.construct_circuit(theta)
self.g_circuit = UnivariateVariationalDistribution(sum(num_qubits), var_form, init_params,
low=self._bounds[0],
high=self._bounds[1])

def test_sample_generation(self):
def tearDown(self):
super().tearDown()
warnings.filterwarnings(action="always", category=DeprecationWarning)

@data(False, True)
def test_sample_generation(self, use_circuits):
""" sample generation test """
if use_circuits:
self.qgan.set_generator(generator_circuit=self.g_circuit)
else:
# ignore deprecation warnings from the deprecation of VariationalForm as input for
# the univariate variational distribution
warnings.filterwarnings("ignore", category=DeprecationWarning)
self.qgan.set_generator(generator_circuit=self.g_var_form)

_, weights_statevector = \
self.qgan._generator.get_output(self.qi_statevector, shots=100)
samples_qasm, weights_qasm = self.qgan._generator.get_output(self.qi_qasm, shots=100)
samples_qasm, weights_qasm = zip(*sorted(zip(samples_qasm, weights_qasm)))
for i, weight_q in enumerate(weights_qasm):
self.assertAlmostEqual(weight_q, weights_statevector[i], delta=0.1)

def test_qgan_training(self):
if not use_circuits:
warnings.filterwarnings(action="always", category=DeprecationWarning)

@data(False, True)
def test_qgan_training(self, use_circuits):
""" qgan training test """
if use_circuits:
self.qgan.set_generator(generator_circuit=self.g_circuit)
else:
# ignore deprecation warnings from the deprecation of VariationalForm as input for
# the univariate variational distribution
warnings.filterwarnings("ignore", category=DeprecationWarning)
self.qgan.set_generator(generator_circuit=self.g_var_form)

trained_statevector = self.qgan.run(self.qi_statevector)
trained_qasm = self.qgan.run(self.qi_qasm)
self.assertAlmostEqual(trained_qasm['rel_entr'], trained_statevector['rel_entr'], delta=0.1)

if not use_circuits:
warnings.filterwarnings(action="always", category=DeprecationWarning)

def test_qgan_training_run_algo_torch(self):
""" qgan training run algo torch test """
try:
Expand Down

0 comments on commit 7aa44a1

Please sign in to comment.