Skip to content

Commit

Permalink
Added 'workgroup' as optional argument to athena component (kubeflow#…
Browse files Browse the repository at this point in the history
…3254)

* added optional workgroup arg

* Added a slighly more elegant way of handling query exec params

* added args to yml and main()

* added args to yml and main()
  • Loading branch information
Leonard Aukea authored and Jeffwan committed Dec 9, 2020
1 parent 9664584 commit 8eb18b8
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 69 deletions.
4 changes: 3 additions & 1 deletion components/aws/athena/query/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ inputs:
- {name: database, description: 'The name of the database.'}
- {name: query, description: 'The SQL query statements to be executed in Athena.'}
- {name: output_path, description: 'The path to the Amazon S3 location where logs for this cluster are stored.'}
- {name: workgroup, description: 'Optional argument to provide Athena workgroup'}
outputs:
- {name: output_path, description: 'The path to the S3 bucket containing the query output in CSV format.'}
implementation:
Expand All @@ -30,7 +31,8 @@ implementation:
--region, {inputValue: region},
--database, {inputValue: database},
--query, {inputValue: query},
--output, {inputValue: output_path}
--output, {inputValue: output_path},
--workgroup, {inputValue: workgroup}
]
fileOutputs:
output_path: /output.txt
164 changes: 96 additions & 68 deletions components/aws/athena/query/src/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,73 +20,101 @@


def get_client(region=None):
"""Builds a client to the AWS Athena API."""
client = boto3.client('athena', region_name=region)
return client

def query(client, query, database, output):
response = client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': database
},
ResultConfiguration={
'OutputLocation': output,
}
)

execution_id = response['QueryExecutionId']
logging.info('Execution ID: %s', execution_id)

# Athena query is aync call, we need to fetch results and wait for execution
state = 'RUNNING'
max_execution = 5 # TODO: this should be an optional parameter from users. or use timeout

while (max_execution > 0 and state in ['RUNNING']):
max_execution = max_execution - 1
response = client.get_query_execution(QueryExecutionId = execution_id)

if 'QueryExecution' in response and \
'Status' in response['QueryExecution'] and \
'State' in response['QueryExecution']['Status']:
state = response['QueryExecution']['Status']['State']
if state == 'FAILED':
raise Exception('Athena Query Failed')
elif state == 'SUCCEEDED':
s3_path = response['QueryExecution']['ResultConfiguration']['OutputLocation']
# could be multiple files?
filename = re.findall('.*\/(.*)', s3_path)[0]
logging.info("S3 output file name %s", filename)
break
time.sleep(5)

# TODO:(@Jeffwan) Add more details.
result = {
'total_bytes_processed': response['QueryExecution']['Statistics']['DataScannedInBytes'],
'filename': filename
}

return result
"""Builds a client to the AWS Athena API."""
client = boto3.client("athena", region_name=region)
return client


def query(client, query, database, output, workgroup=None):
"""Executes an AWS Athena query."""
params = dict(
QueryString=query,
QueryExecutionContext={"Database": database},
ResultConfiguration={"OutputLocation": output,},
)
if workgroup:
params.update(dict(WorkGroup=workgroup))

response = client.start_query_execution(**params)

execution_id = response["QueryExecutionId"]
logging.info("Execution ID: %s", execution_id)

# Athena query is aync call, we need to fetch results and wait for execution
state = "RUNNING"
max_execution = (
5 # TODO: this should be an optional parameter from users. or use timeout
)

while max_execution > 0 and state in ["RUNNING"]:
max_execution = max_execution - 1
response = client.get_query_execution(QueryExecutionId=execution_id)

if (
"QueryExecution" in response
and "Status" in response["QueryExecution"]
and "State" in response["QueryExecution"]["Status"]
):
state = response["QueryExecution"]["Status"]["State"]
if state == "FAILED":
raise Exception("Athena Query Failed")
elif state == "SUCCEEDED":
s3_path = response["QueryExecution"]["ResultConfiguration"][
"OutputLocation"
]
# could be multiple files?
filename = re.findall(".*\/(.*)", s3_path)[0]
logging.info("S3 output file name %s", filename)
break
time.sleep(5)

# TODO:(@Jeffwan) Add more details.
result = {
"total_bytes_processed": response["QueryExecution"]["Statistics"][
"DataScannedInBytes"
],
"filename": filename,
}

return result

def main():
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--region', type=str, help='Athena region.')
parser.add_argument('--database', type=str, required=True, help='The name of the database.')
parser.add_argument('--query', type=str, required=True, help='The SQL query statements to be executed in Athena.')
parser.add_argument('--output', type=str, required=False,
help='The location in Amazon S3 where your query results are stored, such as s3://path/to/query/bucket/')

args = parser.parse_args()

client = get_client(args.region)
results = query(client, args.query, args.database, args.output)

results['output'] = args.output
logging.info('Athena results: %s', results)
with open('/output.txt', 'w+') as f:
json.dump(results, f)


if __name__ == '__main__':
main()
def main():
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--region", type=str, help="Athena region.")
parser.add_argument(
"--database", type=str, required=True, help="The name of the database."
)
parser.add_argument(
"--query",
type=str,
required=True,
help="The SQL query statements to be executed in Athena.",
)
parser.add_argument(
"--output",
type=str,
required=False,
help="The location in Amazon S3 where your query results are stored, such as s3://path/to/query/bucket/",
)
parser.add_argument(
"--workgroup",
type=str,
required=False,
help="Optional argument to provide Athena workgroup",
)

args = parser.parse_args()

client = get_client(args.region)
results = query(client, args.query, args.database, args.output, args.workgroup)

results["output"] = args.output
logging.info("Athena results: %s", results)
with open("/output.txt", "w+") as f:
json.dump(results, f)


if __name__ == "__main__":
main()

0 comments on commit 8eb18b8

Please sign in to comment.