Skip to content

Commit

Permalink
Uncertainty models using circuits (qiskit-community/qiskit-aqua#908)
Browse files Browse the repository at this point in the history
* make variational UMs accept circuits

* test circuit variants

* add deprecation warnings if varform is passed

* fix typo Uni -> Multi

* fix warning filters

* filter deprecation warnings on univ. var. dist

* only filter warnings locally
  • Loading branch information
Cryoris authored Apr 22, 2020
1 parent 2e18866 commit c564fd5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/finance/test_data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setUp(self):

def tearDown(self):
super().tearDown()
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
warnings.filterwarnings(action="always", message="unclosed", category=ResourceWarning)

def test_wrong_use(self):
""" wrong use test """
Expand Down
26 changes: 25 additions & 1 deletion test/finance/test_european_call_expected_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
""" Test European Call Expected Value uncertainty problem """

from test.finance import QiskitFinanceTestCase
import warnings
from ddt import ddt, data

import numpy as np

from qiskit import BasicAer
from qiskit.circuit import ParameterVector
from qiskit.aqua import aqua_globals, QuantumInstance
from qiskit.aqua.algorithms import AmplitudeEstimation
from qiskit.aqua.components.initial_states import Custom
Expand All @@ -28,6 +31,7 @@
from qiskit.finance.components.uncertainty_problems import EuropeanCallExpectedValue


@ddt
class TestEuropeanCallExpectedValue(QiskitFinanceTestCase):
"""Tests European Call Expected Value uncertainty problem """

Expand All @@ -36,8 +40,18 @@ def setUp(self):
self.seed = 457
aqua_globals.random_seed = self.seed

def test_ecev(self):
def tearDown(self):
super().tearDown()
warnings.filterwarnings(action="always", category=DeprecationWarning)

@data(False, True)
def test_ecev(self, use_circuits):
""" European Call Expected Value test """
if not use_circuits:
# ignore deprecation warnings from the deprecation of VariationalForm as input for
# the univariate variational distribution
warnings.filterwarnings("ignore", category=DeprecationWarning)

bounds = np.array([0., 7.])
num_qubits = [3]
entangler_map = []
Expand All @@ -54,10 +68,17 @@ def test_ecev(self):
var_form = RY(int(np.sum(num_qubits)), depth=1,
initial_state=init_distribution,
entangler_map=entangler_map, entanglement_gate='cz')
if use_circuits:
theta = ParameterVector('θ', var_form.num_parameters)
var_form = var_form.construct_circuit(theta)

uncertainty_model = UnivariateVariationalDistribution(
int(sum(num_qubits)), var_form, g_params,
low=bounds[0], high=bounds[1])

if use_circuits:
uncertainty_model._var_form_params = theta

strike_price = 2
c_approx = 0.25
european_call = EuropeanCallExpectedValue(uncertainty_model,
Expand All @@ -71,3 +92,6 @@ def test_ecev(self):
result = algo.run(quantum_instance=BasicAer.get_backend('statevector_simulator'))
self.assertAlmostEqual(result['estimation'], 1.2580, places=4)
self.assertAlmostEqual(result['max_probability'], 0.8785, places=4)

if not use_circuits:
warnings.filterwarnings(action="always", category=DeprecationWarning)

0 comments on commit c564fd5

Please sign in to comment.