Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Snowflake Profile mapping when using AWS default region #1406

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cosmos/profiles/snowflake/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import Any

from cosmos.profiles.base import BaseProfileMapping

DEFAULT_AWS_REGION = "us-west-2"


class SnowflakeBaseProfileMapping(BaseProfileMapping):

@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
}
return profile_vars

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
#
if region and region != DEFAULT_AWS_REGION and region not in account:
account = f"{account}.{region}"

return str(account)
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
Expand Down Expand Up @@ -75,20 +75,7 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
"private_key": self.get_env_var_format("private_key"),
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
profile_vars = super().profile
profile_vars["private_key"] = self.get_env_var_format("private_key")
profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
22 changes: 4 additions & 18 deletions cosmos/profiles/snowflake/user_encrypted_privatekey_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
Expand Down Expand Up @@ -74,20 +74,6 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
# private_key_passphrase should always get set as env var
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
profile_vars = super().profile
profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
23 changes: 5 additions & 18 deletions cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeUserPasswordProfileMapping(BaseProfileMapping):
class SnowflakeUserPasswordProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/password.
https://docs.getdbt.com/reference/warehouse-setups/snowflake-setup
Expand Down Expand Up @@ -76,20 +76,7 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

# remove any null values
profile_vars = super().profile
# password should always get set as env var
profile_vars["password"] = self.get_env_var_format("password")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
23 changes: 5 additions & 18 deletions cosmos/profiles/snowflake/user_privatekey.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
Expand Down Expand Up @@ -65,20 +65,7 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
# private_key should always get set as env var
"private_key": self.get_env_var_format("private_key"),
}

# remove any null values
profile_vars = super().profile
# private_key should always get set as env var
profile_vars["private_key"] = self.get_env_var_format("private_key")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
tatiana marked this conversation as resolved.
Show resolved Hide resolved
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
19 changes: 19 additions & 0 deletions tests/profiles/snowflake/test_snowflake_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from unittest.mock import patch

from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping


@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-west-2"})
@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn")
def test_default_region(mock_conn):
profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn")
response = profile_mapping.transform_account("myaccount")
assert response == "myaccount"


@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-east-1"})
@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn")
def test_non_default_region(mock_conn):
profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn")
response = profile_mapping.transform_account("myaccount")
assert response == "myaccount.us-east-1"
Loading