diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index ba595751320b3..c043be413cdf9 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -23,49 +23,39 @@ from sqlalchemy import select -from airflow.api_internal.internal_api_call import internal_api_call +from airflow.models import Connection, Variable from airflow.secrets import BaseSecretsBackend from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.connection import Connection - class MetastoreBackend(BaseSecretsBackend): """Retrieves Connection object and Variable from airflow metastore database.""" @provide_session def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None: - return MetastoreBackend._fetch_connection(conn_id, session=session) - - @provide_session - def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: """ - Get Airflow Variable from Metadata DB. + Get Connection from Metadata DB. - :param key: Variable Key - :return: Variable Value + :param conn_id: Connection ID + :param session: SQLAlchemy Session + :return: Connection Object """ - return MetastoreBackend._fetch_variable(key=key, session=session) - - @staticmethod - @internal_api_call - @provide_session - def _fetch_connection(conn_id: str, session: Session = NEW_SESSION) -> Connection | None: - from airflow.models.connection import Connection - conn = session.scalar(select(Connection).where(Connection.conn_id == conn_id).limit(1)) session.expunge_all() return conn - @staticmethod - @internal_api_call @provide_session - def _fetch_variable(key: str, session: Session = NEW_SESSION) -> str | None: - from airflow.models.variable import Variable + def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: + """ + Get Airflow Variable from Metadata DB. + :param key: Variable Key + :param session: SQLAlchemy Session + :return: Variable Value + """ var_value = session.scalar(select(Variable).where(Variable.key == key).limit(1)) session.expunge_all() if var_value: