From 7cb3bea2c8787ddac06d8376a754167c0196f38c Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 18 Dec 2024 18:18:59 +0000 Subject: [PATCH 1/5] Refactor ssnowflake profile mapping Introduce the SnowflakeBaseProfileMapping class with behaviour that should be shared between multiple Snowflake profile mapping classes --- cosmos/profiles/snowflake/base.py | 12 ++++++++++++ .../user_encrypted_privatekey_env_variable.py | 4 ++-- .../snowflake/user_encrypted_privatekey_file.py | 4 ++-- cosmos/profiles/snowflake/user_pass.py | 4 ++-- cosmos/profiles/snowflake/user_privatekey.py | 12 ++---------- 5 files changed, 20 insertions(+), 16 deletions(-) create mode 100644 cosmos/profiles/snowflake/base.py diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py new file mode 100644 index 000000000..9e37249da --- /dev/null +++ b/cosmos/profiles/snowflake/base.py @@ -0,0 +1,12 @@ +from ..base import BaseProfileMapping + + +class SnowflakeBaseProfileMapping(BaseProfileMapping): + + def transform_account(self, account: str) -> str: + """Transform the account to the format . if it's not already.""" + region = self.conn.extra_dejson.get("region") + if region and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 70722dd59..e594abe6a 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index e217a6c22..bb8d29890 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 3fc6595c9..ecdc2d804 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): +class SnowflakeUserPasswordProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/password. https://docs.getdbt.com/reference/warehouse-setups/snowflake-setup diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index c74194b7a..b4b02a4f4 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication @@ -74,11 +74,3 @@ def profile(self) -> dict[str, Any | None]: # remove any null values return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) From ce2df184bc02f247a59c50871f2d108348ad025d Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 18 Dec 2024 18:43:35 +0000 Subject: [PATCH 2/5] Fix Snowflake Profile mapping when using AWS default region The dbt docs state: For AWS accounts in the US West default region, you can use abc123 (without any other segments). For some AWS accounts you will have to append the region and/or cloud platform. For example, abc123.eu-west-1 or abc123.eu-west-2.aws. Howevver, a Cosmos user reported that they were facing 404 and seeing a dbt error message when attempting to use SnowflakeUserPasswordProfileMapping with an Airflow Snowflake connection that defined the region us-west-2. By removing the region us-west-2 from the connection, we solved the issue. --- cosmos/profiles/snowflake/base.py | 19 +++++++++++++++++-- .../profiles/snowflake/test_snowflake_base.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 tests/profiles/snowflake/test_snowflake_base.py diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py index 9e37249da..0fcaa1d59 100644 --- a/cosmos/profiles/snowflake/base.py +++ b/cosmos/profiles/snowflake/base.py @@ -1,12 +1,27 @@ -from ..base import BaseProfileMapping +from typing import Any + +from cosmos.profiles.base import BaseProfileMapping + +DEFAULT_AWS_REGION = "us-west-2" class SnowflakeBaseProfileMapping(BaseProfileMapping): + @property + def profile(self) -> dict[str, Any | None]: + """Gets profile.""" + profile_vars = { + **self.mapped_params, + **self.profile_args, + } + + return self.filter_null(profile_vars) + def transform_account(self, account: str) -> str: """Transform the account to the format . if it's not already.""" region = self.conn.extra_dejson.get("region") - if region and region not in account: + # + if region and region != DEFAULT_AWS_REGION and region not in account: account = f"{account}.{region}" return str(account) diff --git a/tests/profiles/snowflake/test_snowflake_base.py b/tests/profiles/snowflake/test_snowflake_base.py new file mode 100644 index 000000000..ee8f6c6b3 --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_base.py @@ -0,0 +1,19 @@ +from unittest.mock import patch + +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping + + +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-west-2"}) +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn") +def test_default_region(mock_conn): + profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn") + response = profile_mapping.transform_account("myaccount") + assert response == "myaccount" + + +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-east-1"}) +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn") +def test_non_default_region(mock_conn): + profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn") + response = profile_mapping.transform_account("myaccount") + assert response == "myaccount.us-east-1" From 47c9277cef2ead0955f5c7c7d249a29c62d11dcf Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 18 Dec 2024 18:49:42 +0000 Subject: [PATCH 3/5] Fix static check --- cosmos/profiles/snowflake/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py index 0fcaa1d59..66ebffb48 100644 --- a/cosmos/profiles/snowflake/base.py +++ b/cosmos/profiles/snowflake/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from cosmos.profiles.base import BaseProfileMapping From 7e380471b0f6b658543edb693d00ca96d2ea0831 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 19 Dec 2024 07:18:55 +0000 Subject: [PATCH 4/5] Refactor to use base snowflake profile --- cosmos/profiles/snowflake/base.py | 3 +-- .../user_encrypted_privatekey_env_variable.py | 11 +++-------- .../snowflake/user_encrypted_privatekey_file.py | 10 ++-------- cosmos/profiles/snowflake/user_pass.py | 11 +++-------- cosmos/profiles/snowflake/user_privatekey.py | 11 +++-------- 5 files changed, 12 insertions(+), 34 deletions(-) diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py index 66ebffb48..599a9c8e5 100644 --- a/cosmos/profiles/snowflake/base.py +++ b/cosmos/profiles/snowflake/base.py @@ -16,8 +16,7 @@ def profile(self) -> dict[str, Any | None]: **self.mapped_params, **self.profile_args, } - - return self.filter_null(profile_vars) + return profile_vars def transform_account(self, account: str) -> str: """Transform the account to the format . if it's not already.""" diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index e594abe6a..1e64e5a46 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -75,14 +75,9 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - "private_key": self.get_env_var_format("private_key"), - "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), - } - - # remove any null values + profile_vars = super().profile + profile_vars["private_key"] = self.get_env_var_format("private_key") + profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) def transform_account(self, account: str) -> str: diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index bb8d29890..704108a4e 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -74,14 +74,8 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # private_key_passphrase should always get set as env var - "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), - } - - # remove any null values + profile_vars = super().profile + profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) def transform_account(self, account: str) -> str: diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index ecdc2d804..9c8f2e2f0 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -76,14 +76,9 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # password should always get set as env var - "password": self.get_env_var_format("password"), - } - - # remove any null values + profile_vars = super().profile + # password should always get set as env var + profile_vars["password"] = self.get_env_var_format("password") return self.filter_null(profile_vars) def transform_account(self, account: str) -> str: diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index b4b02a4f4..40a016af7 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -65,12 +65,7 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # private_key should always get set as env var - "private_key": self.get_env_var_format("private_key"), - } - - # remove any null values + profile_vars = super().profile + # private_key should always get set as env var + profile_vars["private_key"] = self.get_env_var_format("private_key") return self.filter_null(profile_vars) From 8abaee64c08be1d6fa98ea5fd430170f6daa886b Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 19 Dec 2024 07:20:16 +0000 Subject: [PATCH 5/5] Remove transform_account from remaining sf profiles --- .../snowflake/user_encrypted_privatekey_env_variable.py | 8 -------- .../profiles/snowflake/user_encrypted_privatekey_file.py | 8 -------- cosmos/profiles/snowflake/user_pass.py | 8 -------- 3 files changed, 24 deletions(-) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 1e64e5a46..63a6c68d3 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -79,11 +79,3 @@ def profile(self) -> dict[str, Any | None]: profile_vars["private_key"] = self.get_env_var_format("private_key") profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 704108a4e..6f35dad45 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -77,11 +77,3 @@ def profile(self) -> dict[str, Any | None]: profile_vars = super().profile profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 9c8f2e2f0..93c29793b 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -80,11 +80,3 @@ def profile(self) -> dict[str, Any | None]: # password should always get set as env var profile_vars["password"] = self.get_env_var_format("password") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account)