Skip to content

Commit

Permalink
Add non_device_test pytest mark to filter out non-device tests (#29213
Browse files Browse the repository at this point in the history
)

* add conftest

* fix

* remove deselected
fxmarty authored Feb 26, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 93f8617 commit 2a7746c
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -21,10 +21,49 @@
from os.path import abspath, dirname, join

import _pytest
import pytest

from transformers.testing_utils import HfDoctestModule, HfDocTestParser


NOT_DEVICE_TESTS = {
"test_tokenization",
"test_processor",
"test_processing",
"test_feature_extraction",
"test_image_processing",
"test_image_processor",
"test_retrieval",
"test_config",
"test_from_pretrained_no_checkpoint",
"test_keep_in_fp32_modules",
"test_gradient_checkpointing_backward_compatibility",
"test_gradient_checkpointing_enable_disable",
"test_save_load_fast_init_from_base",
"test_fast_init_context_manager",
"test_fast_init_tied_embeddings",
"test_save_load_fast_init_to_base",
"test_torch_save_load",
"test_initialization",
"test_forward_signature",
"test_model_common_attributes",
"test_model_main_input_name",
"test_correct_missing_keys",
"test_tie_model_weights",
"test_can_use_safetensors",
"test_load_save_without_tied_weights",
"test_tied_weights_keys",
"test_model_weights_reload_no_missing_tied_weights",
"test_pt_tf_model_equivalence",
"test_mismatched_shapes_have_properly_initialized_weights",
"test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist",
"test_model_is_small",
"test_tf_from_pt_safetensors",
"test_flax_from_pt_safetensors",
"ModelTest::test_pipeline_", # None of the pipeline tests from PipelineTesterMixin (of which XxxModelTest inherits from) are running on device
"ModelTester::test_pipeline_",
}

# allow having multiple repository checkouts and not needing to remember to rerun
# `pip install -e '.[dev]'` when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(__file__), "src"))
@@ -46,6 +85,13 @@ def pytest_configure(config):
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
config.addinivalue_line("markers", "tool_tests: mark the tool tests that are run on their specific schedule")
config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")


def pytest_collection_modifyitems(items):
for item in items:
if any(test_name in item.nodeid for test_name in NOT_DEVICE_TESTS):
item.add_marker(pytest.mark.not_device_test)


def pytest_addoption(parser):

0 comments on commit 2a7746c

Please sign in to comment.