From be7893ff8a214085a76c84a1f99e1bf8e218f063 Mon Sep 17 00:00:00 2001 From: David Wallin Date: Sat, 1 Apr 2017 02:42:00 +0200 Subject: [PATCH 1/2] Initial support for SQLAlchemy This is cleaned up code developed at tobii.com (in collaboration with knowit.se), which we use to connect from Superset (http://airbnb.io/superset/, https://github.com/airbnb/superset/pull/2531) to an Athena DB. Usage: athena://:@athena.us-east-1.amazonaws.com/?region_name=&s3_staging_dir=s3%3A// --- pyathenajdbc/sqlalchemy_athena.py | 231 ++++++++++++++++++++++++++++++ setup.py | 13 +- 2 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 pyathenajdbc/sqlalchemy_athena.py diff --git a/pyathenajdbc/sqlalchemy_athena.py b/pyathenajdbc/sqlalchemy_athena.py new file mode 100644 index 0000000..406c4bc --- /dev/null +++ b/pyathenajdbc/sqlalchemy_athena.py @@ -0,0 +1,231 @@ +"""Integration between SQLAlchemy and Athena. + +Some code based on +https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py +which is released under the MIT license. +""" + +from __future__ import absolute_import +from __future__ import unicode_literals +import re +#from distutils.version import StrictVersion +#from pyhive import presto +#from pyhive.common import UniversalSet +from sqlalchemy import exc +from sqlalchemy import types +#from sqlalchemy import util +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.sql.compiler import IdentifierPreparer +from pyathenajdbc import error +from pyathenajdbc.converter import JDBCTypeConverter + +# try: +# from sqlalchemy.sql.compiler import SQLCompiler +# except ImportError: +# from sqlalchemy.sql.compiler import DefaultCompiler as SQLCompiler + + +class UniversalSet(object): + def __contains__(self, item): + return True + + +class AthenaIdentifierPreparer(IdentifierPreparer): + # Just quote everything to make things simpler / easier to upgrade + reserved_words = UniversalSet() + +_type_map = { + 'NULL': types.NullType, + 'BOOLEAN': types.Boolean, + 'TINYINT': types.Integer, + 'SMALLINT': types.Integer, + 'BIGINT': types.BigInteger, + 'INTEGER': types.Integer, + 'REAL': types.Float, + 'DOUBLE': types.Float, + 'FLOAT': types.Float, + 'CHAR': types.String, + 'NCHAR': types.String, + 'VARCHAR': types.String, + 'NVARCHAR': types.String, + 'LONGVARCHAR': types.String, + 'LONGNVARCHAR': types.String, + 'DATE': types.DATE, + 'TIMESTAMP': types.TIMESTAMP, + 'TIMESTAMP_WITH_TIMEZONE': types.TIMESTAMP, + 'ARRAY': types.ARRAY, + 'DECIMAL': types.DECIMAL, + 'NUMERIC': types.Numeric, + 'BINARY': types.Binary, + 'VARBINARY': types.Binary, + 'LONGVARBINARY': types.Binary, + # TODO Converter impl + # 'TIME': ???, + # 'BIT': ???, + # 'CLOB': ???, + 'BLOB': types.BLOB, + # 'NCLOB': ???, + # 'STRUCT': ???, + 'JAVA_OBJECT': types.BLOB, + # 'REF_CURSOR': ???, + # 'REF': ???, + # 'DISTINCT': ???, + # 'DATALINK': ???, + # 'SQLXML': ???, + # 'OTHER': ???, + # 'ROWID': ???, +} + + +# class AthenaCompiler(SQLCompiler): +# def visit_char_length_func(self, fn, **kw): +# return 'length{}'.format(self.function_argspec(fn, **kw)) + + +class AthenaDialect(DefaultDialect): + name = 'athena' + driver = 'athena' + preparer = AthenaIdentifierPreparer + # statement_compiler = AthenaCompiler + supports_alter = False + supports_pk_autoincrement = False + supports_default_values = False + supports_empty_insert = False + supports_unicode_statements = True + supports_unicode_binds = True + returns_unicode_strings = True + description_encoding = None + supports_native_boolean = True + + jdbctypeconverter = None + jdbc_type_map = None + + + @classmethod + def dbapi(cls): + import pyathenajdbc + import pyathenajdbc.error + + pyathenajdbc.Error = pyathenajdbc.error.Error + pyathenajdbc.Warning = pyathenajdbc.error.Warning + pyathenajdbc.InterfaceError = pyathenajdbc.error.InterfaceError + pyathenajdbc.DatabaseError = pyathenajdbc.error.DatabaseError + pyathenajdbc.InternalError = pyathenajdbc.error.InternalError + pyathenajdbc.OperationalError = pyathenajdbc.error.OperationalError + pyathenajdbc.ProgrammingError = pyathenajdbc.error.ProgrammingError + pyathenajdbc.IntegrityError = pyathenajdbc.error.IntegrityError + pyathenajdbc.DataError = pyathenajdbc.error.DataError + pyathenajdbc.NotSupportedError = pyathenajdbc.error.NotSupportedError + + return pyathenajdbc + + def create_connect_args(self, url): + db_parts = (url.database or 'hive').split('/') + + # TODO: + # - schema_name='default' + # - profile_name=None + # - credential_file=None + kwargs = { + 'host': url.host, + 'access_key': url.username, + 'secret_key': url.password, + 'region_name': url.query['region_name'], + 's3_staging_dir': url.query['s3_staging_dir'] + } + kwargs.update(url.query) + if len(db_parts) == 1: + kwargs['catalog'] = db_parts[0] + elif len(db_parts) == 2: + kwargs['catalog'] = db_parts[0] + kwargs['schema'] = db_parts[1] + else: + raise ValueError("Unexpected database format {}".format(url.database)) + return ([], kwargs) + + def get_schema_names(self, connection, **kw): + return [schema for (schema,) in connection.execute('SHOW SCHEMAS')] + + def _get_table_columns(self, connection, table_name, schema): + name = table_name + if schema is not None: + name = '%s.%s' % (schema, name) + try: + return connection.execute('SHOW COLUMNS IN {}'.format(name)) + except (error.DatabaseError, exc.DatabaseError) as e: + # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which + # it successfully does in the Hive version. The difference with Athena is that this + # error is raised when fetching the cursor's description rather than the initial execute + # call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped + # presto.DatabaseError here. + # Does the table exist? + msg = ( + e.args[0].get('message') if e.args and isinstance(e.args[0], dict) + else e.args[0] if e.args and isinstance(e.args[0], str) + else None + ) + regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name)) + if msg and re.search(regex, msg): + raise exc.NoSuchTableError(table_name) + else: + raise + + def has_table(self, connection, table_name, schema=None): + try: + self._get_table_columns(connection, table_name, schema) + return True + except exc.NoSuchTableError: + return False + + def get_columns(self, connection, table_name, schema=None, **kwargs): + + if self.jdbctypeconverter is None: + self.jdbctypeconverter = JDBCTypeConverter() + self.jdbc_type_map = {v: k for (k, v) in + self.jdbctypeconverter.jdbc_type_mappings.items()} + + # pylint: disable=unused-argument + name = table_name + if schema is not None: + name = '%s.%s' % (schema, name) + query = 'SELECT * FROM %s LIMIT 0' % name + cursor = connection.execute(query) + schema = cursor.cursor.description + # We need to fetch the empty results otherwise these queries remain in + # flight + cursor.fetchall() + column_info = [] + for col in schema: + column_info.append({ + 'name': col[0], + 'type': _type_map[self.jdbc_type_map[col[1]]], + 'nullable': True, + 'autoincrement': False}) + return column_info + + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + return [] + + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + return [] + + def get_indexes(self, connection, table_name, schema=None, **kw): + return [] + + def get_table_names(self, connection, schema=None, **kw): + query = 'SHOW TABLES' + if schema: + query += ' IN {}'.format(schema) + return [tbl for (tbl,) in connection.execute(query).fetchall()] + + def do_rollback(self, dbapi_connection): + # No transactions for Athena + pass + + def _check_unicode_returns(self, connection, additional_tests=None): + # requests gives back Unicode strings + return True + + def _check_unicode_description(self, connection): + # requests gives back Unicode strings + return True diff --git a/setup.py b/setup.py index 75ed17a..d7f934c 100755 --- a/setup.py +++ b/setup.py @@ -113,7 +113,8 @@ def run(self): 'botocore>=1.0.0' ], extras_require={ - 'Pandas': ['pandas>=0.19.0'] + 'Pandas': ['pandas>=0.19.0'], + 'SQLAlchemy': ['sqlalchemy>=1.0.0'], }, tests_require=[ 'futures', @@ -138,4 +139,14 @@ def run(self): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', ], + entry_points={ + # New versions + 'sqlalchemy.dialects': [ + 'athena = pyathenajdbc.sqlalchemy_athena:AthenaDialect', + ], + # Version 0.5 + 'sqlalchemy.databases': [ + 'athena = pyathenajdbc.sqlalchemy_athena:AthenaDialect', + ], + } ) From 0735388f47e1ed9a2d03418f02bae27d55e3325a Mon Sep 17 00:00:00 2001 From: David Wallin Date: Sat, 1 Apr 2017 02:55:38 +0200 Subject: [PATCH 2/2] remove some commented out code --- pyathenajdbc/sqlalchemy_athena.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/pyathenajdbc/sqlalchemy_athena.py b/pyathenajdbc/sqlalchemy_athena.py index 406c4bc..c366789 100644 --- a/pyathenajdbc/sqlalchemy_athena.py +++ b/pyathenajdbc/sqlalchemy_athena.py @@ -8,22 +8,13 @@ from __future__ import absolute_import from __future__ import unicode_literals import re -#from distutils.version import StrictVersion -#from pyhive import presto -#from pyhive.common import UniversalSet from sqlalchemy import exc from sqlalchemy import types -#from sqlalchemy import util from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.sql.compiler import IdentifierPreparer from pyathenajdbc import error from pyathenajdbc.converter import JDBCTypeConverter -# try: -# from sqlalchemy.sql.compiler import SQLCompiler -# except ImportError: -# from sqlalchemy.sql.compiler import DefaultCompiler as SQLCompiler - class UniversalSet(object): def __contains__(self, item): @@ -66,7 +57,7 @@ class AthenaIdentifierPreparer(IdentifierPreparer): 'BLOB': types.BLOB, # 'NCLOB': ???, # 'STRUCT': ???, - 'JAVA_OBJECT': types.BLOB, + 'JAVA_OBJECT': types.BLOB, # 'REF_CURSOR': ???, # 'REF': ???, # 'DISTINCT': ???, @@ -77,11 +68,6 @@ class AthenaIdentifierPreparer(IdentifierPreparer): } -# class AthenaCompiler(SQLCompiler): -# def visit_char_length_func(self, fn, **kw): -# return 'length{}'.format(self.function_argspec(fn, **kw)) - - class AthenaDialect(DefaultDialect): name = 'athena' driver = 'athena'