Skip to content

Commit

Permalink
Added function in AWSAthenaHook to get s3 output query results file U…
Browse files Browse the repository at this point in the history
…RI (#20124)
  • Loading branch information
subkanthi authored Dec 8, 2021
1 parent 7081831 commit 0e2a0cc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
24 changes: 24 additions & 0 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,30 @@ def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] =
sleep(self.sleep_time)
return final_query_state

def get_output_location(self, query_execution_id: str) -> str:
"""
Function to get the output location of the query results
in s3 uri format.
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: str
"""
output_location = None
if query_execution_id:
response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)

if response:
try:
output_location = response['QueryExecution']['ResultConfiguration']['OutputLocation']
except KeyError:
self.log.error("Error retrieving OutputLocation")
raise
else:
raise ValueError("Invalid Query execution id")

return output_location

def stop_query(self, query_execution_id: str) -> Dict:
"""
Cancel the submitted athena query
Expand Down
13 changes: 13 additions & 0 deletions tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@

MOCK_QUERY_EXECUTION = {'QueryExecutionId': MOCK_DATA['query_execution_id']}

MOCK_QUERY_EXECUTION_OUTPUT = {
'QueryExecution': {
'QueryExecutionId': MOCK_DATA['query_execution_id'],
'ResultConfiguration': {'OutputLocation': 's3://test_bucket/test.csv'},
}
}


class TestAWSAthenaHook(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -161,6 +168,12 @@ def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == 'RUNNING'

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_output_location(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
result = self.athena.get_output_location(query_execution_id=MOCK_DATA['query_execution_id'])
assert result == 's3://test_bucket/test.csv'


if __name__ == '__main__':
unittest.main()

0 comments on commit 0e2a0cc

Please sign in to comment.