Skip to content

Commit

Permalink
Fix init checks for aws redshift to s3 operator (#37861)
Browse files Browse the repository at this point in the history
* remove commented sections

* Update airflow/providers/amazon/aws/transfers/redshift_to_s3.py

Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>

* add checks in execute

* ruff format

---------

Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>
  • Loading branch information
okirialbert and Taragolis authored Mar 4, 2024
1 parent 1726b93 commit ce00420
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 28 deletions.
8 changes: 0 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,6 @@ repos:
pass_filenames: true
files: ^airflow/providers/.*/(operators|transfers|sensors)/.*\.py$
additional_dependencies: [ 'rich>=12.4.4' ]
# TODO: Handle the provider-specific exclusions and remove them from the list, see:
# https://github.com/apache/airflow/issues/36484
exclude: |
(?x)^(
^.*__init__\.py$|
^airflow\/providers\/amazon\/aws\/transfers\/redshift_to_s3\.py$|
^airflow\/providers\/amazon\/aws\/operators\/emr\.py$|
)$
- id: ruff
name: Run 'ruff' for extremely fast Python linting
description: "Run 'ruff' for extremely fast Python linting"
Expand Down
39 changes: 20 additions & 19 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,35 +109,19 @@ def __init__(
) -> None:
super().__init__(**kwargs)
self.s3_bucket = s3_bucket
self.s3_key = f"{s3_key}/{table}_" if (table and table_as_file_name) else s3_key
self.s3_key = s3_key
self.schema = schema
self.table = table
self.redshift_conn_id = redshift_conn_id
self.aws_conn_id = aws_conn_id
self.verify = verify
self.unload_options: list = unload_options or []
self.unload_options = unload_options or []
self.autocommit = autocommit
self.include_header = include_header
self.parameters = parameters
self.table_as_file_name = table_as_file_name
self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}

if select_query:
self.select_query = select_query
elif self.schema and self.table:
self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
else:
raise ValueError(
"Please provide both `schema` and `table` params or `select_query` to fetch the data."
)

if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = [*self.unload_options, "HEADER"]

if self.redshift_data_api_kwargs:
for arg in ["sql", "parameters"]:
if arg in self.redshift_data_api_kwargs:
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
self.select_query = select_query

def _build_unload_query(
self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
Expand All @@ -153,9 +137,26 @@ def _build_unload_query(
"""

def execute(self, context: Context) -> None:
if self.table and self.table_as_file_name:
self.s3_key = f"{self.s3_key}/{self.table}_"

if self.schema and self.table:
self.select_query = f"SELECT * FROM {self.schema}.{self.table}"

if self.select_query is None:
raise ValueError(
"Please provide both `schema` and `table` params or `select_query` to fetch the data."
)

if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = [*self.unload_options, "HEADER"]

redshift_hook: RedshiftDataHook | RedshiftSQLHook
if self.redshift_data_api_kwargs:
redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
for arg in ["sql", "parameters"]:
if arg in self.redshift_data_api_kwargs:
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
else:
redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,15 @@ def test_invalid_param_in_redshift_data_api_kwargs(self, param):
Test passing invalid param in RS Data API kwargs raises an error
"""
with pytest.raises(AirflowException):
RedshiftToS3Operator(
redshift_operator = RedshiftToS3Operator(
s3_bucket="s3_bucket",
s3_key="s3_key",
select_query="select_query",
task_id="task_id",
dag=None,
redshift_data_api_kwargs={param: "param"},
)
redshift_operator.execute(None)

@pytest.mark.parametrize("table_as_file_name, expected_s3_key", [[True, "key/table_"], [False, "key"]])
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
Expand Down

0 comments on commit ce00420

Please sign in to comment.