Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function in AWSAthenaHook to get s3 output query results file URI #20124

Merged
merged 10 commits into from
Dec 8, 2021
24 changes: 24 additions & 0 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,30 @@ def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] =
sleep(self.sleep_time)
return final_query_state

def get_output_location(self, query_execution_id: str) -> str:
"""
Function to get the output location of the query results
in s3 uri format.

:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: str
"""
output_location = None
if query_execution_id:
response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)

if response:
try:
output_location = response['QueryExecution']['ResultConfiguration']['OutputLocation']
except KeyError:
self.log.error("Error retrieving OutputLocation")
subkanthi marked this conversation as resolved.
Show resolved Hide resolved
raise
else:
raise ValueError("Invalid Query execution id")

return output_location

def stop_query(self, query_execution_id: str) -> Dict:
"""
Cancel the submitted athena query
Expand Down
13 changes: 13 additions & 0 deletions tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@

MOCK_QUERY_EXECUTION = {'QueryExecutionId': MOCK_DATA['query_execution_id']}

MOCK_QUERY_EXECUTION_OUTPUT = {
'QueryExecution': {
'QueryExecutionId': MOCK_DATA['query_execution_id'],
'ResultConfiguration': {'OutputLocation': 's3://test_bucket/test.csv'},
}
}


class TestAWSAthenaHook(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -161,6 +168,12 @@ def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == 'RUNNING'

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_output_location(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
result = self.athena.get_output_location(query_execution_id=MOCK_DATA['query_execution_id'])
assert result == 's3://test_bucket/test.csv'


if __name__ == '__main__':
unittest.main()