diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index e2a4a453fbb07..78d6ae484cd3c 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -128,6 +128,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: "authenticator": "snowflake oauth", "private_key_file": "private key", "session_parameters": "session parameters", + "client_request_mfa_token": "client request mfa token", }, indent=1, ), @@ -155,6 +156,7 @@ def __init__(self, *args, **kwargs) -> None: self.schema = kwargs.pop("schema", None) self.authenticator = kwargs.pop("authenticator", None) self.session_parameters = kwargs.pop("session_parameters", None) + self.client_request_mfa_token = kwargs.pop("client_request_mfa_token", None) self.query_ids: list[str] = [] def _get_field(self, extra_dict, field_name): @@ -194,6 +196,7 @@ def _get_conn_params(self) -> dict[str, str | None]: role = self._get_field(extra_dict, "role") or "" insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode")) schema = conn.schema or "" + client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, "client_request_mfa_token")) # authenticator and session_parameters never supported long name so we don't use _get_field authenticator = extra_dict.get("authenticator", "snowflake") @@ -216,6 +219,9 @@ def _get_conn_params(self) -> dict[str, str | None]: if insecure_mode: conn_config["insecure_mode"] = insecure_mode + if client_request_mfa_token: + conn_config["client_request_mfa_token"] = client_request_mfa_token + # If private_key_file is specified in the extra json, load the contents of the file as a private key. # If private_key_content is specified in the extra json, use it as a private key. # As a next step, specify this private key in the connection configuration. @@ -280,7 +286,9 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: **{ k: v for k, v in conn_params.items() - if v and k not in ["session_parameters", "insecure_mode", "private_key"] + if v + and k + not in ["session_parameters", "insecure_mode", "private_key", "client_request_mfa_token"] } ) diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index 54a18eeca7336..16e10db048479 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -137,6 +137,7 @@ class TestPytestSnowflakeHook: "extra__snowflake__region": "af_region", "extra__snowflake__role": "af_role", "extra__snowflake__insecure_mode": "True", + "extra__snowflake__client_request_mfa_token": "True", }, }, ( @@ -156,6 +157,7 @@ class TestPytestSnowflakeHook: "user": "user", "warehouse": "af_wh", "insecure_mode": True, + "client_request_mfa_token": True, }, ), ( @@ -168,6 +170,7 @@ class TestPytestSnowflakeHook: "extra__snowflake__region": "af_region", "extra__snowflake__role": "af_role", "extra__snowflake__insecure_mode": "False", + "extra__snowflake__client_request_mfa_token": "False", }, }, ( @@ -243,6 +246,7 @@ class TestPytestSnowflakeHook: "extra": { **BASE_CONNECTION_KWARGS["extra"], "extra__snowflake__insecure_mode": False, + "extra__snowflake__client_request_mfa_token": False, }, }, (