From 033a21cd58e4e797ada5644acc9c26f543e4ed66 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Thu, 24 Jun 2021 10:47:52 -0700 Subject: [PATCH] [UnitTests] Automatic parametrization over targets, with explicit opt-out (#8010) * [UnitTests] Explicitly list tests that were enabled by TVM_TEST_TARGETS but were skipped Previously, these were removed by a filter in tvm.testing._get_targets(), and weren't listed at all. With this change, they are instead removed by pytest.skipif, and show up as explicitly skipped tests in pytest's summary when using tvm.testing.parametrize_targets. * [UnitTests] Automatic parametrize_targets for tests that use (target,dev) Should make it easier to convert tests from using tvm.testing.enabled_targets to use pytest's parametrized tests instead. * [UnitTests] Added ability to explicitly exclude a target from a particular test Uses tvm_exclude_targets variable, which can be set (1) in the conftest.py to apply to a test directory, (2) in a test script to apply to that module, or (3) on an individual test function to apply to it. The @tvm.testing.exclude_targets decorator is provided for readability in case #3. * [UnitTests] Refactored test_topi_relu.py to use pytest.mark.parametrize * [UnitTests] Added tvm_known_failing_targets option for the unittests. Intended to mark tests that fail for a particular target, and are intended to be fixed in the future. Typically, these would result either from implementing a new test, or from an in-progress implementation of a new target. * [UnitTests] Known failing targets now marked with xfail instead of skipif * [UnitTests] Removed tvm_excluded_targets and tvm_known_failing_targets These were implemented to exclude or mark as failing an entire file or directory of tests. In https://discuss.tvm.apache.org/t/rfc-parametrized-unit-tests/9946/4, it was pointed out that the global variables would be vulnerable to typos in the names, resulting in the option being silently ignored. The decorators `@tvm.testing.exclude_targets` and `@tvm.testing.known_failing_targets` do not have this failure mode, and are the preferred version. * [UnitTests] Added helper functions to tvm.testing. - tvm.testing.parameter() defines a parameter that can be passed to tests. Tests that accept more than one parameter are run for all combinations of parameter values. - tvm.testing.parameters() defines multiple sets of parameter values. Tests that accept more than one parameter are run once for each set of parameter values. - tvm.testing.fixture() is a decorator that defines setup code. The `cache=True` argument can be passed to avoid repeating expensive setup across multiple tests. * [UnitTests] Bugfix for auto parametrizing of "target" Previously, if the @parametrize_targets were present, but had other @pytest.mark.parametrize after it, "target" would get parametrized a second time. Now, it checks more than just the closest "parametrize" marker. * [UnitTests] Renamed "cache" argument of tvm.testing.fixture to "cache_return_value" * [UnitTests] Minor updates to parametrized test implementation. As recommended by @tkonolige: - Avoid infinite loop if LLVM target isn't enabled - Update documentation for preferred use cases of tvm.testing.parametrize_targets, and recommended alternatives. * [UnitTests] Minor updates to parametrized test implementation - Documentation, removed previous example usage of tvm.testing.parametrize_targets * [UnitTests] Changed accidental use of pytest fixtures to a NameError. - Previously, a fixture function defined in a module was accessible through the global scope, and the function definition is accessible if a test function uses that name but fails to declare the fixture as a parameter. Now, it will result in a NameError instead. * [UnitTests] More careful removal of fixture functions from module global scope. - Initial implementation only checked hasattr(obj, "_pytestfixturefunction") before removing obj, which gave false positives for objects that implement __getattr__, such as caffe.layers. Now, also check that the value contained is a FixtureFunctionMarker. * [UnitTests] Copy cached values when using tvm.testing.fixture(cache_return_value=True) To avoid unit tests being able to influence each other through a shared cache, all cached fixtures are passed through copy.deepcopy prior to use. * [UnitTests] Added meta-tests for tvm.testing functionality Co-authored-by: Eric Lunderberg --- conftest.py | 20 +- python/tvm/testing.py | 600 ++++++++++++++++-- tests/python/topi/python/test_topi_relu.py | 77 +-- .../unittest/test_tvm_testing_features.py | 149 +++++ 4 files changed, 760 insertions(+), 86 deletions(-) create mode 100644 tests/python/unittest/test_tvm_testing_features.py diff --git a/conftest.py b/conftest.py index 133a8322ea5d..f591fe970de8 100644 --- a/conftest.py +++ b/conftest.py @@ -14,15 +14,33 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm.testing +import pytest from pytest import ExitCode +import tvm +import tvm.testing + def pytest_configure(config): print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets()))) print("pytest marker:", config.option.markexpr) +@pytest.fixture +def dev(target): + return tvm.device(target) + + +def pytest_generate_tests(metafunc): + tvm.testing._auto_parametrize_target(metafunc) + tvm.testing._parametrize_correlated_parameters(metafunc) + + +def pytest_collection_modifyitems(config, items): + tvm.testing._count_num_fixture_uses(items) + tvm.testing._remove_global_fixture_definitions(items) + + def pytest_sessionfinish(session, exitstatus): # Don't exit with an error if we select a subset of tests that doesn't # include anything diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 1c3346169c6b..8178b0a14b29 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -54,11 +54,16 @@ def test_something(): function in this module. Then targets using this node should be added to the `TVM_TEST_TARGETS` environment variable in the CI. """ +import collections +import copy +import functools import logging import os import sys import time +import pickle import pytest +import _pytest import numpy as np import tvm import tvm.arith @@ -66,6 +71,7 @@ def test_something(): import tvm.te import tvm._ffi from tvm.contrib import nvcc +from tvm.error import TVMError def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): @@ -366,24 +372,44 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): ) -def _get_targets(): - target_str = os.environ.get("TVM_TEST_TARGETS", "") +def _get_targets(target_str=None): + if target_str is None: + target_str = os.environ.get("TVM_TEST_TARGETS", "") + if len(target_str) == 0: target_str = DEFAULT_TEST_TARGETS - targets = set() - for dev in target_str.split(";"): - if len(dev) == 0: - continue - target_kind = dev.split()[0] - if tvm.runtime.enabled(target_kind) and tvm.device(target_kind, 0).exist: - targets.add(dev) - if len(targets) == 0: - logging.warning( + + target_names = set(t.strip() for t in target_str.split(";") if t.strip()) + + targets = [] + for target in target_names: + target_kind = target.split()[0] + is_enabled = tvm.runtime.enabled(target_kind) + is_runnable = is_enabled and tvm.device(target_kind).exist + targets.append( + { + "target": target, + "target_kind": target_kind, + "is_enabled": is_enabled, + "is_runnable": is_runnable, + } + ) + + if all(not t["is_runnable"] for t in targets): + if tvm.runtime.enabled("llvm"): + logging.warning( + "None of the following targets are supported by this build of TVM: %s." + " Try setting TVM_TEST_TARGETS to a supported target. Defaulting to llvm.", + target_str, + ) + return _get_targets("llvm") + + raise TVMError( "None of the following targets are supported by this build of TVM: %s." - " Try setting TVM_TEST_TARGETS to a supported target. Defaulting to llvm.", - target_str, + " Try setting TVM_TEST_TARGETS to a supported target." + " Cannot default to llvm, as it is not enabled." % target_str ) - return {"llvm"} + return targets @@ -425,22 +451,22 @@ def device_enabled(target): nodes and `target="llvm"` on cpu test nodes. """ assert isinstance(target, str), "device_enabled requires a target as a string" - target_kind = target.split(" ")[ - 0 - ] # only check if device name is found, sometime there are extra flags - return any([target_kind in test_target for test_target in _get_targets()]) + # only check if device name is found, sometime there are extra flags + target_kind = target.split(" ")[0] + return any(target_kind == t["target_kind"] for t in _get_targets() if t["is_runnable"]) def enabled_targets(): - """Get all enabled targets with associated contexts. + """Get all enabled targets with associated devices. In most cases, you should use :py:func:`tvm.testing.parametrize_targets` instead of this function. - In this context, enabled means that TVM was built with support for this - target and the target name appears in the TVM_TEST_TARGETS environment - variable. If TVM_TEST_TARGETS is not set, it defaults to variable - DEFAULT_TEST_TARGETS in this module. + In this context, enabled means that TVM was built with support for + this target, the target name appears in the TVM_TEST_TARGETS + environment variable, and a suitable device for running this + target exists. If TVM_TEST_TARGETS is not set, it defaults to + variable DEFAULT_TEST_TARGETS in this module. If you use this function in a test, you **must** decorate the test with :py:func:`tvm.testing.uses_gpu` (otherwise it will never be run on the gpu). @@ -449,8 +475,9 @@ def enabled_targets(): ------- targets: list A list of pairs of all enabled devices and the associated context + """ - return [(tgt, tvm.device(tgt)) for tgt in _get_targets()] + return [(t["target"], tvm.device(t["target"])) for t in _get_targets() if t["is_runnable"]] def _compose(args, decs): @@ -521,6 +548,26 @@ def requires_cuda(*args): return _compose(args, _requires_cuda) +def requires_nvptx(*args): + """Mark a test as requiring the NVPTX compilation on the CUDA runtime + + This also marks the test as requiring a cuda gpu, and requiring + LLVM support. + + Parameters + ---------- + f : function + Function to mark + + """ + _requires_nvptx = [ + pytest.mark.skipif(not device_enabled("nvptx"), reason="NVPTX support not enabled"), + *requires_llvm(), + *requires_gpu(), + ] + return _compose(args, _requires_nvptx) + + def requires_cudagraph(*args): """Mark a test as requiring the CUDA Graph Feature @@ -691,7 +738,7 @@ def _target_to_requirement(target): if target.startswith("vulkan"): return requires_vulkan() if target.startswith("nvptx"): - return [*requires_llvm(), *requires_gpu()] + return requires_nvptx() if target.startswith("metal"): return requires_metal() if target.startswith("opencl"): @@ -701,11 +748,90 @@ def _target_to_requirement(target): return [] -def parametrize_targets(*args): - """Parametrize a test over all enabled targets. +def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): + # Include unrunnable targets here. They get skipped by the + # pytest.mark.skipif in _target_to_requirement(), showing up as + # skipped tests instead of being hidden entirely. + if targets is None: + if excluded_targets is None: + excluded_targets = set() + + if xfail_targets is None: + xfail_targets = set() + + target_marks = [] + for t in _get_targets(): + # Excluded targets aren't included in the params at all. + if t["target_kind"] not in excluded_targets: + + # Known failing targets are included, but are marked + # as expected to fail. + extra_marks = [] + if t["target_kind"] in xfail_targets: + extra_marks.append( + pytest.mark.xfail( + reason='Known failing test for target "{}"'.format(t["target_kind"]) + ) + ) + + target_marks.append((t["target"], extra_marks)) + + else: + target_marks = [(target, []) for target in targets] + + return [ + pytest.param(target, marks=_target_to_requirement(target) + extra_marks) + for target, extra_marks in target_marks + ] + + +def _auto_parametrize_target(metafunc): + """Automatically applies parametrize_targets + + Used if a test function uses the "target" fixture, but isn't + already marked with @tvm.testing.parametrize_targets. Intended + for use in the pytest_generate_tests() handler of a conftest.py + file. + + """ + if "target" in metafunc.fixturenames: + parametrized_args = [ + arg.strip() + for mark in metafunc.definition.iter_markers("parametrize") + for arg in mark.args[0].split(",") + ] + + if "target" not in parametrized_args: + # Check if the function is marked with either excluded or + # known failing targets. + excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", []) + xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", []) + metafunc.parametrize( + "target", + _pytest_target_params(None, excluded_targets, xfail_targets), + scope="session", + ) + - Use this decorator when you want your test to be run over a variety of - targets and devices (including cpu and gpu devices). +def parametrize_targets(*args): + """Parametrize a test over a specific set of targets. + + Use this decorator when you want your test to be run over a + specific set of targets and devices. It is intended for use where + a test is applicable only to a specific target, and is + inapplicable to any others (e.g. verifying target-specific + assembly code matches known assembly code). In most + circumstances, :py:func:`tvm.testing.exclude_targets` or + :py:func:`tvm.testing.known_failing_targets` should be used + instead. + + If used as a decorator without arguments, the test will be + parametrized over all targets in + :py:func:`tvm.testing.enabled_targets`. This behavior is + automatically enabled for any target that accepts arguments of + ``target`` or ``dev``, so the explicit use of the bare decorator + is no longer needed, and is maintained for backwards + compatibility. Parameters ---------- @@ -718,31 +844,421 @@ def parametrize_targets(*args): Example ------- - >>> @tvm.testing.parametrize + >>> @tvm.testing.parametrize_targets("llvm", "cuda") + >>> def test_mytest(target, dev): + >>> ... # do something + """ + + def wrap(targets): + def func(f): + return pytest.mark.parametrize( + "target", _pytest_target_params(targets), scope="session" + )(f) + + return func + + if len(args) == 1 and callable(args[0]): + return wrap(None)(args[0]) + return wrap(args) + + +def exclude_targets(*args): + """Exclude a test from running on a particular target. + + Use this decorator when you want your test to be run over a + variety of targets and devices (including cpu and gpu devices), + but want to exclude some particular target or targets. For + example, a test may wish to be run against all targets in + tvm.testing.enabled_targets(), except for a particular target that + does not support the capabilities. + + Applies pytest.mark.skipif to the targets given. + + Parameters + ---------- + f : function + Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:, + where `xxxxxxxxx` is any name. + targets : list[str] + Set of targets to exclude. + + Example + ------- + >>> @tvm.testing.exclude_targets("cuda") >>> def test_mytest(target, dev): >>> ... # do something Or - >>> @tvm.testing.parametrize("llvm", "cuda") + >>> @tvm.testing.exclude_targets("llvm", "cuda") >>> def test_mytest(target, dev): >>> ... # do something + """ - def wrap(targets): - def func(f): - params = [ - pytest.param(target, tvm.device(target, 0), marks=_target_to_requirement(target)) - for target in targets - ] - return pytest.mark.parametrize("target,dev", params)(f) + def wraps(func): + func.tvm_excluded_targets = args + return func + + return wraps + + +def known_failing_targets(*args): + """Skip a test that is known to fail on a particular target. + + Use this decorator when you want your test to be run over a + variety of targets and devices (including cpu and gpu devices), + but know that it fails for some targets. For example, a newly + implemented runtime may not support all features being tested, and + should be excluded. + + Applies pytest.mark.xfail to the targets given. + Parameters + ---------- + f : function + Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:, + where `xxxxxxxxx` is any name. + targets : list[str] + Set of targets to skip. + + Example + ------- + >>> @tvm.testing.known_failing_targets("cuda") + >>> def test_mytest(target, dev): + >>> ... # do something + + Or + + >>> @tvm.testing.known_failing_targets("llvm", "cuda") + >>> def test_mytest(target, dev): + >>> ... # do something + + """ + + def wraps(func): + func.tvm_known_failing_targets = args return func - if len(args) == 1 and callable(args[0]): - targets = [t for t, _ in enabled_targets()] - return wrap(targets)(args[0]) - return wrap(args) + return wraps + + +def parameter(*values, ids=None): + """Convenience function to define pytest parametrized fixtures. + + Declaring a variable using ``tvm.testing.parameter`` will define a + parametrized pytest fixture that can be used by test + functions. This is intended for cases that have no setup cost, + such as strings, integers, tuples, etc. For cases that have a + significant setup cost, please use :py:func:`tvm.testing.fixture` + instead. + + If a test function accepts multiple parameters defined using + ``tvm.testing.parameter``, then the test will be run using every + combination of those parameters. + + The parameter definition applies to all tests in a module. If a + specific test should have different values for the parameter, that + test should be marked with ``@pytest.mark.parametrize``. + + Parameters + ---------- + values + A list of parameter values. A unit test that accepts this + parameter as an argument will be run once for each parameter + given. + + ids : List[str], optional + A list of names for the parameters. If None, pytest will + generate a name from the value. These generated names may not + be readable/useful for composite types such as tuples. + + Returns + ------- + function + A function output from pytest.fixture. + + Example + ------- + >>> size = tvm.testing.parameter(1, 10, 100) + >>> def test_using_size(size): + >>> ... # Test code here + + Or + + >>> shape = tvm.testing.parameter((5,10), (512,1024), ids=['small','large']) + >>> def test_using_size(shape): + >>> ... # Test code here + + """ + + # Optional cls parameter in case a parameter is defined inside a + # class scope. + @pytest.fixture(params=values, ids=ids) + def as_fixture(*_cls, request): + return request.param + + return as_fixture + + +_parametrize_group = 0 + + +def parameters(*value_sets): + """Convenience function to define pytest parametrized fixtures. + + Declaring a variable using tvm.testing.parameters will define a + parametrized pytest fixture that can be used by test + functions. Like :py:func:`tvm.testing.parameter`, this is intended + for cases that have no setup cost, such as strings, integers, + tuples, etc. For cases that have a significant setup cost, please + use :py:func:`tvm.testing.fixture` instead. + + Unlike :py:func:`tvm.testing.parameter`, if a test function + accepts multiple parameters defined using a single call to + ``tvm.testing.parameters``, then the test will only be run once + for each set of parameters, not for all combinations of + parameters. + + These parameter definitions apply to all tests in a module. If a + specific test should have different values for some parameters, + that test should be marked with ``@pytest.mark.parametrize``. + + Parameters + ---------- + values : List[tuple] + A list of parameter value sets. Each set of values represents + a single combination of values to be tested. A unit test that + accepts parameters defined will be run once for every set of + parameters in the list. + + Returns + ------- + List[function] + Function outputs from pytest.fixture. These should be unpacked + into individual named parameters. + + Example + ------- + >>> size, dtype = tvm.testing.parameters( (16,'float32'), (512,'float16') ) + >>> def test_feature_x(size, dtype): + >>> # Test code here + >>> assert( (size,dtype) in [(16,'float32'), (512,'float16')]) + + """ + global _parametrize_group + parametrize_group = _parametrize_group + _parametrize_group += 1 + + outputs = [] + for param_values in zip(*value_sets): + + # Optional cls parameter in case a parameter is defined inside a + # class scope. + def fixture_func(*_cls, request): + return request.param + + fixture_func.parametrize_group = parametrize_group + fixture_func.parametrize_values = param_values + outputs.append(pytest.fixture(fixture_func)) + + return outputs + + +def _parametrize_correlated_parameters(metafunc): + parametrize_needed = collections.defaultdict(list) + + for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items(): + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "parametrize_group") and hasattr( + fixturedef.func, "parametrize_values" + ): + group = fixturedef.func.parametrize_group + values = fixturedef.func.parametrize_values + parametrize_needed[group].append((name, values)) + + for parametrize_group in parametrize_needed.values(): + if len(parametrize_group) == 1: + name, values = parametrize_group[0] + metafunc.parametrize(name, values, indirect=True) + else: + names = ",".join(name for name, values in parametrize_group) + value_sets = zip(*[values for name, values in parametrize_group]) + metafunc.parametrize(names, value_sets, indirect=True) + + +def fixture(func=None, *, cache_return_value=False): + """Convenience function to define pytest fixtures. + + This should be used as a decorator to mark functions that set up + state before a function. The return value of that fixture + function is then accessible by test functions as that accept it as + a parameter. + + Fixture functions can accept parameters defined with + :py:func:`tvm.testing.parameter`. + + By default, the setup will be performed once for each unit test + that uses a fixture, to ensure that unit tests are independent. + If the setup is expensive to perform, then the + cache_return_value=True argument can be passed to cache the setup. + The fixture function will be run only once (or once per parameter, + if used with tvm.testing.parameter), and the same return value + will be passed to all tests that use it. If the environment + variable TVM_TEST_DISABLE_CACHE is set to a non-zero value, it + will disable this feature and no caching will be performed. + + Example + ------- + >>> @tvm.testing.fixture + >>> def cheap_setup(): + >>> return 5 # Setup code here. + >>> + >>> def test_feature_x(target, dev, cheap_setup) + >>> assert(cheap_setup == 5) # Run test here + + Or + + >>> size = tvm.testing.parameter(1, 10, 100) + >>> + >>> @tvm.testing.fixture + >>> def cheap_setup(size): + >>> return 5*size # Setup code here, based on size. + >>> + >>> def test_feature_x(cheap_setup): + >>> assert(cheap_setup in [5, 50, 500]) + + Or + + >>> @tvm.testing.fixture(cache_return_value=True) + >>> def expensive_setup(): + >>> time.sleep(10) # Setup code here + >>> return 5 + >>> + >>> def test_feature_x(target, dev, expensive_setup): + >>> assert(expensive_setup == 5) + + """ + + force_disable_cache = bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))) + cache_return_value = cache_return_value and not force_disable_cache + + # Deliberately at function scope, so that caching can track how + # many times the fixture has been used. If used, the cache gets + # cleared after the fixture is no longer needed. + scope = "function" + + def wraps(func): + if cache_return_value: + func = _fixture_cache(func) + func = pytest.fixture(func, scope=scope) + return func + + if func is None: + return wraps + + return wraps(func) + + +def _fixture_cache(func): + cache = {} + + # Can't use += on a bound method's property. Therefore, this is a + # list rather than a variable so that it can be accessed from the + # pytest_collection_modifyitems(). + num_uses_remaining = [0] + + # Using functools.lru_cache would require the function arguments + # to be hashable, which wouldn't allow caching fixtures that + # depend on numpy arrays. For example, a fixture that takes a + # numpy array as input, then calculates uses a slow method to + # compute a known correct output for that input. Therefore, + # including a fallback for serializable types. + def get_cache_key(*args, **kwargs): + try: + hash((args, kwargs)) + return (args, kwargs) + except TypeError as e: + pass + + try: + return pickle.dumps((args, kwargs)) + except TypeError as e: + raise TypeError( + "TVM caching of fixtures requires arguments to the fixture " + "to be either hashable or serializable" + ) from e + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + cache_key = get_cache_key(*args, **kwargs) + + try: + cached_value = cache[cache_key] + except KeyError: + cached_value = cache[cache_key] = func(*args, **kwargs) + + try: + yield copy.deepcopy(cached_value) + except TypeError as e: + rfc_url = ( + "https://github.com/apache/tvm-rfcs/blob/main/rfcs/" + "0007-parametrized-unit-tests.md#unresolved-questions" + ) + message = ( + "TVM caching of fixtures can only be used on serializable data types, not {}.\n" + "Please see {} for details/discussion." + ).format(type(cached_value), rfc_url) + raise TypeError(message) from e + + finally: + # Clear the cache once all tests that use a particular fixture + # have completed. + num_uses_remaining[0] -= 1 + if not num_uses_remaining[0]: + cache.clear() + + # Set in the pytest_collection_modifyitems() + wrapper.num_uses_remaining = num_uses_remaining + + return wrapper + + +def _count_num_fixture_uses(items): + # Helper function, counts the number of tests that use each cached + # fixture. Should be called from pytest_collection_modifyitems(). + for item in items: + is_skipped = item.get_closest_marker("skip") or any( + mark.args[0] for mark in item.iter_markers("skipif") + ) + if is_skipped: + continue + + for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): + # Only increment the active fixturedef, in a name has been overridden. + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "num_uses_remaining"): + fixturedef.func.num_uses_remaining[0] += 1 + + +def _remove_global_fixture_definitions(items): + # Helper function, removes fixture definitions from the global + # variables of the modules they were defined in. This is intended + # to improve readability of error messages by giving a NameError + # if a test function accesses a pytest fixture but doesn't include + # it as an argument. Should be called from + # pytest_collection_modifyitems(). + + modules = set(item.module for item in items) + + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if hasattr(obj, "_pytestfixturefunction") and isinstance( + obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker + ): + delattr(module, name) def identity_after(x, sleep): diff --git a/tests/python/topi/python/test_topi_relu.py b/tests/python/topi/python/test_topi_relu.py index 3dc6e7de8069..83007e16f81d 100644 --- a/tests/python/topi/python/test_topi_relu.py +++ b/tests/python/topi/python/test_topi_relu.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test code for relu activation""" +import sys import os import numpy as np import tvm @@ -24,36 +25,43 @@ from tvm.topi.utils import get_const_tuple from tvm.contrib.nvcc import have_fp16 +import pytest import tvm.testing -def verify_relu(m, n, dtype="float32"): +m, n, dtype = tvm.testing.parameters( + (10, 128, "float32"), + (128, 64, "float16"), + (1024 * 100, 512, "float32"), +) + + +def test_relu(target, dev, m, n, dtype): A = te.placeholder((m, n), name="A", dtype=dtype) B = topi.nn.relu(A) a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype) b_np = a_np * (a_np > 0) - def check_target(target, dev): - if dtype == "float16" and target == "cuda" and not have_fp16(tvm.cuda(0).compute_version): - print("Skip because %s does not have fp16 support" % target) - return - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_elemwise_schedule(target)(B) + if dtype == "float16" and target == "cuda" and not have_fp16(tvm.cuda(0).compute_version): + pytest.skip("Skip because %s does not have fp16 support" % target) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - foo = tvm.build(s, [A, B], target, name="relu") - foo(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + print("Running on target: %s" % target) + with tvm.target.Target(target): + s = tvm.topi.testing.get_elemwise_schedule(target)(B) - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + foo = tvm.build(s, [A, B], target, name="relu") + foo(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) -def verify_leaky_relu(m, alpha): - A = te.placeholder((m,), name="A") +size, alpha = tvm.testing.parameters((100, 0.1)) + + +def test_leaky_relu(size, alpha): + A = te.placeholder((size,), name="A") B = topi.nn.leaky_relu(A, alpha) s = te.create_schedule([B.op]) @@ -67,7 +75,14 @@ def verify_leaky_relu(m, alpha): tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) -def verify_prelu(x, w, axis, weight_reshape): +x, w, axis, weight_reshape = tvm.testing.parameters( + ((1, 3, 2, 2), (3,), 1, (3, 1, 1)), + ((1, 3, 2, 2), (2,), 2, (2, 1)), + ((1, 3), (3,), 1, (3,)), +) + + +def test_prelu(x, w, axis, weight_reshape): X = te.placeholder((x), name="X") W = te.placeholder((w), name="W") x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype) @@ -90,29 +105,5 @@ def _prelu_numpy(x, W): tvm.testing.assert_allclose(b.numpy(), out_np, rtol=1e-5) -@tvm.testing.uses_gpu -def test_relu(): - verify_relu(10, 128, "float32") - verify_relu(128, 64, "float16") - - -@tvm.testing.uses_gpu -def test_schedule_big_array(): - verify_relu(1024 * 100, 512) - - -def test_leaky_relu(): - verify_leaky_relu(100, 0.1) - - -def test_prelu(): - verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1)) - verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1)) - verify_prelu((1, 3), (3,), 1, (3,)) - - if __name__ == "__main__": - test_schedule_big_array() - test_relu() - test_leaky_relu() - test_prelu() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py new file mode 100644 index 000000000000..1a7595aac5c7 --- /dev/null +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys + +import pytest + +import tvm.testing + +# This file tests features in tvm.testing, such as verifying that +# cached fixtures are run an appropriate number of times. As a +# result, the order of the tests is important. Use of --last-failed +# or --failed-first while debugging this file is not advised. + + +class TestTargetAutoParametrization: + targets_used = [] + devices_used = [] + enabled_targets = [target for target, dev in tvm.testing.enabled_targets()] + enabled_devices = [dev for target, dev in tvm.testing.enabled_targets()] + + def test_target_parametrization(self, target): + assert target in self.enabled_targets + self.targets_used.append(target) + + def test_device_parametrization(self, dev): + assert dev in self.enabled_devices + self.devices_used.append(dev) + + def test_all_targets_used(self): + assert self.targets_used == self.enabled_targets + assert self.devices_used == self.enabled_devices + + targets_with_explicit_list = [] + + @tvm.testing.parametrize_targets("llvm") + def test_explicit_list(self, target): + assert target == "llvm" + self.targets_with_explicit_list.append(target) + + def test_no_repeats_in_explicit_list(self): + assert self.targets_with_explicit_list == ["llvm"] + + targets_with_exclusion = [] + + @tvm.testing.exclude_targets("llvm") + def test_exclude_target(self, target): + assert "llvm" not in target + self.targets_with_exclusion.append(target) + + def test_all_nonexcluded_targets_ran(self): + assert self.targets_with_exclusion == [ + target for target in self.enabled_targets if not target.startswith("llvm") + ] + + run_targets_with_known_failure = [] + + @tvm.testing.known_failing_targets("llvm") + def test_known_failing_target(self, target): + # This test runs for all targets, but intentionally fails for + # llvm. The behavior is working correctly if this test shows + # up as an expected failure, xfail. + self.run_targets_with_known_failure.append(target) + assert "llvm" not in target + + def test_all_targets_ran(self): + assert self.run_targets_with_known_failure == self.enabled_targets + + +class TestJointParameter: + param1_vals = [1, 2, 3] + param2_vals = ["a", "b", "c"] + + independent_usages = 0 + param1 = tvm.testing.parameter(*param1_vals) + param2 = tvm.testing.parameter(*param2_vals) + + joint_usages = 0 + joint_param_vals = list(zip(param1_vals, param2_vals)) + joint_param1, joint_param2 = tvm.testing.parameters(*joint_param_vals) + + def test_using_independent(self, param1, param2): + type(self).independent_usages += 1 + + def test_independent(self): + assert self.independent_usages == len(self.param1_vals) * len(self.param2_vals) + + def test_using_joint(self, joint_param1, joint_param2): + type(self).joint_usages += 1 + assert (joint_param1, joint_param2) in self.joint_param_vals + + def test_joint(self): + assert self.joint_usages == len(self.joint_param_vals) + + +class TestFixtureCaching: + param1_vals = [1, 2, 3] + param2_vals = ["a", "b", "c"] + + param1 = tvm.testing.parameter(*param1_vals) + param2 = tvm.testing.parameter(*param2_vals) + + uncached_calls = 0 + cached_calls = 0 + + @tvm.testing.fixture + def uncached_fixture(self, param1): + type(self).uncached_calls += 1 + return 2 * param1 + + def test_use_uncached(self, param1, param2, uncached_fixture): + assert 2 * param1 == uncached_fixture + + def test_uncached_count(self): + assert self.uncached_calls == len(self.param1_vals) * len(self.param2_vals) + + @tvm.testing.fixture(cache_return_value=True) + def cached_fixture(self, param1): + type(self).cached_calls += 1 + return 3 * param1 + + def test_use_cached(self, param1, param2, cached_fixture): + assert 3 * param1 == cached_fixture + + def test_cached_count(self): + cache_disabled = bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))) + if cache_disabled: + assert self.cached_calls == len(self.param1_vals) * len(self.param2_vals) + else: + assert self.cached_calls == len(self.param1_vals) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv))