From 957946f7e91f3f70409fd28eba5b92293764cb27 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Wed, 22 May 2024 16:10:50 +0530 Subject: [PATCH] Apply review suggetions --- .../elasticsearch/log/es_task_handler.py | 6 ++--- airflow/providers/elasticsearch/provider.yaml | 2 +- .../elasticsearch/log/test_es_task_handler.py | 23 ++++++++----------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py index 31306e5f24c5c..c397e2b3585ff 100644 --- a/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/airflow/providers/elasticsearch/log/es_task_handler.py @@ -18,7 +18,6 @@ from __future__ import annotations import contextlib -import importlib import inspect import logging import sys @@ -42,6 +41,7 @@ from airflow.utils import timezone from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin +from airflow.utils.module_loading import import_string from airflow.utils.session import create_session if TYPE_CHECKING: @@ -224,9 +224,7 @@ def _get_index_patterns(self, ti: TaskInstance | None) -> str: """ if self.index_patterns_callable: self.log.debug("Using index_patterns_callable: %s", self.index_patterns_callable) - module_path, index_pattern_function = self.index_patterns_callable.rsplit(".", 1) - module = importlib.import_module(module_path) - index_pattern_callable_obj = getattr(module, index_pattern_function) + index_pattern_callable_obj = import_string(self.index_patterns_callable) return index_pattern_callable_obj(ti) self.log.debug("Using index_patterns: %s", self.index_patterns) return self.index_patterns diff --git a/airflow/providers/elasticsearch/provider.yaml b/airflow/providers/elasticsearch/provider.yaml index db252ef1888cc..46ae0530673dd 100644 --- a/airflow/providers/elasticsearch/provider.yaml +++ b/airflow/providers/elasticsearch/provider.yaml @@ -169,7 +169,7 @@ config: description: | A string representing the full path to the Python callable path which accept TI object and return comma separated list of index patterns. This will takes precedence over index_patterns. - version_added: 5.4.0 + version_added: 5.5.0 type: string example: module.callable default: "" diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py index 1a75c3464c5b8..c920f0d466871 100644 --- a/tests/providers/elasticsearch/log/test_es_task_handler.py +++ b/tests/providers/elasticsearch/log/test_es_task_handler.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import importlib import json import logging import os @@ -26,7 +25,7 @@ from io import StringIO from pathlib import Path from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import Mock, patch from urllib.parse import quote import elasticsearch @@ -34,7 +33,6 @@ import pytest from airflow.configuration import conf -from airflow.models import TaskInstance from airflow.providers.elasticsearch.log.es_response import ElasticSearchResponse from airflow.providers.elasticsearch.log.es_task_handler import ( VALID_ES_CONFIG_KEYS, @@ -52,7 +50,6 @@ pytestmark = pytest.mark.db_test - AIRFLOW_SOURCES_ROOT_DIR = Path(__file__).parents[4].resolve() ES_PROVIDER_YAML_FILE = AIRFLOW_SOURCES_ROOT_DIR / "airflow" / "providers" / "elasticsearch" / "provider.yaml" @@ -646,17 +643,17 @@ def test_dynamic_offset(self, stdout_mock, ti, time_machine): assert second_log["asctime"] == t2.format("YYYY-MM-DDTHH:mm:ss.SSSZZ") assert third_log["asctime"] == t3.format("YYYY-MM-DDTHH:mm:ss.SSSZZ") - def test_index_patterns_callable(self): - ti = MagicMock(spec=TaskInstance) + def test_get_index_patterns_with_callable(self): + with patch("airflow.providers.elasticsearch.log.es_task_handler.import_string") as mock_import_string: + mock_callable = Mock(return_value="callable_index_pattern") + mock_import_string.return_value = mock_callable - def mock_callable(ti): - return "mocked_index_patterns" + self.es_task_handler.index_patterns_callable = "path.to.index_pattern_callable" + result = self.es_task_handler._get_index_patterns({}) - importlib.import_module = MagicMock() - importlib.import_module.return_value = MagicMock(**{"mock_callable": mock_callable}) - self.es_task_handler.index_patterns_callable = "module_path.mock_callable" - result = self.es_task_handler._get_index_patterns(ti) - assert result == "mocked_index_patterns" + mock_import_string.assert_called_once_with("path.to.index_pattern_callable") + mock_callable.assert_called_once_with({}) + assert result == "callable_index_pattern" def test_safe_attrgetter():