From 0e2a0ccd3087f53222e7859f414daf0ffa50dfbb Mon Sep 17 00:00:00 2001 From: Kanthi Date: Wed, 8 Dec 2021 15:38:27 -0500 Subject: [PATCH] Added function in AWSAthenaHook to get s3 output query results file URI (#20124) --- airflow/providers/amazon/aws/hooks/athena.py | 24 +++++++++++++++++++ .../providers/amazon/aws/hooks/test_athena.py | 13 ++++++++++ 2 files changed, 37 insertions(+) diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py index 2b30fd8b1fddd..9bb58fd7345c5 100644 --- a/airflow/providers/amazon/aws/hooks/athena.py +++ b/airflow/providers/amazon/aws/hooks/athena.py @@ -227,6 +227,30 @@ def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] = sleep(self.sleep_time) return final_query_state + def get_output_location(self, query_execution_id: str) -> str: + """ + Function to get the output location of the query results + in s3 uri format. + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :return: str + """ + output_location = None + if query_execution_id: + response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) + + if response: + try: + output_location = response['QueryExecution']['ResultConfiguration']['OutputLocation'] + except KeyError: + self.log.error("Error retrieving OutputLocation") + raise + else: + raise ValueError("Invalid Query execution id") + + return output_location + def stop_query(self, query_execution_id: str) -> Dict: """ Cancel the submitted athena query diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py index 5289dd940e49f..a3cf521383f2c 100644 --- a/tests/providers/amazon/aws/hooks/test_athena.py +++ b/tests/providers/amazon/aws/hooks/test_athena.py @@ -39,6 +39,13 @@ MOCK_QUERY_EXECUTION = {'QueryExecutionId': MOCK_DATA['query_execution_id']} +MOCK_QUERY_EXECUTION_OUTPUT = { + 'QueryExecution': { + 'QueryExecutionId': MOCK_DATA['query_execution_id'], + 'ResultConfiguration': {'OutputLocation': 's3://test_bucket/test.csv'}, + } +} + class TestAWSAthenaHook(unittest.TestCase): def setUp(self): @@ -161,6 +168,12 @@ def test_hook_poll_query_with_timeout(self, mock_conn): mock_conn.return_value.get_query_execution.assert_called_once() assert result == 'RUNNING' + @mock.patch.object(AWSAthenaHook, 'get_conn') + def test_hook_get_output_location(self, mock_conn): + mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT + result = self.athena.get_output_location(query_execution_id=MOCK_DATA['query_execution_id']) + assert result == 's3://test_bucket/test.csv' + if __name__ == '__main__': unittest.main()