-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathdb.py
424 lines (347 loc) · 12.8 KB
/
db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from collections import namedtuple, OrderedDict
from enum import Enum
from six import string_types
from six.moves.urllib import parse
from pprint import pformat
import json
import requests
from pinotdb import exceptions
import logging
logger = logging.getLogger(__name__)
class Type(Enum):
STRING = 1
NUMBER = 2
BOOLEAN = 3
def connect(*args, **kwargs):
"""
Constructor for creating a connection to the database.
>>> conn = connect('localhost', 8099)
>>> curs = conn.cursor()
"""
return Connection(*args, **kwargs)
def check_closed(f):
"""Decorator that checks if connection/cursor is closed."""
def g(self, *args, **kwargs):
if self.closed:
raise exceptions.Error(f"{self.__class__.__name__} already closed")
return f(self, *args, **kwargs)
return g
def check_result(f):
"""Decorator that checks if the cursor has results from `execute`."""
def g(self, *args, **kwargs):
if self._results is None:
raise exceptions.Error("Called before `execute`")
return f(self, *args, **kwargs)
return g
def get_description_from_types(column_names, types):
return [
(
name, # name
tc.code, # type_code
None, # [display_size]
None, # [internal_size]
None, # [precision]
None, # [scale]
None, # [null_ok]
)
for name, tc in zip(column_names, types)
]
TypeCodeAndValue = namedtuple(
"TypeCodeAndValue", ["code", "is_iterable", "coerce_to_string"]
)
def get_types_from_column_data_types(column_data_types):
types = [None] * len(column_data_types)
for column_index, column_data_type in enumerate(column_data_types):
data_type = column_data_type.split("_")[0]
is_iterable = "_ARRAY" in column_data_type
if (
data_type == "INT"
or data_type == "LONG"
or data_type == "FLOAT"
or data_type == "DOUBLE"
):
types[column_index] = TypeCodeAndValue(Type.NUMBER, is_iterable, False)
elif data_type == "STRING" or data_type == "BYTES":
types[column_index] = TypeCodeAndValue(Type.STRING, is_iterable, False)
else:
types[column_index] = TypeCodeAndValue(Type.STRING, is_iterable, True)
return types
def get_group_by_column_names(aggregation_results):
group_by_cols = []
for metric in aggregation_results:
metric_name = metric.get("function", "noname")
gby_cols_for_metric = metric.get("groupByColumns", [])
if group_by_cols and group_by_cols != gby_cols_for_metric:
raise exceptions.DatabaseError(
f"Cols for metric {metric_name}: {gby_cols_for_metric} differ from other columns {group_by_cols}"
)
elif not group_by_cols:
group_by_cols = gby_cols_for_metric[:]
return group_by_cols
def is_iterable(value):
try:
_ = iter(value)
return True
except TypeError:
return False
class Connection(object):
"""Connection to a Pinot database."""
def __init__(self, *args, **kwargs):
self._debug = kwargs.get("debug", False)
self._args = args
self._kwargs = kwargs
self.closed = False
self.cursors = []
@check_closed
def close(self):
"""Close the connection now."""
self.closed = True
for cursor in self.cursors:
try:
cursor.close()
except exceptions.Error:
pass # already closed
@check_closed
def commit(self):
"""
Commit any pending transaction to the database.
Not supported.
"""
pass
@check_closed
def cursor(self):
"""Return a new Cursor Object using the connection."""
cursor = Cursor(*self._args, **self._kwargs)
self.cursors.append(cursor)
return cursor
@check_closed
def execute(self, operation, parameters=None):
cursor = self.cursor()
return cursor.execute(operation, parameters)
def __enter__(self):
return self.cursor()
def __exit__(self, *exc):
self.close()
def convert_result_if_required(types, rows):
coercion_needed = any(t.coerce_to_string for t in types)
if not coercion_needed:
return rows
for i, t in enumerate(types):
if t.coerce_to_string:
for row in rows:
row[i] = json.dumps(row[i])
return rows
class Cursor(object):
"""Connection cursor."""
def __init__(
self,
host,
port=8099,
scheme="http",
path="/query/sql",
username=None,
password=None,
verify_ssl=True,
extra_request_headers="",
debug=False,
preserve_types=False,
ignore_exception_error_codes="",
acceptable_respond_fraction=-1,
):
if path == "query":
path = "query/sql"
self.url = parse.urlunparse((scheme, f"{host}:{port}", path, None, None, None))
# This read/write attribute specifies the number of rows to fetch at a
# time with .fetchmany(). It defaults to 1 meaning to fetch a single
# row at a time.
self.arraysize = 1
self.closed = False
# these are updated only after a query
self.description = None
self.rowcount = -1
self._results = None
self._debug = debug
self._preserve_types = preserve_types
self.acceptable_respond_fraction = acceptable_respond_fraction
if ignore_exception_error_codes:
self._ignore_exception_error_codes = set(
[int(x) for x in ignore_exception_error_codes.split(",")]
)
else:
self._ignore_exception_error_codes = []
self.session = requests.Session()
self.session.auth = (username, password)
self.session.verify = verify_ssl
self.session.headers.update({"Content-Type": "application/json"})
extra_headers = {}
if extra_request_headers:
for header in extra_request_headers.split(","):
k, v = header.split("=")
extra_headers[k] = v
self.session.headers.update(extra_headers)
@check_closed
def close(self):
"""Close the cursor."""
self.session.close()
self.closed = True
def is_valid_exception(self, e):
if "errorCode" not in e:
return True
else:
return e["errorCode"] not in self._ignore_exception_error_codes
def check_sufficient_responded(self, query, queried, responded):
fraction = self.acceptable_respond_fraction
if fraction == 0:
return
if queried < 0 or responded < 0:
responded = -1
needed = -1
elif fraction <= -1:
needed = queried
elif fraction > 0 and fraction < 1:
needed = int(fraction * queried)
else:
needed = fraction
if responded < 0 or responded < needed:
raise exceptions.DatabaseError(
f"Query\n\n{query} timed out: Out of {queried}, only"
f" {responded} responded, while needed was {needed}"
)
@check_closed
def execute(self, operation, parameters=None):
query = apply_parameters(operation, parameters or {})
if self._preserve_types:
query += " OPTION(preserveType='true')"
payload = {"sql": query}
if self._debug:
logger.info(
f"Submitting the pinot query to {self.url}:\n{query}\n{pformat(payload)}, with {self.session.headers}"
)
r = self.session.post(self.url, json=payload, verify=self.session.verify, auth=self.session.auth)
if r.encoding is None:
r.encoding = "utf-8"
try:
payload = r.json()
except Exception as e:
raise exceptions.DatabaseError(
f"Error when querying {query} from {self.url}, raw response is:\n{r.text}"
) from e
if self._debug:
logger.info(
f"Got the payload of type {type(payload)} with the status code {0 if not r else r.status_code}:\n{payload}"
)
num_servers_responded = payload.get("numServersResponded", -1)
num_servers_queried = payload.get("numServersQueried", -1)
self.check_sufficient_responded(
query, num_servers_queried, num_servers_responded
)
# raise any error messages
if r.status_code != 200:
msg = f"Query\n\n{query}\n\nreturned an error: {r.status_code}\nFull response is {pformat(payload)}"
raise exceptions.ProgrammingError(msg)
query_exceptions = [
e for e in payload.get("exceptions", []) if self.is_valid_exception(e)
]
if query_exceptions:
msg = "\n".join(pformat(exception) for exception in query_exceptions)
raise exceptions.DatabaseError(msg)
rows = [] # array of array, where inner array is array of column values
column_names = [] # column names, such that len(column_names) == len(rows[0])
column_data_types = [] # column data types 1:1 mapping to column_names
if "resultTable" in payload:
results = payload["resultTable"]
column_names = results.get("dataSchema").get("columnNames")
column_data_types = results.get("dataSchema").get("columnDataTypes")
values = results.get("rows")
if column_names:
rows = values
else:
raise exceptions.DatabaseError(
f"Expected columns and results in resultTable, but got {pformat(results)} instead"
)
logger.debug(f"Got the rows as a type {type(rows)} of size {len(rows)}")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(pformat(rows))
self.description = None
self._results = []
if column_data_types:
types = get_types_from_column_data_types(column_data_types)
if self._debug:
logger.info(
f"Column_names are {pformat(column_names)}, Column_data_types are {pformat(column_data_types)}, Types are {pformat(types)}"
)
self._results = convert_result_if_required(types, rows)
self.description = get_description_from_types(column_names, types)
return self
@check_closed
def executemany(self, operation, seq_of_parameters=None):
raise exceptions.NotSupportedError(
"`executemany` is not supported, use `execute` instead"
)
@check_result
@check_closed
def fetchone(self):
"""
Fetch the next row of a query result set, returning a single sequence,
or `None` when no more data is available.
"""
try:
return self._results.pop(0)
except IndexError:
return None
@check_result
@check_closed
def fetchmany(self, size=None):
"""
Fetch the next set of rows of a query result, returning a sequence of
sequences (e.g. a list of tuples). An empty sequence is returned when
no more rows are available.
"""
size = size or self.arraysize
output, self._results = self._results[:size], self._results[size:]
return output
@check_result
@check_closed
def fetchall(self):
"""
Fetch all (remaining) rows of a query result, returning them as a
sequence of sequences (e.g. a list of tuples). Note that the cursor's
arraysize attribute can affect the performance of this operation.
"""
return list(self)
@check_closed
def setinputsizes(self, sizes):
# not supported
pass
@check_closed
def setoutputsizes(self, sizes):
# not supported
pass
@check_closed
def __iter__(self):
return self
@check_closed
def __next__(self):
output = self.fetchone()
if output is None:
raise StopIteration
return output
next = __next__
def apply_parameters(operation, parameters):
escaped_parameters = {key: escape(value) for key, value in parameters.items()}
return operation % escaped_parameters
def escape(value):
if value == "*":
return value
elif isinstance(value, string_types):
return "'{}'".format(value.replace("'", "''"))
elif isinstance(value, (int, float)):
return value
elif isinstance(value, bool):
return "TRUE" if value else "FALSE"
elif isinstance(value, (list, tuple)):
return ", ".join(escape(element) for element in value)