-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
181 lines (150 loc) · 6.24 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import re
from sqlalchemy import types as sqltypes
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.engine import reflection
class VerticaDialect(PyODBCConnector, PGDialect):
""" Vertica Dialect using a pyodbc connection and PGDialect """
ischema_names = {
'BINARY': sqltypes.BLOB,
'VARBINARY': sqltypes.BLOB,
'BYTEA': sqltypes.BLOB,
'RAW': sqltypes.BLOB,
'BOOLEAN': sqltypes.BOOLEAN,
'CHAR': sqltypes.CHAR,
'VARCHAR': sqltypes.VARCHAR,
'VARCHAR2': sqltypes.VARCHAR,
'DATE': sqltypes.DATE,
'DATETIME': sqltypes.DATETIME,
'SMALLDATETIME': sqltypes.DATETIME,
'TIME': sqltypes.TIME,
'TIME': sqltypes.TIME(timezone=True),
'TIMESTAMP': sqltypes.TIMESTAMP,
'TIMESTAMP WITH TIMEZONE': sqltypes.TIMESTAMP(timezone=True),
# Not supported yet
# INTERVAL
# All the same internal representation
'FLOAT': sqltypes.FLOAT,
'FLOAT8': sqltypes.FLOAT,
'DOUBLE': sqltypes.FLOAT,
'REAL': sqltypes.FLOAT,
'INT': sqltypes.INTEGER,
'INTEGER': sqltypes.INTEGER,
'INT8': sqltypes.INTEGER,
'BIGINT': sqltypes.INTEGER,
'SMALLINT': sqltypes.INTEGER,
'TINYINT': sqltypes.INTEGER,
'NUMERIC': sqltypes.NUMERIC,
'DECIMAL': sqltypes.NUMERIC,
'NUMBER': sqltypes.NUMERIC,
'MONEY': sqltypes.NUMERIC,
}
name = 'vertica'
pyodbc_driver_name = 'Vertica'
def has_schema(self, connection, schema):
query = ("SELECT EXISTS (SELECT schema_name FROM v_catalog.schemata "
"WHERE schema_name='%s')") % (schema)
rs = connection.execute(query)
return bool(rs.scalar())
def has_table(self, connection, table_name, schema=None):
if schema is None:
schema = self._get_default_schema_name(connection)
query = ("SELECT EXISTS ("
"SELECT table_name FROM v_catalog.all_tables "
"WHERE schema_name='%s' AND "
"table_name='%s'"
")") % (schema, table_name)
rs = connection.execute(query)
return bool(rs.scalar())
def has_sequence(self, connection, sequence_name, schema=None):
if schema is None:
schema = self._get_default_schema_name(connection)
query = ("SELECT EXISTS ("
"SELECT sequence_name FROM v_catalog.sequences "
"WHERE sequence_schema='%s' AND "
"sequence_name='%s'"
")") % (schema, sequence_name)
rs = connection.execute(query)
return bool(rs.scalar())
def has_type(self, connection, type_name, schema=None):
query = ("SELECT EXISTS ("
"SELECT type_name FROM v_catalog.types "
"WHERE type_name='%s'"
")") % (type_name)
rs = connection.execute(query)
return bool(rs.scalar())
def _get_server_version_info(self, connection):
v = connection.scalar("select version()")
m = re.match(
'.*Vertica Analytic Database '
'v(\d+)\.(\d+)\.(\d)+.*',
v)
if not m:
raise AssertionError(
"Could not determine version from string '%s'" % v)
return tuple([int(x) for x in m.group(1, 2, 3) if x is not None])
def _get_default_schema_name(self, connection):
return connection.scalar("select current_schema()")
@reflection.cache
def get_schema_names(self, connection, **kw):
query = "SELECT schema_name FROM v_catalog.schemata"
rs = connection.execute(query)
return [row[0] for row in rs if not row[0].startswith('v_')]
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
s = ["SELECT table_name FROM v_catalog.tables"]
if schema is not None:
s.append("WHERE table_schema = '%s'" % (schema,))
s.append("ORDER BY table_schema, table_name")
rs = connection.execute(' '.join(s))
return [row[0] for row in rs]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
s = ["SELECT table_name FROM v_catalog.views"]
if schema is not None:
s.append("WHERE table_schema = '%s'" % (schema,))
s.append("ORDER BY table_schema, table_name")
rs = connection.execute(' '.join(s))
return [row[0] for row in rs]
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
print ('in get columns', table_name)
s = ("SELECT * FROM v_catalog.columns "
"WHERE table_name = '%s' ") % (table_name,)
spk = ("SELECT column_name FROM v_catalog.primary_keys "
"WHERE table_name = '%s' "
"AND constraint_type = 'p'") % (table_name)
if schema is not None:
_pred = lambda p: ("%s AND table_schema = '%s'" % (p, schema))
s = _pred(s)
spk = _pred(spk)
pk_columns = [x[0] for x in connection.execute(spk)]
columns = []
for row in connection.execute(s):
name = row.column_name
dtype = row.data_type.upper()
if '(' in dtype:
dtype = dtype.split('(')[0]
coltype = self.ischema_names[dtype]
primary_key = name in pk_columns
default = row.column_default
nullable = row.is_nullable
columns.append({
'name': name,
'type': coltype,
'nullable': nullable,
'default': default,
'primary_key': primary_key
})
return columns
# constraints are enforced on selects, but returning nothing for these
# methods allows table introspection to work
def get_pk_constraint(self, bind, table_name, schema, **kw):
return {'constrained_columns': [], 'name': 'undefined'}
def get_foreign_keys(self, connection, table_name, schema, **kw):
return []
def get_indexes(self, connection, table_name, schema, **kw):
return []
# Disable index creation since that's not a thing in Vertica.
def visit_create_index(self, create):
return None