diff --git a/docs/changelog.rst b/docs/changelog.rst index 990ec51c..5e31413a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -9,6 +9,10 @@ Compatibility * Added official support for Python 3.13. +* Added ``using`` argument to :fixture:`django_assert_num_queries` and + :fixture:`django_assert_max_num_queries` to easily specify the database + alias to use. + Bugfixes ^^^^^^^^ diff --git a/docs/helpers.rst b/docs/helpers.rst index e5e4ed36..c9e189dd 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -423,11 +423,12 @@ Example ``django_assert_num_queries`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. py:function:: django_assert_num_queries(num, connection=None, info=None) +.. py:function:: django_assert_num_queries(num, connection=None, info=None, *, using=None) :param num: expected number of queries - :param connection: optional non-default DB connection + :param connection: optional database connection :param str info: optional info message to display on failure + :param str using: optional database alias This fixture allows to check for an expected number of DB queries. @@ -462,11 +463,12 @@ If you use type annotations, you can annotate the fixture like this:: ``django_assert_max_num_queries`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. py:function:: django_assert_max_num_queries(num, connection=None, info=None) +.. py:function:: django_assert_max_num_queries(num, connection=None, info=None, *, using=None) :param num: expected maximum number of queries - :param connection: optional non-default DB connection + :param connection: optional database connection :param str info: optional info message to display on failure + :param str using: optional database alias This fixture allows to check for an expected maximum number of DB queries. diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 1daab118..39c84b60 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -606,6 +606,8 @@ def __call__( num: int, connection: Any | None = ..., info: str | None = ..., + *, + using: str | None = ..., ) -> django.test.utils.CaptureQueriesContext: pass # pragma: no cover @@ -617,13 +619,21 @@ def _assert_num_queries( exact: bool = True, connection: Any | None = None, info: str | None = None, + *, + using: str | None = None, ) -> Generator[django.test.utils.CaptureQueriesContext, None, None]: + from django.db import connection as default_conn, connections from django.test.utils import CaptureQueriesContext - if connection is None: - from django.db import connection as conn - else: + if connection and using: + raise ValueError('The "connection" and "using" parameter cannot be used together') + + if connection is not None: conn = connection + elif using is not None: + conn = connections[using] + else: + conn = default_conn verbose = config.getoption("verbose") > 0 with CaptureQueriesContext(conn) as context: diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index dd695136..39c6666f 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -15,6 +15,7 @@ from django.core import mail from django.db import connection, transaction from django.test import AsyncClient, AsyncRequestFactory, Client, RequestFactory +from django.utils.connection import ConnectionDoesNotExist from django.utils.encoding import force_str from .helpers import DjangoPytester @@ -206,6 +207,28 @@ def test_django_assert_num_queries_db_connection( pass +@pytest.mark.django_db +def test_django_assert_num_queries_db_using( + django_assert_num_queries: DjangoAssertNumQueries, +) -> None: + from django.db import connection + + with django_assert_num_queries(1, using="default"): + Item.objects.create(name="foo") + + error_message = 'The "connection" and "using" parameter cannot be used together' + with pytest.raises(ValueError, match=error_message): + with django_assert_num_queries(1, connection=connection, using="default"): + Item.objects.create(name="foo") + + with django_assert_num_queries(1, using=None): + Item.objects.create(name="foo") + + with pytest.raises(ConnectionDoesNotExist): + with django_assert_num_queries(1, using="bad_db_name"): + pass + + @pytest.mark.django_db def test_django_assert_num_queries_output_info(django_pytester: DjangoPytester) -> None: django_pytester.create_test_module(