Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify setup and teardown methods of TestCases #6753

Merged
merged 7 commits into from
Aug 3, 2021
Merged
229 changes: 101 additions & 128 deletions qiskit/test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
HAS_FIXTURES = False

from qiskit.exceptions import MissingOptionalLibraryError
from .decorators import enforce_subclasses_call
from .runtest import RunTest, MultipleExceptions
from .utils import Path, setup_test_logging

Expand Down Expand Up @@ -88,9 +89,106 @@ def gather_details(source_dict, target_dict):
target_dict[name] = _copy_content(content_object)


class BaseQiskitTestCase(unittest.TestCase):
@enforce_subclasses_call(["setUp", "setUpClass", "tearDown", "tearDownClass"])
class QiskitTestCase(unittest.TestCase):
"""Common extra functionality on top of unittest."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__setup_called = False
self.__teardown_called = False

def setUp(self):
super().setUp()
if self.__setup_called:
raise ValueError(
"In File: %s\n"
"TestCase.setUp was already called. Do not explicitly call "
"setUp from your tests. In your own setUp, use super to call "
"the base setUp." % (sys.modules[self.__class__.__module__].__file__,)
)
self.__setup_called = True

def tearDown(self):
super().tearDown()
if self.__teardown_called:
raise ValueError(
"In File: %s\n"
"TestCase.tearDown was already called. Do not explicitly call "
"tearDown from your tests. In your own tearDown, use super to "
"call the base tearDown." % (sys.modules[self.__class__.__module__].__file__,)
)
self.__teardown_called = True
# Reset the default providers, as in practice they acts as a singleton
# due to importing the instances from the top-level qiskit namespace.
from qiskit.providers.basicaer import BasicAer

BasicAer._backends = BasicAer._verify_backends()

@classmethod
def setUpClass(cls):
super().setUpClass()
# Determines if the TestCase is using IBMQ credentials.
cls.using_ibmq_credentials = False
# Set logging to file and stdout if the LOG_LEVEL envar is set.
cls.log = logging.getLogger(cls.__name__)
if os.getenv("LOG_LEVEL"):
filename = "%s.log" % os.path.splitext(inspect.getfile(cls))[0]
setup_test_logging(cls.log, os.getenv("LOG_LEVEL"), filename)

warnings.filterwarnings("error", category=DeprecationWarning)
allow_DeprecationWarning_modules = [
"test.python.pulse.test_parameters",
"test.python.pulse.test_transforms",
"test.python.circuit.test_gate_power",
"test.python.pulse.test_builder",
"test.python.pulse.test_block",
"test.python.quantum_info.operators.symplectic.test_legacy_pauli",
"qiskit.quantum_info.operators.pauli",
"pybobyqa",
"numba",
"qiskit.utils.measurement_error_mitigation",
"qiskit.circuit.library.standard_gates.x",
"qiskit.pulse.schedule",
"qiskit.pulse.instructions.instruction",
"qiskit.pulse.instructions.play",
"qiskit.pulse.library.parametric_pulses",
"qiskit.quantum_info.operators.symplectic.pauli",
"test.python.dagcircuit.test_dagcircuit",
"test.python.quantum_info.operators.test_operator",
"test.python.quantum_info.operators.test_scalar_op",
"test.python.quantum_info.operators.test_superop",
"test.python.quantum_info.operators.channel.test_kraus",
"test.python.quantum_info.operators.channel.test_choi",
"test.python.quantum_info.operators.channel.test_chi",
"test.python.quantum_info.operators.channel.test_superop",
"test.python.quantum_info.operators.channel.test_stinespring",
"test.python.quantum_info.operators.symplectic.test_sparse_pauli_op",
"test.python.quantum_info.operators.channel.test_ptm",
]
for mod in allow_DeprecationWarning_modules:
warnings.filterwarnings("default", category=DeprecationWarning, module=mod)
allow_DeprecationWarning_message = [
r".*LogNormalDistribution.*",
r".*NormalDistribution.*",
r".*UniformDistribution.*",
r".*QuantumCircuit\.combine.*",
r".*QuantumCircuit\.__add__.*",
r".*QuantumCircuit\.__iadd__.*",
r".*QuantumCircuit\.extend.*",
r".*psi @ U.*",
r".*qiskit\.circuit\.library\.standard_gates\.ms import.*",
r"elementwise comparison failed.*",
r"The jsonschema validation included in qiskit-terra.*",
r"The DerivativeBase.parameter_expression_grad method.*",
r"Back-references to from Bit instances.*",
r"The QuantumCircuit.u. method.*",
r"The QuantumCircuit.cu.",
r"The CXDirection pass has been deprecated",
]
for msg in allow_DeprecationWarning_message:
warnings.filterwarnings("default", category=DeprecationWarning, message=msg)

@staticmethod
def _get_resource_path(filename, path=Path.TEST):
"""Get the absolute path to a resource.
Expand Down Expand Up @@ -138,29 +236,7 @@ def assertDictAlmostEqual(
raise self.failureException(msg)


class BasicQiskitTestCase(BaseQiskitTestCase):
"""Helper class that contains common functionality."""

@classmethod
def setUpClass(cls):
# Determines if the TestCase is using IBMQ credentials.
cls.using_ibmq_credentials = False

# Set logging to file and stdout if the LOG_LEVEL envar is set.
cls.log = logging.getLogger(cls.__name__)
if os.getenv("LOG_LEVEL"):
filename = "%s.log" % os.path.splitext(inspect.getfile(cls))[0]
setup_test_logging(cls.log, os.getenv("LOG_LEVEL"), filename)

def tearDown(self):
# Reset the default providers, as in practice they acts as a singleton
# due to importing the instances from the top-level qiskit namespace.
from qiskit.providers.basicaer import BasicAer

BasicAer._backends = BasicAer._verify_backends()


class FullQiskitTestCase(BaseQiskitTestCase):
class FullQiskitTestCase(QiskitTestCase):
"""Helper class that contains common functionality that captures streams."""

run_tests_with = RunTest
Expand Down Expand Up @@ -191,8 +267,6 @@ def _reset(self):
# Generators to ensure unique traceback ids. Maps traceback label to
# iterators.
self._traceback_id_gens = {}
self.__setup_called = False
self.__teardown_called = False
self.__details = None

def onException(self, exc_info, tb_label="traceback"):
Expand All @@ -208,14 +282,6 @@ def onException(self, exc_info, tb_label="traceback"):
def _run_teardown(self, result):
"""Run the tearDown function for this test."""
self.tearDown()
if not self.__teardown_called:
raise ValueError(
"In File: %s\n"
"TestCase.tearDown was not called. Have you upcalled all the "
"way up the hierarchy from your tearDown? e.g. Call "
"super(%s, self).tearDown() from your tearDown()."
% (sys.modules[self.__class__.__module__].__file__, self.__class__.__name__)
)

def _get_test_method(self):
method_name = getattr(self, "_testMethodName")
Expand Down Expand Up @@ -278,14 +344,6 @@ def reraise(exc_class, exc_obj, exc_tb, _marker=object()):
def _run_setup(self, result):
"""Run the setUp function for this test."""
self.setUp()
if not self.__setup_called:
raise ValueError(
"In File: %s\n"
"TestCase.setUp was not called. Have you upcalled all the "
"way up the hierarchy from your setUp? e.g. Call "
"super(%s, self).setUp() from your setUp()."
% (sys.modules[self.__class__.__module__].__file__, self.__class__.__name__)
)

def _add_reason(self, reason):
self.addDetail("reason", content.text_content(reason))
Expand Down Expand Up @@ -342,37 +400,13 @@ def run(self, result=None):

def setUp(self):
super().setUp()
if self.__setup_called:
raise ValueError(
"In File: %s\n"
"TestCase.setUp was already called. Do not explicitly call "
"setUp from your tests. In your own setUp, use super to call "
"the base setUp." % (sys.modules[self.__class__.__module__].__file__,)
)
self.__setup_called = True
if os.environ.get("QISKIT_TEST_CAPTURE_STREAMS"):
stdout = self.useFixture(fixtures.StringStream("stdout")).stream
self.useFixture(fixtures.MonkeyPatch("sys.stdout", stdout))
stderr = self.useFixture(fixtures.StringStream("stderr")).stream
self.useFixture(fixtures.MonkeyPatch("sys.stderr", stderr))
self.useFixture(fixtures.LoggerFixture(nuke_handlers=False, level=None))

def tearDown(self):
super().tearDown()
if self.__teardown_called:
raise ValueError(
"In File: %s\n"
"TestCase.tearDown was already called. Do not explicitly call "
"tearDown from your tests. In your own tearDown, use super to "
"call the base tearDown." % (sys.modules[self.__class__.__module__].__file__,)
)
self.__teardown_called = True
# Reset the default providers, as in practice they acts as a singleton
# due to importing the instances from the top-level qiskit namespace.
from qiskit.providers.basicaer import BasicAer

BasicAer._backends = BasicAer._verify_backends()

def addDetail(self, name, content_object):
"""Add a detail to be reported with this test's outcome.

Expand Down Expand Up @@ -409,65 +443,6 @@ def getDetails(self):
self.__details = {}
return self.__details

@classmethod
def setUpClass(cls):
# Determines if the TestCase is using IBMQ credentials.
cls.using_ibmq_credentials = False
cls.log = logging.getLogger(cls.__name__)

warnings.filterwarnings("error", category=DeprecationWarning)
allow_DeprecationWarning_modules = [
"test.python.pulse.test_parameters",
"test.python.pulse.test_transforms",
"test.python.circuit.test_gate_power",
"test.python.pulse.test_builder",
"test.python.pulse.test_block",
"test.python.quantum_info.operators.symplectic.test_legacy_pauli",
"qiskit.quantum_info.operators.pauli",
"pybobyqa",
"numba",
"qiskit.utils.measurement_error_mitigation",
"qiskit.circuit.library.standard_gates.x",
"qiskit.pulse.schedule",
"qiskit.pulse.instructions.instruction",
"qiskit.pulse.instructions.play",
"qiskit.pulse.library.parametric_pulses",
"qiskit.quantum_info.operators.symplectic.pauli",
"test.python.dagcircuit.test_dagcircuit",
"test.python.quantum_info.operators.test_operator",
"test.python.quantum_info.operators.test_scalar_op",
"test.python.quantum_info.operators.test_superop",
"test.python.quantum_info.operators.channel.test_kraus",
"test.python.quantum_info.operators.channel.test_choi",
"test.python.quantum_info.operators.channel.test_chi",
"test.python.quantum_info.operators.channel.test_superop",
"test.python.quantum_info.operators.channel.test_stinespring",
"test.python.quantum_info.operators.symplectic.test_sparse_pauli_op",
"test.python.quantum_info.operators.channel.test_ptm",
]
for mod in allow_DeprecationWarning_modules:
warnings.filterwarnings("default", category=DeprecationWarning, module=mod)
allow_DeprecationWarning_message = [
r".*LogNormalDistribution.*",
r".*NormalDistribution.*",
r".*UniformDistribution.*",
r".*QuantumCircuit\.combine.*",
r".*QuantumCircuit\.__add__.*",
r".*QuantumCircuit\.__iadd__.*",
r".*QuantumCircuit\.extend.*",
r".*psi @ U.*",
r".*qiskit\.circuit\.library\.standard_gates\.ms import.*",
r"elementwise comparison failed.*",
r"The jsonschema validation included in qiskit-terra.*",
r"The DerivativeBase.parameter_expression_grad method.*",
r"Back-references to from Bit instances.*",
r"The QuantumCircuit.u. method.*",
r"The QuantumCircuit.cu.",
r"The CXDirection pass has been deprecated",
]
for msg in allow_DeprecationWarning_message:
warnings.filterwarnings("default", category=DeprecationWarning, message=msg)


def dicts_almost_equal(dict1, dict2, delta=None, places=None, default_value=0):
"""Test if two dictionaries with numeric values are almost equal.
Expand Down Expand Up @@ -528,7 +503,5 @@ def valid_comparison(value):
return ""


if not HAS_FIXTURES or not os.environ.get("QISKIT_TEST_CAPTURE_STREAMS"):
QiskitTestCase = BasicQiskitTestCase
else:
if HAS_FIXTURES:
QiskitTestCase = FullQiskitTestCase
Loading