-
Notifications
You must be signed in to change notification settings - Fork 7
/
unload.py
88 lines (79 loc) · 3.51 KB
/
unload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
'''
Created on May 11, 2017
@author: Devin
'''
import json
import os
import psycopg2
import argparse
def _simple_sanitize(s):
return s.split(';')[0]
def run(config, tablename, file_path, schema_name=None, sql_file=None, range_col=None, range_start=None, range_end=None):
if not file_path:
file_path = tablename
conn = psycopg2.connect(**config['db'])
unload_options = '\n'.join(config.get('unload_options', []))
cursor = conn.cursor()
if schema_name:
query = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{}' AND table_schema = '{}' ORDER BY ordinal_position".format(tablename, schema_name)
else:
query = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{}' ORDER BY ordinal_position".format(tablename)
cursor.execute(query)
res = cursor.fetchall()
cast_columns = []
columns = [x[0] for x in res]
for col in res:
if 'boolean' in col[1]: # Boolean is a special case; cannot be casted to text so it needs to be handled differently
cast_columns.append("CASE {} WHEN 1 THEN \\\'true\\\' ELSE \\\'false\\\'::text END".format(col[0]))
else:
cast_columns.append("{}::text".format(col[0]))
header_str = ''
for i in columns:
header_str += "\\\'" + i + "\\\' as " + i.split(':')[0] + ', '
header_str = header_str.rstrip().rstrip(',')
column_str = ", ".join(columns)
cast_columns_str = ", ".join(cast_columns)
cursor = conn.cursor()
where_clause = ""
if range_col and range_start and range_end:
where_clause = cursor.mogrify("WHERE {} BETWEEN \\\'{}\\\' AND \\\'{}\\\'".format(range_col, range_start, range_end,))
elif sql_file:
where_clause = sql_file
query = """
UNLOAD (\'SELECT {0} FROM (
SELECT 1 as rn, {1}
UNION ALL
(SELECT 2 as rn, {2}
FROM {3}{4} {5})) ORDER BY rn\')
TO \'{8}\'
CREDENTIALS 'aws_access_key_id={6};aws_secret_access_key={7}'
{9}
""".format(column_str, header_str, cast_columns_str, '{}.'.format(schema_name) if schema_name else '', tablename,
where_clause, config['aws_access_key_id'],
config['aws_secret_access_key'], file_path, unload_options)
print "The following UNLOAD query is being run: \n" + query
cursor.execute(query)
print 'Completed write to {}'.format(file_path)
if __name__ == '__main__':
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'config.json')
with open(config_path, 'r') as f:
config = json.loads(f.read())
parser = argparse.ArgumentParser()
parser.add_argument('-t', help='Table name')
parser.add_argument('-c', help='Schema name')
parser.add_argument('-f', help='Desired S3 file path')
parser.add_argument('-s', help='SQL WHERE clause')
parser.add_argument('-r', help='Range column')
parser.add_argument('-r1', help='Range start')
parser.add_argument('-r2', help='Range end')
raw_args = parser.parse_args()
if 's' in vars(raw_args) and raw_args.s:
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), raw_args.s), 'r') as f:
raw_args.s = f.read()
args = {}
for k, v in vars(raw_args).items():
if v:
args[k] = _simple_sanitize(v)
else:
args[k] = None
run(config, args['t'], args['f'], args['c'], args['s'], args['r'], args['r1'], args['r2'])