diff --git a/setup.py b/setup.py index 8167361..31a608f 100644 --- a/setup.py +++ b/setup.py @@ -26,11 +26,12 @@ install_requires=[ "requests", "numpy", - "pandas" + "pandas", + "sqlalchemy" ], keywords='SQLAlchemy Apache Drill', - author='John Omernik, Charles Givre, Davide Miceli, Massimo Martiradonna', - author_email='john@omernik.com, cgivre@thedataist.com, davide.miceli.dap@gmail.com, massimo.martiradonna.dap@gmail.com', + author='John Omernik, Charles Givre, Davide Miceli, Massimo Martiradonna, Bhargava Vadlamani', + author_email='john@omernik.com, cgivre@thedataist.com, davide.miceli.dap@gmail.com, massimo.martiradonna.dap@gmail.com,vadlamani1729@gmail.com', license='Apache', packages=find_packages(), include_package_data=True, diff --git a/sqlalchemy_drill/sadrill.py b/sqlalchemy_drill/sadrill.py index 54f8f23..d499c09 100755 --- a/sqlalchemy_drill/sadrill.py +++ b/sqlalchemy_drill/sadrill.py @@ -5,6 +5,7 @@ from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy import inspect +import pathlib import requests from pprint import pprint @@ -55,40 +56,34 @@ class DrillIdentifierPreparer(compiler.IdentifierPreparer): ] ) + supported_file_extensions = [".csv",".txt",".avro",".json",".parquet",".tsv",".psv"] + def __init__(self, dialect): super(DrillIdentifierPreparer, self).__init__(dialect, initial_quote='`', final_quote='`') def format_drill_table(self, schema, isFile=True): - formatted_schema = "" - num_dots = schema.count(".") schema = schema.replace('`', '') - # For a file, the last section will be the file extension schema_parts = schema.split('.') - if isFile and num_dots == 3: - # Case for File + Workspace - plugin = schema_parts[0] - workspace = schema_parts[1] - table = schema_parts[2] + "." + schema_parts[3] - formatted_schema = plugin + ".`" + workspace + "`.`" + table + "`" - elif isFile and num_dots == 2: - # Case for file and no workspace - plugin = schema_parts[0] - table = schema_parts[1] + "." + schema_parts[2] - formatted_schema = plugin + "`.`" + table + "`" - else: - # Case for non-file plugins or incomplete schema parts - for part in schema_parts: - quoted_part = "`" + part + "`" - if len(formatted_schema) > 0: - formatted_schema += "." + quoted_part - else: - formatted_schema = quoted_part + if isFile: - return formatted_schema + extension = pathlib.Path(schema).suffix + + if not extension.lower() in self.supported_file_extensions: + print("file system based schema encountered (perhaps a query on a directory/table ? ) ") + return ".".join(["`" + x + "`" for x in schema_parts]) + else: + print("file extension exists in supported types ") + plugin = schema_parts[0] + workspace = schema_parts[1] + table = ".".join(schema_parts[2:]) + return "`" + plugin + "`" + ".`" + workspace + "`.`" + table + "`" + else: + print("not a file based schema ") + return ".".join(["`" + part + "`" for part in schema_parts]) try: @@ -345,8 +340,6 @@ def object_as_dict(obj): def get_columns(self, connection, table_name, schema=None, **kw): - if "@@@" in table_name: - table_name = table_name.replace("@@@", ".") result = [] plugin_type = self.get_plugin_type(connection, schema) diff --git a/test/FormatConvertorTest.py b/test/FormatConvertorTest.py new file mode 100644 index 0000000..1bf4a70 --- /dev/null +++ b/test/FormatConvertorTest.py @@ -0,0 +1,22 @@ +from sqlalchemy_drill.sadrill import DrillIdentifierPreparer,DrillDialect_sadrill +import unittest + +# +class FormatConvertorTest(unittest.TestCase): + + def test_file_without_extension(self): + expected_result = DrillIdentifierPreparer.format_drill_table("s3.root.test_table_csv") + self.assertEqual("""`s3`.`root`.`test_table_csv`""" , expected_result) + + def test_file_with_extension(self): + expected_result = DrillIdentifierPreparer.format_drill_table("s3.root.test_table.csv") + self.assertEqual("""`s3`.`root`.`test_table.csv`""" , expected_result) + + def test_non_file_type_extension(self): + expected_result = DrillIdentifierPreparer.format_drill_table("customers.orders",False) + self.assertEqual("""`customers`.`orders`""" , expected_result) + + def test_wildcard_file_extension(self): + expected_result = DrillIdentifierPreparer.format_drill_table("select * from namespace.schema.folder/*.csv",True) + self.assertEqual("""`select * from namespace`.`schema`.`folder/*.csv`""",expected_result) + diff --git a/test/test_suite.py b/test/test_suite.py index 8f3e01c..88894fd 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1,5 +1,4 @@ -from sqlalchemy.testing.suite import * - +# from sqlalchemy.testing.suite import * from sqlalchemy.testing.suite import ComponentReflectionTest as _ComponentReflectionTest