diff --git a/TESTING.rst b/TESTING.rst index 6836ec701..26b95eec7 100644 --- a/TESTING.rst +++ b/TESTING.rst @@ -255,7 +255,7 @@ add them in ``tests/charts``. .. code-block:: python - class TestBaseChartTest(unittest.TestCase): + class TestBaseChartTest: ... To render the chart create a YAML string with the nested dictionary of options you wish to test. You can then @@ -277,7 +277,7 @@ Example test here: """ - class TestGitSyncScheduler(unittest.TestCase): + class TestGitSyncScheduler: def test_basic(self): helm_settings = yaml.safe_load(git_sync_basic) res = render_chart( diff --git a/airflow/www/forms.py b/airflow/www/forms.py index c46b0c85c..8b5872095 100644 --- a/airflow/www/forms.py +++ b/airflow/www/forms.py @@ -180,6 +180,7 @@ class ConnectionForm(DynamicForm): conn_id = StringField( lazy_gettext('Connection Id'), validators=[InputRequired()], widget=BS3TextFieldWidget() ) + # conn_type is added later via lazy_add_provider_discovered_options_to_connection_form description = StringField(lazy_gettext('Description'), widget=BS3TextAreaFieldWidget()) host = StringField(lazy_gettext('Host'), widget=BS3TextFieldWidget()) schema = StringField(lazy_gettext('Schema'), widget=BS3TextFieldWidget()) diff --git a/airflow/www/package.json b/airflow/www/package.json index 2a24dda20..b4e42e004 100644 --- a/airflow/www/package.json +++ b/airflow/www/package.json @@ -117,5 +117,8 @@ "redoc": "^2.0.0-rc.72", "type-fest": "^2.17.0", "url-search-params-polyfill": "^8.1.0" + }, + "resolutions": { + "d3-color": "^3.1.0" } } diff --git a/airflow/www/views.py b/airflow/www/views.py index cca0da624..5203594dc 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -34,7 +34,7 @@ from json import JSONDecodeError from operator import itemgetter from typing import Any, Callable -from urllib.parse import parse_qsl, unquote, urlencode, urlparse +from urllib.parse import unquote, urljoin, urlsplit import configupdater import flask.json @@ -155,27 +155,21 @@ def truncate_task_duration(task_duration): def get_safe_url(url): """Given a user-supplied URL, ensure it points to our web server""" - valid_schemes = ['http', 'https', ''] - valid_netlocs = [request.host, ''] - if not url: return url_for('Airflow.index') - parsed = urlparse(url) - # If the url contains semicolon, redirect it to homepage to avoid # potential XSS. (Similar to https://github.com/python/cpython/pull/24297/files (bpo-42967)) if ';' in unquote(url): return url_for('Airflow.index') - query = parse_qsl(parsed.query, keep_blank_values=True) - - url = parsed._replace(query=urlencode(query)).geturl() - - if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs: - return url + host_url = urlsplit(request.host_url) + redirect_url = urlsplit(urljoin(request.host_url, url)) + if not (redirect_url.scheme in ("http", "https") and host_url.netloc == redirect_url.netloc): + return url_for('Airflow.index') - return url_for('Airflow.index') + # This will ensure we only redirect to the right scheme/netloc + return redirect_url.geturl() def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): @@ -4229,14 +4223,15 @@ def process_form(self, form, is_created): flash( Markup( "

The Extra connection field contained an invalid value for Conn ID: " - f"{conn_id}.

" + "{conn_id}.

" "

If connection parameters need to be added to Extra, " "please make sure they are in the form of a single, valid JSON object.


" "The following Extra parameters were not added to the connection:
" - f"{extra_json}", - ), + "{extra_json}" + ).format(conn_id=conn_id, extra_json=extra_json), category="error", ) + del form.extra del extra_json for key in self.extra_fields: diff --git a/airflow/www/yarn.lock b/airflow/www/yarn.lock index 755113d08..e2d8bd629 100644 --- a/airflow/www/yarn.lock +++ b/airflow/www/yarn.lock @@ -4483,15 +4483,10 @@ d3-collection@1, d3-collection@^1.0.4: resolved "https://registry.yarnpkg.com/d3-collection/-/d3-collection-1.0.7.tgz#349bd2aa9977db071091c13144d5e4f16b5b310e" integrity sha512-ii0/r5f4sjKNTfh84Di+DpztYwqKhEyUlKoPrzUFfeSkWxjW49xU2QzO9qrPrNkpdI0XJkfzvmTu8V2Zylln6A== -d3-color@1: - version "1.4.1" - resolved "https://registry.yarnpkg.com/d3-color/-/d3-color-1.4.1.tgz#c52002bf8846ada4424d55d97982fef26eb3bc8a" - integrity sha512-p2sTHSLCJI2QKunbGb7ocOh7DgTAn8IrLx21QRc/BSnodXM4sv6aLQlnfpvehFMLZEfBc6g9pH9SWQccFYfJ9Q== - -"d3-color@1 - 2": - version "2.0.0" - resolved "https://registry.yarnpkg.com/d3-color/-/d3-color-2.0.0.tgz#8d625cab42ed9b8f601a1760a389f7ea9189d62e" - integrity sha512-SPXi0TSKPD4g9tw0NMZFnR95XVgUZiBH+uUTqQuDu1OsE2zomHU7ho0FISciaPvosimixwHFl3WHLGabv6dDgQ== +d3-color@1, "d3-color@1 - 2", d3-color@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/d3-color/-/d3-color-3.1.0.tgz#395b2833dfac71507f12ac2f7af23bf819de24e2" + integrity sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA== d3-contour@1: version "1.3.2" diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py index 343c07926..97468feca 100644 --- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py +++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py @@ -106,13 +106,14 @@ def __hash__(self): r"^airflow/api", ], FileGroupForCi.API_CODEGEN_FILES: [ - "^airflow/api_connexion/openapi/v1.yaml", - "^clients/gen", + r"^airflow/api_connexion/openapi/v1\.yaml", + r"^clients/gen", ], FileGroupForCi.HELM_FILES: [ - "^chart", - "^airflow/kubernetes", - "^tests/kubernetes", + r"^chart", + r"^airflow/kubernetes", + r"^tests/kubernetes", + r"^tests/charts", ], FileGroupForCi.SETUP_FILES: [ r"^pyproject.toml", diff --git a/tests/charts/test_airflow_common.py b/tests/charts/test_airflow_common.py index a8e566531..ce8bc7f88 100644 --- a/tests/charts/test_airflow_common.py +++ b/tests/charts/test_airflow_common.py @@ -18,7 +18,6 @@ import jmespath import pytest -from parameterized import parameterized from tests.charts.helm_template_generator import render_chart @@ -32,7 +31,8 @@ class TestAirflowCommon: as it requires extra test setup. """ - @parameterized.expand( + @pytest.mark.parametrize( + "dag_values, expected_mount", [ ( {"gitSync": {"enabled": True}}, @@ -70,7 +70,7 @@ class TestAirflowCommon: "readOnly": False, }, ), - ] + ], ) def test_dags_mount(self, dag_values, expected_mount): docs = render_chart( diff --git a/tests/charts/test_basic_helm_chart.py b/tests/charts/test_basic_helm_chart.py index c6800810b..2ec7997fa 100644 --- a/tests/charts/test_basic_helm_chart.py +++ b/tests/charts/test_basic_helm_chart.py @@ -16,21 +16,20 @@ # under the License. from __future__ import annotations -import unittest import warnings from subprocess import CalledProcessError from typing import Any from unittest import mock import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart OBJECT_COUNT_IN_BASIC_DEPLOYMENT = 35 -class TestBaseChartTest(unittest.TestCase): +class TestBaseChartTest: def _get_values_with_version(self, values, version): if version != "default": values["airflowVersion"] = version @@ -41,7 +40,7 @@ def _get_object_count(self, version): return OBJECT_COUNT_IN_BASIC_DEPLOYMENT + 1 return OBJECT_COUNT_IN_BASIC_DEPLOYMENT - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_basic_deployments(self, version): expected_object_count_in_basic_deployment = self._get_object_count(version) k8s_objects = render_chart( @@ -114,7 +113,7 @@ def test_basic_deployments(self, version): "test-label" ), f"Missing label test-label on {k8s_name}. Current labels: {labels}" - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_basic_deployment_with_standalone_dag_processor(self, version): # Dag Processor creates two extra objects compared to the basic deployment object_count_in_basic_deployment = self._get_object_count(version) @@ -192,7 +191,7 @@ def test_basic_deployment_with_standalone_dag_processor(self, version): "test-label" ), f"Missing label test-label on {k8s_name}. Current labels: {labels}" - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_basic_deployment_without_default_users(self, version): expected_object_count_in_basic_deployment = self._get_object_count(version) k8s_objects = render_chart( @@ -207,7 +206,7 @@ def test_basic_deployment_without_default_users(self, version): assert ('Job', 'test-basic-create-user') not in list_of_kind_names_tuples assert expected_object_count_in_basic_deployment - 2 == len(k8s_objects) - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_basic_deployment_without_statsd(self, version): expected_object_count_in_basic_deployment = self._get_object_count(version) k8s_objects = render_chart( @@ -462,7 +461,7 @@ def get_k8s_objs_with_image(obj: list[Any] | dict[str, Any]) -> list[dict[str, A assert "command" not in obj def test_unsupported_executor(self): - with self.assertRaises(CalledProcessError) as ex_ctx: + with pytest.raises(CalledProcessError) as ex_ctx: render_chart( "test-basic", { @@ -472,23 +471,15 @@ def test_unsupported_executor(self): assert ( 'executor must be one of the following: "LocalExecutor", ' '"LocalKubernetesExecutor", "CeleryExecutor", ' - '"KubernetesExecutor", "CeleryKubernetesExecutor"' in ex_ctx.exception.stderr.decode() + '"KubernetesExecutor", "CeleryKubernetesExecutor"' in ex_ctx.value.stderr.decode() ) - @parameterized.expand( - [ - ("airflow",), - ("pod_template",), - ("flower",), - ("statsd",), - ("redis",), - ("pgbouncer",), - ("pgbouncerExporter",), - ("gitSync",), - ] + @pytest.mark.parametrize( + "image", + ["airflow", "pod_template", "flower", "statsd", "redis", "pgbouncer", "pgbouncerExporter", "gitSync"], ) def test_invalid_pull_policy(self, image): - with self.assertRaises(CalledProcessError) as ex_ctx: + with pytest.raises(CalledProcessError) as ex_ctx: render_chart( "test-basic", { @@ -497,11 +488,11 @@ def test_invalid_pull_policy(self, image): ) assert ( 'pullPolicy must be one of the following: "Always", "Never", "IfNotPresent"' - in ex_ctx.exception.stderr.decode() + in ex_ctx.value.stderr.decode() ) def test_invalid_dags_access_mode(self): - with self.assertRaises(CalledProcessError) as ex_ctx: + with pytest.raises(CalledProcessError) as ex_ctx: render_chart( "test-basic", { @@ -510,10 +501,10 @@ def test_invalid_dags_access_mode(self): ) assert ( 'accessMode must be one of the following: "ReadWriteOnce", "ReadOnlyMany", "ReadWriteMany"' - in ex_ctx.exception.stderr.decode() + in ex_ctx.value.stderr.decode() ) - @parameterized.expand(["abc", "123", "123abc", "123-abc"]) + @pytest.mark.parametrize("namespace", ["abc", "123", "123abc", "123-abc"]) def test_namespace_names(self, namespace): """Test various namespace names to make sure they render correctly in templates""" render_chart(namespace=namespace) diff --git a/tests/charts/test_celery_kubernetes_executor.py b/tests/charts/test_celery_kubernetes_executor.py index 1e628deda..382ada608 100644 --- a/tests/charts/test_celery_kubernetes_executor.py +++ b/tests/charts/test_celery_kubernetes_executor.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class CeleryKubernetesExecutorTest(unittest.TestCase): +class TestCeleryKubernetesExecutor: def test_should_create_a_worker_deployment_with_the_celery_executor(self): docs = render_chart( values={ diff --git a/tests/charts/test_chart_quality.py b/tests/charts/test_chart_quality.py index a4ee2c269..f18561fe6 100644 --- a/tests/charts/test_chart_quality.py +++ b/tests/charts/test_chart_quality.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import unittest from pathlib import Path import yaml @@ -26,7 +25,7 @@ CHART_DIR = Path(__file__).parent / ".." / ".." / "chart" -class ChartQualityTest(unittest.TestCase): +class TestChartQuality: def test_values_validate_schema(self): values = yaml.safe_load((CHART_DIR / "values.yaml").read_text()) schema = json.loads((CHART_DIR / "values.schema.json").read_text()) diff --git a/tests/charts/test_cleanup_pods.py b/tests/charts/test_cleanup_pods.py index 5c3e79b4b..e25920ee7 100644 --- a/tests/charts/test_cleanup_pods.py +++ b/tests/charts/test_cleanup_pods.py @@ -16,15 +16,13 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class CleanupPodsTest(unittest.TestCase): +class TestCleanupPods: def test_should_create_cronjob_for_enabled_cleanup(self): docs = render_chart( values={ @@ -139,14 +137,8 @@ def test_should_add_extraEnvs(self): "spec.jobTemplate.spec.template.spec.containers[0].env", docs[0] ) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"cleanup": {"enabled": True, "command": command, "args": args}}, @@ -248,7 +240,7 @@ def test_should_set_job_history_limits(self): assert 4 == jmespath.search("spec.successfulJobsHistoryLimit", docs[0]) -class CleanupServiceAccountTest(unittest.TestCase): +class TestCleanupServiceAccount: def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/tests/charts/test_configmap.py b/tests/charts/test_configmap.py index 0cf326b93..4211f2c3b 100644 --- a/tests/charts/test_configmap.py +++ b/tests/charts/test_configmap.py @@ -16,15 +16,13 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class ConfigmapTest(unittest.TestCase): +class TestConfigmap: def test_single_annotation(self): docs = render_chart( values={ @@ -48,14 +46,15 @@ def test_multiple_annotations(self): assert "value" == annotations.get("key") assert "value-two" == annotations.get("key-two") - @parameterized.expand( + @pytest.mark.parametrize( + "af_version, secret_key, secret_key_name, expected", [ ('2.2.0', None, None, True), ('2.2.0', "foo", None, False), ('2.2.0', None, "foo", False), ('2.1.3', None, None, False), ('2.1.3', "foo", None, False), - ] + ], ) def test_default_airflow_local_settings(self, af_version, secret_key, secret_key_name, expected): docs = render_chart( diff --git a/tests/charts/test_create_user_job.py b/tests/charts/test_create_user_job.py index 1ede8bdd8..eb4b5cd1c 100644 --- a/tests/charts/test_create_user_job.py +++ b/tests/charts/test_create_user_job.py @@ -16,15 +16,13 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class CreateUserJobTest(unittest.TestCase): +class TestCreateUserJob: def test_should_run_by_default(self): docs = render_chart(show_only=["templates/jobs/create-user-job.yaml"]) assert "Job" == docs[0]["kind"] @@ -197,7 +195,8 @@ def test_should_add_extraEnvs(self): "spec.template.spec.containers[0].env", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, expected_arg", [ ("1.10.14", "airflow create_user"), ("2.0.2", "airflow users create"), @@ -231,14 +230,8 @@ def test_default_command_and_args_airflow_version(self, airflow_version, expecte "admin", ] == jmespath.search("spec.template.spec.containers[0].args", docs[0]) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"createUserJob": {"command": command, "args": args}}, @@ -297,7 +290,7 @@ def test_default_user_overrides(self): ] == jmespath.search("spec.template.spec.containers[0].args", docs[0]) -class CreateUserJobServiceAccountTest(unittest.TestCase): +class TestCreateUserJobServiceAccount: def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/tests/charts/test_dag_processor.py b/tests/charts/test_dag_processor.py index 3e66c8366..aaa4ec20d 100644 --- a/tests/charts/test_dag_processor.py +++ b/tests/charts/test_dag_processor.py @@ -16,20 +16,19 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class DagProcessorTest(unittest.TestCase): - @parameterized.expand( +class TestDagProcessor: + @pytest.mark.parametrize( + "airflow_version, num_docs", [ ("2.2.0", 0), ("2.3.0", 1), - ] + ], ) def test_only_exists_on_new_airflow_versions(self, airflow_version, num_docs): """Standalone Dag Processor was only added from Airflow 2.3 onwards""" @@ -313,7 +312,8 @@ def test_livenessprobe_values_are_configurable(self): "spec.template.spec.containers[0].livenessProbe.exec.command", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "log_persistence_values, expected_volume", [ ({"enabled": False}, {"emptyDir": {}}), ({"enabled": True}, {"persistentVolumeClaim": {"claimName": "release-name-logs"}}), @@ -321,7 +321,7 @@ def test_livenessprobe_values_are_configurable(self): {"enabled": True, "existingClaim": "test-claim"}, {"persistentVolumeClaim": {"claimName": "test-claim"}}, ), - ] + ], ) def test_logs_persistence_changes_volume(self, log_persistence_values, expected_volume): docs = render_chart( @@ -374,14 +374,15 @@ def test_resources_are_not_added_by_default(self): ) assert jmespath.search("spec.template.spec.containers[0].resources", docs[0]) == {} - @parameterized.expand( + @pytest.mark.parametrize( + "strategy, expected_strategy", [ (None, None), ( {"rollingUpdate": {"maxSurge": "100%", "maxUnavailable": "50%"}}, {"rollingUpdate": {"maxSurge": "100%", "maxUnavailable": "50%"}}, ), - ] + ], ) def test_strategy(self, strategy, expected_strategy): """strategy should be used when we aren't using both LocalExecutor and workers.persistence""" @@ -405,7 +406,10 @@ def test_default_command_and_args(self): "spec.template.spec.containers[0].args", docs[0] ) - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = { "dagProcessor": { @@ -423,14 +427,8 @@ def test_revision_history_limit(self, revision_history_limit, global_revision_hi expected_result = revision_history_limit if revision_history_limit else global_revision_history_limit assert jmespath.search("spec.revisionHistoryLimit", docs[0]) == expected_result - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={ diff --git a/tests/charts/test_dags_persistent_volume_claim.py b/tests/charts/test_dags_persistent_volume_claim.py index 404cdb67c..38237c7fc 100644 --- a/tests/charts/test_dags_persistent_volume_claim.py +++ b/tests/charts/test_dags_persistent_volume_claim.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class DagsPersistentVolumeClaimTest(unittest.TestCase): +class TestDagsPersistentVolumeClaim: def test_should_not_generate_a_document_if_persistence_is_disabled(self): docs = render_chart( values={"dags": {"persistence": {"enabled": False}}}, diff --git a/tests/charts/test_elasticsearch_secret.py b/tests/charts/test_elasticsearch_secret.py index 41beeaa48..468921290 100644 --- a/tests/charts/test_elasticsearch_secret.py +++ b/tests/charts/test_elasticsearch_secret.py @@ -17,16 +17,15 @@ from __future__ import annotations import base64 -import unittest from subprocess import CalledProcessError import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class ElasticsearchSecretTest(unittest.TestCase): +class TestElasticsearchSecret: def test_should_not_generate_a_document_if_elasticsearch_disabled(self): docs = render_chart( @@ -37,7 +36,7 @@ def test_should_not_generate_a_document_if_elasticsearch_disabled(self): assert 0 == len(docs) def test_should_raise_error_when_connection_not_provided(self): - with self.assertRaises(CalledProcessError) as ex_ctx: + with pytest.raises(CalledProcessError) as ex_ctx: render_chart( values={ "elasticsearch": { @@ -48,11 +47,11 @@ def test_should_raise_error_when_connection_not_provided(self): ) assert ( "You must set one of the values elasticsearch.secretName or elasticsearch.connection " - "when using a Elasticsearch" in ex_ctx.exception.stderr.decode() + "when using a Elasticsearch" in ex_ctx.value.stderr.decode() ) def test_should_raise_error_when_conflicting_options(self): - with self.assertRaises(CalledProcessError) as ex_ctx: + with pytest.raises(CalledProcessError) as ex_ctx: render_chart( values={ "elasticsearch": { @@ -69,7 +68,7 @@ def test_should_raise_error_when_conflicting_options(self): ) assert ( "You must not set both values elasticsearch.secretName and elasticsearch.connection" - in ex_ctx.exception.stderr.decode() + in ex_ctx.value.stderr.decode() ) def _get_connection(self, values: dict) -> str: @@ -116,7 +115,7 @@ def test_should_generate_secret_with_specified_port(self): assert "http://username:password@elastichostname:2222" == connection - @parameterized.expand(["http", "https"]) + @pytest.mark.parametrize("scheme", ["http", "https"]) def test_should_generate_secret_with_specified_schemes(self, scheme): connection = self._get_connection( { @@ -134,7 +133,8 @@ def test_should_generate_secret_with_specified_schemes(self, scheme): assert f"{scheme}://username:password@elastichostname:9200" == connection - @parameterized.expand( + @pytest.mark.parametrize( + "extra_conn_kwargs, expected_user_info", [ # When both user and password are empty. ({}, ""), diff --git a/tests/charts/test_extra_configmaps_secrets.py b/tests/charts/test_extra_configmaps_secrets.py index a7e2bbe44..8bd8d17c8 100644 --- a/tests/charts/test_extra_configmaps_secrets.py +++ b/tests/charts/test_extra_configmaps_secrets.py @@ -17,19 +17,18 @@ from __future__ import annotations import textwrap -import unittest from base64 import b64encode from unittest import mock +import pytest import yaml -from parameterized import parameterized from tests.charts.helm_template_generator import prepare_k8s_lookup_dict, render_chart RELEASE_NAME = "test-extra-configmaps-secrets" -class ExtraConfigMapsSecretsTest(unittest.TestCase): +class TestExtraConfigMapsSecrets: def test_extra_configmaps(self): values_str = textwrap.dedent( """ @@ -151,12 +150,13 @@ def test_extra_configmaps_secrets_labels(self): for k8s_object in k8s_objects: assert k8s_object['metadata']['labels'] == expected_labels - @parameterized.expand( + @pytest.mark.parametrize( + "chart_labels, local_labels", [ ({}, {"label3": "value3", "label4": "value4"}), ({"label1": "value1", "label2": "value2"}, {}), ({"label1": "value1", "label2": "value2"}, {"label3": "value3", "label4": "value4"}), - ] + ], ) def test_extra_configmaps_secrets_additional_labels(self, chart_labels, local_labels): k8s_objects = render_chart( diff --git a/tests/charts/test_extra_env_env_from.py b/tests/charts/test_extra_env_env_from.py index f1a051780..b32b6fae5 100644 --- a/tests/charts/test_extra_env_env_from.py +++ b/tests/charts/test_extra_env_env_from.py @@ -17,12 +17,11 @@ from __future__ import annotations import textwrap -import unittest from typing import Any import jmespath +import pytest import yaml -from parameterized import parameterized from tests.charts.helm_template_generator import prepare_k8s_lookup_dict, render_chart @@ -73,12 +72,12 @@ ] -class ExtraEnvEnvFromTest(unittest.TestCase): +class TestExtraEnvEnvFrom: k8s_objects: list[dict[str, Any]] k8s_objects_by_key: dict[tuple[str, str], dict[str, Any]] @classmethod - def setUpClass(cls) -> None: + def setup_class(cls) -> None: values_str = textwrap.dedent( """ flower: @@ -102,7 +101,7 @@ def setUpClass(cls) -> None: cls.k8s_objects = render_chart(RELEASE_NAME, values=values) cls.k8s_objects_by_key = prepare_k8s_lookup_dict(cls.k8s_objects) - @parameterized.expand(PARAMS) + @pytest.mark.parametrize("k8s_obj_key, env_paths", PARAMS) def test_extra_env(self, k8s_obj_key, env_paths): expected_env_as_str = textwrap.dedent( f""" @@ -120,7 +119,7 @@ def test_extra_env(self, k8s_obj_key, env_paths): env = jmespath.search(f"{path}.env", k8s_object) assert expected_env_as_str in yaml.dump(env) - @parameterized.expand(PARAMS) + @pytest.mark.parametrize("k8s_obj_key, env_from_paths", PARAMS) def test_extra_env_from(self, k8s_obj_key, env_from_paths): expected_env_from_as_str = textwrap.dedent( f""" diff --git a/tests/charts/test_flower.py b/tests/charts/test_flower.py index 5ee65b54f..5b820f697 100644 --- a/tests/charts/test_flower.py +++ b/tests/charts/test_flower.py @@ -18,7 +18,6 @@ import jmespath import pytest -from parameterized import parameterized from tests.charts.helm_template_generator import render_chart @@ -46,7 +45,10 @@ def test_create_flower(self, executor, flower_enabled, created): assert "release-name-flower" == jmespath.search("metadata.name", docs[0]) assert "flower" == jmespath.search("spec.template.spec.containers[0].name", docs[0]) - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = { "flower": { diff --git a/tests/charts/test_git_sync_scheduler.py b/tests/charts/test_git_sync_scheduler.py index f3f14b91a..1d112ba36 100644 --- a/tests/charts/test_git_sync_scheduler.py +++ b/tests/charts/test_git_sync_scheduler.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class GitSyncSchedulerTest(unittest.TestCase): +class TestGitSyncSchedulerTest: def test_should_add_dags_volume(self): docs = render_chart( values={"dags": {"gitSync": {"enabled": True}}}, diff --git a/tests/charts/test_git_sync_triggerer.py b/tests/charts/test_git_sync_triggerer.py index 484f77756..d64f3eec6 100644 --- a/tests/charts/test_git_sync_triggerer.py +++ b/tests/charts/test_git_sync_triggerer.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class GitSyncTriggererTest(unittest.TestCase): +class TestGitSyncTriggerer: def test_validate_sshkeysecret_not_added_when_persistence_is_enabled(self): docs = render_chart( values={ diff --git a/tests/charts/test_git_sync_webserver.py b/tests/charts/test_git_sync_webserver.py index 412a52794..61b65a548 100644 --- a/tests/charts/test_git_sync_webserver.py +++ b/tests/charts/test_git_sync_webserver.py @@ -16,15 +16,13 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class GitSyncWebserverTest(unittest.TestCase): +class TestGitSyncWebserver: def test_should_add_dags_volume_to_the_webserver_if_git_sync_and_persistence_is_enabled(self): docs = render_chart( values={ @@ -71,28 +69,14 @@ def test_should_have_service_account_defined(self): "spec.template.spec.serviceAccountName", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, exclude_webserver", [ - ( - "2.0.0", - True, - ), - ( - "2.0.2", - True, - ), - ( - "1.10.14", - False, - ), - ( - "1.9.0", - False, - ), - ( - "2.1.0", - True, - ), + ("2.0.0", True), + ("2.0.2", True), + ("1.10.14", False), + ("1.9.0", False), + ("2.1.0", True), ], ) def test_git_sync_with_different_airflow_versions(self, airflow_version, exclude_webserver): diff --git a/tests/charts/test_git_sync_worker.py b/tests/charts/test_git_sync_worker.py index cf810be13..40486f1d5 100644 --- a/tests/charts/test_git_sync_worker.py +++ b/tests/charts/test_git_sync_worker.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class GitSyncWorkerTest(unittest.TestCase): +class TestGitSyncWorker: def test_should_add_dags_volume_to_the_worker_if_git_sync_and_persistence_is_enabled(self): docs = render_chart( values={ diff --git a/tests/charts/test_ingress_flower.py b/tests/charts/test_ingress_flower.py index 9cb496016..aaf14e793 100644 --- a/tests/charts/test_ingress_flower.py +++ b/tests/charts/test_ingress_flower.py @@ -17,15 +17,14 @@ from __future__ import annotations import itertools -import unittest import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class IngressFlowerTest(unittest.TestCase): +class TestIngressFlower: def test_should_pass_validation_with_just_ingress_enabled_v1(self): render_chart( values={"flower": {"enabled": True}, "ingress": {"flower": {"enabled": True}}}, @@ -137,7 +136,8 @@ def test_should_ingress_host_entry_not_exist(self): ) assert not jmespath.search("spec.rules[*].host", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "global_value, flower_value, expected", [ (None, None, False), (None, False, False), @@ -146,7 +146,7 @@ def test_should_ingress_host_entry_not_exist(self): (True, None, True), (False, True, True), # We will deploy it if _either_ are true (True, False, True), - ] + ], ) def test_ingress_created(self, global_value, flower_value, expected): values = {"flower": {"enabled": True}, "ingress": {}} diff --git a/tests/charts/test_ingress_web.py b/tests/charts/test_ingress_web.py index 134fbd1cf..51f1d002c 100644 --- a/tests/charts/test_ingress_web.py +++ b/tests/charts/test_ingress_web.py @@ -16,15 +16,13 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class IngressWebTest(unittest.TestCase): +class TestIngressWeb: def test_should_pass_validation_with_just_ingress_enabled_v1(self): render_chart( values={"ingress": {"web": {"enabled": True}}}, @@ -132,7 +130,8 @@ def test_should_ingress_host_entry_not_exist(self): ) assert not jmespath.search("spec.rules[*].host", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "global_value, web_value, expected", [ (None, None, False), (None, False, False), @@ -141,7 +140,7 @@ def test_should_ingress_host_entry_not_exist(self): (True, None, True), (False, True, True), # We will deploy it if _either_ are true (True, False, True), - ] + ], ) def test_ingress_created(self, global_value, web_value, expected): values = {"ingress": {}} diff --git a/tests/charts/test_keda.py b/tests/charts/test_keda.py index a44d4cc22..af4ca6690 100644 --- a/tests/charts/test_keda.py +++ b/tests/charts/test_keda.py @@ -18,7 +18,6 @@ import jmespath import pytest -from parameterized import parameterized from tests.charts.helm_template_generator import render_chart @@ -32,11 +31,12 @@ def test_keda_disabled_by_default(self): ) assert docs == [] - @parameterized.expand( + @pytest.mark.parametrize( + "executor, is_created", [ ('CeleryExecutor', True), ('CeleryKubernetesExecutor', True), - ] + ], ) def test_keda_enabled(self, executor, is_created): """ @@ -54,12 +54,7 @@ def test_keda_enabled(self, executor, is_created): else: assert docs == [] - @parameterized.expand( - [ - ('CeleryExecutor'), - ('CeleryKubernetesExecutor'), - ] - ) + @pytest.mark.parametrize("executor", ['CeleryExecutor', 'CeleryKubernetesExecutor']) def test_keda_advanced(self, executor): """ Verify keda advanced config. @@ -151,11 +146,12 @@ def test_keda_query_kubernetes_queue(self, executor, queue, should_filter): expected_query = self.build_query(executor=executor, queue=queue) assert jmespath.search("spec.triggers[0].metadata.query", docs[0]) == expected_query - @parameterized.expand( + @pytest.mark.parametrize( + "enabled, kind", [ ('enabled', 'StatefulSet'), ('not_enabled', 'Deployment'), - ] + ], ) def test_persistence(self, enabled, kind): """ diff --git a/tests/charts/test_kerberos.py b/tests/charts/test_kerberos.py index e1a1d85bb..767541cbb 100644 --- a/tests/charts/test_kerberos.py +++ b/tests/charts/test_kerberos.py @@ -17,14 +17,13 @@ from __future__ import annotations import json -import unittest import jmespath from tests.charts.helm_template_generator import render_chart -class KerberosTest(unittest.TestCase): +class TestKerberos: def test_kerberos_not_mentioned_in_render_if_disabled(self): # the name is deliberately shorter as we look for "kerberos" in the rendered chart k8s_objects = render_chart(name="no-krbros", values={"kerberos": {'enabled': False}}) diff --git a/tests/charts/test_limit_ranges.py b/tests/charts/test_limit_ranges.py index 382835142..722574501 100644 --- a/tests/charts/test_limit_ranges.py +++ b/tests/charts/test_limit_ranges.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class LimitRangesTest(unittest.TestCase): +class TestLimitRanges: def test_limit_ranges_template(self): docs = render_chart( values={"limits": [{"max": {"cpu": "500m"}, "min": {"min": "200m"}, "type": "Container"}]}, diff --git a/tests/charts/test_logs_persistent_volume_claim.py b/tests/charts/test_logs_persistent_volume_claim.py index 8b6c5a40d..8ca356116 100644 --- a/tests/charts/test_logs_persistent_volume_claim.py +++ b/tests/charts/test_logs_persistent_volume_claim.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class LogsPersistentVolumeClaimTest(unittest.TestCase): +class TestLogsPersistentVolumeClaim: def test_should_not_generate_a_document_if_persistence_is_disabled(self): docs = render_chart( values={"logs": {"persistence": {"enabled": False}}}, diff --git a/tests/charts/test_metadata_connection_secret.py b/tests/charts/test_metadata_connection_secret.py index 49ad8e403..1c910a977 100644 --- a/tests/charts/test_metadata_connection_secret.py +++ b/tests/charts/test_metadata_connection_secret.py @@ -17,14 +17,13 @@ from __future__ import annotations import base64 -import unittest import jmespath from tests.charts.helm_template_generator import render_chart -class MetadataConnectionSecretTest(unittest.TestCase): +class TestMetadataConnectionSecret: non_chart_database_values = { "user": "someuser", diff --git a/tests/charts/test_migrate_database_job.py b/tests/charts/test_migrate_database_job.py index 5fe6c63c6..69ecb858e 100644 --- a/tests/charts/test_migrate_database_job.py +++ b/tests/charts/test_migrate_database_job.py @@ -18,7 +18,6 @@ import jmespath import pytest -from parameterized import parameterized from tests.charts.helm_template_generator import render_chart @@ -220,7 +219,8 @@ def test_should_add_extra_volume_mounts(self): "spec.template.spec.containers[0].volumeMounts[-1]", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, expected_arg", [ ("1.10.14", "airflow upgradedb"), ("2.0.2", "airflow db upgrade"), @@ -241,14 +241,8 @@ def test_default_command_and_args_airflow_version(self, airflow_version, expecte f"exec \\\n{expected_arg}", ] == jmespath.search("spec.template.spec.containers[0].args", docs[0]) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"migrateDatabaseJob": {"command": command, "args": args}}, diff --git a/tests/charts/test_pdb_pgbouncer.py b/tests/charts/test_pdb_pgbouncer.py index 737fa7315..6964fed23 100644 --- a/tests/charts/test_pdb_pgbouncer.py +++ b/tests/charts/test_pdb_pgbouncer.py @@ -16,12 +16,10 @@ # under the License. from __future__ import annotations -import unittest - from tests.charts.helm_template_generator import render_chart -class PgbouncerPdbTest(unittest.TestCase): +class TestPgbouncerPdb: def test_should_pass_validation_with_just_pdb_enabled_v1(self): render_chart( values={"pgbouncer": {"enabled": True, "podDisruptionBudget": {"enabled": True}}}, diff --git a/tests/charts/test_pdb_scheduler.py b/tests/charts/test_pdb_scheduler.py index e0d35f937..ba23e2a5e 100644 --- a/tests/charts/test_pdb_scheduler.py +++ b/tests/charts/test_pdb_scheduler.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class SchedulerPdbTest(unittest.TestCase): +class TestSchedulerPdb: def test_should_pass_validation_with_just_pdb_enabled_v1(self): render_chart( values={"scheduler": {"podDisruptionBudget": {"enabled": True}}}, diff --git a/tests/charts/test_pdb_webserver.py b/tests/charts/test_pdb_webserver.py index 2d7cedae8..8c10f2d0f 100644 --- a/tests/charts/test_pdb_webserver.py +++ b/tests/charts/test_pdb_webserver.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class WebserverPdbTest(unittest.TestCase): +class TestWebserverPdb: def test_should_pass_validation_with_just_pdb_enabled_v1(self): render_chart( values={"webserver": {"podDisruptionBudget": {"enabled": True}}}, diff --git a/tests/charts/test_pgbouncer.py b/tests/charts/test_pgbouncer.py index 93131c038..c3a8aa0d2 100644 --- a/tests/charts/test_pgbouncer.py +++ b/tests/charts/test_pgbouncer.py @@ -17,16 +17,15 @@ from __future__ import annotations import base64 -import unittest import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class PgbouncerTest(unittest.TestCase): - @parameterized.expand(["pgbouncer-deployment", "pgbouncer-service"]) +class TestPgbouncer: + @pytest.mark.parametrize("yaml_filename", ["pgbouncer-deployment", "pgbouncer-service"]) def test_pgbouncer_resources_not_created_by_default(self, yaml_filename): docs = render_chart( show_only=[f"templates/pgbouncer/{yaml_filename}.yaml"], @@ -97,7 +96,10 @@ def test_pgbouncer_service_extra_annotations(self): "foo": "bar", } == jmespath.search("metadata.annotations", docs[0]) - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = { "pgbouncer": { @@ -253,14 +255,8 @@ def test_default_command_and_args(self): ) assert jmespath.search("spec.template.spec.containers[0].args", docs[0]) is None - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"pgbouncer": {"enabled": True, "command": command, "args": args}}, @@ -312,7 +308,7 @@ def test_should_add_extra_volume_and_extra_volume_mount(self): ) -class PgbouncerConfigTest(unittest.TestCase): +class TestPgbouncerConfig: def test_config_not_created_by_default(self): docs = render_chart( show_only=["templates/secrets/pgbouncer-config-secret.yaml"], @@ -472,7 +468,7 @@ def test_extra_ini_configs(self): assert "stats_period = 30" in ini -class PgbouncerExporterTest(unittest.TestCase): +class TestPgbouncerExporter: def test_secret_not_created_by_default(self): docs = render_chart( show_only=["templates/secrets/pgbouncer-stats-secret.yaml"], diff --git a/tests/charts/test_pod_launcher_role.py b/tests/charts/test_pod_launcher_role.py index 682de8be1..3dec7a993 100644 --- a/tests/charts/test_pod_launcher_role.py +++ b/tests/charts/test_pod_launcher_role.py @@ -16,23 +16,22 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class PodLauncherTest(unittest.TestCase): - @parameterized.expand( +class TestPodLauncher: + @pytest.mark.parametrize( + "executor, rbac, allow, expected_accounts", [ ("CeleryKubernetesExecutor", True, True, ['scheduler', 'worker']), ("KubernetesExecutor", True, True, ['scheduler', 'worker']), ("CeleryExecutor", True, True, ['worker']), ("LocalExecutor", True, True, ['scheduler']), ("LocalExecutor", False, False, []), - ] + ], ) def test_pod_launcher_role(self, executor, rbac, allow, expected_accounts): docs = render_chart( diff --git a/tests/charts/test_pod_template_file.py b/tests/charts/test_pod_template_file.py index c44efbd70..fb04d95b4 100644 --- a/tests/charts/test_pod_template_file.py +++ b/tests/charts/test_pod_template_file.py @@ -17,33 +17,31 @@ from __future__ import annotations import re -import unittest from pathlib import Path from shutil import copyfile, copytree from tempfile import TemporaryDirectory import jmespath import pytest -from parameterized import parameterized from tests.charts.helm_template_generator import render_chart -CHART_DIR = Path(__file__).parent / ".." / ".." / "chart" +@pytest.fixture(scope="class", autouse=True) +def isolate_chart(request): + chart_dir = Path(__file__).parent / ".." / ".." / "chart" + with TemporaryDirectory(prefix=request.cls.__name__) as tmp_dir: + temp_chart_dir = Path(tmp_dir) / "chart" + copytree(chart_dir, temp_chart_dir) + copyfile( + temp_chart_dir / "files/pod-template-file.kubernetes-helm-yaml", + temp_chart_dir / "templates/pod-template-file.yaml", + ) + request.cls.temp_chart_dir = str(temp_chart_dir) + yield -class PodTemplateFileTest(unittest.TestCase): - @classmethod - @pytest.fixture(autouse=True, scope="class") - def isolate_chart(cls): - with TemporaryDirectory() as tmp_dir: - cls.temp_chart_dir = tmp_dir + "/chart" - copytree(CHART_DIR, cls.temp_chart_dir) - copyfile( - cls.temp_chart_dir + "/files/pod-template-file.kubernetes-helm-yaml", - cls.temp_chart_dir + "/templates/pod-template-file.yaml", - ) - yield +class TestPodTemplateFile: def test_should_work(self): docs = render_chart( values={}, @@ -122,7 +120,8 @@ def test_should_not_add_init_container_if_dag_persistence_is_true(self): assert jmespath.search("spec.initContainers", docs[0]) is None - @parameterized.expand( + @pytest.mark.parametrize( + "dag_values, expected_read_only", [ ({"gitSync": {"enabled": True}}, True), ({"persistence": {"enabled": True}}, False), @@ -133,7 +132,7 @@ def test_should_not_add_init_container_if_dag_persistence_is_true(self): }, True, ), - ] + ], ) def test_dags_mount(self, dag_values, expected_read_only): docs = render_chart( @@ -259,7 +258,8 @@ def test_should_use_empty_dir_for_gitsync_without_persistence(self): assert {"name": "dags", "emptyDir": {}} in jmespath.search("spec.volumes", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "log_persistence_values, expected", [ ({"enabled": False}, {"emptyDir": {}}), ({"enabled": True}, {"persistentVolumeClaim": {"claimName": "release-name-logs"}}), @@ -267,7 +267,7 @@ def test_should_use_empty_dir_for_gitsync_without_persistence(self): {"enabled": True, "existingClaim": "test-claim"}, {"persistentVolumeClaim": {"claimName": "test-claim"}}, ), - ] + ], ) def test_logs_persistence_changes_volume(self, log_persistence_values, expected): docs = render_chart( @@ -500,7 +500,7 @@ def test_should_add_fsgroup_to_the_pod_template(self): chart_dir=self.temp_chart_dir, ) - self.assertEqual(5000, jmespath.search("spec.securityContext.fsGroup", docs[0])) + assert jmespath.search("spec.securityContext.fsGroup", docs[0]) == 5000 def test_should_create_valid_volume_mount_and_volume(self): docs = render_chart( diff --git a/tests/charts/test_rbac.py b/tests/charts/test_rbac.py index 08a62781e..28007d5d4 100644 --- a/tests/charts/test_rbac.py +++ b/tests/charts/test_rbac.py @@ -16,10 +16,8 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart @@ -106,7 +104,7 @@ ) -class RBACTest(unittest.TestCase): +class TestRBAC: def _get_values_with_version(self, values, version): if version != "default": values["airflowVersion"] = version @@ -119,7 +117,7 @@ def _get_object_count(self, version): ] + DEPLOYMENT_NO_RBAC_NO_SA_KIND_NAME_TUPLES return DEPLOYMENT_NO_RBAC_NO_SA_KIND_NAME_TUPLES - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_deployments_no_rbac_no_sa(self, version): k8s_objects = render_chart( "test-rbac", @@ -155,13 +153,9 @@ def test_deployments_no_rbac_no_sa(self, version): list_of_kind_names_tuples = [ (k8s_object['kind'], k8s_object['metadata']['name']) for k8s_object in k8s_objects ] + assert sorted(list_of_kind_names_tuples) == sorted(self._get_object_count(version)) - self.assertCountEqual( - list_of_kind_names_tuples, - self._get_object_count(version), - ) - - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_deployments_no_rbac_with_sa(self, version): k8s_objects = render_chart( "test-rbac", @@ -180,12 +174,9 @@ def test_deployments_no_rbac_with_sa(self, version): (k8s_object['kind'], k8s_object['metadata']['name']) for k8s_object in k8s_objects ] real_list_of_kind_names = self._get_object_count(version) + SERVICE_ACCOUNT_NAME_TUPLES - self.assertCountEqual( - list_of_kind_names_tuples, - real_list_of_kind_names, - ) + assert sorted(list_of_kind_names_tuples) == sorted(real_list_of_kind_names) - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_deployments_with_rbac_no_sa(self, version): k8s_objects = render_chart( "test-rbac", @@ -221,12 +212,9 @@ def test_deployments_with_rbac_no_sa(self, version): (k8s_object['kind'], k8s_object['metadata']['name']) for k8s_object in k8s_objects ] real_list_of_kind_names = self._get_object_count(version) + RBAC_ENABLED_KIND_NAME_TUPLES - self.assertCountEqual( - list_of_kind_names_tuples, - real_list_of_kind_names, - ) + assert sorted(list_of_kind_names_tuples) == sorted(real_list_of_kind_names) - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_deployments_with_rbac_with_sa(self, version): k8s_objects = render_chart( "test-rbac", @@ -246,10 +234,7 @@ def test_deployments_with_rbac_with_sa(self, version): real_list_of_kind_names = ( self._get_object_count(version) + SERVICE_ACCOUNT_NAME_TUPLES + RBAC_ENABLED_KIND_NAME_TUPLES ) - self.assertCountEqual( - list_of_kind_names_tuples, - real_list_of_kind_names, - ) + assert sorted(list_of_kind_names_tuples) == sorted(real_list_of_kind_names) def test_service_account_custom_names(self): k8s_objects = render_chart( @@ -284,10 +269,7 @@ def test_service_account_custom_names(self): for k8s_object in k8s_objects if k8s_object['kind'] == "ServiceAccount" ] - self.assertCountEqual( - list_of_sa_names, - CUSTOM_SERVICE_ACCOUNT_NAMES, - ) + assert sorted(list_of_sa_names) == sorted(CUSTOM_SERVICE_ACCOUNT_NAMES) def test_service_account_custom_names_in_objects(self): k8s_objects = render_chart( @@ -330,10 +312,7 @@ def test_service_account_custom_names_in_objects(self): if name and name not in list_of_sa_names_in_objects: list_of_sa_names_in_objects.append(name) - self.assertCountEqual( - list_of_sa_names_in_objects, - CUSTOM_SERVICE_ACCOUNT_NAMES, - ) + assert sorted(list_of_sa_names_in_objects) == sorted(CUSTOM_SERVICE_ACCOUNT_NAMES) def test_service_account_without_resource(self): k8s_objects = render_chart( @@ -360,4 +339,4 @@ def test_service_account_without_resource(self): 'test-rbac-triggerer', 'test-rbac-migrate-database-job', ] - self.assertCountEqual(list_of_sa_names, service_account_names) + assert sorted(list_of_sa_names) == sorted(service_account_names) diff --git a/tests/charts/test_redis.py b/tests/charts/test_redis.py index ac74ac13e..1560e99ad 100644 --- a/tests/charts/test_redis.py +++ b/tests/charts/test_redis.py @@ -17,13 +17,11 @@ from __future__ import annotations import re -import unittest from base64 import b64decode from subprocess import CalledProcessError import jmespath import pytest -from parameterized import parameterized from tests.charts.helm_template_generator import prepare_k8s_lookup_dict, render_chart @@ -38,10 +36,10 @@ } SET_POSSIBLE_REDIS_OBJECT_KEYS = set(REDIS_OBJECTS.values()) -CELERY_EXECUTORS_PARAMS = [("CeleryExecutor",), ("CeleryKubernetesExecutor",)] +CELERY_EXECUTORS_PARAMS = ["CeleryExecutor", "CeleryKubernetesExecutor"] -class RedisTest(unittest.TestCase): +class TestRedis: @staticmethod def get_broker_url_in_broker_url_secret(k8s_obj_by_key): broker_url_in_obj = b64decode( @@ -94,7 +92,7 @@ def assert_broker_url_env( ) assert broker_url_secret_in_worker == expected_broker_url_secret_name - @parameterized.expand(CELERY_EXECUTORS_PARAMS) + @pytest.mark.parametrize("executor", CELERY_EXECUTORS_PARAMS) def test_redis_by_chart_default(self, executor): k8s_objects = render_chart( RELEASE_NAME_REDIS, @@ -117,7 +115,7 @@ def test_redis_by_chart_default(self, executor): self.assert_broker_url_env(k8s_obj_by_key) - @parameterized.expand(CELERY_EXECUTORS_PARAMS) + @pytest.mark.parametrize("executor", CELERY_EXECUTORS_PARAMS) def test_redis_by_chart_password(self, executor): k8s_objects = render_chart( RELEASE_NAME_REDIS, @@ -142,7 +140,7 @@ def test_redis_by_chart_password(self, executor): self.assert_broker_url_env(k8s_obj_by_key) - @parameterized.expand(CELERY_EXECUTORS_PARAMS) + @pytest.mark.parametrize("executor", CELERY_EXECUTORS_PARAMS) def test_redis_by_chart_password_secret_name_missing_broker_url_secret_name(self, executor): with pytest.raises(CalledProcessError): render_chart( @@ -156,7 +154,7 @@ def test_redis_by_chart_password_secret_name_missing_broker_url_secret_name(self }, ) - @parameterized.expand(CELERY_EXECUTORS_PARAMS) + @pytest.mark.parametrize("executor", CELERY_EXECUTORS_PARAMS) def test_redis_by_chart_password_secret_name(self, executor): expected_broker_url_secret_name = "test-redis-broker-url-secret-name" k8s_objects = render_chart( @@ -185,7 +183,7 @@ def test_redis_by_chart_password_secret_name(self, executor): self.assert_broker_url_env(k8s_obj_by_key, expected_broker_url_secret_name) - @parameterized.expand(CELERY_EXECUTORS_PARAMS) + @pytest.mark.parametrize("executor", CELERY_EXECUTORS_PARAMS) def test_external_redis_broker_url(self, executor): k8s_objects = render_chart( RELEASE_NAME_REDIS, @@ -211,7 +209,7 @@ def test_external_redis_broker_url(self, executor): self.assert_broker_url_env(k8s_obj_by_key) - @parameterized.expand(CELERY_EXECUTORS_PARAMS) + @pytest.mark.parametrize("executor", CELERY_EXECUTORS_PARAMS) def test_external_redis_broker_url_secret_name(self, executor): expected_broker_url_secret_name = "redis-broker-url-secret-name" k8s_objects = render_chart( diff --git a/tests/charts/test_resource_quota.py b/tests/charts/test_resource_quota.py index a5c3222bd..1775479f7 100644 --- a/tests/charts/test_resource_quota.py +++ b/tests/charts/test_resource_quota.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -import unittest - import jmespath from tests.charts.helm_template_generator import render_chart -class ResourceQuotaTest(unittest.TestCase): +class TestResourceQuota: def test_resource_quota_template(self): docs = render_chart( values={ diff --git a/tests/charts/test_result_backend_connection_secret.py b/tests/charts/test_result_backend_connection_secret.py index a6bd2e255..70a27baa5 100644 --- a/tests/charts/test_result_backend_connection_secret.py +++ b/tests/charts/test_result_backend_connection_secret.py @@ -17,15 +17,14 @@ from __future__ import annotations import base64 -import unittest import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class ResultBackendConnectionSecretTest(unittest.TestCase): +class TestResultBackendConnectionSecret: def _get_values_with_version(self, values, version): if version != "default": values["airflowVersion"] = version @@ -55,12 +54,13 @@ def test_should_not_generate_a_document_if_using_existing_secret(self): assert 0 == len(docs) - @parameterized.expand( + @pytest.mark.parametrize( + "executor, expected_doc_count", [ ("CeleryExecutor", 1), ("CeleryKubernetesExecutor", 1), ("LocalExecutor", 0), - ] + ], ) def test_should_a_document_be_generated_for_executor(self, executor, expected_doc_count): docs = render_chart( @@ -90,7 +90,7 @@ def _get_connection(self, values: dict) -> str | None: encoded_connection = jmespath.search("data.connection", docs[0]) return base64.b64decode(encoded_connection).decode() - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_default_connection_old_version(self, version): connection = self._get_connection(self._get_values_with_version(version=version, values={})) self._assert_for_old_version( @@ -100,7 +100,7 @@ def test_default_connection_old_version(self, version): "-postgresql:5432/postgres?sslmode=disable", ) - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_should_default_to_custom_metadata_db_connection_with_pgbouncer_overrides(self, version): values = { "pgbouncer": {"enabled": True}, @@ -116,7 +116,7 @@ def test_should_default_to_custom_metadata_db_connection_with_pgbouncer_override ":6543/release-name-result-backend?sslmode=allow", ) - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_should_set_pgbouncer_overrides_when_enabled(self, version): values = {"pgbouncer": {"enabled": True}} connection = self._get_connection(self._get_values_with_version(values=values, version=version)) @@ -142,7 +142,7 @@ def test_should_set_pgbouncer_overrides_with_non_chart_database_when_enabled(sel "/release-name-result-backend?sslmode=allow" == connection ) - @parameterized.expand(["2.3.2", "2.4.0", "default"]) + @pytest.mark.parametrize("version", ["2.3.2", "2.4.0", "default"]) def test_should_default_to_custom_metadata_db_connection_in_old_version(self, version): values = { "data": {"metadataConnection": {**self.non_chart_database_values}}, diff --git a/tests/charts/test_scheduler.py b/tests/charts/test_scheduler.py index 1cfd5659e..0edd0bc12 100644 --- a/tests/charts/test_scheduler.py +++ b/tests/charts/test_scheduler.py @@ -16,16 +16,15 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class SchedulerTest(unittest.TestCase): - @parameterized.expand( +class TestScheduler: + @pytest.mark.parametrize( + "executor, persistence, kind", [ ("CeleryExecutor", False, "Deployment"), ("CeleryExecutor", True, "Deployment"), @@ -35,7 +34,7 @@ class SchedulerTest(unittest.TestCase): ("LocalKubernetesExecutor", True, "StatefulSet"), ("LocalExecutor", True, "StatefulSet"), ("LocalExecutor", False, "Deployment"), - ] + ], ) def test_scheduler_kind(self, executor, persistence, kind): """ @@ -160,7 +159,10 @@ def test_should_add_component_specific_labels(self): assert "test_label" in jmespath.search("spec.template.metadata.labels", docs[0]) assert jmespath.search("spec.template.metadata.labels", docs[0])["test_label"] == "test_label_value" - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = {"scheduler": {}} if revision_history_limit: @@ -331,7 +333,8 @@ def test_livenessprobe_values_are_configurable(self): "spec.template.spec.containers[0].livenessProbe.exec.command", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "log_persistence_values, expected_volume", [ ({"enabled": False}, {"emptyDir": {}}), ({"enabled": True}, {"persistentVolumeClaim": {"claimName": "release-name-logs"}}), @@ -339,7 +342,7 @@ def test_livenessprobe_values_are_configurable(self): {"enabled": True, "existingClaim": "test-claim"}, {"persistentVolumeClaim": {"claimName": "test-claim"}}, ), - ] + ], ) def test_logs_persistence_changes_volume(self, log_persistence_values, expected_volume): docs = render_chart( @@ -404,7 +407,8 @@ def test_airflow_local_settings(self): "readOnly": True, } in jmespath.search("spec.template.spec.containers[0].volumeMounts", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "executor, persistence, update_strategy, expected_update_strategy", [ ("CeleryExecutor", False, {"rollingUpdate": {"partition": 0}}, None), ("CeleryExecutor", True, {"rollingUpdate": {"partition": 0}}, None), @@ -418,7 +422,7 @@ def test_airflow_local_settings(self): ("LocalExecutor", False, {"rollingUpdate": {"partition": 0}}, None), ("LocalExecutor", True, {"rollingUpdate": {"partition": 0}}, {"rollingUpdate": {"partition": 0}}), ("LocalExecutor", True, None, None), - ] + ], ) def test_scheduler_update_strategy( self, executor, persistence, update_strategy, expected_update_strategy @@ -435,7 +439,8 @@ def test_scheduler_update_strategy( assert expected_update_strategy == jmespath.search("spec.updateStrategy", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "executor, persistence, strategy, expected_strategy", [ ("LocalExecutor", False, None, None), ("LocalExecutor", False, {"type": "Recreate"}, {"type": "Recreate"}), @@ -451,7 +456,7 @@ def test_scheduler_update_strategy( {"rollingUpdate": {"maxSurge": "100%", "maxUnavailable": "50%"}}, {"rollingUpdate": {"maxSurge": "100%", "maxUnavailable": "50%"}}, ), - ] + ], ) def test_scheduler_strategy(self, executor, persistence, strategy, expected_strategy): """strategy should be used when we aren't using both a local executor and workers.persistence""" @@ -474,14 +479,8 @@ def test_default_command_and_args(self): "spec.template.spec.containers[0].args", docs[0] ) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"scheduler": {"command": command, "args": args}}, @@ -521,14 +520,8 @@ def test_log_groomer_collector_default_retention_days(self): ) assert "15" == jmespath.search("spec.template.spec.containers[1].env[0].value", docs[0]) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_log_groomer_command_and_args_overrides(self, command, args): docs = render_chart( values={"scheduler": {"logGroomerSidecar": {"command": command, "args": args}}}, @@ -554,11 +547,12 @@ def test_log_groomer_command_and_args_overrides_are_templated(self): assert ["release-name"] == jmespath.search("spec.template.spec.containers[1].command", docs[0]) assert ["Helm"] == jmespath.search("spec.template.spec.containers[1].args", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "retention_days, retention_result", [ (None, None), (30, "30"), - ] + ], ) def test_log_groomer_retention_days_overrides(self, retention_days, retention_result): docs = render_chart( @@ -576,11 +570,12 @@ def test_log_groomer_retention_days_overrides(self, retention_days, retention_re else: assert jmespath.search("spec.template.spec.containers[1].env", docs[0]) is None - @parameterized.expand( + @pytest.mark.parametrize( + "dags_values", [ - ({"gitSync": {"enabled": True}},), - ({"gitSync": {"enabled": True}, "persistence": {"enabled": True}},), - ] + {"gitSync": {"enabled": True}}, + {"gitSync": {"enabled": True}, "persistence": {"enabled": True}}, + ], ) def test_dags_gitsync_sidecar_and_init_container(self, dags_values): docs = render_chart( @@ -593,7 +588,8 @@ def test_dags_gitsync_sidecar_and_init_container(self, dags_values): c["name"] for c in jmespath.search("spec.template.spec.initContainers", docs[0]) ] - @parameterized.expand( + @pytest.mark.parametrize( + "dag_processor, executor, skip_dags_mount", [ (True, "LocalExecutor", False), (True, "CeleryExecutor", True), @@ -603,7 +599,7 @@ def test_dags_gitsync_sidecar_and_init_container(self, dags_values): (False, "CeleryExecutor", False), (False, "KubernetesExecutor", False), (False, "LocalKubernetesExecutor", False), - ] + ], ) def test_dags_mount_and_gitsync_expected_with_dag_processor( self, dag_processor, executor, skip_dags_mount @@ -674,14 +670,15 @@ def test_persistence_volume_annotations(self): ) assert {"foo": "bar"} == jmespath.search("spec.volumeClaimTemplates[0].metadata.annotations", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "executor", [ "LocalExecutor", "LocalKubernetesExecutor", "CeleryExecutor", "KubernetesExecutor", "CeleryKubernetesExecutor", - ] + ], ) def test_scheduler_deployment_has_executor_label(self, executor): docs = render_chart( @@ -693,7 +690,7 @@ def test_scheduler_deployment_has_executor_label(self, executor): assert executor == docs[0]['metadata']['labels'].get('executor') -class SchedulerNetworkPolicyTest(unittest.TestCase): +class TestSchedulerNetworkPolicy: def test_should_add_component_specific_labels(self): docs = render_chart( values={ @@ -709,7 +706,7 @@ def test_should_add_component_specific_labels(self): assert jmespath.search("metadata.labels", docs[0])["test_label"] == "test_label_value" -class SchedulerServiceTest(unittest.TestCase): +class TestSchedulerService: def test_should_add_component_specific_labels(self): docs = render_chart( values={ @@ -725,7 +722,7 @@ def test_should_add_component_specific_labels(self): assert jmespath.search("metadata.labels", docs[0])["test_label"] == "test_label_value" -class SchedulerServiceAccountTest(unittest.TestCase): +class TestSchedulerServiceAccount: def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/tests/charts/test_statsd.py b/tests/charts/test_statsd.py index 91b51a4a9..032183d99 100644 --- a/tests/charts/test_statsd.py +++ b/tests/charts/test_statsd.py @@ -16,16 +16,14 @@ # under the License. from __future__ import annotations -import unittest - import jmespath +import pytest import yaml -from parameterized import parameterized from tests.charts.helm_template_generator import render_chart -class StatsdTest(unittest.TestCase): +class TestStatsd: def test_should_create_statsd_default(self): docs = render_chart(show_only=["templates/statsd/statsd-deployment.yaml"]) @@ -85,7 +83,10 @@ def test_should_add_volume_and_volume_mount_when_exist_override_mappings(self): "subPath": "mappings.yml", } in jmespath.search("spec.template.spec.containers[0].volumeMounts", docs[0]) - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = {"statsd": {"enabled": True}} if revision_history_limit: diff --git a/tests/charts/test_triggerer.py b/tests/charts/test_triggerer.py index 076a74d32..4c7e355f2 100644 --- a/tests/charts/test_triggerer.py +++ b/tests/charts/test_triggerer.py @@ -16,20 +16,19 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class TriggererTest(unittest.TestCase): - @parameterized.expand( +class TestTriggerer: + @pytest.mark.parametrize( + "airflow_version, num_docs", [ ("2.1.0", 0), ("2.2.0", 1), - ] + ], ) def test_only_exists_on_new_airflow_versions(self, airflow_version, num_docs): """Trigger was only added from Airflow 2.2 onwards""" @@ -52,7 +51,10 @@ def test_can_be_disabled(self): assert 0 == len(docs) - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = { "triggerer": { @@ -333,7 +335,8 @@ def test_livenessprobe_values_are_configurable(self): "spec.template.spec.containers[0].livenessProbe.exec.command", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "log_persistence_values, expected_volume", [ ({"enabled": False}, {"emptyDir": {}}), ({"enabled": True}, {"persistentVolumeClaim": {"claimName": "release-name-logs"}}), @@ -341,7 +344,7 @@ def test_livenessprobe_values_are_configurable(self): {"enabled": True, "existingClaim": "test-claim"}, {"persistentVolumeClaim": {"claimName": "test-claim"}}, ), - ] + ], ) def test_logs_persistence_changes_volume(self, log_persistence_values, expected_volume): docs = render_chart( @@ -389,14 +392,15 @@ def test_resources_are_not_added_by_default(self): ) assert jmespath.search("spec.template.spec.containers[0].resources", docs[0]) == {} - @parameterized.expand( + @pytest.mark.parametrize( + "strategy, expected_strategy", [ (None, None), ( {"rollingUpdate": {"maxSurge": "100%", "maxUnavailable": "50%"}}, {"rollingUpdate": {"maxSurge": "100%", "maxUnavailable": "50%"}}, ), - ] + ], ) def test_strategy(self, strategy, expected_strategy): """strategy should be used when we aren't using both LocalExecutor and workers.persistence""" @@ -419,14 +423,8 @@ def test_default_command_and_args(self): "spec.template.spec.containers[0].args", docs[0] ) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"triggerer": {"command": command, "args": args}}, @@ -473,7 +471,7 @@ def test_dags_gitsync_with_persistence_no_sidecar_or_init_container(self): ] -class TriggererServiceAccountTest(unittest.TestCase): +class TestTriggererServiceAccount: def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/tests/charts/test_webserver.py b/tests/charts/test_webserver.py index a6ccb4b4b..9aa122e33 100644 --- a/tests/charts/test_webserver.py +++ b/tests/charts/test_webserver.py @@ -16,15 +16,13 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class WebserverDeploymentTest(unittest.TestCase): +class TestWebserverDeployment: def test_should_add_host_header_to_liveness_and_readiness_probes(self): docs = render_chart( values={ @@ -61,7 +59,10 @@ def test_should_add_path_to_liveness_and_readiness_probes(self): == "/mypath/path/health" ) - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = {"webserver": {}} if revision_history_limit: @@ -75,12 +76,7 @@ def test_revision_history_limit(self, revision_history_limit, global_revision_hi expected_result = revision_history_limit if revision_history_limit else global_revision_history_limit assert jmespath.search("spec.revisionHistoryLimit", docs[0]) == expected_result - @parameterized.expand( - [ - ({"config": {"webserver": {"base_url": ""}}},), - ({},), - ] - ) + @pytest.mark.parametrize("values", [{"config": {"webserver": {"base_url": ""}}}, {}]) def test_should_not_contain_host_header(self, values): print(values) docs = render_chart(values=values, show_only=["templates/webserver/webserver-deployment.yaml"]) @@ -200,7 +196,8 @@ def test_should_add_extraEnvs_to_wait_for_migration_container(self): "spec.template.spec.initContainers[0].env", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, expected_arg", [ ("2.0.0", ["airflow", "db", "check-migrations", "--migration-wait-timeout=60"]), ("2.1.0", ["airflow", "db", "check-migrations", "--migration-wait-timeout=60"]), @@ -387,12 +384,13 @@ def test_affinity_tolerations_topology_spread_constraints_and_node_selector_prec "spec.template.spec.topologySpreadConstraints[0]", docs[0] ) - @parameterized.expand( + @pytest.mark.parametrize( + "log_persistence_values, expected_claim_name", [ ({"enabled": False}, None), ({"enabled": True}, "release-name-logs"), ({"enabled": True, "existingClaim": "test-claim"}, "test-claim"), - ] + ], ) def test_logs_persistence_adds_volume_and_mount(self, log_persistence_values, expected_claim_name): docs = render_chart( @@ -415,12 +413,13 @@ def test_logs_persistence_adds_volume_and_mount(self, log_persistence_values, ex v["name"] for v in jmespath.search("spec.template.spec.containers[0].volumeMounts", docs[0]) ] - @parameterized.expand( + @pytest.mark.parametrize( + "af_version, pod_template_file_expected", [ ("1.10.10", False), ("1.10.12", True), ("2.1.0", True), - ] + ], ) def test_config_volumes_and_mounts(self, af_version, pod_template_file_expected): # setup @@ -488,7 +487,8 @@ def test_webserver_resources_are_not_added_by_default(self): assert jmespath.search("spec.template.spec.containers[0].resources", docs[0]) == {} assert jmespath.search("spec.template.spec.initContainers[0].resources", docs[0]) == {} - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, expected_strategy", [ ("2.0.2", {"type": "RollingUpdate", "rollingUpdate": {"maxSurge": 1, "maxUnavailable": 0}}), ("1.10.14", {"type": "Recreate"}), @@ -540,14 +540,8 @@ def test_default_command_and_args(self): "spec.template.spec.containers[0].args", docs[0] ) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"webserver": {"command": command, "args": args}}, @@ -566,7 +560,8 @@ def test_command_and_args_overrides_are_templated(self): assert ["release-name"] == jmespath.search("spec.template.spec.containers[0].command", docs[0]) assert ["Helm"] == jmespath.search("spec.template.spec.containers[0].args", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, dag_values", [ ("1.10.15", {"gitSync": {"enabled": False}}), ("1.10.15", {"persistence": {"enabled": False}}), @@ -576,7 +571,7 @@ def test_command_and_args_overrides_are_templated(self): ("2.0.0", {"persistence": {"enabled": True}}), ("2.0.0", {"persistence": {"enabled": False}}), ("2.0.0", {"gitSync": {"enabled": True}, "persistence": {"enabled": True}}), - ] + ], ) def test_no_dags_mount_or_volume_or_gitsync_sidecar_expected(self, airflow_version, dag_values): docs = render_chart( @@ -590,12 +585,13 @@ def test_no_dags_mount_or_volume_or_gitsync_sidecar_expected(self, airflow_versi assert "dags" not in [vm["name"] for vm in jmespath.search("spec.template.spec.volumes", docs[0])] assert 1 == len(jmespath.search("spec.template.spec.containers", docs[0])) - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, dag_values, expected_read_only", [ ("1.10.15", {"gitSync": {"enabled": True}}, True), ("1.10.15", {"persistence": {"enabled": True}}, False), ("1.10.15", {"gitSync": {"enabled": True}, "persistence": {"enabled": True}}, True), - ] + ], ) def test_dags_mount(self, airflow_version, dag_values, expected_read_only): docs = render_chart( @@ -621,12 +617,13 @@ def test_dags_gitsync_volume_and_sidecar_and_init_container(self): c["name"] for c in jmespath.search("spec.template.spec.initContainers", docs[0]) ] - @parameterized.expand( + @pytest.mark.parametrize( + "dags_values, expected_claim_name", [ ({"persistence": {"enabled": True}}, "release-name-dags"), ({"persistence": {"enabled": True, "existingClaim": "test-claim"}}, "test-claim"), ({"persistence": {"enabled": True}, "gitSync": {"enabled": True}}, "release-name-dags"), - ] + ], ) def test_dags_persistence_volume_no_sidecar(self, dags_values, expected_claim_name): docs = render_chart( @@ -643,7 +640,7 @@ def test_dags_persistence_volume_no_sidecar(self, dags_values, expected_claim_na assert 1 == len(jmespath.search("spec.template.spec.initContainers", docs[0])) -class WebserverServiceTest(unittest.TestCase): +class TestWebserverService: def test_default_service(self): docs = render_chart( show_only=["templates/webserver/webserver-service.yaml"], @@ -679,7 +676,8 @@ def test_overrides(self): assert "127.0.0.1" == jmespath.search("spec.loadBalancerIP", docs[0]) assert ["10.123.0.0/16"] == jmespath.search("spec.loadBalancerSourceRanges", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "ports, expected_ports", [ ([{"port": 8888}], [{"port": 8888}]), # name is optional with a single port ( @@ -697,7 +695,7 @@ def test_overrides(self): {"name": "sidecar", "port": 80, "targetPort": "sidecar"}, ], ), - ] + ], ) def test_ports_overrides(self, ports, expected_ports): docs = render_chart( @@ -722,7 +720,7 @@ def test_should_add_component_specific_labels(self): assert jmespath.search("metadata.labels", docs[0])["test_label"] == "test_label_value" -class WebserverConfigmapTest(unittest.TestCase): +class TestWebserverConfigmap: def test_no_webserver_config_configmap_by_default(self): docs = render_chart(show_only=["templates/configmaps/webserver-configmap.yaml"]) assert 0 == len(docs) @@ -741,7 +739,7 @@ def test_webserver_config_configmap(self): ) -class WebserverNetworkPolicyTest(unittest.TestCase): +class TestWebserverNetworkPolicy: def test_off_by_default(self): docs = render_chart( show_only=["templates/webserver/webserver-networkpolicy.yaml"], @@ -770,7 +768,8 @@ def test_defaults(self): ) assert [{"port": 8080}] == jmespath.search("spec.ingress[0].ports", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "ports, expected_ports", [ ([{"port": "sidecar"}], [{"port": "sidecar"}]), ( @@ -832,7 +831,7 @@ def test_should_add_component_specific_labels(self): assert jmespath.search("metadata.labels", docs[0])["test_label"] == "test_label_value" -class WebserverServiceAccountTest(unittest.TestCase): +class TestWebserverServiceAccount: def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/tests/charts/test_worker.py b/tests/charts/test_worker.py index 2d6361007..89a755284 100644 --- a/tests/charts/test_worker.py +++ b/tests/charts/test_worker.py @@ -16,22 +16,21 @@ # under the License. from __future__ import annotations -import unittest - import jmespath -from parameterized import parameterized +import pytest from tests.charts.helm_template_generator import render_chart -class WorkerTest(unittest.TestCase): - @parameterized.expand( +class TestWorker: + @pytest.mark.parametrize( + "executor, persistence, kind", [ ("CeleryExecutor", False, "Deployment"), ("CeleryExecutor", True, "StatefulSet"), ("CeleryKubernetesExecutor", False, "Deployment"), ("CeleryKubernetesExecutor", True, "StatefulSet"), - ] + ], ) def test_worker_kind(self, executor, persistence, kind): """ @@ -47,7 +46,10 @@ def test_worker_kind(self, executor, persistence, kind): assert kind == jmespath.search("kind", docs[0]) - @parameterized.expand([(8, 10), (10, 8), (8, None), (None, 10), (None, None)]) + @pytest.mark.parametrize( + "revision_history_limit, global_revision_history_limit", + [(8, 10), (10, 8), (8, None), (None, 10), (None, None)], + ) def test_revision_history_limit(self, revision_history_limit, global_revision_history_limit): values = {"workers": {}} if revision_history_limit: @@ -169,12 +171,13 @@ def test_workers_host_aliases(self): assert "127.0.0.2" == jmespath.search("spec.template.spec.hostAliases[0].ip", docs[0]) assert "test.hostname" == jmespath.search("spec.template.spec.hostAliases[0].hostnames[0]", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "persistence, update_strategy, expected_update_strategy", [ (False, None, None), (True, {"rollingUpdate": {"partition": 0}}, {"rollingUpdate": {"partition": 0}}), (True, None, None), - ] + ], ) def test_workers_update_strategy(self, persistence, update_strategy, expected_update_strategy): docs = render_chart( @@ -190,7 +193,8 @@ def test_workers_update_strategy(self, persistence, update_strategy, expected_up assert expected_update_strategy == jmespath.search("spec.updateStrategy", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "persistence, strategy, expected_strategy", [ (True, None, None), ( @@ -199,7 +203,7 @@ def test_workers_update_strategy(self, persistence, update_strategy, expected_up {"rollingUpdate": {"maxSurge": "100%", "maxUnavailable": "50%"}}, ), (False, None, None), - ] + ], ) def test_workers_strategy(self, persistence, strategy, expected_strategy): docs = render_chart( @@ -378,7 +382,8 @@ def test_disable_livenessprobe(self): livenessprobe = jmespath.search("spec.template.spec.containers[0].livenessProbe", docs[0]) assert livenessprobe is None - @parameterized.expand( + @pytest.mark.parametrize( + "log_persistence_values, expected_volume", [ ({"enabled": False}, {"emptyDir": {}}), ({"enabled": True}, {"persistentVolumeClaim": {"claimName": "release-name-logs"}}), @@ -386,7 +391,7 @@ def test_disable_livenessprobe(self): {"enabled": True, "existingClaim": "test-claim"}, {"persistentVolumeClaim": {"claimName": "test-claim"}}, ), - ] + ], ) def test_logs_persistence_changes_volume(self, log_persistence_values, expected_volume): docs = render_chart( @@ -474,7 +479,8 @@ def test_airflow_local_settings_kerberos_sidecar(self): "readOnly": True, } in jmespath.search("spec.template.spec.containers[2].volumeMounts", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "airflow_version, expected_arg", [ ("1.9.0", "airflow worker"), ("1.10.14", "airflow worker"), @@ -497,14 +503,8 @@ def test_default_command_and_args_airflow_version(self, airflow_version, expecte f"exec \\\n{expected_arg}", ] == jmespath.search("spec.template.spec.containers[0].args", docs[0]) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_command_and_args_overrides(self, command, args): docs = render_chart( values={"workers": {"command": command, "args": args}}, @@ -537,14 +537,8 @@ def test_log_groomer_collector_default_retention_days(self): ) assert "15" == jmespath.search("spec.template.spec.containers[1].env[0].value", docs[0]) - @parameterized.expand( - [ - (None, None), - (None, ["custom", "args"]), - (["custom", "command"], None), - (["custom", "command"], ["custom", "args"]), - ] - ) + @pytest.mark.parametrize("command", [None, ["custom", "command"]]) + @pytest.mark.parametrize("args", [None, ["custom", "args"]]) def test_log_groomer_command_and_args_overrides(self, command, args): docs = render_chart( values={"workers": {"logGroomerSidecar": {"command": command, "args": args}}}, @@ -570,11 +564,12 @@ def test_log_groomer_command_and_args_overrides_are_templated(self): assert ["release-name"] == jmespath.search("spec.template.spec.containers[1].command", docs[0]) assert ["Helm"] == jmespath.search("spec.template.spec.containers[1].args", docs[0]) - @parameterized.expand( + @pytest.mark.parametrize( + "retention_days, retention_result", [ (None, None), (30, "30"), - ] + ], ) def test_log_groomer_retention_days_overrides(self, retention_days, retention_result): docs = render_chart( @@ -651,7 +646,7 @@ def test_persistence_volume_annotations(self): assert {"foo": "bar"} == jmespath.search("spec.volumeClaimTemplates[0].metadata.annotations", docs[0]) -class WorkerKedaAutoScalerTest(unittest.TestCase): +class TestWorkerKedaAutoScaler: def test_should_add_component_specific_labels(self): docs = render_chart( values={ @@ -668,7 +663,7 @@ def test_should_add_component_specific_labels(self): assert jmespath.search("metadata.labels", docs[0])["test_label"] == "test_label_value" -class WorkerNetworkPolicyTest(unittest.TestCase): +class TestWorkerNetworkPolicy: def test_should_add_component_specific_labels(self): docs = render_chart( values={ @@ -685,7 +680,7 @@ def test_should_add_component_specific_labels(self): assert jmespath.search("metadata.labels", docs[0])["test_label"] == "test_label_value" -class WorkerServiceTest(unittest.TestCase): +class TestWorkerService: def test_should_add_component_specific_labels(self): docs = render_chart( values={ @@ -701,7 +696,7 @@ def test_should_add_component_specific_labels(self): assert jmespath.search("metadata.labels", docs[0])["test_label"] == "test_label_value" -class WorkerServiceAccountTest(unittest.TestCase): +class TestWorkerServiceAccount: def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index 0899badfc..79d42b869 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -167,7 +167,13 @@ def test_task_dag_id_equals_filter(admin_client, url, content): "test_url, expected_url", [ ("", "/home"), + ("javascript:alert(1)", "/home"), + (" javascript:alert(1)", "http://localhost:8080/ javascript:alert(1)"), ("http://google.com", "/home"), + ("google.com", "http://localhost:8080/google.com"), + ("\\/google.com", "http://localhost:8080/\\/google.com"), + ("//google.com", "/home"), + ("\\/\\/google.com", "http://localhost:8080/\\/\\/google.com"), ("36539'%3balert(1)%2f%2f166", "/home"), ( "http://localhost:8080/trigger?dag_id=test&origin=36539%27%3balert(1)%2f%2f166&abc=2", diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py index bddc026b6..731dfcb48 100644 --- a/tests/www/views/test_views_connection.py +++ b/tests/www/views/test_views_connection.py @@ -324,3 +324,29 @@ def test_connection_form_widgets_testable_types(mock_pm_hooks, admin_client): } assert ["first"] == ConnectionFormWidget().testable_connection_types + + +def test_process_form_invalid_extra_removed(admin_client): + """ + Test that when an invalid json `extra` is passed in the form, it is removed and _not_ + saved over the existing extras. + """ + from airflow.www.views import lazy_add_provider_discovered_options_to_connection_form + + lazy_add_provider_discovered_options_to_connection_form() + + conn_details = {"conn_id": "test_conn", "conn_type": "http"} + conn = Connection(**conn_details, extra='{"foo": "bar"}') + conn.id = 1 + + with create_session() as session: + session.add(conn) + + data = {**conn_details, "extra": "Invalid"} + resp = admin_client.post('/connection/edit/1', data=data, follow_redirects=True) + + assert resp.status_code == 200 + with create_session() as session: + conn = session.query(Connection).get(1) + + assert conn.extra == '{"foo": "bar"}' diff --git a/tests/www/views/test_views_trigger_dag.py b/tests/www/views/test_views_trigger_dag.py index bc578f8ea..e443d5b50 100644 --- a/tests/www/views/test_views_trigger_dag.py +++ b/tests/www/views/test_views_trigger_dag.py @@ -149,14 +149,14 @@ def test_trigger_dag_form(admin_client): ("36539'%3balert(1)%2f%2f166", "/home"), ( '">