diff --git a/pyathenajdbc/sqlalchemy_athena.py b/pyathenajdbc/sqlalchemy_athena.py new file mode 100644 index 0000000..c366789 --- /dev/null +++ b/pyathenajdbc/sqlalchemy_athena.py @@ -0,0 +1,217 @@ +"""Integration between SQLAlchemy and Athena. + +Some code based on +https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py +which is released under the MIT license. +""" + +from __future__ import absolute_import +from __future__ import unicode_literals +import re +from sqlalchemy import exc +from sqlalchemy import types +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.sql.compiler import IdentifierPreparer +from pyathenajdbc import error +from pyathenajdbc.converter import JDBCTypeConverter + + +class UniversalSet(object): + def __contains__(self, item): + return True + + +class AthenaIdentifierPreparer(IdentifierPreparer): + # Just quote everything to make things simpler / easier to upgrade + reserved_words = UniversalSet() + +_type_map = { + 'NULL': types.NullType, + 'BOOLEAN': types.Boolean, + 'TINYINT': types.Integer, + 'SMALLINT': types.Integer, + 'BIGINT': types.BigInteger, + 'INTEGER': types.Integer, + 'REAL': types.Float, + 'DOUBLE': types.Float, + 'FLOAT': types.Float, + 'CHAR': types.String, + 'NCHAR': types.String, + 'VARCHAR': types.String, + 'NVARCHAR': types.String, + 'LONGVARCHAR': types.String, + 'LONGNVARCHAR': types.String, + 'DATE': types.DATE, + 'TIMESTAMP': types.TIMESTAMP, + 'TIMESTAMP_WITH_TIMEZONE': types.TIMESTAMP, + 'ARRAY': types.ARRAY, + 'DECIMAL': types.DECIMAL, + 'NUMERIC': types.Numeric, + 'BINARY': types.Binary, + 'VARBINARY': types.Binary, + 'LONGVARBINARY': types.Binary, + # TODO Converter impl + # 'TIME': ???, + # 'BIT': ???, + # 'CLOB': ???, + 'BLOB': types.BLOB, + # 'NCLOB': ???, + # 'STRUCT': ???, + 'JAVA_OBJECT': types.BLOB, + # 'REF_CURSOR': ???, + # 'REF': ???, + # 'DISTINCT': ???, + # 'DATALINK': ???, + # 'SQLXML': ???, + # 'OTHER': ???, + # 'ROWID': ???, +} + + +class AthenaDialect(DefaultDialect): + name = 'athena' + driver = 'athena' + preparer = AthenaIdentifierPreparer + # statement_compiler = AthenaCompiler + supports_alter = False + supports_pk_autoincrement = False + supports_default_values = False + supports_empty_insert = False + supports_unicode_statements = True + supports_unicode_binds = True + returns_unicode_strings = True + description_encoding = None + supports_native_boolean = True + + jdbctypeconverter = None + jdbc_type_map = None + + + @classmethod + def dbapi(cls): + import pyathenajdbc + import pyathenajdbc.error + + pyathenajdbc.Error = pyathenajdbc.error.Error + pyathenajdbc.Warning = pyathenajdbc.error.Warning + pyathenajdbc.InterfaceError = pyathenajdbc.error.InterfaceError + pyathenajdbc.DatabaseError = pyathenajdbc.error.DatabaseError + pyathenajdbc.InternalError = pyathenajdbc.error.InternalError + pyathenajdbc.OperationalError = pyathenajdbc.error.OperationalError + pyathenajdbc.ProgrammingError = pyathenajdbc.error.ProgrammingError + pyathenajdbc.IntegrityError = pyathenajdbc.error.IntegrityError + pyathenajdbc.DataError = pyathenajdbc.error.DataError + pyathenajdbc.NotSupportedError = pyathenajdbc.error.NotSupportedError + + return pyathenajdbc + + def create_connect_args(self, url): + db_parts = (url.database or 'hive').split('/') + + # TODO: + # - schema_name='default' + # - profile_name=None + # - credential_file=None + kwargs = { + 'host': url.host, + 'access_key': url.username, + 'secret_key': url.password, + 'region_name': url.query['region_name'], + 's3_staging_dir': url.query['s3_staging_dir'] + } + kwargs.update(url.query) + if len(db_parts) == 1: + kwargs['catalog'] = db_parts[0] + elif len(db_parts) == 2: + kwargs['catalog'] = db_parts[0] + kwargs['schema'] = db_parts[1] + else: + raise ValueError("Unexpected database format {}".format(url.database)) + return ([], kwargs) + + def get_schema_names(self, connection, **kw): + return [schema for (schema,) in connection.execute('SHOW SCHEMAS')] + + def _get_table_columns(self, connection, table_name, schema): + name = table_name + if schema is not None: + name = '%s.%s' % (schema, name) + try: + return connection.execute('SHOW COLUMNS IN {}'.format(name)) + except (error.DatabaseError, exc.DatabaseError) as e: + # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which + # it successfully does in the Hive version. The difference with Athena is that this + # error is raised when fetching the cursor's description rather than the initial execute + # call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped + # presto.DatabaseError here. + # Does the table exist? + msg = ( + e.args[0].get('message') if e.args and isinstance(e.args[0], dict) + else e.args[0] if e.args and isinstance(e.args[0], str) + else None + ) + regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name)) + if msg and re.search(regex, msg): + raise exc.NoSuchTableError(table_name) + else: + raise + + def has_table(self, connection, table_name, schema=None): + try: + self._get_table_columns(connection, table_name, schema) + return True + except exc.NoSuchTableError: + return False + + def get_columns(self, connection, table_name, schema=None, **kwargs): + + if self.jdbctypeconverter is None: + self.jdbctypeconverter = JDBCTypeConverter() + self.jdbc_type_map = {v: k for (k, v) in + self.jdbctypeconverter.jdbc_type_mappings.items()} + + # pylint: disable=unused-argument + name = table_name + if schema is not None: + name = '%s.%s' % (schema, name) + query = 'SELECT * FROM %s LIMIT 0' % name + cursor = connection.execute(query) + schema = cursor.cursor.description + # We need to fetch the empty results otherwise these queries remain in + # flight + cursor.fetchall() + column_info = [] + for col in schema: + column_info.append({ + 'name': col[0], + 'type': _type_map[self.jdbc_type_map[col[1]]], + 'nullable': True, + 'autoincrement': False}) + return column_info + + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + return [] + + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + return [] + + def get_indexes(self, connection, table_name, schema=None, **kw): + return [] + + def get_table_names(self, connection, schema=None, **kw): + query = 'SHOW TABLES' + if schema: + query += ' IN {}'.format(schema) + return [tbl for (tbl,) in connection.execute(query).fetchall()] + + def do_rollback(self, dbapi_connection): + # No transactions for Athena + pass + + def _check_unicode_returns(self, connection, additional_tests=None): + # requests gives back Unicode strings + return True + + def _check_unicode_description(self, connection): + # requests gives back Unicode strings + return True diff --git a/setup.py b/setup.py index 75ed17a..d7f934c 100755 --- a/setup.py +++ b/setup.py @@ -113,7 +113,8 @@ def run(self): 'botocore>=1.0.0' ], extras_require={ - 'Pandas': ['pandas>=0.19.0'] + 'Pandas': ['pandas>=0.19.0'], + 'SQLAlchemy': ['sqlalchemy>=1.0.0'], }, tests_require=[ 'futures', @@ -138,4 +139,14 @@ def run(self): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', ], + entry_points={ + # New versions + 'sqlalchemy.dialects': [ + 'athena = pyathenajdbc.sqlalchemy_athena:AthenaDialect', + ], + # Version 0.5 + 'sqlalchemy.databases': [ + 'athena = pyathenajdbc.sqlalchemy_athena:AthenaDialect', + ], + } )