Skip to content

Commit

Permalink
fix_query_parse_bugs
Browse files Browse the repository at this point in the history
Added unit tests to handle wildcard query use case
  • Loading branch information
Bhargava Vadlamani committed May 24, 2019
1 parent b440ceb commit 9a36709
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 30 deletions.
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
install_requires=[
"requests",
"numpy",
"pandas"
"pandas",
"sqlalchemy"
],
keywords='SQLAlchemy Apache Drill',
author='John Omernik, Charles Givre, Davide Miceli, Massimo Martiradonna',
author_email='john@omernik.com, cgivre@thedataist.com, davide.miceli.dap@gmail.com, massimo.martiradonna.dap@gmail.com',
author='John Omernik, Charles Givre, Davide Miceli, Massimo Martiradonna, Bhargava Vadlamani',
author_email='john@omernik.com, cgivre@thedataist.com, davide.miceli.dap@gmail.com, massimo.martiradonna.dap@gmail.com,vadlamani1729@gmail.com',
license='Apache',
packages=find_packages(),
include_package_data=True,
Expand Down
43 changes: 18 additions & 25 deletions sqlalchemy_drill/sadrill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy import inspect
import pathlib
import requests
from pprint import pprint

Expand Down Expand Up @@ -55,40 +56,34 @@ class DrillIdentifierPreparer(compiler.IdentifierPreparer):
]
)

supported_file_extensions = [".csv",".txt",".avro",".json",".parquet",".tsv",".psv"]

def __init__(self, dialect):
super(DrillIdentifierPreparer, self).__init__(dialect, initial_quote='`', final_quote='`')

def format_drill_table(self, schema, isFile=True):
formatted_schema = ""

num_dots = schema.count(".")
schema = schema.replace('`', '')

# For a file, the last section will be the file extension
schema_parts = schema.split('.')

if isFile and num_dots == 3:
# Case for File + Workspace
plugin = schema_parts[0]
workspace = schema_parts[1]
table = schema_parts[2] + "." + schema_parts[3]
formatted_schema = plugin + ".`" + workspace + "`.`" + table + "`"
elif isFile and num_dots == 2:
# Case for file and no workspace
plugin = schema_parts[0]
table = schema_parts[1] + "." + schema_parts[2]
formatted_schema = plugin + "`.`" + table + "`"
else:
# Case for non-file plugins or incomplete schema parts
for part in schema_parts:
quoted_part = "`" + part + "`"
if len(formatted_schema) > 0:
formatted_schema += "." + quoted_part
else:
formatted_schema = quoted_part
if isFile:

return formatted_schema
extension = pathlib.Path(schema).suffix

if not extension.lower() in self.supported_file_extensions:
print("file system based schema encountered (perhaps a query on a directory/table ? ) ")
return ".".join(["`" + x + "`" for x in schema_parts])
else:
print("file extension exists in supported types ")
plugin = schema_parts[0]
workspace = schema_parts[1]
table = ".".join(schema_parts[2:])
return "`" + plugin + "`" + ".`" + workspace + "`.`" + table + "`"

else:
print("not a file based schema ")
return ".".join(["`" + part + "`" for part in schema_parts])


try:
Expand Down Expand Up @@ -345,8 +340,6 @@ def object_as_dict(obj):

def get_columns(self, connection, table_name, schema=None, **kw):

if "@@@" in table_name:
table_name = table_name.replace("@@@", ".")
result = []

plugin_type = self.get_plugin_type(connection, schema)
Expand Down
22 changes: 22 additions & 0 deletions test/FormatConvertorTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from sqlalchemy_drill.sadrill import DrillIdentifierPreparer,DrillDialect_sadrill
import unittest

#
class FormatConvertorTest(unittest.TestCase):

def test_file_without_extension(self):
expected_result = DrillIdentifierPreparer.format_drill_table("s3.root.test_table_csv")
self.assertEqual("""`s3`.`root`.`test_table_csv`""" , expected_result)

def test_file_with_extension(self):
expected_result = DrillIdentifierPreparer.format_drill_table("s3.root.test_table.csv")
self.assertEqual("""`s3`.`root`.`test_table.csv`""" , expected_result)

def test_non_file_type_extension(self):
expected_result = DrillIdentifierPreparer.format_drill_table("customers.orders",False)
self.assertEqual("""`customers`.`orders`""" , expected_result)

def test_wildcard_file_extension(self):
expected_result = DrillIdentifierPreparer.format_drill_table("select * from namespace.schema.folder/*.csv",True)
self.assertEqual("""`select * from namespace`.`schema`.`folder/*.csv`""",expected_result)

3 changes: 1 addition & 2 deletions test/test_suite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from sqlalchemy.testing.suite import *

# from sqlalchemy.testing.suite import *

from sqlalchemy.testing.suite import ComponentReflectionTest as _ComponentReflectionTest

Expand Down

0 comments on commit 9a36709

Please sign in to comment.