diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index c3970dc018108..a1807e5bb123f 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -386,6 +386,18 @@ type: string default: "0o077" example: ~ + - name: dataset_event_manager_class + description: Class to use as dataset event manager. + version_added: 2.4.0 + type: string + default: ~ + example: 'airflow.datasets.manager.DatasetEventManager' + - name: dataset_event_manager_kwargs + description: Kwargs to supply to dataset event manager. + version_added: 2.4.0 + type: string + default: ~ + example: '{"some_param": "some_value"}' - name: database description: ~ diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 90fa6e0df32c8..4944d32f0e727 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -220,6 +220,14 @@ max_map_length = 1024 # This value is treated as an octal-integer. daemon_umask = 0o077 +# Class to use as dataset event manager. +# Example: dataset_event_manager_class = airflow.datasets.manager.DatasetEventManager +# dataset_event_manager_class = + +# Kwargs to supply to dataset event manager. +# Example: dataset_event_manager_kwargs = {{"some_param": "some_value"}} +# dataset_event_manager_kwargs = + [database] # The SqlAlchemy connection string to the metadata database. # SqlAlchemy supports many different database engines. diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index a8be553052110..e16e3925157a1 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -15,13 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import TYPE_CHECKING + from sqlalchemy.orm.session import Session +from airflow.configuration import conf from airflow.datasets import Dataset from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel -from airflow.models.taskinstance import TaskInstance from airflow.utils.log.logging_mixin import LoggingMixin +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + class DatasetEventManager(LoggingMixin): """ @@ -31,8 +36,11 @@ class DatasetEventManager(LoggingMixin): Airflow deployments can use plugins that broadcast dataset events to each other. """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + def register_dataset_change( - self, *, task_instance: TaskInstance, dataset: Dataset, extra=None, session: Session, **kwargs + self, *, task_instance: "TaskInstance", dataset: Dataset, extra=None, session: Session, **kwargs ) -> None: """ For local datasets, look them up, record the dataset event, queue dagruns, and broadcast @@ -59,3 +67,20 @@ def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: self.log.debug("consuming dag ids %s", consuming_dag_ids) for dag_id in consuming_dag_ids: session.merge(DatasetDagRunQueue(dataset_id=dataset.id, target_dag_id=dag_id)) + + +def resolve_dataset_event_manager(): + _dataset_event_manager_class = conf.getimport( + section='core', + key='dataset_event_manager_class', + fallback='airflow.datasets.manager.DatasetEventManager', + ) + _dataset_event_manager_kwargs = conf.getjson( + section='core', + key='dataset_event_manager_kwargs', + fallback={}, + ) + return _dataset_event_manager_class(**_dataset_event_manager_kwargs) + + +dataset_event_manager = resolve_dataset_event_manager() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 5912094c25513..c810f5e5ff13a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -81,6 +81,7 @@ from airflow.compat.functools import cache from airflow.configuration import conf from airflow.datasets import Dataset +from airflow.datasets.manager import dataset_event_manager from airflow.exceptions import ( AirflowException, AirflowFailException, @@ -585,10 +586,6 @@ def __init__( # can be changed when calling 'run' self.test_mode = False - self.dataset_event_manager = conf.getimport( - 'core', 'dataset_event_manager_class', fallback='airflow.datasets.manager.DatasetEventManager' - )() - @staticmethod def insert_mapping(run_id: str, task: "Operator", map_index: int) -> dict: """:meta private:""" @@ -1538,7 +1535,7 @@ def _register_dataset_changes(self, *, session: Session) -> None: self.log.debug("outlet obj %s", obj) # Lineage can have other types of objects besides datasets if isinstance(obj, Dataset): - self.dataset_event_manager.register_dataset_change( + dataset_event_manager.register_dataset_change( task_instance=self, dataset=obj, session=session,