Skip to content

Commit

Permalink
Fixed getting values from extra__ keys in airflow con (astronomer#923)
Browse files Browse the repository at this point in the history
Fixes an issue astronomer#913 where in airflow 2.5.3 (and maybe other versions)
`keyfile_json` could not be parsed from the airflow connection, because
of the way it was referring to the wrong key in the `extra` dict in the
connection.

I also wrote unit tests - logically i think it might be a better idea to
put them in the tests for base class, but i wanted to adhere to the
overall logic.


Pre-commit fails with `get_dbt_value` being to complex. Personally i
would ignore that as i think branching for different `extras` increases
readability in this case, but i am up for a discussion


Closes astronomer#913
  • Loading branch information
glebkrapivin authored and arojasb3 committed Jul 14, 2024
1 parent fad82eb commit 3522022
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
21 changes: 13 additions & 8 deletions cosmos/profiles/base.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@ def get_profile_file_contents(

return str(yaml.dump(profile_contents, indent=4))

def _get_airflow_conn_field(self, airflow_field: str) -> Any:
# make sure there's no "extra." prefix
if airflow_field.startswith("extra."):
airflow_field = airflow_field.replace("extra.", "", 1)
value = self.conn.extra_dejson.get(airflow_field)
elif airflow_field.startswith("extra__"):
value = self.conn.extra_dejson.get(airflow_field)
else:
value = getattr(self.conn, airflow_field, None)

return value

def get_dbt_value(self, name: str) -> Any:
"""
Gets values for the dbt profile based on the required_by_dbt and required_in_profile_args lists.
Expand All @@ -260,16 +272,9 @@ def get_dbt_value(self, name: str) -> Any:
airflow_fields = [airflow_fields]

for airflow_field in airflow_fields:
# make sure there's no "extra." prefix
if airflow_field.startswith("extra."):
airflow_field = airflow_field.replace("extra.", "", 1)
value = self.conn.extra_dejson.get(airflow_field)
else:
value = getattr(self.conn, airflow_field, None)

value = self._get_airflow_conn_field(airflow_field)
if not value:
continue

# if there's a transform method, use it
if hasattr(self, f"transform_{name}"):
return getattr(self, f"transform_{name}")(value)
Expand Down
21 changes: 18 additions & 3 deletions tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import namedtuple
from unittest.mock import patch

import pytest
Expand All @@ -8,29 +9,43 @@
from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping

ConnExtraParams = namedtuple("ConnExtraParams", ["keyfile_dict", "keyfile_json_extra_key"])
sample_keyfile_dict = {
"type": "service_account",
"private_key_id": "my_private_key_id",
"private_key": "my_private_key",
}


@pytest.fixture(params=[sample_keyfile_dict, json.dumps(sample_keyfile_dict)])
def get_fixture_params():
"""
Make a matrix of params for the fixture that mock connection, as there are multiple fields in
the airflow param mapping for the "keyfile_json" in GoogleCloudServiceAccountDictProfileMapping
"""
fixture_params = []
for d in (sample_keyfile_dict, json.dumps(sample_keyfile_dict)):
for key in GoogleCloudServiceAccountDictProfileMapping.airflow_param_mapping.get("keyfile_json"):
if key.startswith("extra."):
key = key.replace("extra.", "")
fixture_params.append(ConnExtraParams(keyfile_dict=d, keyfile_json_extra_key=key))
return fixture_params


@pytest.fixture(params=get_fixture_params())
def mock_bigquery_conn_with_dict(request): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"dataset": "my_dataset",
"keyfile_dict": request.param,
request.param.keyfile_json_extra_key: request.param.keyfile_dict,
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn

Expand Down

0 comments on commit 3522022

Please sign in to comment.